Skip to content

Commit

Permalink
Feature/align in dot (#3699)
Browse files Browse the repository at this point in the history
* add tests

* implement align

* whats new

* fix changes to whats new

* review: fix typos
  • Loading branch information
mathause authored and fujiisoup committed Jan 20, 2020
1 parent 5c97641 commit aa0f963
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 1 deletion.
5 changes: 4 additions & 1 deletion doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,15 @@ v0.15.0 (unreleased)

Breaking changes
~~~~~~~~~~~~~~~~

- Remove ``compat`` and ``encoding`` kwargs from ``DataArray``, which
have been deprecated since 0.12. (:pull:`3650`).
Instead, specify the encoding when writing to disk or set
the ``encoding`` attribute directly.
By `Maximilian Roos <https://github.com/max-sixty>`_
- :py:func:`xarray.dot`, :py:meth:`DataArray.dot`, and the ``@`` operator now
use ``align="inner"`` (except when ``xarray.set_options(arithmetic_join="exact")``;
:issue:`3694`) by `Mathias Hauser <https://github.com/mathause>`_.


New Features
~~~~~~~~~~~~
Expand Down
7 changes: 7 additions & 0 deletions xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from . import duck_array_ops, utils
from .alignment import deep_align
from .merge import merge_coordinates_without_align
from .options import OPTIONS
from .pycompat import dask_array_type
from .utils import is_dict_like
from .variable import Variable
Expand Down Expand Up @@ -1175,6 +1176,11 @@ def dot(*arrays, dims=None, **kwargs):
subscripts = ",".join(subscripts_list)
subscripts += "->..." + "".join([dim_map[d] for d in output_core_dims[0]])

join = OPTIONS["arithmetic_join"]
# using "inner" emulates `(a * b).sum()` for all joins (except "exact")
if join != "exact":
join = "inner"

# subscripts should be passed to np.einsum as arg, not as kwargs. We need
# to construct a partial function for apply_ufunc to work.
func = functools.partial(duck_array_ops.einsum, subscripts, **kwargs)
Expand All @@ -1183,6 +1189,7 @@ def dot(*arrays, dims=None, **kwargs):
*arrays,
input_core_dims=input_core_dims,
output_core_dims=output_core_dims,
join=join,
dask="allowed",
)
return result.transpose(*[d for d in all_dims if d in result.dims])
Expand Down
54 changes: 54 additions & 0 deletions xarray/tests/test_computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1043,6 +1043,60 @@ def test_dot(use_dask):
pickle.loads(pickle.dumps(xr.dot(da_a)))


@pytest.mark.parametrize("use_dask", [True, False])
def test_dot_align_coords(use_dask):
# GH 3694

if use_dask:
if not has_dask:
pytest.skip("test for dask.")

a = np.arange(30 * 4).reshape(30, 4)
b = np.arange(30 * 4 * 5).reshape(30, 4, 5)

# use partially overlapping coords
coords_a = {"a": np.arange(30), "b": np.arange(4)}
coords_b = {"a": np.arange(5, 35), "b": np.arange(1, 5)}

da_a = xr.DataArray(a, dims=["a", "b"], coords=coords_a)
da_b = xr.DataArray(b, dims=["a", "b", "c"], coords=coords_b)

if use_dask:
da_a = da_a.chunk({"a": 3})
da_b = da_b.chunk({"a": 3})

# join="inner" is the default
actual = xr.dot(da_a, da_b)
# `dot` sums over the common dimensions of the arguments
expected = (da_a * da_b).sum(["a", "b"])
xr.testing.assert_allclose(expected, actual)

actual = xr.dot(da_a, da_b, dims=...)
expected = (da_a * da_b).sum()
xr.testing.assert_allclose(expected, actual)

with xr.set_options(arithmetic_join="exact"):
with raises_regex(ValueError, "indexes along dimension"):
xr.dot(da_a, da_b)

# NOTE: dot always uses `join="inner"` because `(a * b).sum()` yields the same for all
# join method (except "exact")
with xr.set_options(arithmetic_join="left"):
actual = xr.dot(da_a, da_b)
expected = (da_a * da_b).sum(["a", "b"])
xr.testing.assert_allclose(expected, actual)

with xr.set_options(arithmetic_join="right"):
actual = xr.dot(da_a, da_b)
expected = (da_a * da_b).sum(["a", "b"])
xr.testing.assert_allclose(expected, actual)

with xr.set_options(arithmetic_join="outer"):
actual = xr.dot(da_a, da_b)
expected = (da_a * da_b).sum(["a", "b"])
xr.testing.assert_allclose(expected, actual)


def test_where():
cond = xr.DataArray([True, False], dims="x")
actual = xr.where(cond, 1, 0)
Expand Down
55 changes: 55 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -3973,6 +3973,43 @@ def test_dot(self):
with pytest.raises(TypeError):
da.dot(dm.values)

def test_dot_align_coords(self):
# GH 3694

x = np.linspace(-3, 3, 6)
y = np.linspace(-3, 3, 5)
z_a = range(4)
da_vals = np.arange(6 * 5 * 4).reshape((6, 5, 4))
da = DataArray(da_vals, coords=[x, y, z_a], dims=["x", "y", "z"])

z_m = range(2, 6)
dm_vals = range(4)
dm = DataArray(dm_vals, coords=[z_m], dims=["z"])

with xr.set_options(arithmetic_join="exact"):
with raises_regex(ValueError, "indexes along dimension"):
da.dot(dm)

da_aligned, dm_aligned = xr.align(da, dm, join="inner")

# nd dot 1d
actual = da.dot(dm)
expected_vals = np.tensordot(da_aligned.values, dm_aligned.values, [2, 0])
expected = DataArray(expected_vals, coords=[x, da_aligned.y], dims=["x", "y"])
assert_equal(expected, actual)

# multiple shared dims
dm_vals = np.arange(20 * 5 * 4).reshape((20, 5, 4))
j = np.linspace(-3, 3, 20)
dm = DataArray(dm_vals, coords=[j, y, z_m], dims=["j", "y", "z"])
da_aligned, dm_aligned = xr.align(da, dm, join="inner")
actual = da.dot(dm)
expected_vals = np.tensordot(
da_aligned.values, dm_aligned.values, axes=([1, 2], [1, 2])
)
expected = DataArray(expected_vals, coords=[x, j], dims=["x", "j"])
assert_equal(expected, actual)

def test_matmul(self):

# copied from above (could make a fixture)
Expand All @@ -3986,6 +4023,24 @@ def test_matmul(self):
expected = da.dot(da)
assert_identical(result, expected)

def test_matmul_align_coords(self):
# GH 3694

x_a = np.arange(6)
x_b = np.arange(2, 8)
da_vals = np.arange(6)
da_a = DataArray(da_vals, coords=[x_a], dims=["x"])
da_b = DataArray(da_vals, coords=[x_b], dims=["x"])

# only test arithmetic_join="inner" (=default)
result = da_a @ da_b
expected = da_a.dot(da_b)
assert_identical(result, expected)

with xr.set_options(arithmetic_join="exact"):
with raises_regex(ValueError, "indexes along dimension"):
da_a @ da_b

def test_binary_op_propagate_indexes(self):
# regression test for GH2227
self.dv["x"] = np.arange(self.dv.sizes["x"])
Expand Down

0 comments on commit aa0f963

Please sign in to comment.