Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent start trigger initialization in scheduler #39585

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
9070360
fix(baseoperator): change start_trigger into start_trigger_cls and st…
Lee-W May 13, 2024
12905ab
refactor(taskinstance): extract common logic in defer_task
Lee-W May 14, 2024
22e6760
refactor(baseoperator): refactor start trigger arguments as a datacla…
Lee-W May 14, 2024
f88e60d
docs(deferring): update docs for newly introduced StartTriggerArgs
Lee-W May 14, 2024
14199cc
feat(baseoperator): add start_from_trigger as the flag to decide whet…
Lee-W May 15, 2024
df71b44
fix(taskinstance): fix unexpected commit
Lee-W May 28, 2024
a669068
fix(taskinstance): add _defer_task_from_task_deferred
Lee-W May 28, 2024
6731ab4
docs(deferring): update version added
Lee-W May 29, 2024
e32ff04
refactor: rename defer_task_from_* methods
Lee-W May 30, 2024
80ae6eb
fix(dagrun): remove uncessary conditions on scheduling tasks
Lee-W May 30, 2024
009bcd1
fix(taskinstance): remove unnecessay check on trigger_kwargs
Lee-W May 30, 2024
758b2e8
fix(dagrun): set start_date before deferring task from scheduler
Lee-W May 30, 2024
ed94608
refactor(triggers): move StartTriggerArgs to airflow.triggers.base
Lee-W May 30, 2024
dfa47ec
refactor: merge _defer_task* as _defer_task function
Lee-W May 31, 2024
4568a59
refactor(taskinstance): deduplicate defer_task logic
Lee-W May 31, 2024
f1e4462
style: fix mypy warning
Lee-W May 31, 2024
415993a
refactor: reorder parameter as suggested
Lee-W Jun 11, 2024
91f8b30
docs(deferring): reword description as suggested
Lee-W Jun 11, 2024
e1805e5
refactor: make argument "exception" in defer_task method required for…
Lee-W Jun 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions airflow/decorators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,8 +509,8 @@ def _expand(self, expand_input: ExpandInput, *, strict: bool) -> XComArg:
# task's expand() contribute to the op_kwargs operator argument, not
# the operator arguments themselves, and should expand against it.
expand_input_attr="op_kwargs_expand_input",
start_trigger=self.operator_class.start_trigger,
next_method=self.operator_class.next_method,
start_trigger_args=self.operator_class.start_trigger_args,
start_from_trigger=self.operator_class.start_from_trigger,
)
return XComArg(operator=operator)

Expand Down
5 changes: 3 additions & 2 deletions airflow/models/abstractoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,9 @@ class AbstractOperator(Templater, DAGNode):
"node_id", # Duplicates task_id
"task_group", # Doesn't have a useful repr, no point showing in UI
"inherits_from_empty_operator", # impl detail
"start_trigger",
"next_method",
# Decide whether to start task execution from triggerer
"start_trigger_args",
"start_from_trigger",
# For compatibility with TG, for operators these are just the current task, no point showing
"roots",
"leaves",
Expand Down
10 changes: 5 additions & 5 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@
from airflow.models.operator import Operator
from airflow.models.xcom_arg import XComArg
from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
from airflow.triggers.base import BaseTrigger
from airflow.triggers.base import BaseTrigger, StartTriggerArgs
from airflow.utils.task_group import TaskGroup
from airflow.utils.types import ArgNotSet

Expand Down Expand Up @@ -819,8 +819,8 @@ def say_hello_world(**context):
# Set to True for an operator instantiated by a mapped operator.
__from_mapped = False

start_trigger: BaseTrigger | None = None
next_method: str | None = None
start_trigger_args: StartTriggerArgs | None = None
start_from_trigger: bool = False

def __init__(
self,
Expand Down Expand Up @@ -1679,9 +1679,9 @@ def get_serialized_fields(cls):
"is_teardown",
"on_failure_fail_dagrun",
"map_index_template",
"start_trigger",
"next_method",
"start_trigger_args",
"_needs_expansion",
"start_from_trigger",
}
)
DagContext.pop_context_managed_dag()
Expand Down
16 changes: 4 additions & 12 deletions airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from airflow.api_internal.internal_api_call import internal_api_call
from airflow.callbacks.callback_requests import DagCallbackRequest
from airflow.configuration import conf as airflow_conf
from airflow.exceptions import AirflowException, RemovedInAirflow3Warning, TaskDeferred, TaskNotFound
from airflow.exceptions import AirflowException, RemovedInAirflow3Warning, TaskNotFound
from airflow.listeners.listener import get_listener_manager
from airflow.models import Log
from airflow.models.abstractoperator import NotMapped
Expand Down Expand Up @@ -1538,19 +1538,11 @@ def schedule_tis(
and not ti.task.outlets
):
dummy_ti_ids.append((ti.task_id, ti.map_index))
elif (
ti.task.start_trigger is not None
and ti.task.next_method is not None
and not ti.task.on_execute_callback
and not ti.task.on_success_callback
and not ti.task.outlets
):
elif ti.task.start_from_trigger is True and ti.task.start_trigger_args is not None:
ti.start_date = timezone.utcnow()
if ti.state != TaskInstanceState.UP_FOR_RESCHEDULE:
ti.try_number += 1
ti.defer_task(
exception=TaskDeferred(trigger=ti.task.start_trigger, method_name=ti.task.next_method),
session=session,
)
ti.defer_task(exception=None, session=session)
else:
schedulable_ti_ids.append((ti.task_id, ti.map_index))

Expand Down
17 changes: 6 additions & 11 deletions airflow/models/mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@
from airflow.models.param import ParamsDict
from airflow.models.xcom_arg import XComArg
from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
from airflow.triggers.base import BaseTrigger
from airflow.triggers.base import StartTriggerArgs
from airflow.utils.context import Context
from airflow.utils.operator_resources import Resources
from airflow.utils.task_group import TaskGroup
Expand Down Expand Up @@ -237,8 +237,8 @@ def _expand(self, expand_input: ExpandInput, *, strict: bool) -> MappedOperator:
# For classic operators, this points to expand_input because kwargs
# to BaseOperator.expand() contribute to operator arguments.
expand_input_attr="expand_input",
start_trigger=self.operator_class.start_trigger,
next_method=self.operator_class.next_method,
start_trigger_args=self.operator_class.start_trigger_args,
start_from_trigger=self.operator_class.start_from_trigger,
)
return op

Expand Down Expand Up @@ -281,8 +281,8 @@ class MappedOperator(AbstractOperator):
_task_module: str
_task_type: str
_operator_name: str
start_trigger: BaseTrigger | None
next_method: str | None
start_trigger_args: StartTriggerArgs | None
start_from_trigger: bool
_needs_expansion: bool = True

dag: DAG | None
Expand All @@ -309,12 +309,7 @@ class MappedOperator(AbstractOperator):
supports_lineage: bool = False

HIDE_ATTRS_FROM_UI: ClassVar[frozenset[str]] = AbstractOperator.HIDE_ATTRS_FROM_UI | frozenset(
(
"parse_time_mapped_ti_count",
"operator_class",
"start_trigger",
"next_method",
)
("parse_time_mapped_ti_count", "operator_class", "start_trigger_args", "start_from_trigger")
)

def __hash__(self):
Expand Down
39 changes: 29 additions & 10 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -1575,12 +1575,29 @@ def _coalesce_to_orm_ti(*, ti: TaskInstancePydantic | TaskInstance, session: Ses
@internal_api_call
@provide_session
def _defer_task(
Lee-W marked this conversation as resolved.
Show resolved Hide resolved
ti: TaskInstance | TaskInstancePydantic, exception: TaskDeferred, session: Session = NEW_SESSION
ti: TaskInstance | TaskInstancePydantic,
exception: TaskDeferred | None = None,
session: Session = NEW_SESSION,
) -> TaskInstancePydantic | TaskInstance:
from airflow.models.trigger import Trigger

if exception is not None:
trigger_row = Trigger.from_object(exception.trigger)
trigger_kwargs = exception.kwargs
next_method = exception.method_name
timeout = exception.timeout
elif ti.task is not None and ti.task.start_trigger_args is not None:
trigger_row = Trigger(
classpath=ti.task.start_trigger_args.trigger_cls,
kwargs=ti.task.start_trigger_args.trigger_kwargs or {},
)
trigger_kwargs = ti.task.start_trigger_args.trigger_kwargs
Lee-W marked this conversation as resolved.
Show resolved Hide resolved
next_method = ti.task.start_trigger_args.next_method
timeout = ti.task.start_trigger_args.timeout
else:
raise AirflowException("exception and ti.task.start_trigger_args cannot both be None")

# First, make the trigger entry
trigger_row = Trigger.from_object(exception.trigger)
session.add(trigger_row)
session.flush()

Expand All @@ -1594,12 +1611,12 @@ def _defer_task(
# depending on self.next_method semantics
ti.state = TaskInstanceState.DEFERRED
ti.trigger_id = trigger_row.id
ti.next_method = exception.method_name
ti.next_kwargs = exception.kwargs or {}
ti.next_method = next_method
ti.next_kwargs = trigger_kwargs or {}

# Calculate timeout too if it was passed
if exception.timeout is not None:
ti.trigger_timeout = timezone.utcnow() + exception.timeout
if timeout is not None:
ti.trigger_timeout = timezone.utcnow() + timeout
else:
ti.trigger_timeout = None

Expand All @@ -1615,8 +1632,10 @@ def _defer_task(
ti.trigger_timeout = ti.start_date + execution_timeout
if ti.test_mode:
_add_log(event=ti.state, task_instance=ti, session=session)
session.merge(ti)
session.commit()

if exception is not None:
dstandish marked this conversation as resolved.
Show resolved Hide resolved
session.merge(ti)
session.commit()
return ti


Expand Down Expand Up @@ -3000,8 +3019,8 @@ def _execute_task(self, context: Context, task_orig: Operator):
return _execute_task(self, context, task_orig)

@provide_session
def defer_task(self, exception: TaskDeferred, session: Session) -> None:
"""Mark the task as deferred and sets up the trigger that is needed to resume it.
def defer_task(self, exception: TaskDeferred | None, session: Session = NEW_SESSION) -> None:
"""Mark the task as deferred and sets up the trigger that is needed to resume it when TaskDeferred is raised.

:meta: private
"""
Expand Down
31 changes: 12 additions & 19 deletions airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,10 @@
airflow_priority_weight_strategies,
airflow_priority_weight_strategies_classes,
)
from airflow.triggers.base import BaseTrigger
from airflow.triggers.base import BaseTrigger, StartTriggerArgs
from airflow.utils.code_utils import get_python_source
from airflow.utils.context import Context, OutletEventAccessor, OutletEventAccessors
from airflow.utils.docs import get_docs_url
from airflow.utils.helpers import exactly_one
from airflow.utils.module_loading import import_string, qualname
from airflow.utils.operator_resources import Resources
from airflow.utils.task_group import MappedTaskGroup, TaskGroup
Expand Down Expand Up @@ -1018,11 +1017,10 @@ def _serialize_node(cls, op: BaseOperator | MappedOperator, include_deps: bool)
# Used to determine if an Operator is inherited from EmptyOperator
serialize_op["_is_empty"] = op.inherits_from_empty_operator

if exactly_one(op.start_trigger is not None, op.next_method is not None):
raise AirflowException("start_trigger and next_method should both be set.")

serialize_op["start_trigger"] = op.start_trigger.serialize() if op.start_trigger else None
serialize_op["next_method"] = op.next_method
serialize_op["start_trigger_args"] = (
op.start_trigger_args.serialize() if op.start_trigger_args else None
)
serialize_op["start_from_trigger"] = op.start_from_trigger

if op.operator_extra_links:
serialize_op["_operator_extra_links"] = cls._serialize_operator_extra_links(
Expand Down Expand Up @@ -1206,16 +1204,11 @@ def populate_operator(cls, op: Operator, encoded_op: dict[str, Any]) -> None:
# Used to determine if an Operator is inherited from EmptyOperator
setattr(op, "_is_empty", bool(encoded_op.get("_is_empty", False)))

# Deserialize start_trigger
serialized_start_trigger = encoded_op.get("start_trigger")
if serialized_start_trigger:
trigger_cls_name, trigger_kwargs = serialized_start_trigger
trigger_cls = import_string(trigger_cls_name)
start_trigger = trigger_cls(**trigger_kwargs)
setattr(op, "start_trigger", start_trigger)
else:
setattr(op, "start_trigger", None)
setattr(op, "next_method", encoded_op.get("next_method", None))
start_trigger_args = None
if encoded_op.get("start_trigger_args", None):
start_trigger_args = StartTriggerArgs(**encoded_op.get("start_trigger_args", None))
setattr(op, "start_trigger_args", start_trigger_args)
setattr(op, "start_from_trigger", bool(encoded_op.get("start_from_trigger", False)))

@staticmethod
def set_task_dag_references(task: Operator, dag: DAG) -> None:
Expand Down Expand Up @@ -1278,8 +1271,8 @@ def deserialize_operator(cls, encoded_op: dict[str, Any]) -> Operator:
end_date=None,
disallow_kwargs_override=encoded_op["_disallow_kwargs_override"],
expand_input_attr=encoded_op["_expand_input_attr"],
start_trigger=None,
next_method=None,
start_trigger_args=encoded_op.get("start_trigger_args", None),
start_from_trigger=encoded_op.get("start_from_trigger", False),
)
else:
op = SerializedBaseOperator(task_id=encoded_op["task_id"])
Expand Down
20 changes: 20 additions & 0 deletions airflow/triggers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,31 @@
from __future__ import annotations

import abc
from dataclasses import dataclass
from datetime import timedelta
from typing import Any, AsyncIterator

from airflow.utils.log.logging_mixin import LoggingMixin


@dataclass
class StartTriggerArgs:
"""Arguments required for start task execution from triggerer."""

trigger_cls: str
next_method: str
trigger_kwargs: dict[str, Any] | None = None
timeout: timedelta | None = None

def serialize(self):
return {
"trigger_cls": self.trigger_cls,
"trigger_kwargs": self.trigger_kwargs,
"next_method": self.next_method,
"timeout": self.timeout,
}


class BaseTrigger(abc.ABC, LoggingMixin):
"""
Base class for all triggers.
Expand Down
36 changes: 26 additions & 10 deletions docs/apache-airflow/authoring-and-scheduling/deferring.rst
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,14 @@ The ``self.defer`` call raises the ``TaskDeferred`` exception, so it can work an
Triggering Deferral from Start
Lee-W marked this conversation as resolved.
Show resolved Hide resolved
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

If you want to defer your task directly to the triggerer without going into the worker, you can add the class level attributes ``start_trigger`` and ``next_method`` to your deferrable operator.
.. versionadded:: 2.10.0

* ``start_trigger``: An instance of a trigger you want to defer to. It will be serialized into the database.
If you want to defer your task directly to the triggerer without going into the worker, you can set class level attribute ``start_with_trigger`` to ``True`` and add a class level attribute ``start_trigger_args`` with an ``StartTriggerArgs`` object with the following 4 attributes to your deferrable operator:

* ``trigger_cls``: An importable path to your trigger class.
* ``trigger_kwargs``: Additional keyword arguments to pass to the method when it is called.
* ``next_method``: The method name on your operator that you want Airflow to call when it resumes.
* ``timeout``: (Optional) A timedelta that specifies a timeout after which this deferral will fail, and fail the task instance. Defaults to ``None``, which means no timeout.


This is particularly useful when deferring is the only thing the ``execute`` method does. Here's a basic refinement of the previous example.
Expand All @@ -156,23 +160,28 @@ This is particularly useful when deferring is the only thing the ``execute`` met
from datetime import timedelta
from typing import Any

from airflow.triggers.base import StartTriggerArgs
from airflow.sensors.base import BaseSensorOperator
from airflow.triggers.temporal import TimeDeltaTrigger
from airflow.utils.context import Context


class WaitOneHourSensor(BaseSensorOperator):
start_trigger = TimeDeltaTrigger(timedelta(hours=1))
next_method = "execute_complete"
start_trigger_args = StartTriggerArgs(
trigger_cls="airflow.triggers.temporal.TimeDeltaTrigger",
trigger_kwargs={"moment": timedelta(hours=1)},
next_method="execute_complete",
timeout=None,
)
start_from_trigger = True

def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None:
# We have no more work to do here. Mark as complete.
return

``start_trigger`` and ``next_method`` can also be set at the instance level for more flexible configuration.
``start_from_trigger`` and ``trigger_kwargs`` can also be modified at the instance level for more flexible configuration.

.. warning::
Dynamic task mapping is not supported when ``start_trigger`` and ``next_method`` are assigned in instance level.
Dynamic task mapping is not supported when ``trigger_kwargs`` is modified at instance level.

.. code-block:: python

Expand All @@ -184,11 +193,18 @@ This is particularly useful when deferring is the only thing the ``execute`` met
from airflow.utils.context import Context


class WaitOneHourSensor(BaseSensorOperator):
class WaitTwoHourSensor(BaseSensorOperator):
start_trigger_args = StartTriggerArgs(
trigger_cls="airflow.triggers.temporal.TimeDeltaTrigger",
trigger_kwargs={},
next_method="execute_complete",
timeout=None,
)

def __init__(self, *args: list[Any], **kwargs: dict[str, Any]) -> None:
super().__init__(*args, **kwargs)
self.start_trigger = TimeDeltaTrigger(timedelta(hours=1))
self.next_method = "execute_complete"
self.start_trigger_args.trigger_kwargs = {"moment": timedelta(hours=1)}
self.start_from_trigger = True

def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None:
# We have no more work to do here. Mark as complete.
Expand Down
Loading