diff --git a/tests/fixtures.py b/tests/fixtures.py index 838268ea..8113ccde 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -144,37 +144,6 @@ dims=["time", "lat", "lon"], ) -# Special cases, after copying bounds from parent Dataset to data variable -# using xCDAT. -ts_with_bnds_from_parent_cf = xr.DataArray( - name="ts", - data=np.ones((2, 12, 4, 4)), - coords={ - "bnds": np.array([0, 1]), - "time": time_cf.assign_attrs(bounds="time_bnds"), - "lat": lat.assign_attrs(bounds="lat_bnds"), - "lon": lon.assign_attrs(bounds="lon_bnds"), - "lat_bnds": lat_bnds.copy(), - "lon_bnds": lon_bnds.copy(), - "time_bnds": time_bnds.copy(), - }, - dims=["bnds", "time", "lat", "lon"], -) -ts_with_bnds_from_parent_non_cf = xr.DataArray( - name="ts", - data=np.ones((2, 12, 4, 4)), - coords={ - "bnds": np.array([0, 1]), - "time": time_cf.copy(), - "lat": lat.copy(), - "lon": lon.copy(), - "lat_bnds": lat_bnds, - "lon_bnds": lon_bnds, - "time_bnds": time_bnds_non_cf, - }, - dims=["bnds", "time", "lat", "lon"], -) - def generate_dataset(cf_compliant: bool, has_bounds: bool) -> xr.Dataset: """Generates a dataset using coordinate and data variable fixtures. diff --git a/tests/test_bounds.py b/tests/test_bounds.py index 702b0f89..8a5bf659 100644 --- a/tests/test_bounds.py +++ b/tests/test_bounds.py @@ -2,14 +2,7 @@ import pytest import xarray as xr -from tests.fixtures import ( - generate_dataset, - lat_bnds, - lon_bnds, - time_bnds, - ts_cf, - ts_with_bnds_from_parent_cf, -) +from tests.fixtures import generate_dataset, lat_bnds, lon_bnds, time_bnds, ts_cf from xcdat.bounds import DataArrayBoundsAccessor, DatasetBoundsAccessor @@ -29,19 +22,20 @@ def test_decorator_call(self): def test_bounds_property_returns_map_of_coordinate_key_to_bounds_dataarray(self): ds = self.ds_with_bnds.copy() expected = { - "T": ds.time, - "X": ds.lon, - "Y": ds.lat, - "lat": ds.lat, - "latitude": ds.lat, - "lon": ds.lon, - "longitude": ds.lon, - "time": ds.time, + "T": ds.time_bnds, + "X": ds.lon_bnds, + "Y": ds.lat_bnds, + "lat": ds.lat_bnds, + "latitude": ds.lat_bnds, + "lon": ds.lon_bnds, + "longitude": ds.lon_bnds, + "time": ds.time_bnds, } result = ds.bounds.bounds - assert result == expected + for key in expected.keys(): + assert result[key].identical(expected[key]) def test_fill_missing_returns_dataset_with_filled_bounds(self): ds = self.ds_with_bnds.copy() @@ -141,6 +135,7 @@ class TestDataArrayBoundsAccessor: @pytest.fixture(autouse=True) def setUp(self): self.ds = generate_dataset(cf_compliant=True, has_bounds=True) + self.ts = ts_cf.copy() def test__init__(self): obj = DataArrayBoundsAccessor(self.ds.ts) @@ -151,31 +146,33 @@ def test_decorator_call(self): result = self.ds["ts"].bounds._dataarray assert result.identical(expected) - def test__copy_from_parent_copies_bounds(self): - ts_expected = ts_with_bnds_from_parent_cf.copy() - ts_result = ts_cf.copy().bounds.copy_from_parent(self.ds) - - assert ts_result.identical(ts_expected) + def test_copy_returns_exact_copy_with_attrs(self): + ds_bounds = self.ds.drop_vars("ts") + ts = self.ts.copy() - def test__set_bounds_dim_adds_bnds(self): - ts_expected = ts_cf.copy().expand_dims(bnds=np.array([0, 1])) - ts_result = ts_cf.copy().bounds._set_bounds_dim(self.ds) + ts_result = ts.bounds.copy_from_parent(self.ds) + assert ts_result.bounds._bounds.identical(ds_bounds) - assert ts_result.identical(ts_expected) + ts_result_copy = ts_result.bounds.copy() + assert ts_result_copy.bounds._dataarray.identical(ts_result.bounds._dataarray) + assert ts_result_copy.bounds._bounds.identical(ts_result.bounds._bounds) - def test__set_bounds_dim_adds_bounds(self): - ds = self.ds.swap_dims({"bnds": "bounds"}).copy() + def test__copy_from_parent_copies_bounds(self): + ds_bounds = self.ds.drop_vars("ts") + ts = self.ts.copy() - ts_expected = ts_cf.copy().expand_dims(bounds=np.array([0, 1])) - ts_result = ts_cf.copy().bounds._set_bounds_dim(ds) - assert ts_result.identical(ts_expected) + ts_result = ts.bounds.copy_from_parent(self.ds) + assert ts_result.bounds._bounds.identical(ds_bounds) - def test_bounds_property_returns_mapping_of_coordinate_keys_to_bounds_dataarrays( + def test_bounds_property_returns_mapping_of_keys_to_bounds_dataarrays( self, ): - ts = ts_with_bnds_from_parent_cf.copy() + ds = self.ds + ts = self.ts.copy() + ts = ts.bounds.copy_from_parent(ds) + result = ts.bounds.bounds - expected = {"time": ts.time_bnds, "lat": ts.lat_bnds, "lon": ts.lon_bnds} + expected = {"time": ds.time_bnds, "lat": ds.lat_bnds, "lon": ds.lon_bnds} for key in expected.keys(): assert result[key].identical(expected[key]) @@ -183,33 +180,42 @@ def test_bounds_property_returns_mapping_of_coordinate_keys_to_bounds_dataarrays def test_bounds_names_property_returns_mapping_of_coordinate_keys_to_names_of_bounds( self, ): - ts = ts_with_bnds_from_parent_cf.copy() + ts = ts_cf.copy() + ts = ts.bounds.copy_from_parent(self.ds) + result = ts.bounds.bounds_names - expected = {"time": "time_bnds", "lat": "lat_bnds", "lon": "lon_bnds"} + expected = { + "lat": ["lat_bnds"], + "Y": ["lat_bnds"], + "latitude": ["lat_bnds"], + "longitude": ["lon_bnds"], + "lon": ["lon_bnds"], + "time": ["time_bnds"], + "T": ["time_bnds"], + "X": ["lon_bnds"], + } assert result == expected + def test__check_bounds_are_set_raises_error_if_not_set( + self, + ): + ts = self.ts.copy() + + with pytest.raises(ValueError): + ts.bounds.bounds + def test_get_bounds_returns_bounds_dataarray_for_coordinate_key(self): - ts = ts_with_bnds_from_parent_cf.copy() + ts = self.ts.copy() + ts = ts.bounds.copy_from_parent(self.ds) + result = ts.bounds.get_bounds("lat") - expected = ts.lat_bnds + expected = self.ds.lat_bnds assert result.identical(expected) - def test_get_bounds_returns_keyerror_with_incorrect_key(self): - ts = ts_with_bnds_from_parent_cf.copy() + def test_get_bounds_raises_error_with_incorrect_key(self): + ts = self.ts.copy() + ts = ts.bounds.copy_from_parent(self.ds) - with pytest.raises(KeyError): + with pytest.raises(ValueError): ts.bounds.get_bounds("something") - - def test_get_bounds_dim_name_returns_dim_of_coordinate_bounds(self): - ts = ts_with_bnds_from_parent_cf.copy() - result = ts.bounds.get_bounds_dim_name("lat") - expected = "bnds" - - assert result == expected - - def test_get_bounds_dim_name_raises_keyerror_with_incorrect_key(self): - ts = ts_with_bnds_from_parent_cf.copy() - - with pytest.raises(KeyError): - ts.bounds.get_bounds_dim_name("something") diff --git a/tests/test_variable.py b/tests/test_variable.py index 289705df..41c86527 100644 --- a/tests/test_variable.py +++ b/tests/test_variable.py @@ -1,7 +1,7 @@ import pytest -from tests.fixtures import generate_dataset, ts_with_bnds_from_parent_cf -from xcdat.variable import open_variable +from tests.fixtures import generate_dataset +from xcdat.variable import copy_variable, open_variable class TestOpenVariable: @@ -11,12 +11,6 @@ def test_raises_error_if_variable_does_not_exist(self): with pytest.raises(KeyError): open_variable(ds, "invalid_var") - def test_raises_error_if_bounds_dim_is_missing(self): - ds = generate_dataset(cf_compliant=True, has_bounds=False) - - with pytest.raises(KeyError): - open_variable(ds, "ts") - def test_raises_error_if_bounds_are_missing_for_coordinates(self): ds = generate_dataset(cf_compliant=True, has_bounds=True) @@ -31,11 +25,46 @@ def test_raises_error_if_bounds_are_missing_for_coordinates(self): def test_returns_variable_with_bounds(self): ds = generate_dataset(cf_compliant=True, has_bounds=True) - ds.lat.attrs["bounds"] = "lat_bnds" - ds.lon.attrs["bounds"] = "lon_bnds" - ds.time.attrs["bounds"] = "time_bnds" + expected = { + "T": ds.time_bnds, + "X": ds.lon_bnds, + "Y": ds.lat_bnds, + "lat": ds.lat_bnds, + "latitude": ds.lat_bnds, + "lon": ds.lon_bnds, + "longitude": ds.lon_bnds, + "time": ds.time_bnds, + } + + ts = open_variable(ds, "ts") + result = ts.bounds.bounds + + for key in expected.keys(): + assert result[key].identical(expected[key]) + + +class TestCopyVariable: + def test_returns_copy_of_variable_with_bounds(self): + ds = generate_dataset(cf_compliant=True, has_bounds=True) + expected = { + "T": ds.time_bnds, + "X": ds.lon_bnds, + "Y": ds.lat_bnds, + "lat": ds.lat_bnds, + "latitude": ds.lat_bnds, + "lon": ds.lon_bnds, + "longitude": ds.lon_bnds, + "time": ds.time_bnds, + } + + ts = open_variable(ds, "ts") + ts_bounds = ts.bounds.bounds + + for key in expected.keys(): + assert ts_bounds[key].identical(expected[key]) - ts_expected = ts_with_bnds_from_parent_cf.copy() + ts_copy = copy_variable(ts) + ts_copy_bounds = ts_copy.bounds.bounds - ts_result = open_variable(ds, "ts") - assert ts_result.identical(ts_expected) + for key in expected.keys(): + assert ts_copy_bounds[key].identical(ts_bounds[key]) diff --git a/xcdat/bounds.py b/xcdat/bounds.py index 1cc030e3..af0f8ef0 100644 --- a/xcdat/bounds.py +++ b/xcdat/bounds.py @@ -5,7 +5,6 @@ import cf_xarray as cfxr # noqa: F401 import numpy as np import xarray as xr -from cf_xarray.accessor import _get_bounds, _single, _variables, apply_mapper from typing_extensions import Literal from xcdat.logger import setup_custom_logger @@ -73,7 +72,8 @@ def bounds(self) -> Dict[str, Optional[xr.DataArray]]: bounds: Dict[str, Optional[xr.DataArray]] = {} for coord, bounds_name in ds.cf.bounds.items(): - bounds[coord] = ds.get(bounds_name, None) + bound = ds.get(bounds_name[0], None) + bounds[coord] = bound return collections.OrderedDict(sorted(bounds.items())) @@ -112,7 +112,7 @@ def get_bounds(self, coord: Coord) -> xr.DataArray: If an incorrect ``coord`` argument is passed. ValueError - If bounds were not found in the dataset. They must be added. + If bounds were not found. They must be added. """ if coord not in SUPPORTED_COORDS: raise ValueError( @@ -123,9 +123,7 @@ def get_bounds(self, coord: Coord) -> xr.DataArray: try: bounds = self._dataset.cf.get_bounds(coord) except KeyError: - raise KeyError( - f"{coord} bounds were not found in the dataset, they must be added." - ) + raise KeyError(f"{coord} bounds were not found, they must be added.") return bounds @@ -283,28 +281,40 @@ class DataArrayBoundsAccessor: >>> tas = ds["tas"] >>> tas.bounds._copy_from_dataset(ds) - Return dictionary of coordinate keys mapped to bounds DataArrays: + Return dictionary of axis and coordinate keys mapped to bounds DataArrays: >>> tas.bounds.bounds Return dictionary of coordinate keys mapped to bounds names: >>> tas.bounds.bounds_names + """ - Get bounds for a coordinate key: - - >>> tas.bounds.get_bounds("lat") + def __init__(self, dataarray: xr.DataArray): + self._dataarray = dataarray - Get name of bounds dimension: + # A Dataset container to store the bounds from the parent Dataset. + # A Dataset is used instead of a dictionary so that it is interoperable + # with cf_xarray and the DatasetBoundsAccessor class. This allows for + # the dynamic generation of a bounds dict that is name-agnostic (e.g., + # "lat", "latitude", "Y" for latitude bounds). Refer to the ``bounds`` + # class property below for more information. + self._bounds = xr.Dataset() - >>> tas.bounds.get_bounds_dim_name("lat") + def copy(self) -> xr.DataArray: + """Copies the DataArray while maintaining accessor class attributes. - """ + This method is invoked when a copy of a variable is made through + ``xcdat.variable.copy_variable()``. - def __init__(self, dataarray: xr.DataArray): - self._dataarray = dataarray + Returns + ------- + xr.DataArray + The DataArray within the accessor class. + """ + return self._dataarray - def copy_from_parent(self, dataset: xr.Dataset) -> xr.DataArray: + def copy_from_parent(self, dataset: xr.Dataset): """Copies coordinate bounds from the parent Dataset to the DataArray. In an xarray.Dataset, variables (e.g., "tas") and coordinate bounds @@ -332,20 +342,15 @@ def copy_from_parent(self, dataset: xr.Dataset) -> xr.DataArray: .. [3] https://github.com/pydata/xarray/issues/1475 """ - da = self._dataarray.copy() - - # The bounds dimension must be set before adding bounds to the DataArray - # coordinates, otherwise the error below is thrown: - # "ValueError: cannot add coordinates with new dimensions to a DataArray" - da = self._set_bounds_dim(dataset) + bounds = self._bounds.copy() coords = [*dataset.coords] boundless_coords = [] for coord in coords: if coord in SUPPORTED_COORDS: try: - bounds = dataset.cf.get_bounds(coord) - da[bounds.name] = bounds.copy() + coord_bounds = dataset.cf.get_bounds(coord) + bounds[coord_bounds.name] = coord_bounds.copy() except KeyError: boundless_coords.append(coord) @@ -356,56 +361,7 @@ def copy_from_parent(self, dataset: xr.Dataset) -> xr.DataArray: "`xcdat.dataset.open_dataset` to auto-generate missing bounds first" ) - self._dataarray = da - return self._dataarray - - def _set_bounds_dim(self, dataset: xr.Dataset) -> xr.DataArray: - """ - Sets the bounds dimension(s) in the DataArray based on the dims of the - parent Dataset. - - This function uses the "bounds" attribute of each coordinate to map to - the bounds, then extracts the dimension from each bounds. - - Parameters - ---------- - dataset : xr.Dataset - The parent Dataset. - - Returns - ------- - xr.DataArray - The data variable with a bounds dimension. - - Raises - ------ - KeyError - When no bounds dimension exists in the parent Dataset. - """ - da = self._dataarray.copy() - coords = dataset.cf.coordinates.keys() - - dims = set() - for coord in coords: - try: - dims.add(dataset.cf.get_bounds_dim_name(coord)) - except KeyError: - logger.warning( - f"{coord} has no bounds, or the `bounds` attribute is missing to " - "link to the bounds." - ) - - if len(dims) == 0: - raise KeyError( - "No bounds dimension in the parent dataset, which indicates that there " - "are probably no coordinate bounds. Try passing the dataset to " - "`xcdat.dataset.open_dataset` to auto-generate them." - ) - - for dim in dims: - da = da.expand_dims(dim={dim: np.array([0, 1])}) - - self._dataarray = da + self._bounds = bounds return self._dataarray @property @@ -423,40 +379,31 @@ def bounds(self) -> Dict[str, Optional[xr.DataArray]]: ----- Based on ``cf_xarray.accessor.CFDatasetAccessor.bounds``. """ - da = self._dataarray - bounds: Dict[str, Optional[xr.DataArray]] = {} - for coord, bounds_name in self.bounds_names.items(): - bounds[coord] = da.coords.get(bounds_name, None) # type: ignore + self._check_bounds_are_set() - return bounds + bounds = DatasetBoundsAccessor(self._bounds) + return bounds.bounds @property def bounds_names(self) -> Dict[Hashable, List[str]]: """Returns a mapping of coordinate keys to the name of their bounds. - Missing coordinates are handled by ``self.copy_from_parent()``. + Wrapper for ``cf_xarray.accessor.CFDatasetAccessor.bounds_names``. Returns ------- Dict[Hashable, List[str]] Dictionary mapping valid keys to the variable names of their bounds. - - Notes - ----- - Based on ``cf_xarray.accessor.CFDatasetAccessor.bounds_names``. """ - da = self._dataarray - keys = da.coords - - vardict = {key: apply_mapper(_get_bounds, da, key, error=False) for key in keys} + self._check_bounds_are_set() - # Each coord should have only one bound, thus select the first index - # TODO: Handle when there is more than one bounds per coordinate. - return {k: v[0] for k, v in vardict.items() if v} + return self._bounds.cf.bounds def get_bounds(self, coord: Coord) -> xr.DataArray: """Get bounds corresponding to a coordinate key. + Wrapper for ``cf_xarray.accessor.CFDatasetAccessor.get_bounds``. + Parameters ---------- coord : Coord @@ -466,35 +413,15 @@ def get_bounds(self, coord: Coord) -> xr.DataArray: ------- DataArray The bounds for a coordinate key. - - Notes - ----- - Based on ``cf_xarray.accessor.CFDatasetAccessor.get_bounds``. """ - # TODO: Handle when there is more than bounds dimension - return apply_mapper(_variables(_single(_get_bounds)), self._dataarray, coord)[0] - - def get_bounds_dim_name(self, coord: Coord) -> Hashable: - """Get bounds dimension name corresponding to coordinate key. + self._check_bounds_are_set() - Parameters - ---------- - coord : Coord - The coordinate key whose bounds dimension is desired. + bounds = DatasetBoundsAccessor(self._bounds) + return bounds.get_bounds(coord) - Returns - ------- - Hashable - The bounds dimension name. - - Notes - ----- - Based on ``cf_xarray.accessor.CFDatasetAccessor.get_bounds_dim_name``. - """ - crd = self._dataarray[coord] - bounds = self.get_bounds(coord) - bounds_dims = set(bounds.dims) - set(crd.dims) - assert len(bounds_dims) == 1 - bounds_dim = bounds_dims.pop() - assert self._dataarray.sizes[bounds_dim] in [2, 4] - return bounds_dim + def _check_bounds_are_set(self): + if len(self._bounds) == 0: + raise ValueError( + "Variable bounds are not set. Copy them from the parent Dataset using " + "`.bounds.copy_from_parent()`" + ) diff --git a/xcdat/variable.py b/xcdat/variable.py index b2e970de..988130cb 100644 --- a/xcdat/variable.py +++ b/xcdat/variable.py @@ -63,3 +63,46 @@ def open_variable(dataset: xr.Dataset, name: str) -> xr.DataArray: data_var = data_var.bounds.copy_from_parent(dataset) return data_var + + +def copy_variable(var: xr.DataArray) -> xr.DataArray: + """Returns a copy of a variable with all accessor class attributes. + + If a copy of a variable is created by invoking the native + `xarray.DataArray.copy()`, all cached accessor attributes will be dropped. + This is an unfortunate issue that can compromise data integrity. Refer to + https://github.com/pydata/xarray/issues/3268 for more information. + + This function works around this issue by invoking each accessor class' + custom `.copy()` method, which just returns the object within the scope of + the accessor class to retain attributes. + + Parameters + ---------- + var : xr.DataArray + The variable being copied. + + Returns + ------- + xr.DataArray + The copied variable with accessor attributes. + + Examples + -------- + Import: + + >>> from xcdat.dataset import open_dataset + >>> from xcdat.variable import copy_variable, open_variable + + Open a variable: + + >>> ds = open_dataset("file_path") + >>> ts = open_variable(ds, "ts") + + Copying a variable: + + >>> ts_copy = copy_variable(ts) + """ + var_copy = var.bounds.copy() + + return var_copy