Skip to content
This repository has been archived by the owner on May 22, 2021. It is now read-only.

Commit

Permalink
[AIRFLOW-6262] add on_execute_callback to operators (apache#6831)
Browse files Browse the repository at this point in the history
  • Loading branch information
Qingping Hou authored and galuszkak committed Mar 5, 2020
1 parent c012dd7 commit 260e4ed
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 0 deletions.
6 changes: 6 additions & 0 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,9 @@ class derived from this one results in the creation of a task object,
objects to the task instance and is documented under the macros
section of the API.
:type on_failure_callback: callable
:param on_execute_callback: much like the ``on_failure_callback`` except
that it is executed right before the task is executed.
:type on_execute_callback: callable
:param on_retry_callback: much like the ``on_failure_callback`` except
that it is executed when retries occur.
:type on_retry_callback: callable
Expand Down Expand Up @@ -274,6 +277,7 @@ class derived from this one results in the creation of a task object,
'priority_weight',
'sla',
'execution_timeout',
'on_execute_callback',
'on_failure_callback',
'on_success_callback',
'on_retry_callback',
Expand Down Expand Up @@ -307,6 +311,7 @@ def __init__(
pool: str = Pool.DEFAULT_POOL_NAME,
sla: Optional[timedelta] = None,
execution_timeout: Optional[timedelta] = None,
on_execute_callback: Optional[Callable] = None,
on_failure_callback: Optional[Callable] = None,
on_success_callback: Optional[Callable] = None,
on_retry_callback: Optional[Callable] = None,
Expand Down Expand Up @@ -374,6 +379,7 @@ def __init__(
self.pool = pool
self.sla = sla
self.execution_timeout = execution_timeout
self.on_execute_callback = on_execute_callback
self.on_failure_callback = on_failure_callback
self.on_success_callback = on_success_callback
self.on_retry_callback = on_retry_callback
Expand Down
7 changes: 7 additions & 0 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -923,6 +923,13 @@ def signal_handler(signum, frame):
self.render_templates(context=context)
task_copy.pre_execute(context=context)

try:
if task.on_execute_callback:
task.on_execute_callback(context)
except Exception as e3:
self.log.error("Failed when executing execute callback")
self.log.exception(e3)

# If a timeout is specified for the task, make it fail
# if it goes beyond
result = None
Expand Down
26 changes: 26 additions & 0 deletions tests/models/test_taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -1415,6 +1415,32 @@ def test_template_with_json_variable_missing(self):
with self.assertRaises(KeyError):
task.render_template('{{ var.json.get("missing_variable") }}', context)

def test_execute_callback(self):
called = False

def on_execute_callable(context):
nonlocal called
called = True
self.assertEqual(
context['dag_run'].dag_id,
'test_dagrun_execute_callback'
)

dag = DAG('test_execute_callbak', start_date=DEFAULT_DATE,
end_date=DEFAULT_DATE + datetime.timedelta(days=10))
task = DummyOperator(task_id='op', email='[email protected]',
on_execute_callback=on_execute_callable,
dag=dag)
ti = TI(task=task, execution_date=datetime.datetime.now())
ti.state = State.RUNNING
session = settings.Session()
session.merge(ti)
session.commit()
ti._run_raw_task()
assert called
ti.refresh_from_db()
assert ti.state == State.SUCCESS

def test_handle_failure(self):
from unittest import mock

Expand Down

0 comments on commit 260e4ed

Please sign in to comment.