diff --git a/airflow/providers/sftp/sensors/sftp.py b/airflow/providers/sftp/sensors/sftp.py index 02055f31e3f819..de3870937d43b2 100644 --- a/airflow/providers/sftp/sensors/sftp.py +++ b/airflow/providers/sftp/sensors/sftp.py @@ -30,7 +30,7 @@ from airflow.providers.sftp.hooks.sftp import SFTPHook from airflow.providers.sftp.triggers.sftp import SFTPTrigger from airflow.sensors.base import BaseSensorOperator, PokeReturnValue -from airflow.utils.timezone import convert_to_utc +from airflow.utils.timezone import convert_to_utc, parse if TYPE_CHECKING: from airflow.utils.context import Context @@ -57,7 +57,7 @@ def __init__( *, path: str, file_pattern: str = "", - newer_than: datetime | None = None, + newer_than: datetime | str | None = None, sftp_conn_id: str = "sftp_default", python_callable: Callable | None = None, op_args: list | None = None, @@ -70,7 +70,7 @@ def __init__( self.file_pattern = file_pattern self.hook: SFTPHook | None = None self.sftp_conn_id = sftp_conn_id - self.newer_than: datetime | None = newer_than + self.newer_than: datetime | str | None = newer_than self.python_callable: Callable | None = python_callable self.op_args = op_args or [] self.op_kwargs = op_kwargs or {} @@ -105,6 +105,8 @@ def poke(self, context: Context) -> PokeReturnValue | bool: continue if self.newer_than: + if isinstance(self.newer_than, str): + self.newer_than = parse(self.newer_than) _mod_time = convert_to_utc(datetime.strptime(mod_time, "%Y%m%d%H%M%S")) _newer_than = convert_to_utc(self.newer_than) if _newer_than <= _mod_time: diff --git a/tests/providers/sftp/sensors/test_sftp.py b/tests/providers/sftp/sensors/test_sftp.py index 6a08b377ce80f6..25add45e153fb8 100644 --- a/tests/providers/sftp/sensors/test_sftp.py +++ b/tests/providers/sftp/sensors/test_sftp.py @@ -17,7 +17,7 @@ # under the License. from __future__ import annotations -from datetime import datetime +from datetime import datetime, timezone as stdlib_timezone from unittest import mock from unittest.mock import Mock, call, patch @@ -97,11 +97,25 @@ def test_file_not_new_enough(self, sftp_hook_mock): sftp_hook_mock.return_value.get_mod_time.assert_called_once_with("/path/to/file/1970-01-01.txt") assert not output + @pytest.mark.parametrize( + "newer_than", + ( + datetime(2020, 1, 2), + datetime(2020, 1, 2, tzinfo=stdlib_timezone.utc), + "2020-01-02", + "2020-01-02 00:00:00+00:00", + "2020-01-02 00:00:00.001+00:00", + "2020-01-02T00:00:00+00:00", + "2020-01-02T00:00:00Z", + "2020-01-02T00:00:00+04:00", + "2020-01-02T00:00:00.000001+04:00", + ), + ) @patch("airflow.providers.sftp.sensors.sftp.SFTPHook") - def test_naive_datetime(self, sftp_hook_mock): + def test_multiple_datetime_format_in_newer_than(self, sftp_hook_mock, newer_than): sftp_hook_mock.return_value.get_mod_time.return_value = "19700101000000" sftp_sensor = SFTPSensor( - task_id="unit_test", path="/path/to/file/1970-01-01.txt", newer_than=datetime(2020, 1, 2) + task_id="unit_test", path="/path/to/file/1970-01-01.txt", newer_than=newer_than ) context = {"ds": "1970-01-00"} output = sftp_sensor.poke(context)