diff --git a/airflow/io/store/__init__.py b/airflow/io/store/__init__.py index bc9e0cd87dfdd7..d1e897c5de01cc 100644 --- a/airflow/io/store/__init__.py +++ b/airflow/io/store/__init__.py @@ -16,6 +16,8 @@ # under the License. from __future__ import annotations +from contextlib import suppress +from functools import cached_property from typing import TYPE_CHECKING, ClassVar from airflow.io import get_fs, has_fs @@ -37,8 +39,6 @@ class ObjectStore: protocol: str storage_options: Properties | None - _fs: AbstractFileSystem | None = None - def __init__( self, protocol: str, @@ -48,15 +48,17 @@ def __init__( ): self.conn_id = conn_id self.protocol = protocol - self._fs = fs + if fs is not None: + self.fs = fs self.storage_options = storage_options def __str__(self): return f"{self.protocol}-{self.conn_id}" if self.conn_id else self.protocol - @property + @cached_property def fs(self) -> AbstractFileSystem: - return self._connect() + # if the fs is provided in init, the next statement will be ignored + return get_fs(self.protocol, self.conn_id) @property def fsid(self) -> str: @@ -68,9 +70,8 @@ def fsid(self) -> str: :return: deterministic the filesystem ID """ - fs = self._connect() try: - return fs.fsid + return self.fs.fsid except NotImplementedError: return f"{self.fs.protocol}-{self.conn_id or 'env'}" @@ -78,7 +79,7 @@ def serialize(self): return { "protocol": self.protocol, "conn_id": self.conn_id, - "filesystem": qualname(self._fs) if self._fs else None, + "filesystem": qualname(self.fs) if self.fs else None, "storage_options": self.storage_options, } @@ -103,13 +104,14 @@ def deserialize(cls, data: dict[str, str], version: int): return attach(protocol=protocol, conn_id=conn_id, storage_options=data["storage_options"]) - def _connect(self) -> AbstractFileSystem: - if self._fs is None: - self._fs = get_fs(self.protocol, self.conn_id) - return self._fs - def __eq__(self, other): - return isinstance(other, type(self)) and other.conn_id == self.conn_id and other._fs == self._fs + self_fs = None + other_fs = None + with suppress(ValueError): + self_fs = self.fs + with suppress(ValueError): + other_fs = other.fs + return isinstance(other, type(self)) and other.conn_id == self.conn_id and self_fs == other_fs _STORE_CACHE: dict[str, ObjectStore] = {} diff --git a/tests/io/test_path.py b/tests/io/test_path.py index 7c97c7b2b912bd..af6044c1501925 100644 --- a/tests/io/test_path.py +++ b/tests/io/test_path.py @@ -279,7 +279,7 @@ def test_serde_store(self): assert s["protocol"] == "file" assert s["conn_id"] == "mock" - assert s["filesystem"] is None + assert s["filesystem"] == qualname(LocalFileSystem) assert store == d store = attach("localfs", fs=LocalFileSystem())