Skip to content

Commit

Permalink
Execute on_failure_callback when SIGTERM is received (#15172)
Browse files Browse the repository at this point in the history
Currently, on_failure_callback is only called when a task finishes
executing not while executing. When a pod is deleted, a SIGTERM is
sent to the task and the task is stopped immediately. The task is
still running when it was killed and therefore on_failure_callback
is not called.

This PR makes sure that when a pod is marked for deletion and the
task is killed, if the task has on_failure_callback, the callback
is called.

Closes: #14422

(cherry picked from commit def1e7c)
  • Loading branch information
ephraimbuddy authored and potiuk committed May 9, 2021
1 parent 8e00bc9 commit 36d5fab
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 0 deletions.
6 changes: 6 additions & 0 deletions airflow/jobs/local_task_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,12 @@ def signal_handler(signum, frame):
"""Setting kill signal handler"""
self.log.error("Received SIGTERM. Terminating subprocesses")
self.on_kill()
self.task_instance.refresh_from_db()
if self.task_instance.state not in State.finished:
self.task_instance.set_state(State.FAILED)
self.task_instance._run_finished_callback( # pylint: disable=protected-access
error="task received sigterm"
)
raise AirflowException("LocalTaskJob received SIGTERM signal")

# pylint: enable=unused-argument
Expand Down
64 changes: 64 additions & 0 deletions tests/jobs/test_local_task_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#
import multiprocessing
import os
import signal
import time
import unittest
import uuid
Expand Down Expand Up @@ -495,6 +496,69 @@ def task_function(ti):
assert task_terminated_externally.value == 1
assert not process.is_alive()

def test_process_kill_call_on_failure_callback(self):
"""
Test that ensures that when a task is killed with sigterm
on_failure_callback gets executed
"""
# use shared memory value so we can properly track value change even if
# it's been updated across processes.
failure_callback_called = Value('i', 0)
task_terminated_externally = Value('i', 1)
shared_mem_lock = Lock()

def failure_callback(context):
with shared_mem_lock:
failure_callback_called.value += 1
assert context['dag_run'].dag_id == 'test_mark_failure'

dag = DAG(dag_id='test_mark_failure', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'})

def task_function(ti):
# pylint: disable=unused-argument
time.sleep(60)
# This should not happen -- the state change should be noticed and the task should get killed
with shared_mem_lock:
task_terminated_externally.value = 0

task = PythonOperator(
task_id='test_on_failure',
python_callable=task_function,
on_failure_callback=failure_callback,
dag=dag,
)

session = settings.Session()

dag.clear()
dag.create_dagrun(
run_id="test",
state=State.RUNNING,
execution_date=DEFAULT_DATE,
start_date=DEFAULT_DATE,
session=session,
)
ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
ti.refresh_from_db()
job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor())
job1.task_runner = StandardTaskRunner(job1)

settings.engine.dispose()
process = multiprocessing.Process(target=job1.run)
process.start()

for _ in range(0, 10):
ti.refresh_from_db()
if ti.state == State.RUNNING:
break
time.sleep(0.2)
assert ti.state == State.RUNNING
os.kill(ti.pid, signal.SIGTERM)
process.join(timeout=10)
assert failure_callback_called.value == 1
assert task_terminated_externally.value == 1
assert not process.is_alive()


@pytest.fixture()
def clean_db_helper():
Expand Down

0 comments on commit 36d5fab

Please sign in to comment.