Skip to content

Commit

Permalink
Add argument check_dims to assert_allclose to allow transposed inputs (
Browse files Browse the repository at this point in the history
…#5733) (#8991)

* Add argument check_dims to assert_allclose to allow transposed inputs

* Update whats-new.rst

* Add `check_dims` argument to assert_equal and assert_identical + tests

* Assert that dimensions match before transposing or comparing values

* Add docstring for check_dims to assert_equal and assert_identical

* Update doc/whats-new.rst

Co-authored-by: Tom Nicholas <[email protected]>

* Undo fat finger

Co-authored-by: Tom Nicholas <[email protected]>

* Add attribution to whats-new.rst

* Replace check_dims with bool argument check_dim_order, rename align_dims to maybe_transpose_dims

* Remove left-over half-made test

* Remove check_dim_order argument from assert_identical

* assert_allclose/equal: emit full diff if dimensions don't match

* Rename check_dim_order test, test Dataset with different dim orders

* Update whats-new.rst

* Hide maybe_transpose_dims from Pytest traceback

Co-authored-by: Maximilian Roos <[email protected]>

* Ignore mypy error due to missing functools.partial.__name__

---------

Co-authored-by: Tom Nicholas <[email protected]>
Co-authored-by: Maximilian Roos <[email protected]>
  • Loading branch information
3 people authored and andersy005 committed May 10, 2024
1 parent f151a46 commit c4c2fbd
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 4 deletions.
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ New Features
for example, will retain the object. However, one cannot do operations that are not possible on the `ExtensionArray`
then, such as broadcasting.
By `Ilan Gold <https://github.com/ilan-gold>`_.
- :py:func:`testing.assert_allclose`/:py:func:`testing.assert_equal` now accept a new argument `check_dims="transpose"`, controlling whether a transposed array is considered equal. (:issue:`5733`, :pull:`8991`)
By `Ignacio Martinez Vazquez <https://github.com/ignamv>`_.
- Added the option to avoid automatically creating 1D pandas indexes in :py:meth:`Dataset.expand_dims()`, by passing the new kwarg
`create_index=False`. (:pull:`8960`)
By `Tom Nicholas <https://github.com/TomNicholas>`_.
Expand Down
29 changes: 25 additions & 4 deletions xarray/testing/assertions.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,18 @@ def assert_isomorphic(a: DataTree, b: DataTree, from_root: bool = False):
raise TypeError(f"{type(a)} not of type DataTree")


def maybe_transpose_dims(a, b, check_dim_order: bool):
"""Helper for assert_equal/allclose/identical"""
__tracebackhide__ = True
if not isinstance(a, (Variable, DataArray, Dataset)):
return b
if not check_dim_order and set(a.dims) == set(b.dims):
# Ensure transpose won't fail if a dimension is missing
# If this is the case, the difference will be caught by the caller
return b.transpose(*a.dims)
return b


@overload
def assert_equal(a, b): ...

Expand All @@ -104,7 +116,7 @@ def assert_equal(a: DataTree, b: DataTree, from_root: bool = True): ...


@ensure_warnings
def assert_equal(a, b, from_root=True):
def assert_equal(a, b, from_root=True, check_dim_order: bool = True):
"""Like :py:func:`numpy.testing.assert_array_equal`, but for xarray
objects.
Expand All @@ -127,6 +139,8 @@ def assert_equal(a, b, from_root=True):
Only used when comparing DataTree objects. Indicates whether or not to
first traverse to the root of the trees before checking for isomorphism.
If a & b have no parents then this has no effect.
check_dim_order : bool, optional, default is True
Whether dimensions must be in the same order.
See Also
--------
Expand All @@ -137,6 +151,7 @@ def assert_equal(a, b, from_root=True):
assert (
type(a) == type(b) or isinstance(a, Coordinates) and isinstance(b, Coordinates)
)
b = maybe_transpose_dims(a, b, check_dim_order)
if isinstance(a, (Variable, DataArray)):
assert a.equals(b), formatting.diff_array_repr(a, b, "equals")
elif isinstance(a, Dataset):
Expand Down Expand Up @@ -182,6 +197,8 @@ def assert_identical(a, b, from_root=True):
Only used when comparing DataTree objects. Indicates whether or not to
first traverse to the root of the trees before checking for isomorphism.
If a & b have no parents then this has no effect.
check_dim_order : bool, optional, default is True
Whether dimensions must be in the same order.
See Also
--------
Expand Down Expand Up @@ -213,7 +230,9 @@ def assert_identical(a, b, from_root=True):


@ensure_warnings
def assert_allclose(a, b, rtol=1e-05, atol=1e-08, decode_bytes=True):
def assert_allclose(
a, b, rtol=1e-05, atol=1e-08, decode_bytes=True, check_dim_order: bool = True
):
"""Like :py:func:`numpy.testing.assert_allclose`, but for xarray objects.
Raises an AssertionError if two objects are not equal up to desired
Expand All @@ -233,23 +252,25 @@ def assert_allclose(a, b, rtol=1e-05, atol=1e-08, decode_bytes=True):
Whether byte dtypes should be decoded to strings as UTF-8 or not.
This is useful for testing serialization methods on Python 3 that
return saved strings as bytes.
check_dim_order : bool, optional, default is True
Whether dimensions must be in the same order.
See Also
--------
assert_identical, assert_equal, numpy.testing.assert_allclose
"""
__tracebackhide__ = True
assert type(a) == type(b)
b = maybe_transpose_dims(a, b, check_dim_order)

equiv = functools.partial(
_data_allclose_or_equiv, rtol=rtol, atol=atol, decode_bytes=decode_bytes
)
equiv.__name__ = "allclose"
equiv.__name__ = "allclose" # type: ignore[attr-defined]

def compat_variable(a, b):
a = getattr(a, "variable", a)
b = getattr(b, "variable", b)

return a.dims == b.dims and (a._data is b._data or equiv(a.data, b.data))

if isinstance(a, Variable):
Expand Down
19 changes: 19 additions & 0 deletions xarray/tests/test_assertions.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,25 @@ def test_allclose_regression() -> None:
def test_assert_allclose(obj1, obj2) -> None:
with pytest.raises(AssertionError):
xr.testing.assert_allclose(obj1, obj2)
with pytest.raises(AssertionError):
xr.testing.assert_allclose(obj1, obj2, check_dim_order=False)


@pytest.mark.parametrize("func", ["assert_equal", "assert_allclose"])
def test_assert_allclose_equal_transpose(func) -> None:
"""Transposed DataArray raises assertion unless check_dim_order=False."""
obj1 = xr.DataArray([[0, 1, 2], [2, 3, 4]], dims=["a", "b"])
obj2 = xr.DataArray([[0, 2], [1, 3], [2, 4]], dims=["b", "a"])
with pytest.raises(AssertionError):
getattr(xr.testing, func)(obj1, obj2)
getattr(xr.testing, func)(obj1, obj2, check_dim_order=False)
ds1 = obj1.to_dataset(name="varname")
ds1["var2"] = obj1
ds2 = obj1.to_dataset(name="varname")
ds2["var2"] = obj1.transpose()
with pytest.raises(AssertionError):
getattr(xr.testing, func)(ds1, ds2)
getattr(xr.testing, func)(ds1, ds2, check_dim_order=False)


@pytest.mark.filterwarnings("error")
Expand Down

0 comments on commit c4c2fbd

Please sign in to comment.