From 962b41e109f0a44d9f4005ed944669b32cdcce38 Mon Sep 17 00:00:00 2001 From: hussein-awala Date: Sat, 16 Dec 2023 20:36:24 +0100 Subject: [PATCH 1/3] Some improvements for Airflow IO code --- airflow/io/path.py | 3 ++- airflow/io/store/__init__.py | 19 ++++++++----------- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/airflow/io/path.py b/airflow/io/path.py index ed5ab664ebe59f..a3b98819cd16bc 100644 --- a/airflow/io/path.py +++ b/airflow/io/path.py @@ -52,6 +52,7 @@ def __init__( conn_id: str | None = None, **kwargs: typing.Any, ) -> None: + super().__init__(**kwargs) if parsed_url and parsed_url.scheme: self._store = attach(parsed_url.scheme, conn_id) else: @@ -147,7 +148,7 @@ def __new__( conn_id = conn_id or userinfo or None parsed_url = parsed_url._replace(netloc=hostinfo) - return cls._from_parts(args_list, url=parsed_url, conn_id=conn_id, **kwargs) # type: ignore + return cls._from_parts(args_list, url=parsed_url, conn_id=conn_id, **kwargs) @functools.lru_cache def __hash__(self) -> int: diff --git a/airflow/io/store/__init__.py b/airflow/io/store/__init__.py index bc9e0cd87dfdd7..6f93190d6511ac 100644 --- a/airflow/io/store/__init__.py +++ b/airflow/io/store/__init__.py @@ -16,6 +16,7 @@ # under the License. from __future__ import annotations +from functools import cached_property from typing import TYPE_CHECKING, ClassVar from airflow.io import get_fs, has_fs @@ -48,15 +49,17 @@ def __init__( ): self.conn_id = conn_id self.protocol = protocol - self._fs = fs + if fs is not None: + self.fs = fs self.storage_options = storage_options def __str__(self): return f"{self.protocol}-{self.conn_id}" if self.conn_id else self.protocol - @property + @cached_property def fs(self) -> AbstractFileSystem: - return self._connect() + # if the fs is provided in init, the next statement will be ignored + return get_fs(self.protocol, self.conn_id) @property def fsid(self) -> str: @@ -68,9 +71,8 @@ def fsid(self) -> str: :return: deterministic the filesystem ID """ - fs = self._connect() try: - return fs.fsid + return self.fs.fsid except NotImplementedError: return f"{self.fs.protocol}-{self.conn_id or 'env'}" @@ -78,7 +80,7 @@ def serialize(self): return { "protocol": self.protocol, "conn_id": self.conn_id, - "filesystem": qualname(self._fs) if self._fs else None, + "filesystem": qualname(self.fs) if self.fs else None, "storage_options": self.storage_options, } @@ -103,11 +105,6 @@ def deserialize(cls, data: dict[str, str], version: int): return attach(protocol=protocol, conn_id=conn_id, storage_options=data["storage_options"]) - def _connect(self) -> AbstractFileSystem: - if self._fs is None: - self._fs = get_fs(self.protocol, self.conn_id) - return self._fs - def __eq__(self, other): return isinstance(other, type(self)) and other.conn_id == self.conn_id and other._fs == self._fs From a0dba36935fa41ae9afe4e111261f70a3fcf69ec Mon Sep 17 00:00:00 2001 From: hussein-awala Date: Mon, 25 Dec 2023 15:15:11 +0100 Subject: [PATCH 2/3] Revert changes on airflow/io/path.py --- airflow/io/path.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/airflow/io/path.py b/airflow/io/path.py index 47cf5aaa8f0dac..bd7c320653aaee 100644 --- a/airflow/io/path.py +++ b/airflow/io/path.py @@ -52,7 +52,6 @@ def __init__( conn_id: str | None = None, **kwargs: typing.Any, ) -> None: - super().__init__(**kwargs) if parsed_url and parsed_url.scheme: self._store = attach(parsed_url.scheme, conn_id) else: @@ -148,7 +147,7 @@ def __new__( conn_id = conn_id or userinfo or None parsed_url = parsed_url._replace(netloc=hostinfo) - return cls._from_parts(args_list, url=parsed_url, conn_id=conn_id, **kwargs) + return cls._from_parts(args_list, url=parsed_url, conn_id=conn_id, **kwargs) # type: ignore @functools.lru_cache def __hash__(self) -> int: From 157174cf96b21916aea313048b37912ab7579cae Mon Sep 17 00:00:00 2001 From: hussein-awala Date: Mon, 25 Dec 2023 16:06:33 +0100 Subject: [PATCH 3/3] Some refactor --- airflow/io/store/__init__.py | 11 ++++++++--- tests/io/test_path.py | 2 +- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/airflow/io/store/__init__.py b/airflow/io/store/__init__.py index 6f93190d6511ac..d1e897c5de01cc 100644 --- a/airflow/io/store/__init__.py +++ b/airflow/io/store/__init__.py @@ -16,6 +16,7 @@ # under the License. from __future__ import annotations +from contextlib import suppress from functools import cached_property from typing import TYPE_CHECKING, ClassVar @@ -38,8 +39,6 @@ class ObjectStore: protocol: str storage_options: Properties | None - _fs: AbstractFileSystem | None = None - def __init__( self, protocol: str, @@ -106,7 +105,13 @@ def deserialize(cls, data: dict[str, str], version: int): return attach(protocol=protocol, conn_id=conn_id, storage_options=data["storage_options"]) def __eq__(self, other): - return isinstance(other, type(self)) and other.conn_id == self.conn_id and other._fs == self._fs + self_fs = None + other_fs = None + with suppress(ValueError): + self_fs = self.fs + with suppress(ValueError): + other_fs = other.fs + return isinstance(other, type(self)) and other.conn_id == self.conn_id and self_fs == other_fs _STORE_CACHE: dict[str, ObjectStore] = {} diff --git a/tests/io/test_path.py b/tests/io/test_path.py index 7c97c7b2b912bd..af6044c1501925 100644 --- a/tests/io/test_path.py +++ b/tests/io/test_path.py @@ -279,7 +279,7 @@ def test_serde_store(self): assert s["protocol"] == "file" assert s["conn_id"] == "mock" - assert s["filesystem"] is None + assert s["filesystem"] == qualname(LocalFileSystem) assert store == d store = attach("localfs", fs=LocalFileSystem())