diff --git a/CHANGELOG.md b/CHANGELOG.md index 0e83c38..4a255e3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ - Fix file listing when using additional flags [#9](https://github.com/observingClouds/ecmwfspec/issues/9) - Fix incorrect caching if cache directory is given as absolute path [#10](https://github.com/observingClouds/ecmwfspec/issues/10) - Fix some warnings in test suite. [#11](https://github.com/observingClouds/ecmwfspec/issues/11) +- Fix ec_cache for `ectmp` file protocol [#12](https://github.com/observingClouds/ecmwfspec/issues/12) ## 0.0.1 - Initial release diff --git a/ecmwfspec/core.py b/ecmwfspec/core.py index 43b238b..58f634a 100644 --- a/ecmwfspec/core.py +++ b/ecmwfspec/core.py @@ -419,8 +419,11 @@ def _open( cache_options: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> ECFile: - path = "TMP" / Path(self._strip_protocol(path)) - local_path = self.ec_cache.joinpath(*path.parts) + if isinstance(path, Path): + path = "TMP" / Path(self._strip_protocol(path)).relative_to(path.anchor) + elif isinstance(path, str): + path = Path("TMP/" + path) + local_path = Path(os.path.join(self.ec_cache, path.relative_to(path.anchor))) return ECFile( str(path), str(local_path), diff --git a/ecmwfspec/tests/conftest.py b/ecmwfspec/tests/conftest.py index f45a950..17b4e4b 100644 --- a/ecmwfspec/tests/conftest.py +++ b/ecmwfspec/tests/conftest.py @@ -114,6 +114,24 @@ def ec_retrieve(self, search_id: int, out_dir: str, preserve_path: bool) -> None shutil.copy(inp_file, Path(out_dir) / inp_file.name) +class ECTMPMock(ECMock): + """A mock that emulates what ecfs is doing for temporary directories.""" + + def __init__(self, _cache: dict[int, builtins.list[str]] = {}) -> None: + super().__init__(_cache) + + def cp(self, inp_path: str, out_path: str) -> None: + """Mock the ecp method.""" + inp_path = inp_path.replace("ectmp:", "TMP") + os.makedirs(os.path.dirname(out_path), exist_ok=True) + _ = ( + run(["cp", inp_path, out_path], stdout=PIPE, stderr=PIPE) + .stdout.decode() + .split("\n") + ) + return + + def create_data(variable_name: str, size: int) -> xr.Dataset: """Create a xarray dataset.""" coords: dict[str, np.ndarray] = {} @@ -153,6 +171,13 @@ def patch_dir() -> Generator[Path, None, None]: yield Path(temp_dir) +@pytest.fixture(scope="session") +def patch_ectmp_dir() -> Generator[Path, None, None]: + with TemporaryDirectory() as temp_dir: + with mock.patch("ecmwfspec.core.ecfs", ECTMPMock()): + yield Path(temp_dir) + + @pytest.fixture(scope="session") def save_dir() -> Generator[Path, None, None]: """Create a temporary directory.""" diff --git a/ecmwfspec/tests/test_open_dataset.py b/ecmwfspec/tests/test_open_dataset.py index 6ca4cc0..f0484c8 100644 --- a/ecmwfspec/tests/test_open_dataset.py +++ b/ecmwfspec/tests/test_open_dataset.py @@ -16,6 +16,13 @@ from ecmwfspec import xr_accessor # noqa: F401 +def test_protocols() -> None: + """Test that fsspec protocols are registered.""" + protocols = fsspec.available_protocols() + assert "ec" in protocols, f"ec not found in {protocols}" + assert "ectmp" in protocols, f"ectmp not found in {protocols}" + + def test_xr_accessor(patch_dir: Path, zarr_file: Path) -> None: """Test staging.""" zarr_file1 = [*zarr_file.rglob("*.zarr")][0] @@ -137,6 +144,50 @@ def test_ro_mode(patch_dir: Path) -> None: assert url.writable() is False +def test_ectmp(patch_ectmp_dir: Path) -> None: + """Check if ectmp access works.""" + import fsspec + + with TemporaryDirectory() as temp_dir: + inp_file = Path(temp_dir) / "foo.txt" + write_file = (patch_ectmp_dir / "TMP").joinpath(*inp_file.parts[1:]) + write_file.parent.mkdir(exist_ok=True, parents=True) + print(write_file) + with write_file.open("w") as f_obj: + f_obj.write("foo") + url = fsspec.open( + f"ectmp:///{inp_file}", + ec_cache=patch_ectmp_dir, + override=False, + mode="rt", + ).open() + assert Path(url.name) == write_file + assert url.tell() == 0 + assert url.read() == "foo" + + +def test_ectmp_strpath(patch_ectmp_dir: Path) -> None: + """Check if ectmp access works.""" + import fsspec + + with TemporaryDirectory() as temp_dir: + inp_file = Path(temp_dir) / "foo.txt" + write_file = (patch_ectmp_dir / "TMP").joinpath(*inp_file.parts[1:]) + write_file.parent.mkdir(exist_ok=True, parents=True) + print(write_file) + with write_file.open("w") as f_obj: + f_obj.write("foo") + url = fsspec.open( + f"ectmp:///{inp_file}", + ec_cache=str(patch_ectmp_dir), + override=False, + mode="rt", + ).open() + assert Path(url.name) == write_file + assert url.tell() == 0 + assert url.read() == "foo" + + def test_list_files(patch_dir: Path, netcdf_files: Path) -> None: """Test listing the files.""" import fsspec