diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 0f79b648187..9a4601a776f 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -29,6 +29,8 @@ New Features for example, will retain the object. However, one cannot do operations that are not possible on the `ExtensionArray` then, such as broadcasting. By `Ilan Gold `_. +- :py:func:`testing.assert_allclose`/:py:func:`testing.assert_equal` now accept a new argument `check_dims="transpose"`, controlling whether a transposed array is considered equal. (:issue:`5733`, :pull:`8991`) + By `Ignacio Martinez Vazquez `_. - Added the option to avoid automatically creating 1D pandas indexes in :py:meth:`Dataset.expand_dims()`, by passing the new kwarg `create_index=False`. (:pull:`8960`) By `Tom Nicholas `_. diff --git a/xarray/testing/assertions.py b/xarray/testing/assertions.py index 018874c169e..69885868f83 100644 --- a/xarray/testing/assertions.py +++ b/xarray/testing/assertions.py @@ -95,6 +95,18 @@ def assert_isomorphic(a: DataTree, b: DataTree, from_root: bool = False): raise TypeError(f"{type(a)} not of type DataTree") +def maybe_transpose_dims(a, b, check_dim_order: bool): + """Helper for assert_equal/allclose/identical""" + __tracebackhide__ = True + if not isinstance(a, (Variable, DataArray, Dataset)): + return b + if not check_dim_order and set(a.dims) == set(b.dims): + # Ensure transpose won't fail if a dimension is missing + # If this is the case, the difference will be caught by the caller + return b.transpose(*a.dims) + return b + + @overload def assert_equal(a, b): ... @@ -104,7 +116,7 @@ def assert_equal(a: DataTree, b: DataTree, from_root: bool = True): ... @ensure_warnings -def assert_equal(a, b, from_root=True): +def assert_equal(a, b, from_root=True, check_dim_order: bool = True): """Like :py:func:`numpy.testing.assert_array_equal`, but for xarray objects. @@ -127,6 +139,8 @@ def assert_equal(a, b, from_root=True): Only used when comparing DataTree objects. Indicates whether or not to first traverse to the root of the trees before checking for isomorphism. If a & b have no parents then this has no effect. + check_dim_order : bool, optional, default is True + Whether dimensions must be in the same order. See Also -------- @@ -137,6 +151,7 @@ def assert_equal(a, b, from_root=True): assert ( type(a) == type(b) or isinstance(a, Coordinates) and isinstance(b, Coordinates) ) + b = maybe_transpose_dims(a, b, check_dim_order) if isinstance(a, (Variable, DataArray)): assert a.equals(b), formatting.diff_array_repr(a, b, "equals") elif isinstance(a, Dataset): @@ -182,6 +197,8 @@ def assert_identical(a, b, from_root=True): Only used when comparing DataTree objects. Indicates whether or not to first traverse to the root of the trees before checking for isomorphism. If a & b have no parents then this has no effect. + check_dim_order : bool, optional, default is True + Whether dimensions must be in the same order. See Also -------- @@ -213,7 +230,9 @@ def assert_identical(a, b, from_root=True): @ensure_warnings -def assert_allclose(a, b, rtol=1e-05, atol=1e-08, decode_bytes=True): +def assert_allclose( + a, b, rtol=1e-05, atol=1e-08, decode_bytes=True, check_dim_order: bool = True +): """Like :py:func:`numpy.testing.assert_allclose`, but for xarray objects. Raises an AssertionError if two objects are not equal up to desired @@ -233,6 +252,8 @@ def assert_allclose(a, b, rtol=1e-05, atol=1e-08, decode_bytes=True): Whether byte dtypes should be decoded to strings as UTF-8 or not. This is useful for testing serialization methods on Python 3 that return saved strings as bytes. + check_dim_order : bool, optional, default is True + Whether dimensions must be in the same order. See Also -------- @@ -240,16 +261,16 @@ def assert_allclose(a, b, rtol=1e-05, atol=1e-08, decode_bytes=True): """ __tracebackhide__ = True assert type(a) == type(b) + b = maybe_transpose_dims(a, b, check_dim_order) equiv = functools.partial( _data_allclose_or_equiv, rtol=rtol, atol=atol, decode_bytes=decode_bytes ) - equiv.__name__ = "allclose" + equiv.__name__ = "allclose" # type: ignore[attr-defined] def compat_variable(a, b): a = getattr(a, "variable", a) b = getattr(b, "variable", b) - return a.dims == b.dims and (a._data is b._data or equiv(a.data, b.data)) if isinstance(a, Variable): diff --git a/xarray/tests/test_assertions.py b/xarray/tests/test_assertions.py index f7e49a0f3de..aa0ea46f7db 100644 --- a/xarray/tests/test_assertions.py +++ b/xarray/tests/test_assertions.py @@ -57,6 +57,25 @@ def test_allclose_regression() -> None: def test_assert_allclose(obj1, obj2) -> None: with pytest.raises(AssertionError): xr.testing.assert_allclose(obj1, obj2) + with pytest.raises(AssertionError): + xr.testing.assert_allclose(obj1, obj2, check_dim_order=False) + + +@pytest.mark.parametrize("func", ["assert_equal", "assert_allclose"]) +def test_assert_allclose_equal_transpose(func) -> None: + """Transposed DataArray raises assertion unless check_dim_order=False.""" + obj1 = xr.DataArray([[0, 1, 2], [2, 3, 4]], dims=["a", "b"]) + obj2 = xr.DataArray([[0, 2], [1, 3], [2, 4]], dims=["b", "a"]) + with pytest.raises(AssertionError): + getattr(xr.testing, func)(obj1, obj2) + getattr(xr.testing, func)(obj1, obj2, check_dim_order=False) + ds1 = obj1.to_dataset(name="varname") + ds1["var2"] = obj1 + ds2 = obj1.to_dataset(name="varname") + ds2["var2"] = obj1.transpose() + with pytest.raises(AssertionError): + getattr(xr.testing, func)(ds1, ds2) + getattr(xr.testing, func)(ds1, ds2, check_dim_order=False) @pytest.mark.filterwarnings("error")