Skip to content

Commit

Permalink
Propagate indexes in DataArray binary operations. (#3481)
Browse files Browse the repository at this point in the history
* Propagate indexes in DataArray binary operations.

Works by propagating indexes in DataArray._replace.

xref #2227. Tests pass!

* remove commented code.

* fix roll
  • Loading branch information
dcherian authored Nov 5, 2019
1 parent 46c4931 commit b649846
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 3 deletions.
8 changes: 5 additions & 3 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,14 +386,15 @@ def _replace(
variable: Variable = None,
coords=None,
name: Union[Hashable, None, Default] = _default,
indexes=None,
) -> "DataArray":
if variable is None:
variable = self.variable
if coords is None:
coords = self._coords
if name is _default:
name = self.name
return type(self)(variable, coords, name=name, fastpath=True)
return type(self)(variable, coords, name=name, fastpath=True, indexes=indexes)

def _replace_maybe_drop_dims(
self, variable: Variable, name: Union[Hashable, None, Default] = _default
Expand Down Expand Up @@ -440,7 +441,8 @@ def _from_temp_dataset(
) -> "DataArray":
variable = dataset._variables.pop(_THIS_ARRAY)
coords = dataset._variables
return self._replace(variable, coords, name)
indexes = dataset._indexes
return self._replace(variable, coords, name, indexes=indexes)

def _to_dataset_split(self, dim: Hashable) -> Dataset:
def subset(dim, label):
Expand Down Expand Up @@ -2506,7 +2508,7 @@ def func(self, other):
coords, indexes = self.coords._merge_raw(other_coords)
name = self._result_name(other)

return self._replace(variable, coords, name)
return self._replace(variable, coords, name, indexes=indexes)

return func

Expand Down
2 changes: 2 additions & 0 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4891,6 +4891,8 @@ def roll(self, shifts=None, roll_coords=None, **shifts_kwargs):
(dim,) = self.variables[k].dims
if dim in shifts:
indexes[k] = roll_index(v, shifts[dim])
else:
indexes[k] = v
else:
indexes = dict(self.indexes)

Expand Down
1 change: 1 addition & 0 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,7 @@ def _maybe_unstack(self, obj):
for dim in self._inserted_dims:
if dim in obj.coords:
del obj.coords[dim]
del obj.indexes[dim]
return obj

def fillna(self, value):
Expand Down
3 changes: 3 additions & 0 deletions xarray/core/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ def __contains__(self, key):
def __getitem__(self, key):
return self._indexes[key]

def __delitem__(self, key):
del self._indexes[key]

def __repr__(self):
return formatting.indexes_repr(self)

Expand Down
11 changes: 11 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -3953,6 +3953,17 @@ def test_matmul(self):
expected = da.dot(da)
assert_identical(result, expected)

def test_binary_op_propagate_indexes(self):
# regression test for GH2227
self.dv["x"] = np.arange(self.dv.sizes["x"])
expected = self.dv.indexes["x"]

actual = (self.dv * 10).indexes["x"]
assert expected is actual

actual = (self.dv > 10).indexes["x"]
assert expected is actual

def test_binary_op_join_setting(self):
dim = "x"
align_type = "outer"
Expand Down
8 changes: 8 additions & 0 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4951,6 +4951,14 @@ def test_filter_by_attrs(self):
)
assert not bool(new_ds.data_vars)

def test_binary_op_propagate_indexes(self):
ds = Dataset(
{"d1": DataArray([1, 2, 3], dims=["x"], coords={"x": [10, 20, 30]})}
)
expected = ds.indexes["x"]
actual = (ds * 2).indexes["x"]
assert expected is actual

def test_binary_op_join_setting(self):
# arithmetic_join applies to data array coordinates
missing_2 = xr.Dataset({"x": [0, 1]})
Expand Down

0 comments on commit b649846

Please sign in to comment.