From 260e4edebd3473ff603244d0d0e5c33b2d291791 Mon Sep 17 00:00:00 2001 From: Qingping Hou Date: Tue, 17 Dec 2019 05:23:13 -0800 Subject: [PATCH] [AIRFLOW-6262] add on_execute_callback to operators (#6831) --- airflow/models/baseoperator.py | 6 ++++++ airflow/models/taskinstance.py | 7 +++++++ tests/models/test_taskinstance.py | 26 ++++++++++++++++++++++++++ 3 files changed, 39 insertions(+) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index b5e92ae81bbb0c..b179d67d00d49d 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -195,6 +195,9 @@ class derived from this one results in the creation of a task object, objects to the task instance and is documented under the macros section of the API. :type on_failure_callback: callable + :param on_execute_callback: much like the ``on_failure_callback`` except + that it is executed right before the task is executed. + :type on_execute_callback: callable :param on_retry_callback: much like the ``on_failure_callback`` except that it is executed when retries occur. :type on_retry_callback: callable @@ -274,6 +277,7 @@ class derived from this one results in the creation of a task object, 'priority_weight', 'sla', 'execution_timeout', + 'on_execute_callback', 'on_failure_callback', 'on_success_callback', 'on_retry_callback', @@ -307,6 +311,7 @@ def __init__( pool: str = Pool.DEFAULT_POOL_NAME, sla: Optional[timedelta] = None, execution_timeout: Optional[timedelta] = None, + on_execute_callback: Optional[Callable] = None, on_failure_callback: Optional[Callable] = None, on_success_callback: Optional[Callable] = None, on_retry_callback: Optional[Callable] = None, @@ -374,6 +379,7 @@ def __init__( self.pool = pool self.sla = sla self.execution_timeout = execution_timeout + self.on_execute_callback = on_execute_callback self.on_failure_callback = on_failure_callback self.on_success_callback = on_success_callback self.on_retry_callback = on_retry_callback diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 0c674d8a59432a..cdda50ff98eea1 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -923,6 +923,13 @@ def signal_handler(signum, frame): self.render_templates(context=context) task_copy.pre_execute(context=context) + try: + if task.on_execute_callback: + task.on_execute_callback(context) + except Exception as e3: + self.log.error("Failed when executing execute callback") + self.log.exception(e3) + # If a timeout is specified for the task, make it fail # if it goes beyond result = None diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index 91b343f088a0f1..e4c9537a659bbb 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -1415,6 +1415,32 @@ def test_template_with_json_variable_missing(self): with self.assertRaises(KeyError): task.render_template('{{ var.json.get("missing_variable") }}', context) + def test_execute_callback(self): + called = False + + def on_execute_callable(context): + nonlocal called + called = True + self.assertEqual( + context['dag_run'].dag_id, + 'test_dagrun_execute_callback' + ) + + dag = DAG('test_execute_callbak', start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE + datetime.timedelta(days=10)) + task = DummyOperator(task_id='op', email='test@test.test', + on_execute_callback=on_execute_callable, + dag=dag) + ti = TI(task=task, execution_date=datetime.datetime.now()) + ti.state = State.RUNNING + session = settings.Session() + session.merge(ti) + session.commit() + ti._run_raw_task() + assert called + ti.refresh_from_db() + assert ti.state == State.SUCCESS + def test_handle_failure(self): from unittest import mock