From a342a36725fb62ff0b5db618daeaee790fc4c189 Mon Sep 17 00:00:00 2001 From: Keewis Date: Fri, 27 Dec 2019 23:49:43 +0100 Subject: [PATCH 01/38] add tests for variable by inheriting from VariableSubclassobjects --- xarray/tests/test_units.py | 43 +++++++++++++++++++++++++++++++++++++- 1 file changed, 42 insertions(+), 1 deletion(-) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index f8a8a259c1f..2224a81f238 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -7,12 +7,13 @@ import xarray as xr from xarray.core import formatting from xarray.core.npcompat import IS_NEP18_ACTIVE +from .test_variable import VariableSubclassobjects pint = pytest.importorskip("pint") DimensionalityError = pint.errors.DimensionalityError -unit_registry = pint.UnitRegistry() +unit_registry = pint.UnitRegistry(force_ndarray=True) Quantity = unit_registry.Quantity pytestmark = [ @@ -1245,6 +1246,46 @@ def test_dot_dataarray(dtype): assert_equal_with_units(expected, actual) +def delete_attrs(*to_delete): + def wrapper(cls): + for item in to_delete: + setattr(cls, item, None) + + return cls + + return wrapper + + +@delete_attrs( + "test_getitem_with_mask", + "test_getitem_with_mask_nd_indexer", + "test_index_0d_string", + "test_index_0d_datetime", + "test_index_0d_timedelta64", + "test_0d_time_data", + "test_datetime64_conversion", + "test_timedelta64_conversion", + "test_pandas_period_index", + "test_1d_math", + "test_1d_reduce", + "test_array_interface", + "test___array__", + "test_copy_index", + "test_concat", + "test_concat_number_strings", + "test_concat_fixed_len_str", + "test_concat_mixed_dtypes", + "test_pandas_datetime64_with_tz", + "test_multiindex", +) +class TestVariable(VariableSubclassobjects): + @staticmethod + def cls(dims, data, *args, **kwargs): + return xr.Variable( + dims, unit_registry.Quantity(data, unit_registry.m), *args, **kwargs + ) + + class TestDataArray: @pytest.mark.filterwarnings("error:::pint[.*]") @pytest.mark.parametrize( From c27cf4c7d0db9c7317af495468b87502f312b31c Mon Sep 17 00:00:00 2001 From: Keewis Date: Sat, 28 Dec 2019 21:05:03 +0100 Subject: [PATCH 02/38] make sure the utility functions work with variables --- xarray/tests/test_units.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index 2224a81f238..d64f4b629d4 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -114,6 +114,10 @@ def extract_units(obj): } units = {**vars_units, **coords_units} + elif isinstance(obj, xr.Variable): + vars_units = {None: array_extract_units(obj.data)} + + units = {**vars_units} elif isinstance(obj, Quantity): vars_units = {None: array_extract_units(obj)} @@ -149,6 +153,9 @@ def strip_units(obj): new_obj = xr.DataArray( name=strip_units(obj.name), data=data, coords=coords, dims=obj.dims ) + elif isinstance(obj, xr.Variable): + data = array_strip_units(obj.data) + new_obj = obj.copy(data=data) elif isinstance(obj, unit_registry.Quantity): new_obj = obj.magnitude elif isinstance(obj, (list, tuple)): @@ -160,7 +167,7 @@ def strip_units(obj): def attach_units(obj, units): - if not isinstance(obj, (xr.DataArray, xr.Dataset)): + if not isinstance(obj, (xr.DataArray, xr.Dataset, xr.Variable)): return array_attach_units(obj, units.get("data", 1)) if isinstance(obj, xr.Dataset): @@ -173,7 +180,7 @@ def attach_units(obj, units): } new_obj = xr.Dataset(data_vars=data_vars, coords=coords, attrs=obj.attrs) - else: + elif isinstance(obj, xr.DataArray): # try the array name, "data" and None, then fall back to dimensionless data_units = ( units.get(obj.name, None) @@ -199,6 +206,16 @@ def attach_units(obj, units): new_obj = xr.DataArray( name=obj.name, data=data, coords=coords, attrs=attrs, dims=dims ) + else: + data_units = ( + units.get(obj.name, None) + or units.get("data", None) + or units.get(None, None) + or 1 + ) + + data = array_attach_units(obj.data, data_units) + new_obj = obj.copy(data=data) return new_obj From 3a7f2a0c1377a1326e9293ee2588a9e1fae7a032 Mon Sep 17 00:00:00 2001 From: Keewis Date: Sat, 28 Dec 2019 21:10:36 +0100 Subject: [PATCH 03/38] add additional tests for aggregation and numpy methods --- xarray/tests/test_units.py | 100 +++++++++++++++++++++++++++++++++++++ 1 file changed, 100 insertions(+) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index d64f4b629d4..26eb8426298 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -1302,6 +1302,106 @@ def cls(dims, data, *args, **kwargs): dims, unit_registry.Quantity(data, unit_registry.m), *args, **kwargs ) + @pytest.mark.parametrize( + "func", + ( + pytest.param( + method("all"), marks=pytest.mark.xfail(reason="not implemented by pint") + ), + pytest.param( + method("any"), marks=pytest.mark.xfail(reason="not implemented by pint") + ), + method("argmax"), + method("argmin"), + method("argsort"), + method("cumprod"), + method("cumsum"), + method("max"), + method("mean"), + method("median"), + method("min"), + pytest.param( + method("prod"), + marks=pytest.mark.xfail(reason="not implemented by pint"), + ), + method("std"), + method("sum"), + method("var"), + ), + ids=repr, + ) + def test_aggregation(self, func, dtype): + array = np.linspace(0, 1, 10).astype(dtype) * ( + unit_registry.m if func.name != "cumprod" else unit_registry.dimensionless + ) + variable = xr.Variable("x", array) + + units = extract_units(func(array)) + expected = attach_units(func(strip_units(variable)), units) + actual = func(variable) + + xr.testing.assert_identical(expected, actual) + + @pytest.mark.parametrize( + "func", + ( + method("astype", np.float32), + method("conj"), + method("conjugate"), + method("searchsorted", 5), + method("clip", min=2, max=7), + ), + ids=repr, + ) + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.cm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + def test_numpy_methods(self, func, unit, error, dtype): + array = np.linspace(0, 1, 10).astype(dtype) * unit_registry.m + variable = xr.Variable("x", array) + + args = [ + item * unit if isinstance(item, (int, float, list)) else item + for item in func.args + ] + kwargs = { + key: value * unit if isinstance(value, (int, float, list)) else value + for key, value in func.kwargs.items() + } + + if error is not None and func.name in ("searchsorted", "clip"): + with pytest.raises(error): + func(variable, *args, **kwargs) + + return + + converted_args = [ + strip_units(convert_units(item, {None: unit_registry.m})) for item in args + ] + converted_kwargs = { + key: strip_units(convert_units(value, {None: unit_registry.m})) + for key, value in kwargs.items() + } + + print("running on:", func, args, kwargs) + units = extract_units(func(array, *args, **kwargs)) + expected = attach_units( + func(strip_units(variable), *converted_args, **converted_kwargs), units + ) + actual = func(variable, *args, **kwargs) + + assert extract_units(expected) == extract_units(actual) + xr.testing.assert_allclose(expected, actual) + class TestDataArray: @pytest.mark.filterwarnings("error:::pint[.*]") From b0b1e2c560d3cef59c07390bc61f664cf03658cf Mon Sep 17 00:00:00 2001 From: Keewis Date: Sun, 29 Dec 2019 15:02:19 +0100 Subject: [PATCH 04/38] don't assume variables have a name attribute --- xarray/tests/test_units.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index 26eb8426298..83ed40ef174 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -207,12 +207,7 @@ def attach_units(obj, units): name=obj.name, data=data, coords=coords, attrs=attrs, dims=dims ) else: - data_units = ( - units.get(obj.name, None) - or units.get("data", None) - or units.get(None, None) - or 1 - ) + data_units = units.get("data", None) or units.get(None, None) or 1 data = array_attach_units(obj.data, data_units) new_obj = obj.copy(data=data) From 1893329e08413bc65ce6fbb007ddaa8e20f0bcde Mon Sep 17 00:00:00 2001 From: Keewis Date: Sun, 29 Dec 2019 15:10:11 +0100 Subject: [PATCH 05/38] properly merge the used arguments --- xarray/tests/test_units.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index 83ed40ef174..d3667b53cef 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -295,6 +295,16 @@ def dtype(request): return request.param +def merge_args(default_args, new_args): + from itertools import zip_longest + + fill_value = object() + return [ + second if second is not fill_value else first + for first, second in zip_longest(default_args, new_args, fillvalue=fill_value) + ] + + class method: def __init__(self, name, *args, **kwargs): self.name = name @@ -305,7 +315,7 @@ def __call__(self, obj, *args, **kwargs): from collections.abc import Callable from functools import partial - all_args = list(self.args) + list(args) + all_args = merge_args(self.args, args) all_kwargs = {**self.kwargs, **kwargs} func = getattr(obj, self.name, None) @@ -347,7 +357,7 @@ def __init__(self, name_or_function, *args, **kwargs): self.kwargs = kwargs def __call__(self, *args, **kwargs): - all_args = list(self.args) + list(args) + all_args = merge_args(self.args, args) all_kwargs = {**self.kwargs, **kwargs} return self.func(*all_args, **all_kwargs) From 4058039289d253b6e627dc64586ad68a363b0e63 Mon Sep 17 00:00:00 2001 From: Keewis Date: Sun, 29 Dec 2019 16:00:53 +0100 Subject: [PATCH 06/38] properly attach to non-xarray objects --- xarray/tests/test_units.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index d3667b53cef..32d53b4e51c 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -168,7 +168,8 @@ def strip_units(obj): def attach_units(obj, units): if not isinstance(obj, (xr.DataArray, xr.Dataset, xr.Variable)): - return array_attach_units(obj, units.get("data", 1)) + units = units.get("data", None) or units.get(None, None) or 1 + return array_attach_units(obj, units) if isinstance(obj, xr.Dataset): data_vars = { From 6099585af63104deac69c8d75b969b00b653364a Mon Sep 17 00:00:00 2001 From: Keewis Date: Sun, 29 Dec 2019 16:04:15 +0100 Subject: [PATCH 07/38] add a test function for searchsorted and item --- xarray/tests/test_units.py | 62 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 60 insertions(+), 2 deletions(-) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index 32d53b4e51c..c68a01ea780 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -1354,7 +1354,6 @@ def test_aggregation(self, func, dtype): method("astype", np.float32), method("conj"), method("conjugate"), - method("searchsorted", 5), method("clip", min=2, max=7), ), ids=repr, @@ -1398,7 +1397,6 @@ def test_numpy_methods(self, func, unit, error, dtype): for key, value in kwargs.items() } - print("running on:", func, args, kwargs) units = extract_units(func(array, *args, **kwargs)) expected = attach_units( func(strip_units(variable), *converted_args, **converted_kwargs), units @@ -1408,6 +1406,66 @@ def test_numpy_methods(self, func, unit, error, dtype): assert extract_units(expected) == extract_units(actual) xr.testing.assert_allclose(expected, actual) + @pytest.mark.parametrize( + "func", (method("item", 5), method("searchsorted", 5)), ids=repr + ) + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.cm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + def test_raw_numpy_methods(self, func, unit, error, dtype): + array = np.linspace(0, 1, 10).astype(dtype) * unit_registry.m + variable = xr.Variable("x", array) + + args = [ + item * unit + if isinstance(item, (int, float, list)) and func.name != "item" + else item + for item in func.args + ] + kwargs = { + key: value * unit + if isinstance(value, (int, float, list)) and func.name != "item" + else value + for key, value in func.kwargs.items() + } + + if error is not None and func.name != "item": + with pytest.raises(error): + func(variable, *args, **kwargs) + + return + + converted_args = [ + strip_units(convert_units(item, {None: unit_registry.m})) + if func.name != "item" + else item + for item in args + ] + converted_kwargs = { + key: strip_units(convert_units(value, {None: unit_registry.m})) + if func.name != "item" + else value + for key, value in kwargs.items() + } + + units = extract_units(func(array, *args, **kwargs)) + expected = attach_units( + func(strip_units(variable), *converted_args, **converted_kwargs), units + ) + actual = func(variable, *args, **kwargs) + + assert extract_units(expected) == extract_units(actual) + np.testing.assert_allclose(expected, actual) + class TestDataArray: @pytest.mark.filterwarnings("error:::pint[.*]") From a7b0f6a2369a15839704eb6723eed479992e99c4 Mon Sep 17 00:00:00 2001 From: Keewis Date: Sun, 29 Dec 2019 16:10:14 +0100 Subject: [PATCH 08/38] add tests for missing value detection methods --- xarray/tests/test_units.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index c68a01ea780..39812edbdba 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -1466,6 +1466,29 @@ def test_raw_numpy_methods(self, func, unit, error, dtype): assert extract_units(expected) == extract_units(actual) np.testing.assert_allclose(expected, actual) + @pytest.mark.parametrize( + "func", (method("isnull"), method("notnull"), method("count")), ids=repr + ) + def test_missing_value_detection(self, func, dtype): + array = ( + np.array( + [ + [1.4, 2.3, np.nan, 7.2], + [np.nan, 9.7, np.nan, np.nan], + [2.1, np.nan, np.nan, 4.6], + [9.9, np.nan, 7.2, 9.1], + ] + ).astype(dtype) + * unit_registry.degK + ) + + variable = xr.Variable(("x", "y"), array) + + expected = func(strip_units(variable)) + actual = func(variable) + + xr.testing.assert_identical(expected, actual) + class TestDataArray: @pytest.mark.filterwarnings("error:::pint[.*]") From b9338aef95bf1829a55d9e6761d9dc338290de0c Mon Sep 17 00:00:00 2001 From: Keewis Date: Sun, 29 Dec 2019 17:03:36 +0100 Subject: [PATCH 09/38] add tests for comparisons --- xarray/tests/test_units.py | 59 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index 39812edbdba..9f5afbf6242 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -296,6 +296,11 @@ def dtype(request): return request.param +def merge_mappings(*mappings): + for key in set(mappings[0]).intersection(*mappings[1:]): + yield key, tuple(m[key] for m in mappings) + + def merge_args(default_args, new_args): from itertools import zip_longest @@ -1489,6 +1494,60 @@ def test_missing_value_detection(self, func, dtype): xr.testing.assert_identical(expected, actual) + @pytest.mark.parametrize( + "unit", + ( + pytest.param(1, id="no_unit"), + pytest.param(unit_registry.dimensionless, id="dimensionless"), + pytest.param(unit_registry.s, id="incompatible_unit"), + pytest.param(unit_registry.cm, id="compatible_unit"), + pytest.param(unit_registry.m, id="identical_unit"), + ), + ) + @pytest.mark.parametrize( + "convert_data", + ( + pytest.param(False, id="no_conversion"), + pytest.param(True, id="with_conversion"), + ), + ) + @pytest.mark.parametrize("func", (method("equals"), method("identical")), ids=repr) + def test_comparisons(self, func, unit, convert_data, dtype): + def is_compatible(first, second): + return { + key: unit_registry.Quantity(0, unit1).check(unit2) + for key, (unit1, unit2) in merge_mappings(first, second) + } + + array = np.linspace(0, 1, 9).astype(dtype) + quantity1 = array * unit_registry.m + variable = xr.Variable("x", quantity1) + + if convert_data and quantity1.check(unit): + quantity2 = convert_units(array * unit_registry.m, {None: unit}) + else: + quantity2 = array * unit + other = xr.Variable("x", quantity2) + + expected = func( + strip_units(variable), + strip_units( + convert_units(other, extract_units(variable)) + if quantity1.check(unit) + else other + ), + ) + if func.name == "identical": + expected &= extract_units(variable) == extract_units(other) + else: + expected &= all( + is_compatible(extract_units(variable), extract_units(other)).values() + ) + + actual = func(variable, other) + + assert expected == actual + class TestDataArray: @pytest.mark.filterwarnings("error:::pint[.*]") From 00169c5094866bd4836b71696972d80028ccb046 Mon Sep 17 00:00:00 2001 From: Keewis Date: Sun, 29 Dec 2019 17:37:09 +0100 Subject: [PATCH 10/38] don't try to check missing value behaviour for int arrays --- xarray/tests/test_units.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index 9f5afbf6242..253464d50a4 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -1474,7 +1474,7 @@ def test_raw_numpy_methods(self, func, unit, error, dtype): @pytest.mark.parametrize( "func", (method("isnull"), method("notnull"), method("count")), ids=repr ) - def test_missing_value_detection(self, func, dtype): + def test_missing_value_detection(self, func): array = ( np.array( [ @@ -1483,10 +1483,9 @@ def test_missing_value_detection(self, func, dtype): [2.1, np.nan, np.nan, 4.6], [9.9, np.nan, 7.2, 9.1], ] - ).astype(dtype) + ) * unit_registry.degK ) - variable = xr.Variable(("x", "y"), array) expected = func(strip_units(variable)) From a37fada973992dcc022cf3902688355d5e970fef Mon Sep 17 00:00:00 2001 From: Keewis Date: Sun, 29 Dec 2019 17:57:33 +0100 Subject: [PATCH 11/38] xfail the compatible unit tests --- xarray/tests/test_units.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index 253464d50a4..9deb9aa0697 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -1499,7 +1499,13 @@ def test_missing_value_detection(self, func): pytest.param(1, id="no_unit"), pytest.param(unit_registry.dimensionless, id="dimensionless"), pytest.param(unit_registry.s, id="incompatible_unit"), - pytest.param(unit_registry.cm, id="compatible_unit"), + pytest.param( + unit_registry.cm, + id="compatible_unit", + marks=pytest.mark.xfail( + reason="checking for identical units does not work properly, yet" + ), + ), pytest.param(unit_registry.m, id="identical_unit"), ), ) From 9081c365b501952395b310eefef577a513402874 Mon Sep 17 00:00:00 2001 From: Keewis Date: Sun, 29 Dec 2019 18:04:09 +0100 Subject: [PATCH 12/38] use an additional check since identical is not sufficient right now --- xarray/tests/test_units.py | 1 + 1 file changed, 1 insertion(+) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index 9deb9aa0697..82c1fbcc303 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -1491,6 +1491,7 @@ def test_missing_value_detection(self, func): expected = func(strip_units(variable)) actual = func(variable) + assert extract_units(expected) == extract_units(actual) xr.testing.assert_identical(expected, actual) @pytest.mark.parametrize( From 41cb4d78baab803a5e3ed6652982b449bbed5190 Mon Sep 17 00:00:00 2001 From: Keewis Date: Sun, 29 Dec 2019 18:05:21 +0100 Subject: [PATCH 13/38] check for compatibility by comparing the dimensionality of units --- xarray/tests/test_units.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index 82c1fbcc303..b54d3771227 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -1519,9 +1519,16 @@ def test_missing_value_detection(self, func): ) @pytest.mark.parametrize("func", (method("equals"), method("identical")), ids=repr) def test_comparisons(self, func, unit, convert_data, dtype): - def is_compatible(first, second): + def is_compatible(unit1, unit2): + return ( + isinstance(unit1, unit_registry.Unit) + and isinstance(unit2, unit_registry.Unit) + and unit1.dimensionality == unit2.dimensionality + ) + + def compatible_mappings(first, second): return { - key: unit_registry.Quantity(0, unit1).check(unit2) + key: is_compatible(unit1, unit2) for key, (unit1, unit2) in merge_mappings(first, second) } @@ -1529,7 +1536,7 @@ def is_compatible(first, second): quantity1 = array * unit_registry.m variable = xr.Variable("x", quantity1) - if convert_data and quantity1.check(unit): + if convert_data and is_compatible(unit_registry.m, unit): quantity2 = convert_units(array * unit_registry.m, {None: unit}) else: quantity2 = array * unit @@ -1539,7 +1546,7 @@ def is_compatible(first, second): strip_units(variable), strip_units( convert_units(other, extract_units(variable)) - if quantity1.check(unit) + if is_compatible(unit_registry.m, unit) else other ), ) @@ -1547,7 +1554,9 @@ def is_compatible(first, second): expected &= extract_units(variable) == extract_units(other) else: expected &= all( - is_compatible(extract_units(variable), extract_units(other)).values() + compatible_mappings( + extract_units(variable), extract_units(other) + ).values() ) actual = func(variable, other) From d2258afd101ebc78569789e67847b2d117826dd8 Mon Sep 17 00:00:00 2001 From: Keewis Date: Sun, 29 Dec 2019 18:06:45 +0100 Subject: [PATCH 14/38] add tests for fillna --- xarray/tests/test_units.py | 56 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index b54d3771227..1a418d75101 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -1494,6 +1494,62 @@ def test_missing_value_detection(self, func): assert extract_units(expected) == extract_units(actual) xr.testing.assert_identical(expected, actual) + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param( + 1, + DimensionalityError, + id="no_unit", + marks=pytest.mark.xfail(reason="uses 0 as a replacement"), + ), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param( + unit_registry.cm, + None, + id="compatible_unit", + marks=pytest.mark.xfail(reason="converts to fill value's unit"), + ), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + def test_missing_value_fillna(self, unit, error): + value = 0 + array = ( + np.array( + [ + [1.4, 2.3, np.nan, 7.2], + [np.nan, 9.7, np.nan, np.nan], + [2.1, np.nan, np.nan, 4.6], + [9.9, np.nan, 7.2, 9.1], + ] + ) + * unit_registry.m + ) + variable = xr.Variable(("x", "y"), array) + + fill_value = value * unit + + if error is not None: + with pytest.raises(error): + print(variable.fillna(value=fill_value)) + + return + + expected = attach_units( + strip_units(variable).fillna( + value=fill_value.to(unit_registry.m).magnitude + ), + extract_units(variable), + ) + actual = variable.fillna(value=fill_value) + + assert extract_units(expected) == extract_units(actual) + xr.testing.assert_identical(expected, actual) + @pytest.mark.parametrize( "unit", ( From 2f6e6df3c37547a62ed4ce0864cef5898c638174 Mon Sep 17 00:00:00 2001 From: Keewis Date: Mon, 30 Dec 2019 00:10:12 +0100 Subject: [PATCH 15/38] add tests for isel --- xarray/tests/test_units.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index 1a418d75101..0c68b85063e 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -1619,6 +1619,24 @@ def compatible_mappings(first, second): assert expected == actual + @pytest.mark.parametrize( + "indices", + ( + pytest.param(4, id="single index"), + pytest.param([5, 2, 9, 1], id="multiple indices"), + ), + ) + def test_isel(self, indices, dtype): + array = np.linspace(0, 5, 10).astype(dtype) * unit_registry.s + variable = xr.Variable("x", array) + + expected = attach_units( + strip_units(variable).isel(x=indices), extract_units(variable) + ) + actual = variable.isel(x=indices) + + xr.testing.assert_identical(expected, actual) + class TestDataArray: @pytest.mark.filterwarnings("error:::pint[.*]") From 69bc684f32918b39a784c8fec377648dbf32a2e1 Mon Sep 17 00:00:00 2001 From: Keewis Date: Tue, 31 Dec 2019 15:38:52 +0100 Subject: [PATCH 16/38] add tests for 1d math --- xarray/tests/test_units.py | 96 +++++++++++++++++++++++++++++++------- 1 file changed, 80 insertions(+), 16 deletions(-) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index 0c68b85063e..03ce98b1b2f 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -29,6 +29,21 @@ ] +def is_compatible(unit1, unit2): + return ( + isinstance(unit1, unit_registry.Unit) + and isinstance(unit2, unit_registry.Unit) + and unit1.dimensionality == unit2.dimensionality + ) + + +def compatible_mappings(first, second): + return { + key: is_compatible(unit1, unit2) + for key, (unit1, unit2) in merge_mappings(first, second) + } + + def array_extract_units(obj): if isinstance(obj, (xr.Variable, xr.DataArray, xr.Dataset)): obj = obj.data @@ -347,12 +362,16 @@ def __repr__(self): class function: - def __init__(self, name_or_function, *args, **kwargs): + def __init__(self, name_or_function, *args, function_label=None, **kwargs): if callable(name_or_function): - self.name = name_or_function.__name__ + self.name = ( + function_label + if function_label is not None + else name_or_function.__name__ + ) self.func = name_or_function else: - self.name = name_or_function + self.name = name_or_function if function_label is None else function_label self.func = getattr(np, name_or_function) if self.func is None: raise AttributeError( @@ -1575,19 +1594,6 @@ def test_missing_value_fillna(self, unit, error): ) @pytest.mark.parametrize("func", (method("equals"), method("identical")), ids=repr) def test_comparisons(self, func, unit, convert_data, dtype): - def is_compatible(unit1, unit2): - return ( - isinstance(unit1, unit_registry.Unit) - and isinstance(unit2, unit_registry.Unit) - and unit1.dimensionality == unit2.dimensionality - ) - - def compatible_mappings(first, second): - return { - key: is_compatible(unit1, unit2) - for key, (unit1, unit2) in merge_mappings(first, second) - } - array = np.linspace(0, 1, 9).astype(dtype) quantity1 = array * unit_registry.m variable = xr.Variable("x", quantity1) @@ -1637,6 +1643,64 @@ def test_isel(self, indices, dtype): xr.testing.assert_identical(expected, actual) + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.cm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + @pytest.mark.parametrize( + "func", + ( + function(lambda x, *_: +x, function_label="unary_plus"), + function(lambda x, *_: -x, function_label="unary_minus"), + function(lambda x, *_: abs(x), function_label="absolute"), + function(lambda x, y: x + y, function_label="sum"), + function(lambda x, y: y + x, function_label="commutative_sum"), + function(lambda x, y: x * y, function_label="product"), + function(lambda x, y: y * x, function_label="commutative_product"), + ), + ids=repr, + ) + def test_1d_math(self, func, unit, error, dtype): + base_unit = unit_registry.m + array = np.arange(5).astype(dtype) * base_unit + variable = xr.Variable("x", array) + + values = np.ones(5) + y = values * unit + + if error is not None and func.name in ("sum", "commutative_sum"): + with pytest.raises(error): + func(variable, y) + + return + + units = extract_units(func(array, y)) + if all(compatible_mappings(units, extract_units(y)).values()): + converted_y = convert_units(y, units) + else: + converted_y = y + + if all(compatible_mappings(units, extract_units(variable)).values()): + converted_variable = convert_units(variable, units) + else: + converted_variable = variable + + expected = attach_units( + func(strip_units(converted_variable), strip_units(converted_y)), units + ) + actual = func(variable, y) + + assert extract_units(expected) == extract_units(actual) + xr.testing.assert_allclose(expected, actual) + class TestDataArray: @pytest.mark.filterwarnings("error:::pint[.*]") From e81029bbd25683db8f761a238f3d05645292270d Mon Sep 17 00:00:00 2001 From: Keewis Date: Wed, 8 Jan 2020 16:03:47 +0100 Subject: [PATCH 17/38] add tests for broadcast_equals --- xarray/tests/test_units.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index 03ce98b1b2f..90daae953d3 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -1625,6 +1625,38 @@ def test_comparisons(self, func, unit, convert_data, dtype): assert expected == actual + @pytest.mark.parametrize( + "unit", + ( + pytest.param(1, id="no_unit"), + pytest.param(unit_registry.dimensionless, id="dimensionless"), + pytest.param(unit_registry.s, id="incompatible_unit"), + pytest.param(unit_registry.cm, id="compatible_unit"), + pytest.param(unit_registry.m, id="identical_unit"), + ), + ) + def test_broadcast_equals(self, unit, dtype): + base_unit = unit_registry.m + left_array = np.ones(shape=(2, 2), dtype=dtype) * base_unit + value = ( + (1 * base_unit).to(unit).magnitude if is_compatible(unit, base_unit) else 1 + ) + right_array = np.full(shape=(2,), fill_value=value, dtype=dtype) * unit + + left = xr.Variable(("x", "y"), left_array) + right = xr.Variable("x", right_array) + + units = { + **extract_units(left), + **({} if is_compatible(unit, base_unit) else {None: None}), + } + expected = strip_units(left).broadcast_equals( + strip_units(convert_units(right, units)) + ) & is_compatible(unit, base_unit) + actual = left.broadcast_equals(right) + + assert expected == actual + @pytest.mark.parametrize( "indices", ( From 1302d12c11ceb673bf57ba135739f4700e3c82a8 Mon Sep 17 00:00:00 2001 From: Keewis Date: Wed, 8 Jan 2020 16:18:31 +0100 Subject: [PATCH 18/38] add tests for masking --- xarray/tests/test_units.py | 56 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index 90daae953d3..12f851f1650 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -1733,6 +1733,62 @@ def test_1d_math(self, func, unit, error, dtype): assert extract_units(expected) == extract_units(actual) xr.testing.assert_allclose(expected, actual) + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param( + unit_registry.cm, + None, + id="compatible_unit", + marks=pytest.mark.xfail( + reason="getitem_with_mask converts to the unit of other" + ), + ), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + @pytest.mark.parametrize( + "func", (method("where"), method("_getitem_with_mask")), ids=repr + ) + def test_masking(self, func, unit, error, dtype): + base_unit = unit_registry.m + array = np.linspace(0, 5, 10).astype(dtype) * base_unit + variable = xr.Variable("x", array) + cond = np.array([True, False] * 5) + + other = -1 * unit + + if error is not None: + with pytest.raises(error): + func(variable, cond, other) + + return + + expected = attach_units( + func( + strip_units(variable), + cond, + strip_units( + convert_units( + other, + {None: base_unit} + if is_compatible(base_unit, unit) + else {None: None}, + ) + ), + ), + extract_units(variable), + ) + actual = func(variable, cond, other) + + assert extract_units(expected) == extract_units(actual) + xr.testing.assert_identical(expected, actual) + class TestDataArray: @pytest.mark.filterwarnings("error:::pint[.*]") From 59c9edceac54baa40228704f818b6ed7b2208f8a Mon Sep 17 00:00:00 2001 From: Keewis Date: Wed, 8 Jan 2020 16:31:12 +0100 Subject: [PATCH 19/38] add tests for squeeze --- xarray/tests/test_units.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index 12f851f1650..cbb61e3c88d 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -1789,6 +1789,29 @@ def test_masking(self, func, unit, error, dtype): assert extract_units(expected) == extract_units(actual) xr.testing.assert_identical(expected, actual) + def test_squeeze(self, dtype): + shape = (2, 1, 3, 1, 1, 2) + names = list("abcdef") + array = np.ones(shape=shape) * unit_registry.m + variable = xr.Variable(names, array) + + expected = attach_units( + strip_units(variable).squeeze(), extract_units(variable) + ) + actual = variable.squeeze() + + assert extract_units(expected) == extract_units(actual) + xr.testing.assert_identical(expected, actual) + + for dim in names: + expected = attach_units( + strip_units(variable).squeeze(), extract_units(variable) + ) + actual = variable.squeeze() + + assert extract_units(expected) == extract_units(actual) + xr.testing.assert_identical(expected, actual) + class TestDataArray: @pytest.mark.filterwarnings("error:::pint[.*]") From bbb4105a612154b4b94ff35e2f011b2b42211d77 Mon Sep 17 00:00:00 2001 From: Keewis Date: Wed, 8 Jan 2020 16:44:08 +0100 Subject: [PATCH 20/38] add tests for coarsen, quantile, roll, round and shift --- xarray/tests/test_units.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index cbb61e3c88d..54d16b6a905 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -1812,6 +1812,38 @@ def test_squeeze(self, dtype): assert extract_units(expected) == extract_units(actual) xr.testing.assert_identical(expected, actual) + @pytest.mark.parametrize( + "func", + ( + method("coarsen", windows={"x": 2}, func=np.mean), + method("quantile", q=[0.25, 0.75]), + pytest.param( + method("rank", dim="x"), + marks=pytest.mark.xfail(reason="rank not implemented for non-ndarray"), + ), + method("roll", {"x": 2}), + pytest.param( + method("shift", {"x": -2}), + marks=pytest.mark.xfail( + reason="trying to concatenate ndarray to quantity" + ), + ), + method("round", 2), + ), + ids=repr, + ) + def test_computation(self, func, dtype): + base_unit = unit_registry.m + array = np.linspace(0, 5, 10).astype(dtype) * base_unit + variable = xr.Variable("x", array) + + expected = attach_units(func(strip_units(variable)), extract_units(variable)) + + actual = func(variable) + + assert extract_units(expected) == extract_units(actual) + xr.testing.assert_identical(expected, actual) + class TestDataArray: @pytest.mark.filterwarnings("error:::pint[.*]") From a11885b32b76b30fa9c20d718dd3585936b67f0a Mon Sep 17 00:00:00 2001 From: Keewis Date: Wed, 8 Jan 2020 17:53:42 +0100 Subject: [PATCH 21/38] add tests for searchsorted --- xarray/tests/test_units.py | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index 54d16b6a905..e6cad04a07b 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -1844,6 +1844,40 @@ def test_computation(self, func, dtype): assert extract_units(expected) == extract_units(actual) xr.testing.assert_identical(expected, actual) + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.cm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + def test_searchsorted(self, unit, error, dtype): + base_unit = unit_registry.m + array = np.linspace(0, 5, 10).astype(dtype) * base_unit + variable = xr.Variable("x", array) + + value = 0 * unit + + if error is not None: + with pytest.raises(error): + variable.searchsorted(value) + + return + + expected = strip_units(variable).searchsorted( + strip_units(convert_units(value, {None: base_unit})) + ) + + actual = variable.searchsorted(value) + + assert extract_units(expected) == extract_units(actual) + np.testing.assert_allclose(expected, actual) + class TestDataArray: @pytest.mark.filterwarnings("error:::pint[.*]") From 6dba6bca179402dd3ad4eb2550681f2af4ea2bff Mon Sep 17 00:00:00 2001 From: Keewis Date: Wed, 8 Jan 2020 18:04:41 +0100 Subject: [PATCH 22/38] add tests for rolling_window --- xarray/tests/test_units.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index e6cad04a07b..8628abdd4fd 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -1822,13 +1822,17 @@ def test_squeeze(self, dtype): marks=pytest.mark.xfail(reason="rank not implemented for non-ndarray"), ), method("roll", {"x": 2}), + pytest.param( + method("rolling_window", "x", 3, "window"), + marks=pytest.mark.xfail(reason="converts to ndarray"), + ), + method("round", 2), pytest.param( method("shift", {"x": -2}), marks=pytest.mark.xfail( reason="trying to concatenate ndarray to quantity" ), ), - method("round", 2), ), ids=repr, ) From d27cf8075947113c873fd779cca2615e617229f6 Mon Sep 17 00:00:00 2001 From: Keewis Date: Wed, 8 Jan 2020 18:05:04 +0100 Subject: [PATCH 23/38] add tests for transpose --- xarray/tests/test_units.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index 8628abdd4fd..fd163338143 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -1815,7 +1815,7 @@ def test_squeeze(self, dtype): @pytest.mark.parametrize( "func", ( - method("coarsen", windows={"x": 2}, func=np.mean), + method("coarsen", windows={"y": 2}, func=np.mean), method("quantile", q=[0.25, 0.75]), pytest.param( method("rank", dim="x"), @@ -1833,13 +1833,14 @@ def test_squeeze(self, dtype): reason="trying to concatenate ndarray to quantity" ), ), + method("transpose", "y", "x"), ), ids=repr, ) def test_computation(self, func, dtype): base_unit = unit_registry.m - array = np.linspace(0, 5, 10).astype(dtype) * base_unit - variable = xr.Variable("x", array) + array = np.linspace(0, 5, 5 * 10).reshape(5, 10).astype(dtype) * base_unit + variable = xr.Variable(("x", "y"), array) expected = attach_units(func(strip_units(variable)), extract_units(variable)) From 5dffc12c983721953b9f26cd93293c96404e4f5c Mon Sep 17 00:00:00 2001 From: Keewis Date: Wed, 8 Jan 2020 18:17:38 +0100 Subject: [PATCH 24/38] add tests for stack and unstack --- xarray/tests/test_units.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index fd163338143..e6f41158d33 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -1883,6 +1883,30 @@ def test_searchsorted(self, unit, error, dtype): assert extract_units(expected) == extract_units(actual) np.testing.assert_allclose(expected, actual) + def test_stack(self, dtype): + array = np.linspace(0, 5, 3 * 10).reshape(3, 10).astype(dtype) * unit_registry.m + variable = xr.Variable(("x", "y"), array) + + expected = attach_units( + strip_units(variable).stack(z=("x", "y")), extract_units(variable) + ) + actual = variable.stack(z=("x", "y")) + + assert extract_units(expected) == extract_units(actual) + xr.testing.assert_identical(expected, actual) + + def test_unstack(self, dtype): + array = np.linspace(0, 5, 3 * 10).astype(dtype) * unit_registry.m + variable = xr.Variable("z", array) + + expected = attach_units( + strip_units(variable).unstack(z={"x": 3, "y": 10}), extract_units(variable) + ) + actual = variable.unstack(z={"x": 3, "y": 10}) + + assert extract_units(expected) == extract_units(actual) + xr.testing.assert_identical(expected, actual) + class TestDataArray: @pytest.mark.filterwarnings("error:::pint[.*]") From cbb8da3a44531b043c935a004c6cce1626fe5bf5 Mon Sep 17 00:00:00 2001 From: Keewis Date: Wed, 8 Jan 2020 18:31:08 +0100 Subject: [PATCH 25/38] add tests for set_dims --- xarray/tests/test_units.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index e6f41158d33..71e6a69c83f 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -1907,6 +1907,19 @@ def test_unstack(self, dtype): assert extract_units(expected) == extract_units(actual) xr.testing.assert_identical(expected, actual) + def test_set_dims(self, dtype): + array = np.linspace(0, 5, 3 * 10).reshape(3, 10).astype(dtype) * unit_registry.m + variable = xr.Variable(("x", "y"), array) + + dims = {"z": 6, "x": 3, "a": 1, "b": 4, "y": 10} + expected = attach_units( + strip_units(variable).set_dims(dims), extract_units(variable) + ) + actual = variable.set_dims(dims) + + assert extract_units(expected) == extract_units(actual) + xr.testing.assert_identical(expected, actual) + class TestDataArray: @pytest.mark.filterwarnings("error:::pint[.*]") From e2faede82d6bd4021497d236015803909ec857f7 Mon Sep 17 00:00:00 2001 From: Keewis Date: Wed, 8 Jan 2020 23:53:00 +0100 Subject: [PATCH 26/38] add tests for concat, copy and no_conflicts --- xarray/tests/test_units.py | 92 +++++++++++++++++++++++++++++++++++++- 1 file changed, 91 insertions(+), 1 deletion(-) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index 71e6a69c83f..9c795be29b3 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -1318,11 +1318,11 @@ def wrapper(cls): "test_array_interface", "test___array__", "test_copy_index", - "test_concat", "test_concat_number_strings", "test_concat_fixed_len_str", "test_concat_mixed_dtypes", "test_pandas_datetime64_with_tz", + "test_pandas_data", "test_multiindex", ) class TestVariable(VariableSubclassobjects): @@ -1907,6 +1907,44 @@ def test_unstack(self, dtype): assert extract_units(expected) == extract_units(actual) xr.testing.assert_identical(expected, actual) + @pytest.mark.xfail(reason="ignores units") + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.cm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + def test_concat(self, unit, error, dtype): + array1 = ( + np.linspace(0, 5, 3 * 10).reshape(3, 10).astype(dtype) * unit_registry.m + ) + array2 = np.linspace(5, 10, 10 * 2).reshape(10, 2).astype(dtype) * unit + + variable = xr.Variable(("x", "y"), array1) + other = xr.Variable(("y", "z"), array2) + + if error is not None: + with pytest.raises(error): + variable.concat(other) + + return + + units = extract_units(variable) + expected = attach_units( + strip_units(variable).concat(strip_units(convert_units(other, units))), + units, + ) + actual = variable.concat(other) + + assert extract_units(expected) == extract_units(actual) + xr.testing.assert_identical(expected, actual) + def test_set_dims(self, dtype): array = np.linspace(0, 5, 3 * 10).reshape(3, 10).astype(dtype) * unit_registry.m variable = xr.Variable(("x", "y"), array) @@ -1920,6 +1958,58 @@ def test_set_dims(self, dtype): assert extract_units(expected) == extract_units(actual) xr.testing.assert_identical(expected, actual) + def test_copy(self, dtype): + array = np.linspace(0, 5, 10).astype(dtype) * unit_registry.m + other = np.arange(10).astype(dtype) * unit_registry.s + + variable = xr.Variable("x", array) + expected = attach_units( + strip_units(variable).copy(data=strip_units(other)), extract_units(other) + ) + actual = variable.copy(data=other) + + assert extract_units(expected) == extract_units(actual) + xr.testing.assert_identical(expected, actual) + + @pytest.mark.parametrize( + "unit", + ( + pytest.param(1, id="no_unit"), + pytest.param(unit_registry.dimensionless, id="dimensionless"), + pytest.param(unit_registry.s, id="incompatible_unit"), + pytest.param(unit_registry.cm, id="compatible_unit"), + pytest.param(unit_registry.m, id="identical_unit"), + ), + ) + def test_no_conflicts(self, unit, dtype): + base_unit = unit_registry.m + array1 = ( + np.array( + [ + [6.3, 0.3, 0.45], + [np.nan, 0.3, 0.3], + [3.7, np.nan, 0.2], + [9.43, 0.3, 0.7], + ] + ) + * base_unit + ) + array2 = np.array([np.nan, 0.3, np.nan]) * unit + + variable = xr.Variable(("x", "y"), array1) + other = xr.Variable("y", array2) + + expected = strip_units(variable).no_conflicts( + strip_units( + convert_units( + other, {None: base_unit if is_compatible(base_unit, unit) else None} + ) + ) + ) & is_compatible(base_unit, unit) + actual = variable.no_conflicts(other) + + assert expected == actual + class TestDataArray: @pytest.mark.filterwarnings("error:::pint[.*]") From 350214bd69c804124cfbc2f22e4bfbb82fac33b8 Mon Sep 17 00:00:00 2001 From: Keewis Date: Wed, 8 Jan 2020 23:54:47 +0100 Subject: [PATCH 27/38] add tests for reduce --- xarray/tests/test_units.py | 1 + 1 file changed, 1 insertion(+) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index 9c795be29b3..33f0251e990 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -1826,6 +1826,7 @@ def test_squeeze(self, dtype): method("rolling_window", "x", 3, "window"), marks=pytest.mark.xfail(reason="converts to ndarray"), ), + method("reduce", np.std, "x"), method("round", 2), pytest.param( method("shift", {"x": -2}), From e34a3aa96da0cf8a244e0d83efee1d0f01bc4aac Mon Sep 17 00:00:00 2001 From: Keewis Date: Thu, 9 Jan 2020 00:12:23 +0100 Subject: [PATCH 28/38] add tests for pad_with_fill_value --- xarray/tests/test_units.py | 45 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index 33f0251e990..9899f26d543 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -2011,6 +2011,51 @@ def test_no_conflicts(self, unit, dtype): assert expected == actual + @pytest.mark.parametrize( + "unit,error", + ( + pytest.param( + 1, + DimensionalityError, + id="no_unit", + marks=pytest.mark.xfail( + reason="is not treated the same as dimensionless" + ), + ), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.cm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ) + def test_pad_with_fill_value(self, unit, error, dtype): + array = np.linspace(0, 5, 3 * 10).reshape(3, 10).astype(dtype) * unit_registry.m + variable = xr.Variable(("x", "y"), array) + + fill_value = np.array(-100) * unit + + func = method("pad_with_fill_value", x=(2, 3), y=(1, 4)) + if error is not None: + with pytest.raises(error): + func(variable, fill_value=fill_value) + + return + + units = extract_units(variable) + expected = attach_units( + func( + strip_units(variable), + fill_value=strip_units(convert_units(fill_value, units)), + ), + units, + ) + actual = func(variable, fill_value=fill_value) + + assert extract_units(expected) == extract_units(actual) + xr.testing.assert_identical(expected, actual) + class TestDataArray: @pytest.mark.filterwarnings("error:::pint[.*]") From 48df9c551754f7029075ed46e2666e56c053d22f Mon Sep 17 00:00:00 2001 From: Keewis Date: Thu, 9 Jan 2020 00:14:33 +0100 Subject: [PATCH 29/38] all and any have been implemented by pint --- xarray/tests/test_units.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index 9899f26d543..b7ba0a81888 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -1335,12 +1335,8 @@ def cls(dims, data, *args, **kwargs): @pytest.mark.parametrize( "func", ( - pytest.param( - method("all"), marks=pytest.mark.xfail(reason="not implemented by pint") - ), - pytest.param( - method("any"), marks=pytest.mark.xfail(reason="not implemented by pint") - ), + method("all"), + method("any"), method("argmax"), method("argmin"), method("argsort"), From fa9a80b42a6b7a99efc0f97c5a5d0bb95e81d517 Mon Sep 17 00:00:00 2001 From: Keewis Date: Thu, 9 Jan 2020 00:25:21 +0100 Subject: [PATCH 30/38] remove a unnecessary call to np.array --- xarray/tests/test_units.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index b7ba0a81888..7d67d5be2fd 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -2030,7 +2030,7 @@ def test_pad_with_fill_value(self, unit, error, dtype): array = np.linspace(0, 5, 3 * 10).reshape(3, 10).astype(dtype) * unit_registry.m variable = xr.Variable(("x", "y"), array) - fill_value = np.array(-100) * unit + fill_value = -100 * unit func = method("pad_with_fill_value", x=(2, 3), y=(1, 4)) if error is not None: From 5771d4d88e262e6edc60f53e2b0cf92f6831f829 Mon Sep 17 00:00:00 2001 From: Keewis Date: Thu, 9 Jan 2020 01:07:45 +0100 Subject: [PATCH 31/38] mark the unrelated failures as xfail --- xarray/tests/test_units.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index 7d67d5be2fd..a53e7b1d7ca 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -391,6 +391,7 @@ def __repr__(self): return f"function_{self.name}" +@pytest.mark.xfail(reason="test bug: apply_ufunc should not be called that way") def test_apply_ufunc_dataarray(dtype): func = function( xr.apply_ufunc, np.mean, input_core_dims=[["x"]], kwargs={"axis": -1} @@ -406,6 +407,7 @@ def test_apply_ufunc_dataarray(dtype): assert_equal_with_units(expected, actual) +@pytest.mark.xfail(reason="test bug: apply_ufunc should not be called that way") def test_apply_ufunc_dataset(dtype): func = function( xr.apply_ufunc, np.mean, input_core_dims=[["x"]], kwargs={"axis": -1} @@ -4347,6 +4349,7 @@ def test_stacking_stacked(self, func, dtype): assert_equal_with_units(expected, actual) + @pytest.mark.xfail(reason="does not work with quantities yet") def test_to_stacked_array(self, dtype): labels = np.arange(5).astype(dtype) * unit_registry.s arrays = {name: np.linspace(0, 1, 10) * unit_registry.m for name in labels} From 6da9f4bc557f5837bff87d4ab9e25f4b806b7528 Mon Sep 17 00:00:00 2001 From: Keewis Date: Tue, 14 Jan 2020 17:40:16 +0100 Subject: [PATCH 32/38] remove a debug print --- xarray/tests/test_units.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index a53e7b1d7ca..646f18b503a 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -1552,7 +1552,7 @@ def test_missing_value_fillna(self, unit, error): if error is not None: with pytest.raises(error): - print(variable.fillna(value=fill_value)) + variable.fillna(value=fill_value) return From a30f78734e59e93fb98bb21c9b7682f0c8bda610 Mon Sep 17 00:00:00 2001 From: Keewis Date: Tue, 14 Jan 2020 17:53:17 +0100 Subject: [PATCH 33/38] document the need for force_ndarray --- xarray/tests/test_units.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index 646f18b503a..704a42d2ddd 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -13,6 +13,8 @@ DimensionalityError = pint.errors.DimensionalityError +# make sure scalars are converted to 0d arrays so quantities can +# always be treated like ndarrays unit_registry = pint.UnitRegistry(force_ndarray=True) Quantity = unit_registry.Quantity From 61dff86e79b14e6353222da8eaba3d5a31989fdd Mon Sep 17 00:00:00 2001 From: Keewis Date: Wed, 15 Jan 2020 01:48:05 +0100 Subject: [PATCH 34/38] make is_compatible a little bit more robust --- xarray/tests/test_units.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index 704a42d2ddd..4d872da79cb 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -32,11 +32,16 @@ def is_compatible(unit1, unit2): - return ( - isinstance(unit1, unit_registry.Unit) - and isinstance(unit2, unit_registry.Unit) - and unit1.dimensionality == unit2.dimensionality - ) + def _valid_type(unit1): + return isinstance(unit1, (int, unit_registry.Unit, Quantity)) + + if not _valid_type(unit1) or not _valid_type(unit2): + raise ValueError("can only pass int, Unit or Quantity instances") + + unit1 = unit_registry.dimensionless if isinstance(unit1, int) else unit1 + unit2 = unit_registry.dimensionless if isinstance(unit2, int) else unit2 + + return unit1.dimensionality == unit2.dimensionality def compatible_mappings(first, second): From f67c29ee9bc6957abfa90dc1ae25bdf8ca497e7b Mon Sep 17 00:00:00 2001 From: Keewis Date: Wed, 15 Jan 2020 11:48:16 +0100 Subject: [PATCH 35/38] clean up is_compatible --- xarray/tests/test_units.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index 4d872da79cb..2f2e2684429 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -32,16 +32,15 @@ def is_compatible(unit1, unit2): - def _valid_type(unit1): - return isinstance(unit1, (int, unit_registry.Unit, Quantity)) - - if not _valid_type(unit1) or not _valid_type(unit2): - raise ValueError("can only pass int, Unit or Quantity instances") + def dimensionality(obj): + if isinstance(obj, (unit_registry.Quantity, unit_registry.Unit)): + unit_like = obj + else: + unit_like = unit_registry.dimensionless - unit1 = unit_registry.dimensionless if isinstance(unit1, int) else unit1 - unit2 = unit_registry.dimensionless if isinstance(unit2, int) else unit2 + return unit_like.dimensionality - return unit1.dimensionality == unit2.dimensionality + return dimensionality(unit1) == dimensionality(unit2) def compatible_mappings(first, second): From b412d86f158cc41ed0de1b4d47f8c4a320f5574e Mon Sep 17 00:00:00 2001 From: Keewis Date: Wed, 15 Jan 2020 14:11:57 +0100 Subject: [PATCH 36/38] add docstrings explaining the use of the method and function classes --- xarray/tests/test_units.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index 2f2e2684429..6ba27a63d36 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -333,6 +333,11 @@ def merge_args(default_args, new_args): class method: + """ wrapper class to help with passing methods via parametrize + + This is works a bit similar to using `partial(Class.method, arg, kwarg)` + """ + def __init__(self, name, *args, **kwargs): self.name = name self.args = args @@ -351,7 +356,7 @@ def __call__(self, obj, *args, **kwargs): if not isinstance(obj, (xr.Variable, xr.DataArray, xr.Dataset)): numpy_func = getattr(np, self.name) func = partial(numpy_func, obj) - # remove typical xr args like "dim" + # remove typical xarray args like "dim" exclude_kwargs = ("dim", "dims") all_kwargs = { key: value @@ -368,6 +373,11 @@ def __repr__(self): class function: + """ wrapper class for numpy functions + + Same as method, but the name is used for referencing numpy functions + """ + def __init__(self, name_or_function, *args, function_label=None, **kwargs): if callable(name_or_function): self.name = ( From 5375c9d88460f81a2cc7460a64ef6388e48b1c2e Mon Sep 17 00:00:00 2001 From: Keewis Date: Wed, 15 Jan 2020 17:09:49 +0100 Subject: [PATCH 37/38] put the unit comparisons into a function --- xarray/tests/test_units.py | 38 ++++++++++++++++++++++---------------- 1 file changed, 22 insertions(+), 16 deletions(-) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index 6ba27a63d36..946f2db5274 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -278,6 +278,10 @@ def convert_units(obj, to): return new_obj +def assert_units_equal(a, b): + assert extract_units(a) == extract_units(b) + + def assert_equal_with_units(a, b): # works like xr.testing.assert_equal, but also explicitly checks units # so, it is more like assert_identical @@ -1384,6 +1388,7 @@ def test_aggregation(self, func, dtype): expected = attach_units(func(strip_units(variable)), units) actual = func(variable) + assert_units_equal(expected, actual) xr.testing.assert_identical(expected, actual) @pytest.mark.parametrize( @@ -1441,7 +1446,7 @@ def test_numpy_methods(self, func, unit, error, dtype): ) actual = func(variable, *args, **kwargs) - assert extract_units(expected) == extract_units(actual) + assert_units_equal(expected, actual) xr.testing.assert_allclose(expected, actual) @pytest.mark.parametrize( @@ -1501,7 +1506,7 @@ def test_raw_numpy_methods(self, func, unit, error, dtype): ) actual = func(variable, *args, **kwargs) - assert extract_units(expected) == extract_units(actual) + assert_units_equal(expected, actual) np.testing.assert_allclose(expected, actual) @pytest.mark.parametrize( @@ -1524,7 +1529,7 @@ def test_missing_value_detection(self, func): expected = func(strip_units(variable)) actual = func(variable) - assert extract_units(expected) == extract_units(actual) + assert_units_equal(expected, actual) xr.testing.assert_identical(expected, actual) @pytest.mark.parametrize( @@ -1580,7 +1585,7 @@ def test_missing_value_fillna(self, unit, error): ) actual = variable.fillna(value=fill_value) - assert extract_units(expected) == extract_units(actual) + assert_units_equal(expected, actual) xr.testing.assert_identical(expected, actual) @pytest.mark.parametrize( @@ -1687,6 +1692,7 @@ def test_isel(self, indices, dtype): ) actual = variable.isel(x=indices) + assert_units_equal(expected, actual) xr.testing.assert_identical(expected, actual) @pytest.mark.parametrize( @@ -1744,7 +1750,7 @@ def test_1d_math(self, func, unit, error, dtype): ) actual = func(variable, y) - assert extract_units(expected) == extract_units(actual) + assert_units_equal(expected, actual) xr.testing.assert_allclose(expected, actual) @pytest.mark.parametrize( @@ -1800,7 +1806,7 @@ def test_masking(self, func, unit, error, dtype): ) actual = func(variable, cond, other) - assert extract_units(expected) == extract_units(actual) + assert_units_equal(expected, actual) xr.testing.assert_identical(expected, actual) def test_squeeze(self, dtype): @@ -1814,7 +1820,7 @@ def test_squeeze(self, dtype): ) actual = variable.squeeze() - assert extract_units(expected) == extract_units(actual) + assert_units_equal(expected, actual) xr.testing.assert_identical(expected, actual) for dim in names: @@ -1823,7 +1829,7 @@ def test_squeeze(self, dtype): ) actual = variable.squeeze() - assert extract_units(expected) == extract_units(actual) + assert_units_equal(expected, actual) xr.testing.assert_identical(expected, actual) @pytest.mark.parametrize( @@ -1861,7 +1867,7 @@ def test_computation(self, func, dtype): actual = func(variable) - assert extract_units(expected) == extract_units(actual) + assert_units_equal(expected, actual) xr.testing.assert_identical(expected, actual) @pytest.mark.parametrize( @@ -1895,7 +1901,7 @@ def test_searchsorted(self, unit, error, dtype): actual = variable.searchsorted(value) - assert extract_units(expected) == extract_units(actual) + assert_units_equal(expected, actual) np.testing.assert_allclose(expected, actual) def test_stack(self, dtype): @@ -1907,7 +1913,7 @@ def test_stack(self, dtype): ) actual = variable.stack(z=("x", "y")) - assert extract_units(expected) == extract_units(actual) + assert_units_equal(expected, actual) xr.testing.assert_identical(expected, actual) def test_unstack(self, dtype): @@ -1919,7 +1925,7 @@ def test_unstack(self, dtype): ) actual = variable.unstack(z={"x": 3, "y": 10}) - assert extract_units(expected) == extract_units(actual) + assert_units_equal(expected, actual) xr.testing.assert_identical(expected, actual) @pytest.mark.xfail(reason="ignores units") @@ -1957,7 +1963,7 @@ def test_concat(self, unit, error, dtype): ) actual = variable.concat(other) - assert extract_units(expected) == extract_units(actual) + assert_units_equal(expected, actual) xr.testing.assert_identical(expected, actual) def test_set_dims(self, dtype): @@ -1970,7 +1976,7 @@ def test_set_dims(self, dtype): ) actual = variable.set_dims(dims) - assert extract_units(expected) == extract_units(actual) + assert_units_equal(expected, actual) xr.testing.assert_identical(expected, actual) def test_copy(self, dtype): @@ -1983,7 +1989,7 @@ def test_copy(self, dtype): ) actual = variable.copy(data=other) - assert extract_units(expected) == extract_units(actual) + assert_units_equal(expected, actual) xr.testing.assert_identical(expected, actual) @pytest.mark.parametrize( @@ -2067,7 +2073,7 @@ def test_pad_with_fill_value(self, unit, error, dtype): ) actual = func(variable, fill_value=fill_value) - assert extract_units(expected) == extract_units(actual) + assert_units_equal(expected, actual) xr.testing.assert_identical(expected, actual) From a852aef81fca113ed27e0b348ef90d7c861fa261 Mon Sep 17 00:00:00 2001 From: Keewis Date: Wed, 15 Jan 2020 17:27:50 +0100 Subject: [PATCH 38/38] actually squeeze the dimensions separately --- xarray/tests/test_units.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index 946f2db5274..2cb1550c088 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -1823,11 +1823,12 @@ def test_squeeze(self, dtype): assert_units_equal(expected, actual) xr.testing.assert_identical(expected, actual) - for dim in names: + names = tuple(name for name, size in zip(names, shape) if shape == 1) + for name in names: expected = attach_units( - strip_units(variable).squeeze(), extract_units(variable) + strip_units(variable).squeeze(dim=name), extract_units(variable) ) - actual = variable.squeeze() + actual = variable.squeeze(dim=name) assert_units_equal(expected, actual) xr.testing.assert_identical(expected, actual)