Skip to content

Commit

Permalink
enable xr.ALL_DIMS in xr.dot (#3424)
Browse files Browse the repository at this point in the history
* enable xr.ALL_DIMS in xr.dot

* trailing whitespace

* move whats new to other ellipsis work

* xr.ALL_DIMS -> Ellipsis
  • Loading branch information
mathause authored and max-sixty committed Oct 29, 2019
1 parent 80e4e89 commit 4d5237b
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 8 deletions.
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ New Features
to reduce over all dimensions. While we have no plans to remove `xr.ALL_DIMS`, we suggest
using `...`.
By `Maximilian Roos <https://github.com/max-sixty>`_
- :py:func:`~xarray.dot`, and :py:func:`~xarray.DataArray.dot` now support the
`dims=...` option to sum over the union of dimensions of all input arrays
(:issue:`3423`) by `Mathias Hauser <https://github.com/mathause>`_.
- Added integration tests against `pint <https://pint.readthedocs.io/>`_.
(:pull:`3238`) by `Justus Magin <https://github.com/keewis>`_.

Expand Down
20 changes: 15 additions & 5 deletions xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1055,9 +1055,9 @@ def dot(*arrays, dims=None, **kwargs):
----------
arrays: DataArray (or Variable) objects
Arrays to compute.
dims: str or tuple of strings, optional
Which dimensions to sum over.
If not speciified, then all the common dimensions are summed over.
dims: '...', str or tuple of strings, optional
Which dimensions to sum over. Ellipsis ('...') sums over all dimensions.
If not specified, then all the common dimensions are summed over.
**kwargs: dict
Additional keyword arguments passed to numpy.einsum or
dask.array.einsum
Expand All @@ -1070,7 +1070,7 @@ def dot(*arrays, dims=None, **kwargs):
--------
>>> import numpy as np
>>> import xarray as xp
>>> import xarray as xr
>>> da_a = xr.DataArray(np.arange(3 * 2).reshape(3, 2), dims=['a', 'b'])
>>> da_b = xr.DataArray(np.arange(3 * 2 * 2).reshape(3, 2, 2),
... dims=['a', 'b', 'c'])
Expand Down Expand Up @@ -1117,6 +1117,14 @@ def dot(*arrays, dims=None, **kwargs):
[273, 446, 619]])
Dimensions without coordinates: a, d
>>> xr.dot(da_a, da_b)
<xarray.DataArray (c: 2)>
array([110, 125])
Dimensions without coordinates: c
>>> xr.dot(da_a, da_b, dims=...)
<xarray.DataArray ()>
array(235)
"""
from .dataarray import DataArray
from .variable import Variable
Expand All @@ -1141,7 +1149,9 @@ def dot(*arrays, dims=None, **kwargs):
einsum_axes = "abcdefghijklmnopqrstuvwxyz"
dim_map = {d: einsum_axes[i] for i, d in enumerate(all_dims)}

if dims is None:
if dims is ...:
dims = all_dims
elif dims is None:
# find dimensions that occur more than one times
dim_counts = Counter()
for arr in arrays:
Expand Down
6 changes: 3 additions & 3 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2742,9 +2742,9 @@ def dot(
----------
other : DataArray
The other array with which the dot product is performed.
dims: hashable or sequence of hashables, optional
Along which dimensions to be summed over. Default all the common
dimensions are summed over.
dims: '...', hashable or sequence of hashables, optional
Which dimensions to sum over. Ellipsis ('...') sums over all dimensions.
If not specified, then all the common dimensions are summed over.
Returns
-------
Expand Down
17 changes: 17 additions & 0 deletions xarray/tests/test_computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -998,6 +998,23 @@ def test_dot(use_dask):
assert actual.dims == ("b",)
assert (actual.data == np.zeros(actual.shape)).all()

# Ellipsis (...) sums over all dimensions
actual = xr.dot(da_a, da_b, dims=...)
assert actual.dims == ()
assert (actual.data == np.einsum("ij,ijk->", a, b)).all()

actual = xr.dot(da_a, da_b, da_c, dims=...)
assert actual.dims == ()
assert (actual.data == np.einsum("ij,ijk,kl-> ", a, b, c)).all()

actual = xr.dot(da_a, dims=...)
assert actual.dims == ()
assert (actual.data == np.einsum("ij-> ", a)).all()

actual = xr.dot(da_a.sel(a=[]), da_a.sel(a=[]), dims=...)
assert actual.dims == ()
assert (actual.data == np.zeros(actual.shape)).all()

# Invalid cases
if not use_dask:
with pytest.raises(TypeError):
Expand Down
10 changes: 10 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -3925,6 +3925,16 @@ def test_dot(self):
expected = DataArray(expected_vals, coords=[x, j], dims=["x", "j"])
assert_equal(expected, actual)

# Ellipsis: all dims are shared
actual = da.dot(da, dims=...)
expected = da.dot(da)
assert_equal(expected, actual)

# Ellipsis: not all dims are shared
actual = da.dot(dm, dims=...)
expected = da.dot(dm, dims=("j", "x", "y", "z"))
assert_equal(expected, actual)

with pytest.raises(NotImplementedError):
da.dot(dm.to_dataset(name="dm"))
with pytest.raises(TypeError):
Expand Down

0 comments on commit 4d5237b

Please sign in to comment.