Skip to content

Commit

Permalink
Support using a list of callbacks in `on_*_callback/sla_miss_callback…
Browse files Browse the repository at this point in the history
…`s (apache#28469)

* Support using a list of callbacks in `on_*_callback/sla_miss_callback`s

Previously, it was only possible to specify a single callback function when defining a DAG/task callbacks.
This change allows users to specify a list of callback functions, which will be invoked in the order they are provided.

This will not affect DAG/task that use a single callback function.

* Apply suggestion from code review
  • Loading branch information
ephraimbuddy authored Dec 23, 2022
1 parent 7eb8054 commit d7bd6d7
Show file tree
Hide file tree
Showing 9 changed files with 143 additions and 100 deletions.
24 changes: 17 additions & 7 deletions airflow/dag_processing/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,13 +473,23 @@ def manage_slas(self, dag: DAG, session: Session = None) -> None:
notification_sent = False
if dag.sla_miss_callback:
# Execute the alert callback
self.log.info("Calling SLA miss callback")
try:
dag.sla_miss_callback(dag, task_list, blocking_task_list, slas, blocking_tis)
notification_sent = True
except Exception:
Stats.incr("sla_callback_notification_failure")
self.log.exception("Could not call sla_miss_callback for DAG %s", dag.dag_id)
callbacks = (
dag.sla_miss_callback
if isinstance(dag.sla_miss_callback, list)
else [dag.sla_miss_callback]
)
for callback in callbacks:
self.log.info("Calling SLA miss callback %s", callback)
try:
callback(dag, task_list, blocking_task_list, slas, blocking_tis)
notification_sent = True
except Exception:
Stats.incr("sla_callback_notification_failure")
self.log.exception(
"Could not call sla_miss_callback(%s) for DAG %s",
callback.func_name, # type: ignore[attr-defined]
dag.dag_id,
)
email_content = f"""\
Here's a list of tasks that missed their SLAs:
<pre><code>{task_list}\n<code></pre>
Expand Down
8 changes: 4 additions & 4 deletions airflow/example_dags/tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,10 @@
# 'wait_for_downstream': False,
# 'sla': timedelta(hours=2),
# 'execution_timeout': timedelta(seconds=300),
# 'on_failure_callback': some_function,
# 'on_success_callback': some_other_function,
# 'on_retry_callback': another_function,
# 'sla_miss_callback': yet_another_function,
# 'on_failure_callback': some_function, # or list of functions
# 'on_success_callback': some_other_function, # or list of functions
# 'on_retry_callback': another_function, # or list of functions
# 'sla_miss_callback': yet_another_function, # or list of functions
# 'trigger_rule': 'all_success'
},
# [END default_args]
Expand Down
18 changes: 9 additions & 9 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,10 +211,10 @@ def partial(
weight_rule: str = DEFAULT_WEIGHT_RULE,
sla: timedelta | None = None,
max_active_tis_per_dag: int | None = None,
on_execute_callback: TaskStateChangeCallback | None = None,
on_failure_callback: TaskStateChangeCallback | None = None,
on_success_callback: TaskStateChangeCallback | None = None,
on_retry_callback: TaskStateChangeCallback | None = None,
on_execute_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None,
on_failure_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None,
on_success_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None,
on_retry_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None,
run_as_user: str | None = None,
executor_config: dict | None = None,
inlets: Any | None = None,
Expand Down Expand Up @@ -538,7 +538,7 @@ class derived from this one results in the creation of a task object,
notification are sent once and only once for each task instance.
:param execution_timeout: max time allowed for the execution of
this task instance, if it goes beyond it will raise and fail.
:param on_failure_callback: a function to be called when a task instance
:param on_failure_callback: a function or list of functions to be called when a task instance
of this task fails. a context dictionary is passed as a single
parameter to this function. Context contains references to related
objects to the task instance and is documented under the macros
Expand Down Expand Up @@ -706,10 +706,10 @@ def __init__(
pool_slots: int = DEFAULT_POOL_SLOTS,
sla: timedelta | None = None,
execution_timeout: timedelta | None = DEFAULT_TASK_EXECUTION_TIMEOUT,
on_execute_callback: TaskStateChangeCallback | None = None,
on_failure_callback: TaskStateChangeCallback | None = None,
on_success_callback: TaskStateChangeCallback | None = None,
on_retry_callback: TaskStateChangeCallback | None = None,
on_execute_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None,
on_failure_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None,
on_success_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None,
on_retry_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None,
pre_execute: TaskPreExecuteHook | None = None,
post_execute: TaskPostExecuteHook | None = None,
trigger_rule: str = DEFAULT_TRIGGER_RULE,
Expand Down
34 changes: 18 additions & 16 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,15 +309,15 @@ class DAG(LoggingMixin):
:param dagrun_timeout: specify how long a DagRun should be up before
timing out / failing, so that new DagRuns can be created. The timeout
is only enforced for scheduled DagRuns.
:param sla_miss_callback: specify a function to call when reporting SLA
:param sla_miss_callback: specify a function or list of functions to call when reporting SLA
timeouts. See :ref:`sla_miss_callback<concepts:sla_miss_callback>` for
more information about the function signature and parameters that are
passed to the callback.
:param default_view: Specify DAG default view (grid, graph, duration,
gantt, landing_times), default grid
:param orientation: Specify DAG orientation in graph view (LR, TB, RL, BT), default LR
:param catchup: Perform scheduler catchup (or only run latest)? Defaults to True
:param on_failure_callback: A function to be called when a DagRun of this dag fails.
:param on_failure_callback: A function or list of functions to be called when a DagRun of this dag fails.
A context dictionary is passed as a single parameter to this function.
:param on_success_callback: Much like the ``on_failure_callback`` except
that it is executed when the dag succeeds.
Expand Down Expand Up @@ -396,12 +396,12 @@ def __init__(
max_active_tasks: int = conf.getint("core", "max_active_tasks_per_dag"),
max_active_runs: int = conf.getint("core", "max_active_runs_per_dag"),
dagrun_timeout: timedelta | None = None,
sla_miss_callback: SLAMissCallback | None = None,
sla_miss_callback: None | SLAMissCallback | list[SLAMissCallback] = None,
default_view: str = conf.get_mandatory_value("webserver", "dag_default_view").lower(),
orientation: str = conf.get_mandatory_value("webserver", "dag_orientation"),
catchup: bool = conf.getboolean("scheduler", "catchup_by_default"),
on_success_callback: DagStateChangeCallback | None = None,
on_failure_callback: DagStateChangeCallback | None = None,
on_success_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None,
on_failure_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None,
doc_md: str | None = None,
params: dict | None = None,
access_control: dict | None = None,
Expand Down Expand Up @@ -1313,19 +1313,21 @@ def handle_callback(self, dagrun, success=True, reason=None, session=NEW_SESSION
:param reason: Completion reason
:param session: Database session
"""
callback = self.on_success_callback if success else self.on_failure_callback
if callback:
self.log.info("Executing dag callback function: %s", callback)
callbacks = self.on_success_callback if success else self.on_failure_callback
if callbacks:
callbacks = callbacks if isinstance(callbacks, list) else [callbacks]
tis = dagrun.get_task_instances(session=session)
ti = tis[-1] # get first TaskInstance of DagRun
ti.task = self.get_task(ti.task_id)
context = ti.get_template_context(session=session)
context.update({"reason": reason})
try:
callback(context)
except Exception:
self.log.exception("failed to invoke dag state update callback")
Stats.incr("dag.callback_exceptions")
for callback in callbacks:
self.log.info("Executing dag callback function: %s", callback)
try:
callback(context)
except Exception:
self.log.exception("failed to invoke dag state update callback")
Stats.incr("dag.callback_exceptions")

def get_active_runs(self):
"""
Expand Down Expand Up @@ -3468,12 +3470,12 @@ def dag(
max_active_tasks: int = conf.getint("core", "max_active_tasks_per_dag"),
max_active_runs: int = conf.getint("core", "max_active_runs_per_dag"),
dagrun_timeout: timedelta | None = None,
sla_miss_callback: SLAMissCallback | None = None,
sla_miss_callback: None | SLAMissCallback | list[SLAMissCallback] = None,
default_view: str = conf.get_mandatory_value("webserver", "dag_default_view").lower(),
orientation: str = conf.get_mandatory_value("webserver", "dag_orientation"),
catchup: bool = conf.getboolean("scheduler", "catchup_by_default"),
on_success_callback: DagStateChangeCallback | None = None,
on_failure_callback: DagStateChangeCallback | None = None,
on_success_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None,
on_failure_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None,
doc_md: str | None = None,
params: dict | None = None,
access_control: dict | None = None,
Expand Down
8 changes: 4 additions & 4 deletions airflow/models/mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,31 +447,31 @@ def resources(self) -> Resources | None:
return self.partial_kwargs.get("resources")

@property
def on_execute_callback(self) -> TaskStateChangeCallback | None:
def on_execute_callback(self) -> None | TaskStateChangeCallback | list[TaskStateChangeCallback]:
return self.partial_kwargs.get("on_execute_callback")

@on_execute_callback.setter
def on_execute_callback(self, value: TaskStateChangeCallback | None) -> None:
self.partial_kwargs["on_execute_callback"] = value

@property
def on_failure_callback(self) -> TaskStateChangeCallback | None:
def on_failure_callback(self) -> None | TaskStateChangeCallback | list[TaskStateChangeCallback]:
return self.partial_kwargs.get("on_failure_callback")

@on_failure_callback.setter
def on_failure_callback(self, value: TaskStateChangeCallback | None) -> None:
self.partial_kwargs["on_failure_callback"] = value

@property
def on_retry_callback(self) -> TaskStateChangeCallback | None:
def on_retry_callback(self) -> None | TaskStateChangeCallback | list[TaskStateChangeCallback]:
return self.partial_kwargs.get("on_retry_callback")

@on_retry_callback.setter
def on_retry_callback(self, value: TaskStateChangeCallback | None) -> None:
self.partial_kwargs["on_retry_callback"] = value

@property
def on_success_callback(self) -> TaskStateChangeCallback | None:
def on_success_callback(self) -> None | TaskStateChangeCallback | list[TaskStateChangeCallback]:
return self.partial_kwargs.get("on_success_callback")

@on_success_callback.setter
Expand Down
42 changes: 27 additions & 15 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@
from airflow.utils.email import send_email
from airflow.utils.helpers import render_template_to_string
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.module_loading import qualname
from airflow.utils.net import get_hostname
from airflow.utils.operator_helpers import context_to_airflow_vars
from airflow.utils.platform import getuser
Expand Down Expand Up @@ -1528,14 +1529,22 @@ def signal_handler(signum, frame):
Stats.incr("ti_successes")

def _run_finished_callback(
self, callback: TaskStateChangeCallback | None, context: Context, callback_type: str
self,
callbacks: None | TaskStateChangeCallback | list[TaskStateChangeCallback],
context: Context,
callback_type: str,
) -> None:
"""Run callback after task finishes"""
try:
if callback:
callback(context)
except Exception: # pylint: disable=broad-except
self.log.exception(f"Error when executing {callback_type} callback")
if callbacks:
callbacks = callbacks if isinstance(callbacks, list) else [callbacks]
for callback in callbacks:
try:
callback(context)
except Exception: # pylint: disable=broad-except
callback_name = qualname(callback).split(".")[-1]
self.log.exception(
f"Error when executing {callback_name} callback" # type: ignore[attr-defined]
)

def _execute_task(self, context, task_orig):
"""Executes Task (optionally with a Timeout) and pushes Xcom results"""
Expand Down Expand Up @@ -1632,11 +1641,14 @@ def _defer_task(self, session: Session, defer: TaskDeferred) -> None:

def _run_execute_callback(self, context: Context, task: Operator) -> None:
"""Functions that need to be run before a Task is executed"""
try:
if task.on_execute_callback:
task.on_execute_callback(context)
except Exception:
self.log.exception("Failed when executing execute callback")
callbacks = task.on_execute_callback
if callbacks:
callbacks = callbacks if isinstance(callbacks, list) else [callbacks]
for callback in callbacks:
try:
callback(context)
except Exception:
self.log.exception("Failed when executing execute callback")

@provide_session
def run(
Expand Down Expand Up @@ -1814,15 +1826,15 @@ def handle_failure(
if force_fail or not self.is_eligible_to_retry():
self.state = State.FAILED
email_for_state = operator.attrgetter("email_on_failure")
callback = task.on_failure_callback if task else None
callbacks = task.on_failure_callback if task else None
callback_type = "on_failure"
else:
if self.state == State.QUEUED:
# We increase the try_number so as to fail the task if it fails to start after sometime
self._try_number += 1
self.state = State.UP_FOR_RETRY
email_for_state = operator.attrgetter("email_on_retry")
callback = task.on_retry_callback if task else None
callbacks = task.on_retry_callback if task else None
callback_type = "on_retry"

self._log_state("Immediate failure requested. " if force_fail else "")
Expand All @@ -1832,8 +1844,8 @@ def handle_failure(
except Exception:
self.log.exception("Failed to send email to: %s", task.email)

if callback and context:
self._run_finished_callback(callback, context, callback_type)
if callbacks and context:
self._run_finished_callback(callbacks, context, callback_type)

if not test_mode:
session.merge(self)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,5 +80,10 @@ In the following example, failures in any task call the ``task_failure_alert`` f
task1 = EmptyOperator(task_id="task1")
task2 = EmptyOperator(task_id="task2")
task3 = EmptyOperator(task_id="task3", on_success_callback=dag_success_alert)
task3 = EmptyOperator(task_id="task3", on_success_callback=[dag_success_alert])
task1 >> task2 >> task3
.. note::
As of Airflow 2.6.0, callbacks now supports a list of callback functions, allowing users to specify multiple functions
to be executed in the desired event. Simply pass a list of callback functions to the callback args when defining your DAG/task
callbacks: e.g ``on_failure_callback=[callback_func_1, callback_func_2]``
44 changes: 25 additions & 19 deletions tests/dag_processing/test_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,28 +303,34 @@ def test_dag_file_processor_sla_miss_callback_exception(self, mock_stats_incr, c
sla_callback = MagicMock(side_effect=RuntimeError("Could not call function"))

test_start_date = timezone.utcnow() - datetime.timedelta(days=1)
dag, task = create_dummy_dag(
dag_id="test_sla_miss",
task_id="dummy",
sla_miss_callback=sla_callback,
default_args={"start_date": test_start_date, "sla": datetime.timedelta(hours=1)},
)
mock_stats_incr.reset_mock()

session.merge(TaskInstance(task=task, execution_date=test_start_date, state="Success"))
for i, callback in enumerate([[sla_callback], sla_callback]):
dag, task = create_dummy_dag(
dag_id=f"test_sla_miss_{i}",
task_id="dummy",
sla_miss_callback=callback,
default_args={"start_date": test_start_date, "sla": datetime.timedelta(hours=1)},
)
mock_stats_incr.reset_mock()

# Create an SlaMiss where notification was sent, but email was not
session.merge(SlaMiss(task_id="dummy", dag_id="test_sla_miss", execution_date=test_start_date))
session.merge(TaskInstance(task=task, execution_date=test_start_date, state="Success"))

# Now call manage_slas and see if the sla_miss callback gets called
mock_log = mock.MagicMock()
dag_file_processor = DagFileProcessor(dag_ids=[], dag_directory=TEST_DAGS_FOLDER, log=mock_log)
dag_file_processor.manage_slas(dag=dag, session=session)
assert sla_callback.called
mock_log.exception.assert_called_once_with(
"Could not call sla_miss_callback for DAG %s", "test_sla_miss"
)
mock_stats_incr.assert_called_once_with("sla_callback_notification_failure")
# Create an SlaMiss where notification was sent, but email was not
session.merge(
SlaMiss(task_id="dummy", dag_id=f"test_sla_miss_{i}", execution_date=test_start_date)
)

# Now call manage_slas and see if the sla_miss callback gets called
mock_log = mock.MagicMock()
dag_file_processor = DagFileProcessor(dag_ids=[], dag_directory=TEST_DAGS_FOLDER, log=mock_log)
dag_file_processor.manage_slas(dag=dag, session=session)
assert sla_callback.called
mock_log.exception.assert_called_once_with(
"Could not call sla_miss_callback(%s) for DAG %s",
sla_callback.func_name, # type: ignore[attr-defined]
f"test_sla_miss_{i}",
)
mock_stats_incr.assert_called_once_with("sla_callback_notification_failure")

@mock.patch("airflow.dag_processing.processor.send_email")
def test_dag_file_processor_only_collect_emails_from_sla_missed_tasks(
Expand Down
Loading

0 comments on commit d7bd6d7

Please sign in to comment.