diff --git a/airflow/exceptions.py b/airflow/exceptions.py index 4bf946fd8e3ac4..e6ef9bd4e1ea6f 100644 --- a/airflow/exceptions.py +++ b/airflow/exceptions.py @@ -29,6 +29,12 @@ from airflow.models import DagRun +class AirflowTermSignal(Exception): + """Raise when we receive a TERM signal""" + + status_code = HTTPStatus.INTERNAL_SERVER_ERROR + + class AirflowException(Exception): """ Base class for all Airflow's errors. diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index c74ff8b6b05671..3f6ebefd952d21 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -77,6 +77,7 @@ AirflowSensorTimeout, AirflowSkipException, AirflowTaskTimeout, + AirflowTermSignal, DagRunNotFound, RemovedInAirflow3Warning, TaskDeferralError, @@ -1527,8 +1528,7 @@ def signal_handler(signum, frame): os._exit(1) return self.log.error("Received SIGTERM. Terminating subprocesses.") - self.task.on_kill() - raise AirflowException("Task received SIGTERM signal") + raise AirflowTermSignal("Task received SIGTERM signal") signal.signal(signal.SIGTERM, signal_handler) @@ -1567,10 +1567,15 @@ def signal_handler(signum, frame): # Execute the task with set_current_context(context): - result = self._execute_task(context, task_orig) - - # Run post_execute callback - self.task.post_execute(context=context, result=result) + try: + result = self._execute_task(context, task_orig) + # Run post_execute callback + self.task.post_execute(context=context, result=result) + except AirflowTermSignal: + self.task.on_kill() + if self.task.on_failure_callback: + self._run_finished_callback(self.task.on_failure_callback, context, "on_failure") + raise AirflowException("Task received SIGTERM signal") Stats.incr(f"operator_successes_{self.task.task_type}", 1, 1) Stats.incr( diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index e540202e1fb077..54884f633e537e 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -472,6 +472,28 @@ def task_function(ti): ti.refresh_from_db() assert ti.state == State.UP_FOR_RETRY + def test_task_sigterm_calls_on_failure_callack(self, dag_maker, caplog): + """ + Test that ensures that tasks call on_failure_callback when they receive sigterm + """ + + def task_function(ti): + os.kill(ti.pid, signal.SIGTERM) + + with dag_maker(): + task_ = PythonOperator( + task_id="test_on_failure", + python_callable=task_function, + on_failure_callback=lambda context: context["ti"].log.info("on_failure_callback called"), + ) + + dr = dag_maker.create_dagrun() + ti = dr.task_instances[0] + ti.task = task_ + with pytest.raises(AirflowException): + ti.run() + assert "on_failure_callback called" in caplog.text + @pytest.mark.parametrize("state", [State.SUCCESS, State.FAILED, State.SKIPPED]) def test_task_sigterm_doesnt_change_state_of_finished_tasks(self, state, dag_maker): session = settings.Session() @@ -2332,7 +2354,6 @@ def on_execute_callable(context): assert context["dag_run"].dag_id == "test_dagrun_execute_callback" for i, callback_input in enumerate([[on_execute_callable], on_execute_callable]): - ti = create_task_instance( dag_id=f"test_execute_callback_{i}", on_execute_callback=callback_input, @@ -2369,7 +2390,6 @@ def on_finish_callable(context): completed = True for i, callback_input in enumerate([[on_finish_callable], on_finish_callable]): - ti = create_task_instance( dag_id=f"test_finish_callback_{i}", end_date=timezone.utcnow() + datetime.timedelta(days=10),