From 4d5237ba2d56c316cbc12b25572164afdbaef541 Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Tue, 29 Oct 2019 20:12:50 +0100 Subject: [PATCH] enable xr.ALL_DIMS in xr.dot (#3424) * enable xr.ALL_DIMS in xr.dot * trailing whitespace * move whats new to other ellipsis work * xr.ALL_DIMS -> Ellipsis --- doc/whats-new.rst | 3 +++ xarray/core/computation.py | 20 +++++++++++++++----- xarray/core/dataarray.py | 6 +++--- xarray/tests/test_computation.py | 17 +++++++++++++++++ xarray/tests/test_dataarray.py | 10 ++++++++++ 5 files changed, 48 insertions(+), 8 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 82355a6bda4..6bcf4b61436 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -39,6 +39,9 @@ New Features to reduce over all dimensions. While we have no plans to remove `xr.ALL_DIMS`, we suggest using `...`. By `Maximilian Roos `_ +- :py:func:`~xarray.dot`, and :py:func:`~xarray.DataArray.dot` now support the + `dims=...` option to sum over the union of dimensions of all input arrays + (:issue:`3423`) by `Mathias Hauser `_. - Added integration tests against `pint `_. (:pull:`3238`) by `Justus Magin `_. diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 2ab2ab78416..2c87f378762 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1055,9 +1055,9 @@ def dot(*arrays, dims=None, **kwargs): ---------- arrays: DataArray (or Variable) objects Arrays to compute. - dims: str or tuple of strings, optional - Which dimensions to sum over. - If not speciified, then all the common dimensions are summed over. + dims: '...', str or tuple of strings, optional + Which dimensions to sum over. Ellipsis ('...') sums over all dimensions. + If not specified, then all the common dimensions are summed over. **kwargs: dict Additional keyword arguments passed to numpy.einsum or dask.array.einsum @@ -1070,7 +1070,7 @@ def dot(*arrays, dims=None, **kwargs): -------- >>> import numpy as np - >>> import xarray as xp + >>> import xarray as xr >>> da_a = xr.DataArray(np.arange(3 * 2).reshape(3, 2), dims=['a', 'b']) >>> da_b = xr.DataArray(np.arange(3 * 2 * 2).reshape(3, 2, 2), ... dims=['a', 'b', 'c']) @@ -1117,6 +1117,14 @@ def dot(*arrays, dims=None, **kwargs): [273, 446, 619]]) Dimensions without coordinates: a, d + >>> xr.dot(da_a, da_b) + + array([110, 125]) + Dimensions without coordinates: c + + >>> xr.dot(da_a, da_b, dims=...) + + array(235) """ from .dataarray import DataArray from .variable import Variable @@ -1141,7 +1149,9 @@ def dot(*arrays, dims=None, **kwargs): einsum_axes = "abcdefghijklmnopqrstuvwxyz" dim_map = {d: einsum_axes[i] for i, d in enumerate(all_dims)} - if dims is None: + if dims is ...: + dims = all_dims + elif dims is None: # find dimensions that occur more than one times dim_counts = Counter() for arr in arrays: diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 0c220acaee0..62890f9cefa 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -2742,9 +2742,9 @@ def dot( ---------- other : DataArray The other array with which the dot product is performed. - dims: hashable or sequence of hashables, optional - Along which dimensions to be summed over. Default all the common - dimensions are summed over. + dims: '...', hashable or sequence of hashables, optional + Which dimensions to sum over. Ellipsis ('...') sums over all dimensions. + If not specified, then all the common dimensions are summed over. Returns ------- diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index 383427b479b..1f2634cc9b0 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -998,6 +998,23 @@ def test_dot(use_dask): assert actual.dims == ("b",) assert (actual.data == np.zeros(actual.shape)).all() + # Ellipsis (...) sums over all dimensions + actual = xr.dot(da_a, da_b, dims=...) + assert actual.dims == () + assert (actual.data == np.einsum("ij,ijk->", a, b)).all() + + actual = xr.dot(da_a, da_b, da_c, dims=...) + assert actual.dims == () + assert (actual.data == np.einsum("ij,ijk,kl-> ", a, b, c)).all() + + actual = xr.dot(da_a, dims=...) + assert actual.dims == () + assert (actual.data == np.einsum("ij-> ", a)).all() + + actual = xr.dot(da_a.sel(a=[]), da_a.sel(a=[]), dims=...) + assert actual.dims == () + assert (actual.data == np.zeros(actual.shape)).all() + # Invalid cases if not use_dask: with pytest.raises(TypeError): diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 4b3ffdc021a..5114d13b0dc 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -3925,6 +3925,16 @@ def test_dot(self): expected = DataArray(expected_vals, coords=[x, j], dims=["x", "j"]) assert_equal(expected, actual) + # Ellipsis: all dims are shared + actual = da.dot(da, dims=...) + expected = da.dot(da) + assert_equal(expected, actual) + + # Ellipsis: not all dims are shared + actual = da.dot(dm, dims=...) + expected = da.dot(dm, dims=("j", "x", "y", "z")) + assert_equal(expected, actual) + with pytest.raises(NotImplementedError): da.dot(dm.to_dataset(name="dm")) with pytest.raises(TypeError):