Skip to content

Commit

Permalink
add next_kwargs to StartTriggerArgs (apache#40376)
Browse files Browse the repository at this point in the history
  • Loading branch information
Lee-W authored and romsharon98 committed Jul 26, 2024
1 parent 4ba5ee1 commit 3b192cd
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 4 deletions.
6 changes: 3 additions & 3 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -1606,15 +1606,15 @@ def _defer_task(

if exception is not None:
trigger_row = Trigger.from_object(exception.trigger)
trigger_kwargs = exception.kwargs
next_method = exception.method_name
next_kwargs = exception.kwargs
timeout = exception.timeout
elif ti.task is not None and ti.task.start_trigger_args is not None:
trigger_row = Trigger(
classpath=ti.task.start_trigger_args.trigger_cls,
kwargs=ti.task.start_trigger_args.trigger_kwargs or {},
)
trigger_kwargs = ti.task.start_trigger_args.trigger_kwargs
next_kwargs = ti.task.start_trigger_args.next_kwargs
next_method = ti.task.start_trigger_args.next_method
timeout = ti.task.start_trigger_args.timeout
else:
Expand All @@ -1635,7 +1635,7 @@ def _defer_task(
ti.state = TaskInstanceState.DEFERRED
ti.trigger_id = trigger_row.id
ti.next_method = next_method
ti.next_kwargs = trigger_kwargs or {}
ti.next_kwargs = next_kwargs or {}

# Calculate timeout too if it was passed
if timeout is not None:
Expand Down
2 changes: 2 additions & 0 deletions airflow/triggers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,15 @@ class StartTriggerArgs:
trigger_cls: str
next_method: str
trigger_kwargs: dict[str, Any] | None = None
next_kwargs: dict[str, Any] | None = None
timeout: timedelta | None = None

def serialize(self):
return {
"trigger_cls": self.trigger_cls,
"trigger_kwargs": self.trigger_kwargs,
"next_method": self.next_method,
"next_kwargs": self.next_kwargs,
"timeout": self.timeout,
}

Expand Down
5 changes: 4 additions & 1 deletion docs/apache-airflow/authoring-and-scheduling/deferring.rst
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,9 @@ Triggering Deferral from Start
If you want to defer your task directly to the triggerer without going into the worker, you can set class level attribute ``start_with_trigger`` to ``True`` and add a class level attribute ``start_trigger_args`` with an ``StartTriggerArgs`` object with the following 4 attributes to your deferrable operator:

* ``trigger_cls``: An importable path to your trigger class.
* ``trigger_kwargs``: Additional keyword arguments to pass to the method when it is called.
* ``trigger_kwargs``: Keyword arguments to pass to the ``trigger_cls`` when it's initialized.
* ``next_method``: The method name on your operator that you want Airflow to call when it resumes.
* ``next_kwargs``: Additional keyword arguments to pass to the ``next_method`` when it is called.
* ``timeout``: (Optional) A timedelta that specifies a timeout after which this deferral will fail, and fail the task instance. Defaults to ``None``, which means no timeout.


Expand All @@ -170,6 +171,7 @@ This is particularly useful when deferring is the only thing the ``execute`` met
trigger_cls="airflow.triggers.temporal.TimeDeltaTrigger",
trigger_kwargs={"moment": timedelta(hours=1)},
next_method="execute_complete",
next_kwargs=None,
timeout=None,
)
start_from_trigger = True
Expand Down Expand Up @@ -198,6 +200,7 @@ This is particularly useful when deferring is the only thing the ``execute`` met
trigger_cls="airflow.triggers.temporal.TimeDeltaTrigger",
trigger_kwargs={},
next_method="execute_complete",
next_kwargs=None,
timeout=None,
)
Expand Down
3 changes: 3 additions & 0 deletions tests/serialization/test_dag_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2199,6 +2199,7 @@ class TestOperator(BaseOperator):
trigger_cls="airflow.triggers.testing.SuccessTrigger",
trigger_kwargs=None,
next_method="execute_complete",
next_kwargs=None,
timeout=None,
)
start_from_trigger = False
Expand All @@ -2216,6 +2217,7 @@ class Test2Operator(BaseOperator):
trigger_cls="airflow.triggers.testing.SuccessTrigger",
trigger_kwargs={},
next_method="execute_complete",
next_kwargs=None,
timeout=None,
)
start_from_trigger = True
Expand All @@ -2239,6 +2241,7 @@ def execute_complete(self):
"trigger_cls": "airflow.triggers.testing.SuccessTrigger",
"trigger_kwargs": {},
"next_method": "execute_complete",
"next_kwargs": None,
"timeout": None,
}
assert task["__var"]["start_from_trigger"] is True
Expand Down

0 comments on commit 3b192cd

Please sign in to comment.