diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index f8a8a259c1f..2cb1550c088 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -7,12 +7,15 @@ import xarray as xr from xarray.core import formatting from xarray.core.npcompat import IS_NEP18_ACTIVE +from .test_variable import VariableSubclassobjects pint = pytest.importorskip("pint") DimensionalityError = pint.errors.DimensionalityError -unit_registry = pint.UnitRegistry() +# make sure scalars are converted to 0d arrays so quantities can +# always be treated like ndarrays +unit_registry = pint.UnitRegistry(force_ndarray=True) Quantity = unit_registry.Quantity pytestmark = [ @@ -28,6 +31,25 @@ ] +def is_compatible(unit1, unit2): + def dimensionality(obj): + if isinstance(obj, (unit_registry.Quantity, unit_registry.Unit)): + unit_like = obj + else: + unit_like = unit_registry.dimensionless + + return unit_like.dimensionality + + return dimensionality(unit1) == dimensionality(unit2) + + +def compatible_mappings(first, second): + return { + key: is_compatible(unit1, unit2) + for key, (unit1, unit2) in merge_mappings(first, second) + } + + def array_extract_units(obj): if isinstance(obj, (xr.Variable, xr.DataArray, xr.Dataset)): obj = obj.data @@ -113,6 +135,10 @@ def extract_units(obj): } units = {**vars_units, **coords_units} + elif isinstance(obj, xr.Variable): + vars_units = {None: array_extract_units(obj.data)} + + units = {**vars_units} elif isinstance(obj, Quantity): vars_units = {None: array_extract_units(obj)} @@ -148,6 +174,9 @@ def strip_units(obj): new_obj = xr.DataArray( name=strip_units(obj.name), data=data, coords=coords, dims=obj.dims ) + elif isinstance(obj, xr.Variable): + data = array_strip_units(obj.data) + new_obj = obj.copy(data=data) elif isinstance(obj, unit_registry.Quantity): new_obj = obj.magnitude elif isinstance(obj, (list, tuple)): @@ -159,8 +188,9 @@ def strip_units(obj): def attach_units(obj, units): - if not isinstance(obj, (xr.DataArray, xr.Dataset)): - return array_attach_units(obj, units.get("data", 1)) + if not isinstance(obj, (xr.DataArray, xr.Dataset, xr.Variable)): + units = units.get("data", None) or units.get(None, None) or 1 + return array_attach_units(obj, units) if isinstance(obj, xr.Dataset): data_vars = { @@ -172,7 +202,7 @@ def attach_units(obj, units): } new_obj = xr.Dataset(data_vars=data_vars, coords=coords, attrs=obj.attrs) - else: + elif isinstance(obj, xr.DataArray): # try the array name, "data" and None, then fall back to dimensionless data_units = ( units.get(obj.name, None) @@ -198,6 +228,11 @@ def attach_units(obj, units): new_obj = xr.DataArray( name=obj.name, data=data, coords=coords, attrs=attrs, dims=dims ) + else: + data_units = units.get("data", None) or units.get(None, None) or 1 + + data = array_attach_units(obj.data, data_units) + new_obj = obj.copy(data=data) return new_obj @@ -243,6 +278,10 @@ def convert_units(obj, to): return new_obj +def assert_units_equal(a, b): + assert extract_units(a) == extract_units(b) + + def assert_equal_with_units(a, b): # works like xr.testing.assert_equal, but also explicitly checks units # so, it is more like assert_identical @@ -282,7 +321,27 @@ def dtype(request): return request.param +def merge_mappings(*mappings): + for key in set(mappings[0]).intersection(*mappings[1:]): + yield key, tuple(m[key] for m in mappings) + + +def merge_args(default_args, new_args): + from itertools import zip_longest + + fill_value = object() + return [ + second if second is not fill_value else first + for first, second in zip_longest(default_args, new_args, fillvalue=fill_value) + ] + + class method: + """ wrapper class to help with passing methods via parametrize + + This is works a bit similar to using `partial(Class.method, arg, kwarg)` + """ + def __init__(self, name, *args, **kwargs): self.name = name self.args = args @@ -292,7 +351,7 @@ def __call__(self, obj, *args, **kwargs): from collections.abc import Callable from functools import partial - all_args = list(self.args) + list(args) + all_args = merge_args(self.args, args) all_kwargs = {**self.kwargs, **kwargs} func = getattr(obj, self.name, None) @@ -301,7 +360,7 @@ def __call__(self, obj, *args, **kwargs): if not isinstance(obj, (xr.Variable, xr.DataArray, xr.Dataset)): numpy_func = getattr(np, self.name) func = partial(numpy_func, obj) - # remove typical xr args like "dim" + # remove typical xarray args like "dim" exclude_kwargs = ("dim", "dims") all_kwargs = { key: value @@ -318,12 +377,21 @@ def __repr__(self): class function: - def __init__(self, name_or_function, *args, **kwargs): + """ wrapper class for numpy functions + + Same as method, but the name is used for referencing numpy functions + """ + + def __init__(self, name_or_function, *args, function_label=None, **kwargs): if callable(name_or_function): - self.name = name_or_function.__name__ + self.name = ( + function_label + if function_label is not None + else name_or_function.__name__ + ) self.func = name_or_function else: - self.name = name_or_function + self.name = name_or_function if function_label is None else function_label self.func = getattr(np, name_or_function) if self.func is None: raise AttributeError( @@ -334,7 +402,7 @@ def __init__(self, name_or_function, *args, **kwargs): self.kwargs = kwargs def __call__(self, *args, **kwargs): - all_args = list(self.args) + list(args) + all_args = merge_args(self.args, args) all_kwargs = {**self.kwargs, **kwargs} return self.func(*all_args, **all_kwargs) @@ -343,6 +411,7 @@ def __repr__(self): return f"function_{self.name}" +@pytest.mark.xfail(reason="test bug: apply_ufunc should not be called that way") def test_apply_ufunc_dataarray(dtype): func = function( xr.apply_ufunc, np.mean, input_core_dims=[["x"]], kwargs={"axis": -1} @@ -358,6 +427,7 @@ def test_apply_ufunc_dataarray(dtype): assert_equal_with_units(expected, actual) +@pytest.mark.xfail(reason="test bug: apply_ufunc should not be called that way") def test_apply_ufunc_dataset(dtype): func = function( xr.apply_ufunc, np.mean, input_core_dims=[["x"]], kwargs={"axis": -1} @@ -1245,6 +1315,769 @@ def test_dot_dataarray(dtype): assert_equal_with_units(expected, actual) +def delete_attrs(*to_delete): + def wrapper(cls): + for item in to_delete: + setattr(cls, item, None) + + return cls + + return wrapper + + +@delete_attrs( + "test_getitem_with_mask", + "test_getitem_with_mask_nd_indexer", + "test_index_0d_string", + "test_index_0d_datetime", + "test_index_0d_timedelta64", + "test_0d_time_data", + "test_datetime64_conversion", + "test_timedelta64_conversion", + "test_pandas_period_index", + "test_1d_math", + "test_1d_reduce", + "test_array_interface", + "test___array__", + "test_copy_index", + "test_concat_number_strings", + "test_concat_fixed_len_str", + "test_concat_mixed_dtypes", + "test_pandas_datetime64_with_tz", + "test_pandas_data", + "test_multiindex", +) +class TestVariable(VariableSubclassobjects): + @staticmethod + def cls(dims, data, *args, **kwargs): + return xr.Variable( + dims, unit_registry.Quantity(data, unit_registry.m), *args, **kwargs + ) + + @pytest.mark.parametrize( + "func", + ( + method("all"), + method("any"), + method("argmax"), + method("argmin"), + method("argsort"), + method("cumprod"), + method("cumsum"), + method("max"), + method("mean"), + method("median"), + method("min"), + pytest.param( + method("prod"), + marks=pytest.mark.xfail(reason="not implemented by pint"), + ), + method("std"), + method("sum"), + method("var"), + ), + ids=repr, + ) + def test_aggregation(self, func, dtype): + array = np.linspace(0, 1, 10).astype(dtype) * ( + unit_registry.m if func.name != "cumprod" else unit_registry.dimensionless + ) + variable = xr.Variable("x", array) + + units = extract_units(func(array)) + expected = attach_units(func(strip_units(variable)), units) + actual = func(variable) + + assert_units_equal(expected, actual) + xr.testing.assert_identical(expected, actual) + + @pytest.mark.parametrize( + "func", + ( + method("astype", np.float32), + method("conj"), + method("conjugate"), + method("clip", min=2, max=7), + ), + ids=repr, + ) + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.cm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + def test_numpy_methods(self, func, unit, error, dtype): + array = np.linspace(0, 1, 10).astype(dtype) * unit_registry.m + variable = xr.Variable("x", array) + + args = [ + item * unit if isinstance(item, (int, float, list)) else item + for item in func.args + ] + kwargs = { + key: value * unit if isinstance(value, (int, float, list)) else value + for key, value in func.kwargs.items() + } + + if error is not None and func.name in ("searchsorted", "clip"): + with pytest.raises(error): + func(variable, *args, **kwargs) + + return + + converted_args = [ + strip_units(convert_units(item, {None: unit_registry.m})) for item in args + ] + converted_kwargs = { + key: strip_units(convert_units(value, {None: unit_registry.m})) + for key, value in kwargs.items() + } + + units = extract_units(func(array, *args, **kwargs)) + expected = attach_units( + func(strip_units(variable), *converted_args, **converted_kwargs), units + ) + actual = func(variable, *args, **kwargs) + + assert_units_equal(expected, actual) + xr.testing.assert_allclose(expected, actual) + + @pytest.mark.parametrize( + "func", (method("item", 5), method("searchsorted", 5)), ids=repr + ) + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.cm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + def test_raw_numpy_methods(self, func, unit, error, dtype): + array = np.linspace(0, 1, 10).astype(dtype) * unit_registry.m + variable = xr.Variable("x", array) + + args = [ + item * unit + if isinstance(item, (int, float, list)) and func.name != "item" + else item + for item in func.args + ] + kwargs = { + key: value * unit + if isinstance(value, (int, float, list)) and func.name != "item" + else value + for key, value in func.kwargs.items() + } + + if error is not None and func.name != "item": + with pytest.raises(error): + func(variable, *args, **kwargs) + + return + + converted_args = [ + strip_units(convert_units(item, {None: unit_registry.m})) + if func.name != "item" + else item + for item in args + ] + converted_kwargs = { + key: strip_units(convert_units(value, {None: unit_registry.m})) + if func.name != "item" + else value + for key, value in kwargs.items() + } + + units = extract_units(func(array, *args, **kwargs)) + expected = attach_units( + func(strip_units(variable), *converted_args, **converted_kwargs), units + ) + actual = func(variable, *args, **kwargs) + + assert_units_equal(expected, actual) + np.testing.assert_allclose(expected, actual) + + @pytest.mark.parametrize( + "func", (method("isnull"), method("notnull"), method("count")), ids=repr + ) + def test_missing_value_detection(self, func): + array = ( + np.array( + [ + [1.4, 2.3, np.nan, 7.2], + [np.nan, 9.7, np.nan, np.nan], + [2.1, np.nan, np.nan, 4.6], + [9.9, np.nan, 7.2, 9.1], + ] + ) + * unit_registry.degK + ) + variable = xr.Variable(("x", "y"), array) + + expected = func(strip_units(variable)) + actual = func(variable) + + assert_units_equal(expected, actual) + xr.testing.assert_identical(expected, actual) + + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param( + 1, + DimensionalityError, + id="no_unit", + marks=pytest.mark.xfail(reason="uses 0 as a replacement"), + ), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param( + unit_registry.cm, + None, + id="compatible_unit", + marks=pytest.mark.xfail(reason="converts to fill value's unit"), + ), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + def test_missing_value_fillna(self, unit, error): + value = 0 + array = ( + np.array( + [ + [1.4, 2.3, np.nan, 7.2], + [np.nan, 9.7, np.nan, np.nan], + [2.1, np.nan, np.nan, 4.6], + [9.9, np.nan, 7.2, 9.1], + ] + ) + * unit_registry.m + ) + variable = xr.Variable(("x", "y"), array) + + fill_value = value * unit + + if error is not None: + with pytest.raises(error): + variable.fillna(value=fill_value) + + return + + expected = attach_units( + strip_units(variable).fillna( + value=fill_value.to(unit_registry.m).magnitude + ), + extract_units(variable), + ) + actual = variable.fillna(value=fill_value) + + assert_units_equal(expected, actual) + xr.testing.assert_identical(expected, actual) + + @pytest.mark.parametrize( + "unit", + ( + pytest.param(1, id="no_unit"), + pytest.param(unit_registry.dimensionless, id="dimensionless"), + pytest.param(unit_registry.s, id="incompatible_unit"), + pytest.param( + unit_registry.cm, + id="compatible_unit", + marks=pytest.mark.xfail( + reason="checking for identical units does not work properly, yet" + ), + ), + pytest.param(unit_registry.m, id="identical_unit"), + ), + ) + @pytest.mark.parametrize( + "convert_data", + ( + pytest.param(False, id="no_conversion"), + pytest.param(True, id="with_conversion"), + ), + ) + @pytest.mark.parametrize("func", (method("equals"), method("identical")), ids=repr) + def test_comparisons(self, func, unit, convert_data, dtype): + array = np.linspace(0, 1, 9).astype(dtype) + quantity1 = array * unit_registry.m + variable = xr.Variable("x", quantity1) + + if convert_data and is_compatible(unit_registry.m, unit): + quantity2 = convert_units(array * unit_registry.m, {None: unit}) + else: + quantity2 = array * unit + other = xr.Variable("x", quantity2) + + expected = func( + strip_units(variable), + strip_units( + convert_units(other, extract_units(variable)) + if is_compatible(unit_registry.m, unit) + else other + ), + ) + if func.name == "identical": + expected &= extract_units(variable) == extract_units(other) + else: + expected &= all( + compatible_mappings( + extract_units(variable), extract_units(other) + ).values() + ) + + actual = func(variable, other) + + assert expected == actual + + @pytest.mark.parametrize( + "unit", + ( + pytest.param(1, id="no_unit"), + pytest.param(unit_registry.dimensionless, id="dimensionless"), + pytest.param(unit_registry.s, id="incompatible_unit"), + pytest.param(unit_registry.cm, id="compatible_unit"), + pytest.param(unit_registry.m, id="identical_unit"), + ), + ) + def test_broadcast_equals(self, unit, dtype): + base_unit = unit_registry.m + left_array = np.ones(shape=(2, 2), dtype=dtype) * base_unit + value = ( + (1 * base_unit).to(unit).magnitude if is_compatible(unit, base_unit) else 1 + ) + right_array = np.full(shape=(2,), fill_value=value, dtype=dtype) * unit + + left = xr.Variable(("x", "y"), left_array) + right = xr.Variable("x", right_array) + + units = { + **extract_units(left), + **({} if is_compatible(unit, base_unit) else {None: None}), + } + expected = strip_units(left).broadcast_equals( + strip_units(convert_units(right, units)) + ) & is_compatible(unit, base_unit) + actual = left.broadcast_equals(right) + + assert expected == actual + + @pytest.mark.parametrize( + "indices", + ( + pytest.param(4, id="single index"), + pytest.param([5, 2, 9, 1], id="multiple indices"), + ), + ) + def test_isel(self, indices, dtype): + array = np.linspace(0, 5, 10).astype(dtype) * unit_registry.s + variable = xr.Variable("x", array) + + expected = attach_units( + strip_units(variable).isel(x=indices), extract_units(variable) + ) + actual = variable.isel(x=indices) + + assert_units_equal(expected, actual) + xr.testing.assert_identical(expected, actual) + + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.cm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + @pytest.mark.parametrize( + "func", + ( + function(lambda x, *_: +x, function_label="unary_plus"), + function(lambda x, *_: -x, function_label="unary_minus"), + function(lambda x, *_: abs(x), function_label="absolute"), + function(lambda x, y: x + y, function_label="sum"), + function(lambda x, y: y + x, function_label="commutative_sum"), + function(lambda x, y: x * y, function_label="product"), + function(lambda x, y: y * x, function_label="commutative_product"), + ), + ids=repr, + ) + def test_1d_math(self, func, unit, error, dtype): + base_unit = unit_registry.m + array = np.arange(5).astype(dtype) * base_unit + variable = xr.Variable("x", array) + + values = np.ones(5) + y = values * unit + + if error is not None and func.name in ("sum", "commutative_sum"): + with pytest.raises(error): + func(variable, y) + + return + + units = extract_units(func(array, y)) + if all(compatible_mappings(units, extract_units(y)).values()): + converted_y = convert_units(y, units) + else: + converted_y = y + + if all(compatible_mappings(units, extract_units(variable)).values()): + converted_variable = convert_units(variable, units) + else: + converted_variable = variable + + expected = attach_units( + func(strip_units(converted_variable), strip_units(converted_y)), units + ) + actual = func(variable, y) + + assert_units_equal(expected, actual) + xr.testing.assert_allclose(expected, actual) + + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param( + unit_registry.cm, + None, + id="compatible_unit", + marks=pytest.mark.xfail( + reason="getitem_with_mask converts to the unit of other" + ), + ), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + @pytest.mark.parametrize( + "func", (method("where"), method("_getitem_with_mask")), ids=repr + ) + def test_masking(self, func, unit, error, dtype): + base_unit = unit_registry.m + array = np.linspace(0, 5, 10).astype(dtype) * base_unit + variable = xr.Variable("x", array) + cond = np.array([True, False] * 5) + + other = -1 * unit + + if error is not None: + with pytest.raises(error): + func(variable, cond, other) + + return + + expected = attach_units( + func( + strip_units(variable), + cond, + strip_units( + convert_units( + other, + {None: base_unit} + if is_compatible(base_unit, unit) + else {None: None}, + ) + ), + ), + extract_units(variable), + ) + actual = func(variable, cond, other) + + assert_units_equal(expected, actual) + xr.testing.assert_identical(expected, actual) + + def test_squeeze(self, dtype): + shape = (2, 1, 3, 1, 1, 2) + names = list("abcdef") + array = np.ones(shape=shape) * unit_registry.m + variable = xr.Variable(names, array) + + expected = attach_units( + strip_units(variable).squeeze(), extract_units(variable) + ) + actual = variable.squeeze() + + assert_units_equal(expected, actual) + xr.testing.assert_identical(expected, actual) + + names = tuple(name for name, size in zip(names, shape) if shape == 1) + for name in names: + expected = attach_units( + strip_units(variable).squeeze(dim=name), extract_units(variable) + ) + actual = variable.squeeze(dim=name) + + assert_units_equal(expected, actual) + xr.testing.assert_identical(expected, actual) + + @pytest.mark.parametrize( + "func", + ( + method("coarsen", windows={"y": 2}, func=np.mean), + method("quantile", q=[0.25, 0.75]), + pytest.param( + method("rank", dim="x"), + marks=pytest.mark.xfail(reason="rank not implemented for non-ndarray"), + ), + method("roll", {"x": 2}), + pytest.param( + method("rolling_window", "x", 3, "window"), + marks=pytest.mark.xfail(reason="converts to ndarray"), + ), + method("reduce", np.std, "x"), + method("round", 2), + pytest.param( + method("shift", {"x": -2}), + marks=pytest.mark.xfail( + reason="trying to concatenate ndarray to quantity" + ), + ), + method("transpose", "y", "x"), + ), + ids=repr, + ) + def test_computation(self, func, dtype): + base_unit = unit_registry.m + array = np.linspace(0, 5, 5 * 10).reshape(5, 10).astype(dtype) * base_unit + variable = xr.Variable(("x", "y"), array) + + expected = attach_units(func(strip_units(variable)), extract_units(variable)) + + actual = func(variable) + + assert_units_equal(expected, actual) + xr.testing.assert_identical(expected, actual) + + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.cm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + def test_searchsorted(self, unit, error, dtype): + base_unit = unit_registry.m + array = np.linspace(0, 5, 10).astype(dtype) * base_unit + variable = xr.Variable("x", array) + + value = 0 * unit + + if error is not None: + with pytest.raises(error): + variable.searchsorted(value) + + return + + expected = strip_units(variable).searchsorted( + strip_units(convert_units(value, {None: base_unit})) + ) + + actual = variable.searchsorted(value) + + assert_units_equal(expected, actual) + np.testing.assert_allclose(expected, actual) + + def test_stack(self, dtype): + array = np.linspace(0, 5, 3 * 10).reshape(3, 10).astype(dtype) * unit_registry.m + variable = xr.Variable(("x", "y"), array) + + expected = attach_units( + strip_units(variable).stack(z=("x", "y")), extract_units(variable) + ) + actual = variable.stack(z=("x", "y")) + + assert_units_equal(expected, actual) + xr.testing.assert_identical(expected, actual) + + def test_unstack(self, dtype): + array = np.linspace(0, 5, 3 * 10).astype(dtype) * unit_registry.m + variable = xr.Variable("z", array) + + expected = attach_units( + strip_units(variable).unstack(z={"x": 3, "y": 10}), extract_units(variable) + ) + actual = variable.unstack(z={"x": 3, "y": 10}) + + assert_units_equal(expected, actual) + xr.testing.assert_identical(expected, actual) + + @pytest.mark.xfail(reason="ignores units") + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.cm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + def test_concat(self, unit, error, dtype): + array1 = ( + np.linspace(0, 5, 3 * 10).reshape(3, 10).astype(dtype) * unit_registry.m + ) + array2 = np.linspace(5, 10, 10 * 2).reshape(10, 2).astype(dtype) * unit + + variable = xr.Variable(("x", "y"), array1) + other = xr.Variable(("y", "z"), array2) + + if error is not None: + with pytest.raises(error): + variable.concat(other) + + return + + units = extract_units(variable) + expected = attach_units( + strip_units(variable).concat(strip_units(convert_units(other, units))), + units, + ) + actual = variable.concat(other) + + assert_units_equal(expected, actual) + xr.testing.assert_identical(expected, actual) + + def test_set_dims(self, dtype): + array = np.linspace(0, 5, 3 * 10).reshape(3, 10).astype(dtype) * unit_registry.m + variable = xr.Variable(("x", "y"), array) + + dims = {"z": 6, "x": 3, "a": 1, "b": 4, "y": 10} + expected = attach_units( + strip_units(variable).set_dims(dims), extract_units(variable) + ) + actual = variable.set_dims(dims) + + assert_units_equal(expected, actual) + xr.testing.assert_identical(expected, actual) + + def test_copy(self, dtype): + array = np.linspace(0, 5, 10).astype(dtype) * unit_registry.m + other = np.arange(10).astype(dtype) * unit_registry.s + + variable = xr.Variable("x", array) + expected = attach_units( + strip_units(variable).copy(data=strip_units(other)), extract_units(other) + ) + actual = variable.copy(data=other) + + assert_units_equal(expected, actual) + xr.testing.assert_identical(expected, actual) + + @pytest.mark.parametrize( + "unit", + ( + pytest.param(1, id="no_unit"), + pytest.param(unit_registry.dimensionless, id="dimensionless"), + pytest.param(unit_registry.s, id="incompatible_unit"), + pytest.param(unit_registry.cm, id="compatible_unit"), + pytest.param(unit_registry.m, id="identical_unit"), + ), + ) + def test_no_conflicts(self, unit, dtype): + base_unit = unit_registry.m + array1 = ( + np.array( + [ + [6.3, 0.3, 0.45], + [np.nan, 0.3, 0.3], + [3.7, np.nan, 0.2], + [9.43, 0.3, 0.7], + ] + ) + * base_unit + ) + array2 = np.array([np.nan, 0.3, np.nan]) * unit + + variable = xr.Variable(("x", "y"), array1) + other = xr.Variable("y", array2) + + expected = strip_units(variable).no_conflicts( + strip_units( + convert_units( + other, {None: base_unit if is_compatible(base_unit, unit) else None} + ) + ) + ) & is_compatible(base_unit, unit) + actual = variable.no_conflicts(other) + + assert expected == actual + + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param( + 1, + DimensionalityError, + id="no_unit", + marks=pytest.mark.xfail( + reason="is not treated the same as dimensionless" + ), + ), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.cm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + def test_pad_with_fill_value(self, unit, error, dtype): + array = np.linspace(0, 5, 3 * 10).reshape(3, 10).astype(dtype) * unit_registry.m + variable = xr.Variable(("x", "y"), array) + + fill_value = -100 * unit + + func = method("pad_with_fill_value", x=(2, 3), y=(1, 4)) + if error is not None: + with pytest.raises(error): + func(variable, fill_value=fill_value) + + return + + units = extract_units(variable) + expected = attach_units( + func( + strip_units(variable), + fill_value=strip_units(convert_units(fill_value, units)), + ), + units, + ) + actual = func(variable, fill_value=fill_value) + + assert_units_equal(expected, actual) + xr.testing.assert_identical(expected, actual) + + class TestDataArray: @pytest.mark.filterwarnings("error:::pint[.*]") @pytest.mark.parametrize( @@ -3539,6 +4372,7 @@ def test_stacking_stacked(self, func, dtype): assert_equal_with_units(expected, actual) + @pytest.mark.xfail(reason="does not work with quantities yet") def test_to_stacked_array(self, dtype): labels = np.arange(5).astype(dtype) * unit_registry.s arrays = {name: np.linspace(0, 1, 10) * unit_registry.m for name in labels}