diff --git a/airflow/example_dags/example_sensors.py b/airflow/example_dags/example_sensors.py index 9e3fc02f9c55c..9dbe83d6e4c40 100644 --- a/airflow/example_dags/example_sensors.py +++ b/airflow/example_dags/example_sensors.py @@ -98,20 +98,26 @@ def failure_callable(): t6 = FileSensor(task_id="wait_for_file", filepath="/tmp/temporary_file_for_testing") # [END example_file_sensor] - t7 = BashOperator( + # [START example_file_sensor_async] + t7 = FileSensor( + task_id="wait_for_file_async", filepath="/tmp/temporary_file_for_testing", deferrable=True + ) + # [END example_file_sensor_async] + + t8 = BashOperator( task_id="create_file_after_3_seconds", bash_command="sleep 3; touch /tmp/temporary_file_for_testing" ) # [START example_python_sensors] - t8 = PythonSensor(task_id="success_sensor_python", python_callable=success_callable) + t9 = PythonSensor(task_id="success_sensor_python", python_callable=success_callable) - t9 = PythonSensor( + t10 = PythonSensor( task_id="failure_timeout_sensor_python", timeout=3, soft_fail=True, python_callable=failure_callable ) # [END example_python_sensors] # [START example_day_of_week_sensor] - t10 = DayOfWeekSensor( + t11 = DayOfWeekSensor( task_id="week_day_sensor_failing_on_timeout", timeout=3, soft_fail=True, week_day=WeekDay.MONDAY ) # [END example_day_of_week_sensor] @@ -120,7 +126,7 @@ def failure_callable(): tx.trigger_rule = TriggerRule.NONE_FAILED [t0, t0a, t1, t1a, t2, t2a, t3, t4] >> tx - t5 >> t6 >> tx - t7 >> tx - [t8, t9] >> tx - t10 >> tx + t5 >> t6 >> t7 >> tx + t8 >> tx + [t9, t10] >> tx + t11 >> tx diff --git a/airflow/sensors/filesystem.py b/airflow/sensors/filesystem.py index 65498557bbb58..2c803d01f5cb0 100644 --- a/airflow/sensors/filesystem.py +++ b/airflow/sensors/filesystem.py @@ -19,11 +19,15 @@ import datetime import os +from functools import cached_property from glob import glob from typing import TYPE_CHECKING, Sequence +from airflow.configuration import conf +from airflow.exceptions import AirflowException from airflow.hooks.filesystem import FSHook from airflow.sensors.base import BaseSensorOperator +from airflow.triggers.file import FileTrigger if TYPE_CHECKING: from airflow.utils.context import Context @@ -42,6 +46,8 @@ class FileSensor(BaseSensorOperator): the base path set within the connection), can be a glob. :param recursive: when set to ``True``, enables recursive directory matching behavior of ``**`` in glob filepath parameter. Defaults to ``False``. + :param deferrable: If waiting for completion, whether to defer the task until done, + default is ``False``. .. seealso:: For more information on how to use this sensor, take a look at the guide: @@ -53,19 +59,31 @@ class FileSensor(BaseSensorOperator): template_fields: Sequence[str] = ("filepath",) ui_color = "#91818a" - def __init__(self, *, filepath, fs_conn_id="fs_default", recursive=False, **kwargs): + def __init__( + self, + *, + filepath, + fs_conn_id="fs_default", + recursive=False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + **kwargs, + ): super().__init__(**kwargs) self.filepath = filepath self.fs_conn_id = fs_conn_id self.recursive = recursive + self.deferrable = deferrable - def poke(self, context: Context): + @cached_property + def path(self) -> str: hook = FSHook(self.fs_conn_id) basepath = hook.get_path() full_path = os.path.join(basepath, self.filepath) - self.log.info("Poking for file %s", full_path) + return full_path - for path in glob(full_path, recursive=self.recursive): + def poke(self, context: Context) -> bool: + self.log.info("Poking for file %s", self.path) + for path in glob(self.path, recursive=self.recursive): if os.path.isfile(path): mod_time = datetime.datetime.fromtimestamp(os.path.getmtime(path)).strftime("%Y%m%d%H%M%S") self.log.info("Found File %s last modified: %s", path, mod_time) @@ -75,3 +93,22 @@ def poke(self, context: Context): if files: return True return False + + def execute(self, context: Context) -> None: + if not self.deferrable: + super().execute(context=context) + if not self.poke(context=context): + self.defer( + timeout=datetime.timedelta(seconds=self.timeout), + trigger=FileTrigger( + filepath=self.path, + recursive=self.recursive, + poke_interval=self.poke_interval, + ), + method_name="execute_complete", + ) + + def execute_complete(self, context: Context, event: bool | None = None) -> None: + if not event: + raise AirflowException("%s task failed as %s not found.", self.task_id, self.filepath) + self.log.info("%s completed successfully as %s found.", self.task_id, self.filepath) diff --git a/airflow/triggers/file.py b/airflow/triggers/file.py index 93880407e5360..4a1df581bb073 100644 --- a/airflow/triggers/file.py +++ b/airflow/triggers/file.py @@ -20,6 +20,7 @@ import datetime import os import typing +import warnings from glob import glob from typing import Any @@ -34,18 +35,29 @@ class FileTrigger(BaseTrigger): be a glob. :param recursive: when set to ``True``, enables recursive directory matching behavior of ``**`` in glob filepath parameter. Defaults to ``False``. + :param poke_interval: Time that the job should wait in between each try """ def __init__( self, filepath: str, recursive: bool = False, - poll_interval: float = 5.0, + poke_interval: float = 5.0, + **kwargs, ): super().__init__() self.filepath = filepath self.recursive = recursive - self.poll_interval = poll_interval + if kwargs.get("poll_interval") is not None: + warnings.warn( + "`poll_interval` has been deprecated and will be removed in future." + "Please use `poke_interval` instead.", + DeprecationWarning, + stacklevel=2, + ) + self.poke_interval: float = kwargs["poll_interval"] + else: + self.poke_interval = poke_interval def serialize(self) -> tuple[str, dict[str, Any]]: """Serialize FileTrigger arguments and classpath.""" @@ -54,7 +66,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]: { "filepath": self.filepath, "recursive": self.recursive, - "poll_interval": self.poll_interval, + "poke_interval": self.poke_interval, }, ) @@ -70,4 +82,4 @@ async def run(self) -> typing.AsyncIterator[TriggerEvent]: for _, _, files in os.walk(self.filepath): if files: yield TriggerEvent(True) - await asyncio.sleep(self.poll_interval) + await asyncio.sleep(self.poke_interval) diff --git a/docs/apache-airflow/howto/operator/file.rst b/docs/apache-airflow/howto/operator/file.rst index 81077cc25ff4e..49ca1c75f6042 100644 --- a/docs/apache-airflow/howto/operator/file.rst +++ b/docs/apache-airflow/howto/operator/file.rst @@ -31,3 +31,11 @@ Default connection is ``fs_default``. :dedent: 4 :start-after: [START example_file_sensor] :end-before: [END example_file_sensor] + +Also for this job you can use sensor in the deferrable mode: + +.. exampleinclude:: /../../airflow/example_dags/example_sensors.py + :language: python + :dedent: 4 + :start-after: [START example_file_sensor_async] + :end-before: [END example_file_sensor_async] diff --git a/tests/sensors/test_filesystem.py b/tests/sensors/test_filesystem.py index 9a92ce79be470..9158e801cab57 100644 --- a/tests/sensors/test_filesystem.py +++ b/tests/sensors/test_filesystem.py @@ -23,9 +23,10 @@ import pytest -from airflow.exceptions import AirflowSensorTimeout +from airflow.exceptions import AirflowSensorTimeout, TaskDeferred from airflow.models.dag import DAG from airflow.sensors.filesystem import FileSensor +from airflow.triggers.file import FileTrigger from airflow.utils.timezone import datetime pytestmark = pytest.mark.db_test @@ -219,3 +220,17 @@ def test_subdirectory_empty(self): with pytest.raises(AirflowSensorTimeout): task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) shutil.rmtree(temp_dir) + + def test_task_defer(self): + task = FileSensor( + task_id="test", + filepath="temp_dir", + fs_conn_id="fs_default", + deferrable=True, + dag=self.dag, + ) + + with pytest.raises(TaskDeferred) as exc: + task.execute(None) + + assert isinstance(exc.value.trigger, FileTrigger), "Trigger is not a FileTrigger" diff --git a/tests/triggers/test_file.py b/tests/triggers/test_file.py index a324574643e42..6fb25dea3f00c 100644 --- a/tests/triggers/test_file.py +++ b/tests/triggers/test_file.py @@ -33,7 +33,7 @@ def test_serialization(self): assert classpath == "airflow.triggers.file.FileTrigger" assert kwargs == { "filepath": self.FILE_PATH, - "poll_interval": 5, + "poke_interval": 5, "recursive": False, }