From 36d5fab167a663c1c5ad3762bca68396866d8eab Mon Sep 17 00:00:00 2001 From: Ephraim Anierobi Date: Fri, 23 Apr 2021 23:47:20 +0100 Subject: [PATCH] Execute ``on_failure_callback`` when SIGTERM is received (#15172) Currently, on_failure_callback is only called when a task finishes executing not while executing. When a pod is deleted, a SIGTERM is sent to the task and the task is stopped immediately. The task is still running when it was killed and therefore on_failure_callback is not called. This PR makes sure that when a pod is marked for deletion and the task is killed, if the task has on_failure_callback, the callback is called. Closes: #14422 (cherry picked from commit def1e7c5841d89a60f8972a84b83fe362a6a878d) --- airflow/jobs/local_task_job.py | 6 +++ tests/jobs/test_local_task_job.py | 64 +++++++++++++++++++++++++++++++ 2 files changed, 70 insertions(+) diff --git a/airflow/jobs/local_task_job.py b/airflow/jobs/local_task_job.py index d68bfc72218c94..b6423106142496 100644 --- a/airflow/jobs/local_task_job.py +++ b/airflow/jobs/local_task_job.py @@ -76,6 +76,12 @@ def signal_handler(signum, frame): """Setting kill signal handler""" self.log.error("Received SIGTERM. Terminating subprocesses") self.on_kill() + self.task_instance.refresh_from_db() + if self.task_instance.state not in State.finished: + self.task_instance.set_state(State.FAILED) + self.task_instance._run_finished_callback( # pylint: disable=protected-access + error="task received sigterm" + ) raise AirflowException("LocalTaskJob received SIGTERM signal") # pylint: enable=unused-argument diff --git a/tests/jobs/test_local_task_job.py b/tests/jobs/test_local_task_job.py index d8a03866038385..d8776df11fbc82 100644 --- a/tests/jobs/test_local_task_job.py +++ b/tests/jobs/test_local_task_job.py @@ -18,6 +18,7 @@ # import multiprocessing import os +import signal import time import unittest import uuid @@ -495,6 +496,69 @@ def task_function(ti): assert task_terminated_externally.value == 1 assert not process.is_alive() + def test_process_kill_call_on_failure_callback(self): + """ + Test that ensures that when a task is killed with sigterm + on_failure_callback gets executed + """ + # use shared memory value so we can properly track value change even if + # it's been updated across processes. + failure_callback_called = Value('i', 0) + task_terminated_externally = Value('i', 1) + shared_mem_lock = Lock() + + def failure_callback(context): + with shared_mem_lock: + failure_callback_called.value += 1 + assert context['dag_run'].dag_id == 'test_mark_failure' + + dag = DAG(dag_id='test_mark_failure', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) + + def task_function(ti): + # pylint: disable=unused-argument + time.sleep(60) + # This should not happen -- the state change should be noticed and the task should get killed + with shared_mem_lock: + task_terminated_externally.value = 0 + + task = PythonOperator( + task_id='test_on_failure', + python_callable=task_function, + on_failure_callback=failure_callback, + dag=dag, + ) + + session = settings.Session() + + dag.clear() + dag.create_dagrun( + run_id="test", + state=State.RUNNING, + execution_date=DEFAULT_DATE, + start_date=DEFAULT_DATE, + session=session, + ) + ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) + ti.refresh_from_db() + job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor()) + job1.task_runner = StandardTaskRunner(job1) + + settings.engine.dispose() + process = multiprocessing.Process(target=job1.run) + process.start() + + for _ in range(0, 10): + ti.refresh_from_db() + if ti.state == State.RUNNING: + break + time.sleep(0.2) + assert ti.state == State.RUNNING + os.kill(ti.pid, signal.SIGTERM) + process.join(timeout=10) + assert failure_callback_called.value == 1 + assert task_terminated_externally.value == 1 + assert not process.is_alive() + @pytest.fixture() def clean_db_helper():