From d587677bc421c5ee4c2c178c71dd659e1fabcce7 Mon Sep 17 00:00:00 2001 From: Pankaj Date: Wed, 17 Jan 2024 17:01:56 +0530 Subject: [PATCH 1/8] Add deferrable param in FileSensor --- airflow/sensors/filesystem.py | 33 +++++++++++++++++++++++++++++--- airflow/triggers/file.py | 21 ++++++++++++++++---- tests/sensors/test_filesystem.py | 17 +++++++++++++++- tests/triggers/test_file.py | 2 +- 4 files changed, 64 insertions(+), 9 deletions(-) diff --git a/airflow/sensors/filesystem.py b/airflow/sensors/filesystem.py index 65498557bbb580..5b6a3405d91381 100644 --- a/airflow/sensors/filesystem.py +++ b/airflow/sensors/filesystem.py @@ -22,8 +22,10 @@ from glob import glob from typing import TYPE_CHECKING, Sequence +from airflow.exceptions import AirflowException from airflow.hooks.filesystem import FSHook from airflow.sensors.base import BaseSensorOperator +from airflow.triggers.file import FileTrigger if TYPE_CHECKING: from airflow.utils.context import Context @@ -42,6 +44,8 @@ class FileSensor(BaseSensorOperator): the base path set within the connection), can be a glob. :param recursive: when set to ``True``, enables recursive directory matching behavior of ``**`` in glob filepath parameter. Defaults to ``False``. + :param deferrable: If waiting for completion, whether to defer the task until done, + default is ``False``. .. seealso:: For more information on how to use this sensor, take a look at the guide: @@ -53,19 +57,25 @@ class FileSensor(BaseSensorOperator): template_fields: Sequence[str] = ("filepath",) ui_color = "#91818a" - def __init__(self, *, filepath, fs_conn_id="fs_default", recursive=False, **kwargs): + def __init__( + self, *, filepath, fs_conn_id="fs_default", recursive=False, deferrable: bool = False, **kwargs + ): super().__init__(**kwargs) self.filepath = filepath self.fs_conn_id = fs_conn_id self.recursive = recursive + self.deferrable = deferrable - def poke(self, context: Context): + @property + def path(self): hook = FSHook(self.fs_conn_id) basepath = hook.get_path() full_path = os.path.join(basepath, self.filepath) self.log.info("Poking for file %s", full_path) + return full_path - for path in glob(full_path, recursive=self.recursive): + def poke(self, context: Context): + for path in glob(self.path, recursive=self.recursive): if os.path.isfile(path): mod_time = datetime.datetime.fromtimestamp(os.path.getmtime(path)).strftime("%Y%m%d%H%M%S") self.log.info("Found File %s last modified: %s", path, mod_time) @@ -75,3 +85,20 @@ def poke(self, context: Context): if files: return True return False + + def execute(self, context: Context) -> None: + if self.deferrable and not self.poke(context=context): + self.defer( + timeout=datetime.timedelta(seconds=self.timeout), + trigger=FileTrigger( + filepath=self.path, + recursive=self.recursive, + poke_interval=self.poke_interval, + ), + method_name="execute_complete", + ) + + def execute_complete(self, context: Context, event: bool | None = None) -> None: + if not event: + raise AirflowException("%s task failed as %s not found.", self.task_id, self.filepath) + self.log.info("%s completed successfully as %s found.", self.task_id, self.filepath) diff --git a/airflow/triggers/file.py b/airflow/triggers/file.py index 93880407e53602..fe5772993b8269 100644 --- a/airflow/triggers/file.py +++ b/airflow/triggers/file.py @@ -20,9 +20,11 @@ import datetime import os import typing +import warnings from glob import glob from typing import Any +from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.triggers.base import BaseTrigger, TriggerEvent @@ -34,18 +36,29 @@ class FileTrigger(BaseTrigger): be a glob. :param recursive: when set to ``True``, enables recursive directory matching behavior of ``**`` in glob filepath parameter. Defaults to ``False``. + param poke_interval: Time that the job should wait in between each try """ def __init__( self, filepath: str, recursive: bool = False, - poll_interval: float = 5.0, + poke_interval: float = 5.0, + **kwargs, ): super().__init__() self.filepath = filepath self.recursive = recursive - self.poll_interval = poll_interval + if kwargs.get("poll_interval") is not None: + warnings.warn( + "`poll_interval` has been deprecated and will be removed in future." + "Please use `poke_interval_interval` instead.", + AirflowProviderDeprecationWarning, + stacklevel=2, + ) + self.poke_interval: float = kwargs["poll_interval"] + else: + self.poke_interval = poke_interval def serialize(self) -> tuple[str, dict[str, Any]]: """Serialize FileTrigger arguments and classpath.""" @@ -54,7 +67,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]: { "filepath": self.filepath, "recursive": self.recursive, - "poll_interval": self.poll_interval, + "poke_interval": self.poke_interval, }, ) @@ -70,4 +83,4 @@ async def run(self) -> typing.AsyncIterator[TriggerEvent]: for _, _, files in os.walk(self.filepath): if files: yield TriggerEvent(True) - await asyncio.sleep(self.poll_interval) + await asyncio.sleep(self.poke_interval) diff --git a/tests/sensors/test_filesystem.py b/tests/sensors/test_filesystem.py index 9a92ce79be4701..9158e801cab57d 100644 --- a/tests/sensors/test_filesystem.py +++ b/tests/sensors/test_filesystem.py @@ -23,9 +23,10 @@ import pytest -from airflow.exceptions import AirflowSensorTimeout +from airflow.exceptions import AirflowSensorTimeout, TaskDeferred from airflow.models.dag import DAG from airflow.sensors.filesystem import FileSensor +from airflow.triggers.file import FileTrigger from airflow.utils.timezone import datetime pytestmark = pytest.mark.db_test @@ -219,3 +220,17 @@ def test_subdirectory_empty(self): with pytest.raises(AirflowSensorTimeout): task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) shutil.rmtree(temp_dir) + + def test_task_defer(self): + task = FileSensor( + task_id="test", + filepath="temp_dir", + fs_conn_id="fs_default", + deferrable=True, + dag=self.dag, + ) + + with pytest.raises(TaskDeferred) as exc: + task.execute(None) + + assert isinstance(exc.value.trigger, FileTrigger), "Trigger is not a FileTrigger" diff --git a/tests/triggers/test_file.py b/tests/triggers/test_file.py index a324574643e422..6fb25dea3f00cb 100644 --- a/tests/triggers/test_file.py +++ b/tests/triggers/test_file.py @@ -33,7 +33,7 @@ def test_serialization(self): assert classpath == "airflow.triggers.file.FileTrigger" assert kwargs == { "filepath": self.FILE_PATH, - "poll_interval": 5, + "poke_interval": 5, "recursive": False, } From c9c7de9728fd62f016b2b5b739d28c0380de5a79 Mon Sep 17 00:00:00 2001 From: Pankaj Date: Wed, 17 Jan 2024 17:30:24 +0530 Subject: [PATCH 2/8] Use conf value for as default --- airflow/sensors/filesystem.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/airflow/sensors/filesystem.py b/airflow/sensors/filesystem.py index 5b6a3405d91381..e6486122b93d93 100644 --- a/airflow/sensors/filesystem.py +++ b/airflow/sensors/filesystem.py @@ -22,6 +22,7 @@ from glob import glob from typing import TYPE_CHECKING, Sequence +from airflow.configuration import conf from airflow.exceptions import AirflowException from airflow.hooks.filesystem import FSHook from airflow.sensors.base import BaseSensorOperator @@ -58,7 +59,13 @@ class FileSensor(BaseSensorOperator): ui_color = "#91818a" def __init__( - self, *, filepath, fs_conn_id="fs_default", recursive=False, deferrable: bool = False, **kwargs + self, + *, + filepath, + fs_conn_id="fs_default", + recursive=False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + **kwargs, ): super().__init__(**kwargs) self.filepath = filepath From 705137903da9e774c2901d232f434b2c750b2197 Mon Sep 17 00:00:00 2001 From: Pankaj Date: Thu, 18 Jan 2024 00:02:15 +0530 Subject: [PATCH 3/8] Apply review suggestions --- airflow/sensors/filesystem.py | 4 +++- airflow/triggers/file.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/airflow/sensors/filesystem.py b/airflow/sensors/filesystem.py index e6486122b93d93..21672919361938 100644 --- a/airflow/sensors/filesystem.py +++ b/airflow/sensors/filesystem.py @@ -94,7 +94,9 @@ def poke(self, context: Context): return False def execute(self, context: Context) -> None: - if self.deferrable and not self.poke(context=context): + if not self.deferrable: + super().execute(context=context) + if not self.poke(context=context): self.defer( timeout=datetime.timedelta(seconds=self.timeout), trigger=FileTrigger( diff --git a/airflow/triggers/file.py b/airflow/triggers/file.py index fe5772993b8269..0f47e4ce140d27 100644 --- a/airflow/triggers/file.py +++ b/airflow/triggers/file.py @@ -52,7 +52,7 @@ def __init__( if kwargs.get("poll_interval") is not None: warnings.warn( "`poll_interval` has been deprecated and will be removed in future." - "Please use `poke_interval_interval` instead.", + "Please use `poke_interval` instead.", AirflowProviderDeprecationWarning, stacklevel=2, ) From f6c0a49bb8bedc0009ab9b293d51cfb04ed0a776 Mon Sep 17 00:00:00 2001 From: Pankaj Date: Thu, 18 Jan 2024 00:07:35 +0530 Subject: [PATCH 4/8] Apply review suggestions --- airflow/triggers/file.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/airflow/triggers/file.py b/airflow/triggers/file.py index 0f47e4ce140d27..19ffa9112200e1 100644 --- a/airflow/triggers/file.py +++ b/airflow/triggers/file.py @@ -24,7 +24,6 @@ from glob import glob from typing import Any -from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.triggers.base import BaseTrigger, TriggerEvent @@ -53,7 +52,7 @@ def __init__( warnings.warn( "`poll_interval` has been deprecated and will be removed in future." "Please use `poke_interval` instead.", - AirflowProviderDeprecationWarning, + DeprecationWarning, stacklevel=2, ) self.poke_interval: float = kwargs["poll_interval"] From 694366e0fa118ea10baf777480655010cfddc2c8 Mon Sep 17 00:00:00 2001 From: Pankaj Date: Thu, 18 Jan 2024 00:31:53 +0530 Subject: [PATCH 5/8] Fix docs --- airflow/triggers/file.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/triggers/file.py b/airflow/triggers/file.py index 19ffa9112200e1..4a1df581bb0739 100644 --- a/airflow/triggers/file.py +++ b/airflow/triggers/file.py @@ -35,7 +35,7 @@ class FileTrigger(BaseTrigger): be a glob. :param recursive: when set to ``True``, enables recursive directory matching behavior of ``**`` in glob filepath parameter. Defaults to ``False``. - param poke_interval: Time that the job should wait in between each try + :param poke_interval: Time that the job should wait in between each try """ def __init__( From 487cbbad100f87f437bf9942b032192617c99fde Mon Sep 17 00:00:00 2001 From: Pankaj Date: Thu, 18 Jan 2024 12:03:01 +0530 Subject: [PATCH 6/8] Apply review suggestions --- airflow/sensors/filesystem.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/airflow/sensors/filesystem.py b/airflow/sensors/filesystem.py index 21672919361938..2c803d01f5cb04 100644 --- a/airflow/sensors/filesystem.py +++ b/airflow/sensors/filesystem.py @@ -19,6 +19,7 @@ import datetime import os +from functools import cached_property from glob import glob from typing import TYPE_CHECKING, Sequence @@ -73,15 +74,15 @@ def __init__( self.recursive = recursive self.deferrable = deferrable - @property - def path(self): + @cached_property + def path(self) -> str: hook = FSHook(self.fs_conn_id) basepath = hook.get_path() full_path = os.path.join(basepath, self.filepath) - self.log.info("Poking for file %s", full_path) return full_path - def poke(self, context: Context): + def poke(self, context: Context) -> bool: + self.log.info("Poking for file %s", self.path) for path in glob(self.path, recursive=self.recursive): if os.path.isfile(path): mod_time = datetime.datetime.fromtimestamp(os.path.getmtime(path)).strftime("%Y%m%d%H%M%S") From 92657717d4eff9a96d008624c5999e8102561b40 Mon Sep 17 00:00:00 2001 From: Pankaj Date: Thu, 18 Jan 2024 18:24:24 +0530 Subject: [PATCH 7/8] Add docs --- airflow/example_dags/example_sensors.py | 4 ++++ docs/apache-airflow/howto/operator/file.rst | 8 ++++++++ 2 files changed, 12 insertions(+) diff --git a/airflow/example_dags/example_sensors.py b/airflow/example_dags/example_sensors.py index 9e3fc02f9c55cd..904bb57cdf0246 100644 --- a/airflow/example_dags/example_sensors.py +++ b/airflow/example_dags/example_sensors.py @@ -98,6 +98,10 @@ def failure_callable(): t6 = FileSensor(task_id="wait_for_file", filepath="/tmp/temporary_file_for_testing") # [END example_file_sensor] + # [START example_file_sensor_async] + t6 = FileSensor(task_id="wait_for_file", filepath="/tmp/temporary_file_for_testing", deferrable=True) + # [END example_file_sensor_async] + t7 = BashOperator( task_id="create_file_after_3_seconds", bash_command="sleep 3; touch /tmp/temporary_file_for_testing" ) diff --git a/docs/apache-airflow/howto/operator/file.rst b/docs/apache-airflow/howto/operator/file.rst index 81077cc25ff4ef..49ca1c75f60429 100644 --- a/docs/apache-airflow/howto/operator/file.rst +++ b/docs/apache-airflow/howto/operator/file.rst @@ -31,3 +31,11 @@ Default connection is ``fs_default``. :dedent: 4 :start-after: [START example_file_sensor] :end-before: [END example_file_sensor] + +Also for this job you can use sensor in the deferrable mode: + +.. exampleinclude:: /../../airflow/example_dags/example_sensors.py + :language: python + :dedent: 4 + :start-after: [START example_file_sensor_async] + :end-before: [END example_file_sensor_async] From d46150963ba8f6c6307bf754211b51f7f650b420 Mon Sep 17 00:00:00 2001 From: Pankaj Date: Thu, 18 Jan 2024 18:30:22 +0530 Subject: [PATCH 8/8] Fix duplicate task --- airflow/example_dags/example_sensors.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/airflow/example_dags/example_sensors.py b/airflow/example_dags/example_sensors.py index 904bb57cdf0246..9dbe83d6e4c401 100644 --- a/airflow/example_dags/example_sensors.py +++ b/airflow/example_dags/example_sensors.py @@ -99,23 +99,25 @@ def failure_callable(): # [END example_file_sensor] # [START example_file_sensor_async] - t6 = FileSensor(task_id="wait_for_file", filepath="/tmp/temporary_file_for_testing", deferrable=True) + t7 = FileSensor( + task_id="wait_for_file_async", filepath="/tmp/temporary_file_for_testing", deferrable=True + ) # [END example_file_sensor_async] - t7 = BashOperator( + t8 = BashOperator( task_id="create_file_after_3_seconds", bash_command="sleep 3; touch /tmp/temporary_file_for_testing" ) # [START example_python_sensors] - t8 = PythonSensor(task_id="success_sensor_python", python_callable=success_callable) + t9 = PythonSensor(task_id="success_sensor_python", python_callable=success_callable) - t9 = PythonSensor( + t10 = PythonSensor( task_id="failure_timeout_sensor_python", timeout=3, soft_fail=True, python_callable=failure_callable ) # [END example_python_sensors] # [START example_day_of_week_sensor] - t10 = DayOfWeekSensor( + t11 = DayOfWeekSensor( task_id="week_day_sensor_failing_on_timeout", timeout=3, soft_fail=True, week_day=WeekDay.MONDAY ) # [END example_day_of_week_sensor] @@ -124,7 +126,7 @@ def failure_callable(): tx.trigger_rule = TriggerRule.NONE_FAILED [t0, t0a, t1, t1a, t2, t2a, t3, t4] >> tx - t5 >> t6 >> tx - t7 >> tx - [t8, t9] >> tx - t10 >> tx + t5 >> t6 >> t7 >> tx + t8 >> tx + [t9, t10] >> tx + t11 >> tx