From 7bc21059a48f003099665ac08e66e2286416cc81 Mon Sep 17 00:00:00 2001 From: Keewis Date: Thu, 7 Nov 2019 20:51:49 +0100 Subject: [PATCH 01/21] add tests for replication functions --- xarray/tests/test_units.py | 79 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 76 insertions(+), 3 deletions(-) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index 9d14104bb50..24474d1be05 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -278,7 +278,7 @@ def __repr__(self): @pytest.mark.parametrize("func", (xr.zeros_like, xr.ones_like)) -def test_replication(func, dtype): +def test_replication_dataarray(func, dtype): array = np.linspace(0, 10, 20).astype(dtype) * unit_registry.s data_array = xr.DataArray(data=array, dims="x") @@ -289,8 +289,33 @@ def test_replication(func, dtype): assert_equal_with_units(expected, result) +@pytest.mark.parametrize("func", (xr.zeros_like, xr.ones_like)) +def test_replication_dataset(func, dtype): + array1 = np.linspace(0, 10, 20).astype(dtype) * unit_registry.s + array2 = np.linspace(5, 10, 10).astype(dtype) * unit_registry.Pa + x = np.arange(20).astype(dtype) * unit_registry.m + y = np.arange(10).astype(dtype) * unit_registry.m + z = y.to(unit_registry.mm) + + ds = xr.Dataset( + data_vars={"a": ("x", array1), "b": ("y", array2)}, + coords={"x": x, "y": y, "z": ("y", z)}, + ) + + numpy_func = getattr(np, func.__name__) + expected = ds.copy( + data={name: numpy_func(array.data) for name, array in ds.data_vars.items()} + ) + result = func(ds) + + assert_equal_with_units(expected, result) + + @pytest.mark.xfail( - reason="np.full_like on Variable strips the unit and pint does not allow mixed args" + reason=( + "pint is undecided on how `full_like` should work, so incorrect errors " + "may be thrown: hgrecco/pint#882" + ) ) @pytest.mark.parametrize( "unit,error", @@ -304,7 +329,7 @@ def test_replication(func, dtype): pytest.param(unit_registry.s, None, id="identical_unit"), ), ) -def test_replication_full_like(unit, error, dtype): +def test_replication_full_like_dataarray(unit, error, dtype): array = np.linspace(0, 5, 10) * unit_registry.s data_array = xr.DataArray(data=array, dims="x") @@ -319,6 +344,54 @@ def test_replication_full_like(unit, error, dtype): assert_equal_with_units(expected, result) +@pytest.mark.xfail( + reason=( + "pint is undecided on how `full_like` should work, so incorrect errors " + "may be thrown: hgrecco/pint#882" + ) +) +@pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.m, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.ms, None, id="compatible_unit"), + pytest.param(unit_registry.s, None, id="identical_unit"), + ), +) +def test_replication_full_like_dataset(unit, error, dtype): + array1 = np.linspace(0, 10, 20).astype(dtype) * unit_registry.s + array2 = np.linspace(5, 10, 10).astype(dtype) * unit_registry.Pa + x = np.arange(20).astype(dtype) * unit_registry.m + y = np.arange(10).astype(dtype) * unit_registry.m + z = y.to(unit_registry.mm) + + ds = xr.Dataset( + data_vars={"a": ("x", array1), "b": ("y", array2)}, + coords={"x": x, "y": y, "z": ("y", z)}, + ) + + fill_value = -1 * unit + if error is not None: + with pytest.raises(error): + xr.full_like(ds, fill_value=fill_value) + + return + + expected = ds.copy( + data={ + name: np.full_like(array, fill_value=fill_value) + for name, array in ds.data_vars.items() + } + ) + result = xr.full_like(ds, fill_value=fill_value) + + assert_equal_with_units(expected, result) + + class TestDataArray: @pytest.mark.filterwarnings("error:::pint[.*]") @pytest.mark.parametrize( From 007ad7b2cb373ef65f47272fc5ff2fa0cb30c9a7 Mon Sep 17 00:00:00 2001 From: Keewis Date: Thu, 7 Nov 2019 21:08:10 +0100 Subject: [PATCH 02/21] add tests for `xarray.dot` --- xarray/tests/test_units.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index 24474d1be05..6ce81301d13 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -392,6 +392,26 @@ def test_replication_full_like_dataset(unit, error, dtype): assert_equal_with_units(expected, result) +@pytest.mark.xfail(reason="pint does not implement `np.einsum`") +def test_dot_dataarray(dtype): + array1 = ( + np.linspace(0, 10, 5 * 10).reshape(5, 10).astype(dtype) + * unit_registry.m + / unit_registry.s + ) + array2 = ( + np.linspace(10, 20, 10 * 20).reshape(10, 20).astype(dtype) * unit_registry.s + ) + + arr1 = xr.DataArray(data=array1, dims=("x", "y")) + arr2 = xr.DataArray(data=array2, dims=("y", "z")) + + expected = array1.dot(array2) + result = xr.dot(arr1, arr2) + + assert_equal_with_units(expected, result) + + class TestDataArray: @pytest.mark.filterwarnings("error:::pint[.*]") @pytest.mark.parametrize( From c1407e2770c9c6bf7942b36f0f7b59893f976c62 Mon Sep 17 00:00:00 2001 From: Keewis Date: Thu, 7 Nov 2019 22:06:13 +0100 Subject: [PATCH 03/21] add tests for apply_ufunc --- xarray/tests/test_units.py | 37 +++++++++++++++++++++++++++++++++---- 1 file changed, 33 insertions(+), 4 deletions(-) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index 94a96261d86..b2bf3879c8e 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -266,17 +266,46 @@ def __repr__(self): class function: - def __init__(self, name): - self.name = name - self.func = getattr(np, name) + def __init__(self, name_or_function, *args, **kwargs): + if callable(name_or_function): + self.name = name_or_function.__name__ + self.func = name_or_function + else: + self.name = name_or_function + self.func = getattr(np, name_or_function) + if self.func is None: + raise AttributeError( + f"module 'numpy' has no attribute named '{self.name}'" + ) + + self.args = args + self.kwargs = kwargs def __call__(self, *args, **kwargs): - return self.func(*args, **kwargs) + all_args = list(self.args) + list(args) + all_kwargs = {**self.kwargs, **kwargs} + + return self.func(*all_args, **all_kwargs) def __repr__(self): return f"function_{self.name}" +def test_apply_ufunc_dataarray(dtype): + func = function( + xr.apply_ufunc, np.mean, input_core_dims=[["x"]], kwargs={"axis": -1} + ) + + array = np.linspace(0, 10, 20).astype(dtype) * unit_registry.m + x = np.arange(20) * unit_registry.s + data_array = xr.DataArray(data=array, dims="x", coords={"x": x}) + + expected = attach_units(func(strip_units(data_array)), extract_units(data_array)) + result = func(data_array) + + assert_equal_with_units(expected, result) + + @pytest.mark.parametrize("func", (xr.zeros_like, xr.ones_like)) def test_replication_dataarray(func, dtype): array = np.linspace(0, 10, 20).astype(dtype) * unit_registry.s From a21c87209d8e9368a5c895626ed026c10167cfa1 Mon Sep 17 00:00:00 2001 From: Keewis Date: Thu, 7 Nov 2019 22:59:17 +0100 Subject: [PATCH 04/21] explicitly set the test ids to repr --- xarray/tests/test_units.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index b2bf3879c8e..990d8adf333 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -357,6 +357,7 @@ def test_replication_dataset(func, dtype): pytest.param(unit_registry.ms, None, id="compatible_unit"), pytest.param(unit_registry.s, None, id="identical_unit"), ), + ids=repr, ) def test_replication_full_like_dataarray(unit, error, dtype): array = np.linspace(0, 5, 10) * unit_registry.s @@ -390,6 +391,7 @@ def test_replication_full_like_dataarray(unit, error, dtype): pytest.param(unit_registry.ms, None, id="compatible_unit"), pytest.param(unit_registry.s, None, id="identical_unit"), ), + ids=repr, ) def test_replication_full_like_dataset(unit, error, dtype): array1 = np.linspace(0, 10, 20).astype(dtype) * unit_registry.s From 9abb7298dc6071c50d054bae27b14545770b1237 Mon Sep 17 00:00:00 2001 From: Keewis Date: Thu, 7 Nov 2019 23:00:40 +0100 Subject: [PATCH 05/21] add tests for align --- xarray/tests/test_units.py | 94 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 94 insertions(+) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index 990d8adf333..af8d599b2e5 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -306,6 +306,100 @@ def test_apply_ufunc_dataarray(dtype): assert_equal_with_units(expected, result) +@pytest.mark.xfail(reason="pint does not implement `np.result_type`") +@pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.mm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ids=repr, +) +@pytest.mark.parametrize("fill_value", (np.float64(10), np.float64(np.nan))) +def test_align_dataarray(fill_value, unit, error, dtype): + array1 = np.linspace(0, 10, 2 * 5).reshape(2, 5).astype(dtype) * unit_registry.m + array2 = np.linspace(0, 8, 2 * 5).reshape(2, 5).astype(dtype) * unit_registry.m + x = np.arange(2) * unit_registry.m + y = np.arange(5) * unit_registry.m + z = np.arange(2, 7) * unit_registry.m + + data_array1 = xr.DataArray(data=array1, coords={"x": x, "y": y}, dims=("x", "y")) + data_array2 = xr.DataArray(data=array2, coords={"x": x, "y": z}, dims=("x", "y")) + + fill_value = fill_value * unit + if isinstance(fill_value, unit_registry.Quantity) and fill_value.check( + unit_registry.m + ): + fill_value = fill_value.to(unit_registry.m) + + func = function(xr.align, join="outer", fill_value=fill_value) + if error is not None: + with pytest.raises(error): + func(data_array1, data_array2) + + return + + stripped_kwargs = {key: strip_units(value) for key, value in func.kwargs.items()} + expected_a, expected_b = func( + strip_units(data_array1), strip_units(data_array2), **stripped_kwargs + ) + result_a, result_b = func(data_array1, data_array2) + + assert_equal_with_units(expected_a, result_a) + assert_equal_with_units(expected_b, result_b) + + +@pytest.mark.xfail(reason="pint does not implement `np.result_type`") +@pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.mm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ids=repr, +) +@pytest.mark.parametrize("fill_value", (np.float64(10), np.float64(np.nan))) +def test_align_dataset(fill_value, unit, error, dtype): + array1 = np.linspace(0, 10, 2 * 5).reshape(2, 5).astype(dtype) * unit_registry.m + array2 = np.linspace(0, 8, 2 * 5).reshape(2, 5).astype(dtype) * unit_registry.m + x = np.arange(2) * unit_registry.m + y = np.arange(5) * unit_registry.m + z = np.arange(2, 7) * unit_registry.m + + ds1 = xr.Dataset(data_vars={"a": (("x", "y"), array1)}, coords={"x": x, "y": y}) + ds2 = xr.Dataset(data_vars={"a": (("x", "y"), array2)}, coords={"x": x, "y": z}) + + fill_value = fill_value * unit + if isinstance(fill_value, unit_registry.Quantity) and fill_value.check( + unit_registry.m + ): + fill_value = fill_value.to(unit_registry.m) + + func = function(xr.align, join="outer", fill_value=fill_value) + if error is not None: + with pytest.raises(error): + func(ds1, ds2) + + return + + stripped_kwargs = {key: strip_units(value) for key, value in func.kwargs.items()} + expected_a, expected_b = func(strip_units(ds1), strip_units(ds2), **stripped_kwargs) + result_a, result_b = func(ds1, ds2) + + assert_equal_with_units(expected_a, result_a) + assert_equal_with_units(expected_b, result_b) + + @pytest.mark.parametrize("func", (xr.zeros_like, xr.ones_like)) def test_replication_dataarray(func, dtype): array = np.linspace(0, 10, 20).astype(dtype) * unit_registry.s From 4664630066c22b599f57e8097159bcbc71e8d746 Mon Sep 17 00:00:00 2001 From: Keewis Date: Fri, 8 Nov 2019 18:33:04 +0100 Subject: [PATCH 06/21] cover a bit more of align --- xarray/tests/test_units.py | 90 ++++++++++++++++++++++++++++---------- 1 file changed, 68 insertions(+), 22 deletions(-) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index af8d599b2e5..659d9f5f5e2 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -306,7 +306,9 @@ def test_apply_ufunc_dataarray(dtype): assert_equal_with_units(expected, result) -@pytest.mark.xfail(reason="pint does not implement `np.result_type`") +@pytest.mark.xfail( + reason="pint does not implement `np.result_type` and align strips units" +) @pytest.mark.parametrize( "unit,error", ( @@ -320,18 +322,36 @@ def test_apply_ufunc_dataarray(dtype): ), ids=repr, ) +@pytest.mark.parametrize("variant", ("data", "dims", "coords")) @pytest.mark.parametrize("fill_value", (np.float64(10), np.float64(np.nan))) -def test_align_dataarray(fill_value, unit, error, dtype): - array1 = np.linspace(0, 10, 2 * 5).reshape(2, 5).astype(dtype) * unit_registry.m - array2 = np.linspace(0, 8, 2 * 5).reshape(2, 5).astype(dtype) * unit_registry.m - x = np.arange(2) * unit_registry.m - y = np.arange(5) * unit_registry.m - z = np.arange(2, 7) * unit_registry.m - - data_array1 = xr.DataArray(data=array1, coords={"x": x, "y": y}, dims=("x", "y")) - data_array2 = xr.DataArray(data=array2, coords={"x": x, "y": z}, dims=("x", "y")) +def test_align_dataarray(fill_value, variant, unit, error, dtype): + original_unit = unit_registry.m + + variants = { + "data": (unit, original_unit, original_unit), + "dims": (original_unit, unit, original_unit), + "coords": (original_unit, original_unit, unit), + } + data_unit, dim_unit, coord_unit = variants.get(variant) + + array1 = np.linspace(0, 10, 2 * 5).reshape(2, 5).astype(dtype) * original_unit + array2 = np.linspace(0, 8, 2 * 5).reshape(2, 5).astype(dtype) * data_unit + x = np.arange(2) * original_unit + x_a1 = np.array([10, 5]) * original_unit + x_a2 = np.array([10, 5]) * coord_unit + + y1 = np.arange(5) * original_unit + y2 = np.arange(2, 7) * dim_unit + + data_array1 = xr.DataArray( + data=array1, coords={"x": x, "x_a": ("x", x_a1), "y": y1}, dims=("x", "y") + ) + data_array2 = xr.DataArray( + data=array2, coords={"x": x, "x_a": ("x", x_a2), "y": y2}, dims=("x", "y") + ) - fill_value = fill_value * unit + # FIXME: convert x_a2 and y2 in data_array2, too + fill_value = fill_value * data_unit if isinstance(fill_value, unit_registry.Quantity) and fill_value.check( unit_registry.m ): @@ -354,7 +374,9 @@ def test_align_dataarray(fill_value, unit, error, dtype): assert_equal_with_units(expected_b, result_b) -@pytest.mark.xfail(reason="pint does not implement `np.result_type`") +@pytest.mark.xfail( + reason="pint does not implement `np.result_type` and align strips units" +) @pytest.mark.parametrize( "unit,error", ( @@ -368,18 +390,39 @@ def test_align_dataarray(fill_value, unit, error, dtype): ), ids=repr, ) +@pytest.mark.parametrize("variant", ("data", "dims", "coords")) @pytest.mark.parametrize("fill_value", (np.float64(10), np.float64(np.nan))) -def test_align_dataset(fill_value, unit, error, dtype): - array1 = np.linspace(0, 10, 2 * 5).reshape(2, 5).astype(dtype) * unit_registry.m - array2 = np.linspace(0, 8, 2 * 5).reshape(2, 5).astype(dtype) * unit_registry.m - x = np.arange(2) * unit_registry.m - y = np.arange(5) * unit_registry.m - z = np.arange(2, 7) * unit_registry.m +def test_align_dataset(fill_value, unit, variant, error, dtype): + original_unit = unit_registry.m - ds1 = xr.Dataset(data_vars={"a": (("x", "y"), array1)}, coords={"x": x, "y": y}) - ds2 = xr.Dataset(data_vars={"a": (("x", "y"), array2)}, coords={"x": x, "y": z}) + variants = { + "data": (unit, original_unit, original_unit), + "dims": (original_unit, unit, original_unit), + "coords": (original_unit, original_unit, unit), + } + data_unit, dim_unit, coord_unit = variants.get(variant) - fill_value = fill_value * unit + array1 = np.linspace(0, 10, 2 * 5).reshape(2, 5).astype(dtype) * original_unit + array2 = np.linspace(0, 10, 2 * 5).reshape(2, 5).astype(dtype) * data_unit + + x = np.arange(2) * original_unit + x_a1 = np.array([10, 5]) * original_unit + x_a2 = np.array([10, 5]) * coord_unit + + y1 = np.arange(5) * original_unit + y2 = np.arange(2, 7) * dim_unit + + ds1 = xr.Dataset( + data_vars={"a": (("x", "y"), array1)}, + coords={"x": x, "x_a": ("x", x_a1), "y": y1}, + ) + ds2 = xr.Dataset( + data_vars={"a": (("x", "y"), array2)}, + coords={"x": x, "x_a": ("x", x_a2), "y": y2}, + ) + + # FIXME: convert x_a2 and y2 in ds2, too + fill_value = fill_value * data_unit if isinstance(fill_value, unit_registry.Quantity) and fill_value.check( unit_registry.m ): @@ -393,7 +436,10 @@ def test_align_dataset(fill_value, unit, error, dtype): return stripped_kwargs = {key: strip_units(value) for key, value in func.kwargs.items()} - expected_a, expected_b = func(strip_units(ds1), strip_units(ds2), **stripped_kwargs) + expected_a, expected_b = tuple( + attach_units(result, extract_units(ds1)) + for result in func(strip_units(ds1), strip_units(ds2), **stripped_kwargs) + ) result_a, result_b = func(ds1, ds2) assert_equal_with_units(expected_a, result_a) From 3ecdb356e20e388880150c61793a64ec0f19627d Mon Sep 17 00:00:00 2001 From: Keewis Date: Fri, 8 Nov 2019 19:16:14 +0100 Subject: [PATCH 07/21] add tests for broadcast --- xarray/tests/test_units.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index 659d9f5f5e2..70dfb071bff 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -446,6 +446,37 @@ def test_align_dataset(fill_value, unit, variant, error, dtype): assert_equal_with_units(expected_b, result_b) +def test_broadcast_dataarray(dtype): + array1 = np.linspace(0, 10, 2) * unit_registry.Pa + array2 = np.linspace(0, 10, 3) * unit_registry.Pa + + a = xr.DataArray(data=array1, dims="x") + b = xr.DataArray(data=array2, dims="y") + + expected_a, expected_b = tuple( + attach_units(elem, extract_units(a)) + for elem in xr.broadcast(strip_units(a), strip_units(b)) + ) + result_a, result_b = xr.broadcast(a, b) + + assert_equal_with_units(expected_a, result_a) + assert_equal_with_units(expected_b, result_b) + + +def test_broadcast_dataset(dtype): + array1 = np.linspace(0, 10, 2) * unit_registry.Pa + array2 = np.linspace(0, 10, 3) * unit_registry.Pa + + ds = xr.Dataset(data_vars={"a": ("x", array1), "b": ("y", array2)}) + + expected, = tuple( + attach_units(elem, extract_units(ds)) for elem in xr.broadcast(strip_units(ds)) + ) + result, = xr.broadcast(ds) + + assert_equal_with_units(expected, result) + + @pytest.mark.parametrize("func", (xr.zeros_like, xr.ones_like)) def test_replication_dataarray(func, dtype): array = np.linspace(0, 10, 20).astype(dtype) * unit_registry.s From cf75525d7008724399cf86e42016582c041d2c71 Mon Sep 17 00:00:00 2001 From: Keewis Date: Fri, 8 Nov 2019 19:26:31 +0100 Subject: [PATCH 08/21] black changed how tuple unpacking should look like --- xarray/tests/test_units.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index 70dfb071bff..59fa983afda 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -469,10 +469,10 @@ def test_broadcast_dataset(dtype): ds = xr.Dataset(data_vars={"a": ("x", array1), "b": ("y", array2)}) - expected, = tuple( + (expected,) = tuple( attach_units(elem, extract_units(ds)) for elem in xr.broadcast(strip_units(ds)) ) - result, = xr.broadcast(ds) + (result,) = xr.broadcast(ds) assert_equal_with_units(expected, result) From b79d9610d441ebf1fced78a9dc00bc32b44132cd Mon Sep 17 00:00:00 2001 From: Keewis Date: Fri, 8 Nov 2019 23:16:05 +0100 Subject: [PATCH 09/21] correct the xfail message for full_like tests --- xarray/tests/test_units.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index 59fa983afda..a11116733fb 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -514,7 +514,7 @@ def test_replication_dataset(func, dtype): @pytest.mark.xfail( reason=( "pint is undecided on how `full_like` should work, so incorrect errors " - "may be thrown: hgrecco/pint#882" + "may be expected: hgrecco/pint#882" ) ) @pytest.mark.parametrize( @@ -548,7 +548,7 @@ def test_replication_full_like_dataarray(unit, error, dtype): @pytest.mark.xfail( reason=( "pint is undecided on how `full_like` should work, so incorrect errors " - "may be thrown: hgrecco/pint#882" + "may be expected: hgrecco/pint#882" ) ) @pytest.mark.parametrize( From 6b6729e2ee4e1abd26b341e7efb8d22e17c4568f Mon Sep 17 00:00:00 2001 From: Keewis Date: Fri, 8 Nov 2019 23:39:30 +0100 Subject: [PATCH 10/21] add tests for where --- xarray/tests/test_units.py | 88 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 88 insertions(+) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index a11116733fb..d558e415779 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -594,6 +594,94 @@ def test_replication_full_like_dataset(unit, error, dtype): assert_equal_with_units(expected, result) +@pytest.mark.xfail(reason="`where` strips units") +@pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.mm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ids=repr, +) +@pytest.mark.parametrize("fill_value", (np.nan, 10.2)) +def test_where_dataarray(fill_value, unit, error, dtype): + array = np.linspace(0, 5, 10).astype(dtype) * unit_registry.m + + x = xr.DataArray(data=array, dims="x") + cond = x < 5 * unit_registry.m + # FIXME: this should work without wrapping in array() + fill_value = np.array(fill_value) * unit + + if error is not None: + with pytest.raises(error): + xr.where(cond, x, fill_value) + + return + + fill_value_ = ( + fill_value.to(unit_registry.m) + if isinstance(fill_value, unit_registry.Quantity) + and fill_value.check(unit_registry.m) + else fill_value + ) + expected = attach_units( + xr.where(cond, strip_units(x), strip_units(fill_value_)), extract_units(x) + ) + result = xr.where(cond, x, fill_value) + + assert_equal_with_units(expected, result) + + +@pytest.mark.xfail(reason="`where` strips units") +@pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.mm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ids=repr, +) +@pytest.mark.parametrize("fill_value", (np.nan, 10.2)) +def test_where_dataset(fill_value, unit, error, dtype): + array1 = np.linspace(0, 5, 10).astype(dtype) * unit_registry.m + array2 = np.linspace(-5, 0, 10).astype(dtype) * unit_registry.m + x = np.arange(10) * unit_registry.s + + ds = xr.Dataset(data_vars={"a": ("x", array1), "b": ("x", array2)}, coords={"x": x}) + cond = ds.x < 5 * unit_registry.s + # FIXME: this should work without wrapping in array() + fill_value = np.array(fill_value) * unit + + if error is not None: + with pytest.raises(error): + xr.where(cond, ds, fill_value) + + return + + fill_value_ = ( + fill_value.to(unit_registry.m) + if isinstance(fill_value, unit_registry.Quantity) + and fill_value.check(unit_registry.m) + else fill_value + ) + expected = attach_units( + xr.where(cond, strip_units(ds), strip_units(fill_value_)), extract_units(ds) + ) + result = xr.where(cond, ds, fill_value) + + assert_equal_with_units(expected, result) + + @pytest.mark.xfail(reason="pint does not implement `np.einsum`") def test_dot_dataarray(dtype): array1 = ( From 93b003c4ecb866096719adf9166e2bb9e49fdda4 Mon Sep 17 00:00:00 2001 From: Keewis Date: Fri, 8 Nov 2019 23:57:46 +0100 Subject: [PATCH 11/21] add tests for concat --- xarray/tests/test_units.py | 98 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 98 insertions(+) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index d558e415779..5e0dfc64567 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -477,6 +477,104 @@ def test_broadcast_dataset(dtype): assert_equal_with_units(expected, result) +@pytest.mark.xfail(reason="`concat` strips units") +@pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.mm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ids=repr, +) +@pytest.mark.parametrize( + "variant", + ( + "data", + pytest.param("dims", marks=pytest.mark.xfail(reason="indexes strip units")), + ), +) +def test_concat_dataarray(variant, unit, error, dtype): + original_unit = unit_registry.m + + variants = {"data": (unit, original_unit), "dims": (original_unit, unit)} + data_unit, dims_unit = variants.get(variant) + + array1 = np.linspace(0, 5, 10).astype(dtype) * unit_registry.m + array2 = np.linspace(-5, 0, 5).astype(dtype) * data_unit + x1 = np.arange(5, 15) * original_unit + x2 = np.arange(5) * dims_unit + + arr1 = xr.DataArray(data=array1, coords={"x": x1}, dims="x") + arr2 = xr.DataArray(data=array2, coords={"x": x2}, dims="x") + + if error is not None: + with pytest.raises(error): + xr.concat([arr1, arr2], dim="x") + + return + + expected = attach_units( + xr.concat([strip_units(arr1), strip_units(arr2)], dim="x"), extract_units(arr1) + ) + result = xr.concat([arr1, arr2], dim="x") + + assert_equal_with_units(expected, result) + + +@pytest.mark.xfail(False, reason="`concat` strips units") +@pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.mm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ids=repr, +) +@pytest.mark.parametrize( + "variant", + ( + "data", + pytest.param("dims", marks=pytest.mark.xfail(reason="indexes strip units")), + ), +) +def test_concat_dataset(variant, unit, error, dtype): + original_unit = unit_registry.m + + variants = {"data": (unit, original_unit), "dims": (original_unit, unit)} + data_unit, dims_unit = variants.get(variant) + + array1 = np.linspace(0, 5, 10).astype(dtype) * unit_registry.m + array2 = np.linspace(-5, 0, 5).astype(dtype) * data_unit + x1 = np.arange(5, 15) * original_unit + x2 = np.arange(5) * dims_unit + + ds1 = xr.Dataset(data_vars={"a": ("x", array1)}, coords={"x": x1}) + ds2 = xr.Dataset(data_vars={"a": ("x", array2)}, coords={"x": x2}) + + if error is not None: + with pytest.raises(error): + xr.concat([ds1, ds2], dim="x") + + return + + expected = attach_units( + xr.concat([strip_units(ds1), strip_units(ds2)], dim="x"), extract_units(ds1) + ) + result = xr.concat([ds1, ds2], dim="x") + + assert_equal_with_units(expected, result) + + @pytest.mark.parametrize("func", (xr.zeros_like, xr.ones_like)) def test_replication_dataarray(func, dtype): array = np.linspace(0, 10, 20).astype(dtype) * unit_registry.s From f8351ec60bbf5afa9ca9aa05d882393c399e948d Mon Sep 17 00:00:00 2001 From: Keewis Date: Sat, 9 Nov 2019 00:16:50 +0100 Subject: [PATCH 12/21] add tests for combine_by_coords --- xarray/tests/test_units.py | 69 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index 5e0dfc64567..83b42b66cc6 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -477,6 +477,75 @@ def test_broadcast_dataset(dtype): assert_equal_with_units(expected, result) +@pytest.mark.xfail(reason="`combine_by_coords` strips units") +@pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.mm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ids=repr, +) +@pytest.mark.parametrize( + "variant", + ( + "data", + pytest.param("dims", marks=pytest.mark.xfail(reason="indexes strip units")), + "coords", + ), +) +def test_combine_by_coords_dataset(variant, unit, error, dtype): + original_unit = unit_registry.m + + variants = { + "data": (unit, original_unit, original_unit), + "dims": (original_unit, unit, original_unit), + "coords": (original_unit, original_unit, unit), + } + data_unit, dim_unit, coord_unit = variants.get(variant) + + array1 = np.zeros(shape=(2, 3), dtype=dtype) * original_unit + array2 = np.zeros(shape=(2, 3), dtype=dtype) * original_unit + x = np.arange(1, 4) * 10 * original_unit + y = np.arange(2) * original_unit + z = np.arange(3) * original_unit + + other_array1 = np.ones_like(array1) * data_unit + other_array2 = np.ones_like(array2) * data_unit + other_x = np.arange(1, 4) * 10 * dim_unit + other_y = np.arange(2, 4) * dim_unit + other_z = np.arange(3, 6) * coord_unit + + ds = xr.Dataset( + data_vars={"a": (("y", "x"), array1), "b": (("y", "x"), array2)}, + coords={"x": x, "y": y, "z": ("x", z)}, + ) + other = xr.Dataset( + data_vars={"a": (("y", "x"), other_array1), "b": (("y", "x"), other_array2)}, + coords={"x": other_x, "y": other_y, "z": ("x", other_z)}, + ) + + if error is not None: + with pytest.raises(error): + xr.combine_by_coords([ds, other]) + + return + + units = extract_units(ds) + # FIXME: convert other to `units` before `strip_units` + expected = attach_units( + xr.combine_by_coords([strip_units(ds), strip_units(other)]), units + ) + result = xr.combine_by_coords([ds, other]) + + assert_equal_with_units(expected, result) + + @pytest.mark.xfail(reason="`concat` strips units") @pytest.mark.parametrize( "unit,error", From eb8fe4ebc4595053b86ef57de079acecca87fff5 Mon Sep 17 00:00:00 2001 From: Keewis Date: Sat, 9 Nov 2019 13:58:04 +0100 Subject: [PATCH 13/21] fix a bug in convert_units --- xarray/tests/test_units.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index 2c70ac0738c..c2876d7c190 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -222,7 +222,9 @@ def convert_units(obj, to): if name != obj.name } - new_obj = xr.DataArray(name=name, data=data, coords=coords, attrs=obj.attrs) + new_obj = xr.DataArray( + name=name, data=data, coords=coords, attrs=obj.attrs, dims=obj.dims + ) elif isinstance(obj, unit_registry.Quantity): units = to.get(None) new_obj = obj.to(units) if units is not None else obj From f9f727e079d2a2962392d4188b9e2fe346109f10 Mon Sep 17 00:00:00 2001 From: Keewis Date: Sat, 9 Nov 2019 14:12:29 +0100 Subject: [PATCH 14/21] convert the align results to the same units --- xarray/tests/test_units.py | 68 +++++++++++++++++++++++++++----------- 1 file changed, 48 insertions(+), 20 deletions(-) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index c2876d7c190..e00fc0b9c56 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -365,7 +365,14 @@ def test_apply_ufunc_dataarray(dtype): ), ids=repr, ) -@pytest.mark.parametrize("variant", ("data", "dims", "coords")) +@pytest.mark.parametrize( + "variant", + ( + "data", + pytest.param("dims", marks=pytest.mark.xfail(reason="indexes strip units")), + "coords", + ), +) @pytest.mark.parametrize("fill_value", (np.float64(10), np.float64(np.nan))) def test_align_dataarray(fill_value, variant, unit, error, dtype): original_unit = unit_registry.m @@ -393,13 +400,7 @@ def test_align_dataarray(fill_value, variant, unit, error, dtype): data=array2, coords={"x": x, "x_a": ("x", x_a2), "y": y2}, dims=("x", "y") ) - # FIXME: convert x_a2 and y2 in data_array2, too fill_value = fill_value * data_unit - if isinstance(fill_value, unit_registry.Quantity) and fill_value.check( - unit_registry.m - ): - fill_value = fill_value.to(unit_registry.m) - func = function(xr.align, join="outer", fill_value=fill_value) if error is not None: with pytest.raises(error): @@ -407,9 +408,24 @@ def test_align_dataarray(fill_value, variant, unit, error, dtype): return - stripped_kwargs = {key: strip_units(value) for key, value in func.kwargs.items()} - expected_a, expected_b = func( - strip_units(data_array1), strip_units(data_array2), **stripped_kwargs + stripped_kwargs = { + key: strip_units( + convert_units(value, {None: original_unit}) + if isinstance(value, unit_registry.Quantity) + else value + ) + for key, value in func.kwargs.items() + } + units = extract_units(data_array1) + # FIXME: should the expected_b have the same units as data_array1 + # or data_array2? + expected_a, expected_b = tuple( + attach_units(elem, units) + for elem in func( + strip_units(data_array1), + strip_units(convert_units(data_array2, units)), + **stripped_kwargs, + ) ) result_a, result_b = func(data_array1, data_array2) @@ -433,7 +449,14 @@ def test_align_dataarray(fill_value, variant, unit, error, dtype): ), ids=repr, ) -@pytest.mark.parametrize("variant", ("data", "dims", "coords")) +@pytest.mark.parametrize( + "variant", + ( + "data", + pytest.param("dims", marks=pytest.mark.xfail(reason="indexes strip units")), + "coords", + ), +) @pytest.mark.parametrize("fill_value", (np.float64(10), np.float64(np.nan))) def test_align_dataset(fill_value, unit, variant, error, dtype): original_unit = unit_registry.m @@ -464,13 +487,7 @@ def test_align_dataset(fill_value, unit, variant, error, dtype): coords={"x": x, "x_a": ("x", x_a2), "y": y2}, ) - # FIXME: convert x_a2 and y2 in ds2, too fill_value = fill_value * data_unit - if isinstance(fill_value, unit_registry.Quantity) and fill_value.check( - unit_registry.m - ): - fill_value = fill_value.to(unit_registry.m) - func = function(xr.align, join="outer", fill_value=fill_value) if error is not None: with pytest.raises(error): @@ -478,10 +495,21 @@ def test_align_dataset(fill_value, unit, variant, error, dtype): return - stripped_kwargs = {key: strip_units(value) for key, value in func.kwargs.items()} + stripped_kwargs = { + key: strip_units( + convert_units(value, {None: original_unit}) + if isinstance(value, unit_registry.Quantity) + else value + ) + for key, value in func.kwargs.items() + } + units = extract_units(ds1) + # FIXME: should the expected_b have the same units as ds1 or ds2? expected_a, expected_b = tuple( - attach_units(result, extract_units(ds1)) - for result in func(strip_units(ds1), strip_units(ds2), **stripped_kwargs) + attach_units(elem, units) + for elem in func( + strip_units(ds1), strip_units(convert_units(ds2, units)), **stripped_kwargs + ) ) result_a, result_b = func(ds1, ds2) From 3d0dfb131caa12c4451f418a887f0a25d940a002 Mon Sep 17 00:00:00 2001 From: Keewis Date: Sat, 9 Nov 2019 14:43:24 +0100 Subject: [PATCH 15/21] rename the combine_by_coords test --- xarray/tests/test_units.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index e00fc0b9c56..29afd8a1fb0 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -570,7 +570,7 @@ def test_broadcast_dataset(dtype): "coords", ), ) -def test_combine_by_coords_dataset(variant, unit, error, dtype): +def test_combine_by_coords(variant, unit, error, dtype): original_unit = unit_registry.m variants = { From 2e426a3cd464d31b88110935833fa525ed642c4c Mon Sep 17 00:00:00 2001 From: Keewis Date: Sat, 9 Nov 2019 14:44:24 +0100 Subject: [PATCH 16/21] convert the units for expected in combine_by_coords --- xarray/tests/test_units.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index 29afd8a1fb0..be9f699bacb 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -608,9 +608,11 @@ def test_combine_by_coords(variant, unit, error, dtype): return units = extract_units(ds) - # FIXME: convert other to `units` before `strip_units` expected = attach_units( - xr.combine_by_coords([strip_units(ds), strip_units(other)]), units + xr.combine_by_coords( + [strip_units(ds), strip_units(convert_units(other, units))] + ), + units, ) result = xr.combine_by_coords([ds, other]) From 341ffbc5063f47fb3f92ce4f3e546d0a65b71aed Mon Sep 17 00:00:00 2001 From: Keewis Date: Sat, 9 Nov 2019 14:45:24 +0100 Subject: [PATCH 17/21] add tests for combine_nested --- xarray/tests/test_units.py | 100 +++++++++++++++++++++++++++++++++++++ 1 file changed, 100 insertions(+) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index be9f699bacb..f4d45fe5bf0 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -619,6 +619,106 @@ def test_combine_by_coords(variant, unit, error, dtype): assert_equal_with_units(expected, result) +@pytest.mark.xfail(reason="blocked by `where`") +@pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.mm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ids=repr, +) +@pytest.mark.parametrize( + "variant", + ( + "data", + pytest.param("dims", marks=pytest.mark.xfail(reason="indexes strip units")), + "coords", + ), +) +def test_combine_nested(variant, unit, error, dtype): + original_unit = unit_registry.m + + variants = { + "data": (unit, original_unit, original_unit), + "dims": (original_unit, unit, original_unit), + "coords": (original_unit, original_unit, unit), + } + data_unit, dim_unit, coord_unit = variants.get(variant) + + array1 = np.zeros(shape=(2, 3), dtype=dtype) * original_unit + array2 = np.zeros(shape=(2, 3), dtype=dtype) * original_unit + + x = np.arange(1, 4) * 10 * original_unit + y = np.arange(2) * original_unit + z = np.arange(3) * original_unit + + ds1 = xr.Dataset( + data_vars={"a": (("y", "x"), array1), "b": (("y", "x"), array2)}, + coords={"x": x, "y": y, "z": ("x", z)}, + ) + ds2 = xr.Dataset( + data_vars={ + "a": (("y", "x"), np.ones_like(array1) * data_unit), + "b": (("y", "x"), np.ones_like(array2) * data_unit), + }, + coords={ + "x": np.arange(3) * dim_unit, + "y": np.arange(2, 4) * dim_unit, + "z": ("x", np.arange(-3, 0) * coord_unit), + }, + ) + ds3 = xr.Dataset( + data_vars={ + "a": (("y", "x"), np.zeros_like(array1) * np.nan * data_unit), + "b": (("y", "x"), np.zeros_like(array2) * np.nan * data_unit), + }, + coords={ + "x": np.arange(3, 6) * dim_unit, + "y": np.arange(4, 6) * dim_unit, + "z": ("x", np.arange(3, 6) * coord_unit), + }, + ) + ds4 = xr.Dataset( + data_vars={ + "a": (("y", "x"), -1 * np.ones_like(array1) * data_unit), + "b": (("y", "x"), -1 * np.ones_like(array2) * data_unit), + }, + coords={ + "x": np.arange(6, 9) * dim_unit, + "y": np.arange(6, 8) * dim_unit, + "z": ("x", np.arange(6, 9) * coord_unit), + }, + ) + + func = function(xr.combine_nested, concat_dim=["x", "y"]) + if error is not None: + with pytest.raises(error): + func([[ds1, ds2], [ds3, ds4]]) + + return + + units = extract_units(ds1) + convert_and_strip = lambda ds: strip_units(convert_units(ds, units)) + expected = attach_units( + func( + [ + [strip_units(ds1), convert_and_strip(ds2)], + [convert_and_strip(ds3), convert_and_strip(ds4)], + ] + ), + units, + ) + result = func([[ds1, ds2], [ds3, ds4]]) + + assert_equal_with_units(expected, result) + + @pytest.mark.xfail(reason="`concat` strips units") @pytest.mark.parametrize( "unit,error", From a474203ef8aa6127867f28db6f15dfea58279b83 Mon Sep 17 00:00:00 2001 From: Keewis Date: Sat, 9 Nov 2019 15:08:23 +0100 Subject: [PATCH 18/21] add tests for merge with datasets --- xarray/tests/test_units.py | 104 ++++++++++++++++++++++++++++++++++++- 1 file changed, 103 insertions(+), 1 deletion(-) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index f4d45fe5bf0..f489beff8f9 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -768,7 +768,7 @@ def test_concat_dataarray(variant, unit, error, dtype): assert_equal_with_units(expected, result) -@pytest.mark.xfail(False, reason="`concat` strips units") +@pytest.mark.xfail(reason="`concat` strips units") @pytest.mark.parametrize( "unit,error", ( @@ -817,6 +817,108 @@ def test_concat_dataset(variant, unit, error, dtype): assert_equal_with_units(expected, result) +@pytest.mark.xfail(reason="blocked by `where`") +@pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.mm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ids=repr, +) +@pytest.mark.parametrize( + "variant", + ( + "data", + pytest.param("dims", marks=pytest.mark.xfail(reason="indexes strip units")), + "coords", + ), +) +def test_merge_dataset(variant, unit, error, dtype): + original_unit = unit_registry.m + + variants = { + "data": (unit, original_unit, original_unit), + "dims": (original_unit, unit, original_unit), + "coords": (original_unit, original_unit, unit), + } + data_unit, dim_unit, coord_unit = variants.get(variant) + + array1 = np.zeros(shape=(2, 3), dtype=dtype) * original_unit + array2 = np.zeros(shape=(2, 3), dtype=dtype) * original_unit + + x = np.arange(11, 14) * original_unit + y = np.arange(2) * original_unit + z = np.arange(3) * original_unit + + ds1 = xr.Dataset( + data_vars={"a": (("y", "x"), array1), "b": (("y", "x"), array2)}, + coords={"x": x, "y": y, "z": ("x", z)}, + ) + ds2 = xr.Dataset( + data_vars={ + "a": (("y", "x"), np.ones_like(array1) * data_unit), + "b": (("y", "x"), np.ones_like(array2) * data_unit), + }, + coords={ + "x": np.arange(3) * dim_unit, + "y": np.arange(2, 4) * dim_unit, + "z": ("x", np.arange(-3, 0) * coord_unit), + }, + ) + ds3 = xr.Dataset( + data_vars={ + "a": (("y", "x"), np.zeros_like(array1) * np.nan * data_unit), + "b": (("y", "x"), np.zeros_like(array2) * np.nan * data_unit), + }, + coords={ + "x": np.arange(3, 6) * dim_unit, + "y": np.arange(4, 6) * dim_unit, + "z": ("x", np.arange(3, 6) * coord_unit), + }, + ) + ds4 = xr.Dataset( + data_vars={ + "a": (("y", "x"), -1 * np.ones_like(array1) * data_unit), + "b": (("y", "x"), -1 * np.ones_like(array2) * data_unit), + }, + coords={ + "x": np.arange(6, 9) * dim_unit, + "y": np.arange(6, 8) * dim_unit, + "z": ("x", np.arange(6, 9) * coord_unit), + }, + ) + + func = function(xr.merge) + if error is not None: + with pytest.raises(error): + func([ds1, ds2, ds3, ds4]) + + return + + units = extract_units(ds1) + convert_and_strip = lambda ds: strip_units(convert_units(ds, units)) + expected = attach_units( + func( + [ + strip_units(ds1), + convert_and_strip(ds2), + convert_and_strip(ds3), + convert_and_strip(ds4), + ] + ), + units, + ) + result = func([ds1, ds2, ds3, ds4]) + + assert_equal_with_units(expected, result) + + @pytest.mark.parametrize("func", (xr.zeros_like, xr.ones_like)) def test_replication_dataarray(func, dtype): array = np.linspace(0, 10, 20).astype(dtype) * unit_registry.s From 61627f051fa2d6928eb9b76d7258c1b72a2a58c8 Mon Sep 17 00:00:00 2001 From: Keewis Date: Sun, 10 Nov 2019 02:38:26 +0100 Subject: [PATCH 19/21] only use three datasets for merging --- xarray/tests/test_units.py | 25 +++---------------------- 1 file changed, 3 insertions(+), 22 deletions(-) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index f489beff8f9..0cff8229733 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -882,39 +882,20 @@ def test_merge_dataset(variant, unit, error, dtype): "z": ("x", np.arange(3, 6) * coord_unit), }, ) - ds4 = xr.Dataset( - data_vars={ - "a": (("y", "x"), -1 * np.ones_like(array1) * data_unit), - "b": (("y", "x"), -1 * np.ones_like(array2) * data_unit), - }, - coords={ - "x": np.arange(6, 9) * dim_unit, - "y": np.arange(6, 8) * dim_unit, - "z": ("x", np.arange(6, 9) * coord_unit), - }, - ) func = function(xr.merge) if error is not None: with pytest.raises(error): - func([ds1, ds2, ds3, ds4]) + func([ds1, ds2, ds3]) return units = extract_units(ds1) convert_and_strip = lambda ds: strip_units(convert_units(ds, units)) expected = attach_units( - func( - [ - strip_units(ds1), - convert_and_strip(ds2), - convert_and_strip(ds3), - convert_and_strip(ds4), - ] - ), - units, + func([strip_units(ds1), convert_and_strip(ds2), convert_and_strip(ds3)]), units ) - result = func([ds1, ds2, ds3, ds4]) + result = func([ds1, ds2, ds3]) assert_equal_with_units(expected, result) From d989ae825993ce4e3affb3de268df73b6dd6fc88 Mon Sep 17 00:00:00 2001 From: Keewis Date: Sun, 10 Nov 2019 02:40:57 +0100 Subject: [PATCH 20/21] add tests for merge with dataarrays --- xarray/tests/test_units.py | 90 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 90 insertions(+) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index 0cff8229733..b19ce90967b 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -817,6 +817,96 @@ def test_concat_dataset(variant, unit, error, dtype): assert_equal_with_units(expected, result) +@pytest.mark.xfail(reason="blocked by `where`") +@pytest.mark.parametrize( + "unit,error", + ( + pytest.param(1, DimensionalityError, id="no_unit"), + pytest.param( + unit_registry.dimensionless, DimensionalityError, id="dimensionless" + ), + pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), + pytest.param(unit_registry.mm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), + ), + ids=repr, +) +@pytest.mark.parametrize( + "variant", + ( + "data", + pytest.param("dims", marks=pytest.mark.xfail(reason="indexes strip units")), + "coords", + ), +) +def test_merge_dataarray(variant, unit, error, dtype): + original_unit = unit_registry.m + + variants = { + "data": (unit, original_unit, original_unit), + "dims": (original_unit, unit, original_unit), + "coords": (original_unit, original_unit, unit), + } + data_unit, dim_unit, coord_unit = variants.get(variant) + + array1 = np.linspace(0, 1, 2 * 3).reshape(2, 3).astype(dtype) * original_unit + array2 = np.linspace(1, 2, 2 * 4).reshape(2, 4).astype(dtype) * data_unit + array3 = np.linspace(0, 2, 3 * 4).reshape(3, 4).astype(dtype) * data_unit + + x = np.arange(2) * original_unit + y = np.arange(3) * original_unit + z = np.arange(4) * original_unit + u = np.linspace(10, 20, 2) * original_unit + v = np.linspace(10, 20, 3) * original_unit + w = np.linspace(10, 20, 4) * original_unit + + arr1 = xr.DataArray( + name="a", + data=array1, + coords={"x": x, "y": y, "u": ("x", u), "v": ("y", v)}, + dims=("x", "y"), + ) + arr2 = xr.DataArray( + name="b", + data=array2, + coords={ + "x": np.arange(2, 4) * dim_unit, + "z": z, + "u": ("x", np.linspace(20, 30, 2) * coord_unit), + "w": ("z", w), + }, + dims=("x", "z"), + ) + arr3 = xr.DataArray( + name="c", + data=array3, + coords={ + "y": np.arange(3, 6) * dim_unit, + "z": np.arange(4, 8) * dim_unit, + "v": ("y", np.linspace(10, 20, 3) * coord_unit), + "w": ("z", np.linspace(10, 20, 4) * coord_unit), + }, + dims=("y", "z"), + ) + + func = function(xr.merge) + if error is not None: + with pytest.raises(error): + func([arr1, arr2, arr3]) + + return + + units = {name: original_unit for name in list("abcuvwxyz")} + convert_and_strip = lambda arr: strip_units(convert_units(arr, units)) + expected = attach_units( + func([strip_units(arr1), convert_and_strip(arr2), convert_and_strip(arr3)]), + units, + ) + result = func([arr1, arr2, arr3]) + + assert_equal_with_units(expected, result) + + @pytest.mark.xfail(reason="blocked by `where`") @pytest.mark.parametrize( "unit,error", From c1d8e9207a238a0535e8a96c90bc24d3611614ef Mon Sep 17 00:00:00 2001 From: Keewis Date: Sun, 10 Nov 2019 02:42:15 +0100 Subject: [PATCH 21/21] update whats-new.rst --- doc/whats-new.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index d2a4b32a71f..c5c47801404 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -104,7 +104,7 @@ Internal Changes ~~~~~~~~~~~~~~~~ - Added integration tests against `pint `_. - (:pull:`3238`, :pull:`3447`) by `Justus Magin `_. + (:pull:`3238`, :pull:`3447`, :pull:`3493`) by `Justus Magin `_. .. note::