Skip to content

Commit

Permalink
Fix on_failure_callback when task receives a SIGTERM (#29743)
Browse files Browse the repository at this point in the history
This fixes on_failure_callback when task receives a SIGTERM by
raising a different exception in the handler and catching the
exception during task execution so we can directly run the failure
callback.
  • Loading branch information
ephraimbuddy authored Feb 24, 2023
1 parent 38b901e commit 671b88e
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 8 deletions.
6 changes: 6 additions & 0 deletions airflow/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@
from airflow.models import DagRun


class AirflowTermSignal(Exception):
"""Raise when we receive a TERM signal"""

status_code = HTTPStatus.INTERNAL_SERVER_ERROR


class AirflowException(Exception):
"""
Base class for all Airflow's errors.
Expand Down
17 changes: 11 additions & 6 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
AirflowSensorTimeout,
AirflowSkipException,
AirflowTaskTimeout,
AirflowTermSignal,
DagRunNotFound,
RemovedInAirflow3Warning,
TaskDeferralError,
Expand Down Expand Up @@ -1527,8 +1528,7 @@ def signal_handler(signum, frame):
os._exit(1)
return
self.log.error("Received SIGTERM. Terminating subprocesses.")
self.task.on_kill()
raise AirflowException("Task received SIGTERM signal")
raise AirflowTermSignal("Task received SIGTERM signal")

signal.signal(signal.SIGTERM, signal_handler)

Expand Down Expand Up @@ -1567,10 +1567,15 @@ def signal_handler(signum, frame):

# Execute the task
with set_current_context(context):
result = self._execute_task(context, task_orig)

# Run post_execute callback
self.task.post_execute(context=context, result=result)
try:
result = self._execute_task(context, task_orig)
# Run post_execute callback
self.task.post_execute(context=context, result=result)
except AirflowTermSignal:
self.task.on_kill()
if self.task.on_failure_callback:
self._run_finished_callback(self.task.on_failure_callback, context, "on_failure")
raise AirflowException("Task received SIGTERM signal")

Stats.incr(f"operator_successes_{self.task.task_type}", 1, 1)
Stats.incr(
Expand Down
24 changes: 22 additions & 2 deletions tests/models/test_taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,28 @@ def task_function(ti):
ti.refresh_from_db()
assert ti.state == State.UP_FOR_RETRY

def test_task_sigterm_calls_on_failure_callack(self, dag_maker, caplog):
"""
Test that ensures that tasks call on_failure_callback when they receive sigterm
"""

def task_function(ti):
os.kill(ti.pid, signal.SIGTERM)

with dag_maker():
task_ = PythonOperator(
task_id="test_on_failure",
python_callable=task_function,
on_failure_callback=lambda context: context["ti"].log.info("on_failure_callback called"),
)

dr = dag_maker.create_dagrun()
ti = dr.task_instances[0]
ti.task = task_
with pytest.raises(AirflowException):
ti.run()
assert "on_failure_callback called" in caplog.text

@pytest.mark.parametrize("state", [State.SUCCESS, State.FAILED, State.SKIPPED])
def test_task_sigterm_doesnt_change_state_of_finished_tasks(self, state, dag_maker):
session = settings.Session()
Expand Down Expand Up @@ -2332,7 +2354,6 @@ def on_execute_callable(context):
assert context["dag_run"].dag_id == "test_dagrun_execute_callback"

for i, callback_input in enumerate([[on_execute_callable], on_execute_callable]):

ti = create_task_instance(
dag_id=f"test_execute_callback_{i}",
on_execute_callback=callback_input,
Expand Down Expand Up @@ -2369,7 +2390,6 @@ def on_finish_callable(context):
completed = True

for i, callback_input in enumerate([[on_finish_callable], on_finish_callable]):

ti = create_task_instance(
dag_id=f"test_finish_callback_{i}",
end_date=timezone.utcnow() + datetime.timedelta(days=10),
Expand Down

0 comments on commit 671b88e

Please sign in to comment.