Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create a new method used to resume the task in order to implement specific logic for operators #33424

Merged
merged 6 commits into from
Aug 21, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
AirflowException,
FailStopDagInvalidTriggerRule,
RemovedInAirflow3Warning,
TaskDeferralError,
TaskDeferred,
)
from airflow.lineage import apply_lineage, prepare_lineage
Expand Down Expand Up @@ -1590,6 +1591,22 @@ def defer(
"""
raise TaskDeferred(trigger=trigger, method_name=method_name, kwargs=kwargs, timeout=timeout)

def resume_execution(self, next_method: str, next_kwargs: dict[str, Any] | None, context: Context):
"""This method is called when a deferred task is resumed."""
# __fail__ is a special signal value for next_method that indicates
# this task was scheduled specifically to fail.
if next_method == "__fail__":
next_kwargs = next_kwargs or {}
traceback = next_kwargs.get("traceback")
if traceback is not None:
self.log.error("Trigger failed:\n%s", "\n".join(traceback))
raise TaskDeferralError(next_kwargs.get("error", "Unknown"))
# Grab the callable off the Operator/Task and add in any kwargs
execute_callable = getattr(self, next_method)
if next_kwargs:
execute_callable = functools.partial(execute_callable, **next_kwargs)
return execute_callable(context)

def unmap(self, resolve: None | dict[str, Any] | tuple[Context, Session]) -> BaseOperator:
"""Get the "normal" operator from the current operator.

Expand Down
22 changes: 6 additions & 16 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from collections import defaultdict
from datetime import datetime, timedelta
from enum import Enum
from functools import partial
from pathlib import PurePath
from types import TracebackType
from typing import TYPE_CHECKING, Any, Callable, Collection, Generator, Iterable, Tuple
Expand Down Expand Up @@ -81,7 +80,6 @@
AirflowTaskTimeout,
DagRunNotFound,
RemovedInAirflow3Warning,
TaskDeferralError,
TaskDeferred,
UnmappableXComLengthPushed,
UnmappableXComTypePushed,
Expand Down Expand Up @@ -1710,19 +1708,11 @@ def _execute_task(self, context, task_orig):
# If the task has been deferred and is being executed due to a trigger,
# then we need to pick the right method to come back to, otherwise
# we go for the default execute
execute_callable_kwargs = {}
if self.next_method:
# __fail__ is a special signal value for next_method that indicates
# this task was scheduled specifically to fail.
if self.next_method == "__fail__":
next_kwargs = self.next_kwargs or {}
traceback = self.next_kwargs.get("traceback")
if traceback is not None:
self.log.error("Trigger failed:\n%s", "\n".join(traceback))
raise TaskDeferralError(next_kwargs.get("error", "Unknown"))
# Grab the callable off the Operator/Task and add in any kwargs
execute_callable = getattr(task_to_execute, self.next_method)
if self.next_kwargs:
execute_callable = partial(execute_callable, **self.next_kwargs)
execute_callable = task_to_execute.resume_execution
execute_callable_kwargs["next_method"] = self.next_method
execute_callable_kwargs["next_kwargs"] = self.next_kwargs
else:
execute_callable = task_to_execute.execute
# If a timeout is specified for the task, make it fail
Expand All @@ -1742,12 +1732,12 @@ def _execute_task(self, context, task_orig):
raise AirflowTaskTimeout()
# Run task in timeout wrapper
with timeout(timeout_seconds):
result = execute_callable(context=context)
result = execute_callable(context=context, **execute_callable_kwargs)
except AirflowTaskTimeout:
task_to_execute.on_kill()
raise
else:
result = execute_callable(context=context)
result = execute_callable(context=context, **execute_callable_kwargs)
with create_session() as session:
if task_to_execute.do_xcom_push:
xcom_value = result
Expand Down
11 changes: 11 additions & 0 deletions airflow/sensors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
AirflowSensorTimeout,
AirflowSkipException,
AirflowTaskTimeout,
TaskDeferralError,
)
from airflow.executors.executor_loader import ExecutorLoader
from airflow.models.baseoperator import BaseOperator
Expand Down Expand Up @@ -281,6 +282,16 @@ def run_duration() -> float:
self.log.info("Success criteria met. Exiting.")
return xcom_value

def resume_execution(self, next_method: str, next_kwargs: dict[str, Any] | None, context: Context):
try:
return super().resume_execution(next_method, next_kwargs, context)
except (AirflowException, TaskDeferralError) as e:
if self.soft_fail:
raise AirflowSkipException(e)
hussein-awala marked this conversation as resolved.
Show resolved Hide resolved
raise
except Exception:
hussein-awala marked this conversation as resolved.
Show resolved Hide resolved
raise
hussein-awala marked this conversation as resolved.
Show resolved Hide resolved

def _get_next_poke_interval(
self,
started_at: datetime.datetime | float,
Expand Down
34 changes: 32 additions & 2 deletions tests/sensors/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,12 @@
import pytest
import time_machine

from airflow.exceptions import AirflowException, AirflowRescheduleException, AirflowSensorTimeout
from airflow.exceptions import (
AirflowException,
AirflowRescheduleException,
AirflowSensorTimeout,
AirflowSkipException,
)
from airflow.executors.debug_executor import DebugExecutor
from airflow.executors.executor_constants import (
CELERY_EXECUTOR,
Expand All @@ -37,7 +42,7 @@
)
from airflow.executors.local_executor import LocalExecutor
from airflow.executors.sequential_executor import SequentialExecutor
from airflow.models import TaskReschedule
from airflow.models import TaskInstance, TaskReschedule
from airflow.models.xcom import XCom
from airflow.operators.empty import EmptyOperator
from airflow.providers.celery.executors.celery_executor import CeleryExecutor
Expand Down Expand Up @@ -70,6 +75,15 @@ def poke(self, context: Context):
return self.return_value


class DummyAsyncSensor(BaseSensorOperator):
def __init__(self, return_value=False, **kwargs):
super().__init__(**kwargs)
self.return_value = return_value

def execute_complete(self, context, event=None):
raise AirflowException("Should be skipped")


class DummySensorWithXcomValue(BaseSensorOperator):
def __init__(self, return_value=False, xcom_value=None, **kwargs):
super().__init__(**kwargs)
Expand Down Expand Up @@ -910,3 +924,19 @@ def test_poke_mode_only_bad_poke(self):
sensor = DummyPokeOnlySensor(task_id="foo", mode="poke", poke_changes_mode=True)
with pytest.raises(ValueError, match="Cannot set mode to 'reschedule'. Only 'poke' is acceptable"):
sensor.poke({})


class TestAsyncSensor:
@pytest.mark.parametrize(
"soft_fail, expected_exception",
[
(True, AirflowSkipException),
(False, AirflowException),
],
)
def test_fail_after_resuming_deffered_sensor(self, soft_fail, expected_exception):
async_sensor = DummyAsyncSensor(task_id="dummy_async_sensor", soft_fail=soft_fail)
ti = TaskInstance(task=async_sensor)
ti.next_method = "execute_complete"
with pytest.raises(expected_exception):
ti._execute_task({}, None)