Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Explicitly keep track of indexes with merging #3234

Merged
merged 9 commits into from
Oct 4, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 15 additions & 11 deletions xarray/core/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def align(
all_indexes[dim].append(index)

if join == "override":
objects = _override_indexes(list(objects), all_indexes, exclude)
objects = _override_indexes(objects, all_indexes, exclude)

# We don't reindex over dimensions with all equal indexes for two reasons:
# - It's faster for the usual case (already aligned objects).
Expand Down Expand Up @@ -365,26 +365,27 @@ def is_alignable(obj):
targets = []
no_key = object()
not_replaced = object()
for n, variables in enumerate(objects):
for position, variables in enumerate(objects):
if is_alignable(variables):
positions.append(n)
positions.append(position)
keys.append(no_key)
targets.append(variables)
out.append(not_replaced)
elif is_dict_like(variables):
current_out = OrderedDict()
for k, v in variables.items():
if is_alignable(v) and k not in indexes:
# Skip variables in indexes for alignment, because these
# should to be overwritten instead:
# https://github.com/pydata/xarray/issues/725
positions.append(n)
if is_alignable(v):
positions.append(position)
keys.append(k)
targets.append(v)
out.append(OrderedDict(variables))
current_out[k] = not_replaced
else:
current_out[k] = v
out.append(current_out)
elif raise_on_invalid:
raise ValueError(
"object to align is neither an xarray.Dataset, "
"an xarray.DataArray nor a dictionary: %r" % variables
"an xarray.DataArray nor a dictionary: {!r}".format(variables)
)
else:
out.append(variables)
Expand All @@ -405,7 +406,10 @@ def is_alignable(obj):
out[position][key] = aligned_obj

# something went wrong: we should have replaced all sentinel values
assert all(arg is not not_replaced for arg in out)
for arg in out:
assert arg is not not_replaced
if is_dict_like(arg):
assert all(value is not not_replaced for value in arg.values())

return out

Expand Down
39 changes: 18 additions & 21 deletions xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,13 @@

from . import duck_array_ops, utils
from .alignment import deep_align
from .merge import expand_and_merge_variables
from .merge import merge_coordinates_without_align
from .pycompat import dask_array_type
from .utils import is_dict_like
from .variable import Variable

if TYPE_CHECKING:
from .coordinates import Coordinates # noqa
from .dataset import Dataset

_DEFAULT_FROZEN_SET = frozenset() # type: frozenset
Expand Down Expand Up @@ -152,17 +153,16 @@ def result_name(objects: list) -> Any:
return name


def _get_coord_variables(args):
input_coords = []
def _get_coords_list(args) -> List["Coordinates"]:
coords_list = []
for arg in args:
try:
coords = arg.coords
except AttributeError:
pass # skip this argument
else:
coord_vars = getattr(coords, "variables", coords)
input_coords.append(coord_vars)
return input_coords
coords_list.append(coords)
return coords_list


def build_output_coords(
Expand All @@ -185,32 +185,29 @@ def build_output_coords(
-------
OrderedDict of Variable objects with merged coordinates.
"""
input_coords = _get_coord_variables(args)
coords_list = _get_coords_list(args)

if exclude_dims:
input_coords = [
OrderedDict(
(k, v) for k, v in coord_vars.items() if exclude_dims.isdisjoint(v.dims)
)
for coord_vars in input_coords
]

if len(input_coords) == 1:
if len(coords_list) == 1 and not exclude_dims:
# we can skip the expensive merge
unpacked_input_coords, = input_coords
merged = OrderedDict(unpacked_input_coords)
unpacked_coords, = coords_list
merged_vars = OrderedDict(unpacked_coords.variables)
else:
merged = expand_and_merge_variables(input_coords)
# TODO: save these merged indexes, instead of re-computing them later
merged_vars, unused_indexes = merge_coordinates_without_align(
coords_list, exclude_dims=exclude_dims
)

output_coords = []
for output_dims in signature.output_core_dims:
dropped_dims = signature.all_input_core_dims - set(output_dims)
if dropped_dims:
filtered = OrderedDict(
(k, v) for k, v in merged.items() if dropped_dims.isdisjoint(v.dims)
(k, v)
for k, v in merged_vars.items()
if dropped_dims.isdisjoint(v.dims)
)
else:
filtered = merged
filtered = merged_vars
output_coords.append(filtered)

return output_coords
Expand Down
88 changes: 54 additions & 34 deletions xarray/core/coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
Hashable,
Iterator,
Mapping,
Sequence,
Set,
Sequence,
Tuple,
Union,
cast,
Expand All @@ -17,11 +17,7 @@

from . import formatting, indexing
from .indexes import Indexes
from .merge import (
expand_and_merge_variables,
merge_coords,
merge_coords_for_inplace_math,
)
from .merge import merge_coords, merge_coordinates_without_align
from .utils import Frozen, ReprObject, either_dict_or_kwargs
from .variable import Variable

Expand All @@ -34,7 +30,7 @@
_THIS_ARRAY = ReprObject("<this-array>")


class AbstractCoordinates(Mapping[Hashable, "DataArray"]):
class Coordinates(Mapping[Hashable, "DataArray"]):
__slots__ = ()

def __getitem__(self, key: Hashable) -> "DataArray":
Expand All @@ -57,10 +53,10 @@ def indexes(self) -> Indexes:

@property
def variables(self):
raise NotImplementedError()
raise NotImplementedError

def _update_coords(self, coords):
raise NotImplementedError()
def _update_coords(self, coords, indexes):
raise NotImplementedError

def __iter__(self) -> Iterator["Hashable"]:
# needs to be in the same order as the dataset variables
Expand Down Expand Up @@ -116,38 +112,38 @@ def to_index(self, ordered_dims: Sequence[Hashable] = None) -> pd.Index:

def update(self, other: Mapping[Hashable, Any]) -> None:
other_vars = getattr(other, "variables", other)
coords = merge_coords(
coords, indexes = merge_coords(
[self.variables, other_vars], priority_arg=1, indexes=self.indexes
)
self._update_coords(coords)
self._update_coords(coords, indexes)

def _merge_raw(self, other):
"""For use with binary arithmetic."""
if other is None:
variables = OrderedDict(self.variables)
indexes = OrderedDict(self.indexes)
else:
# don't align because we already called xarray.align
variables = expand_and_merge_variables([self.variables, other.variables])
return variables
variables, indexes = merge_coordinates_without_align([self, other])
return variables, indexes

@contextmanager
def _merge_inplace(self, other):
"""For use with in-place binary arithmetic."""
if other is None:
yield
else:
# don't include indexes in priority_vars, because we didn't align
# first
priority_vars = OrderedDict(
kv for kv in self.variables.items() if kv[0] not in self.dims
)
variables = merge_coords_for_inplace_math(
[self.variables, other.variables], priority_vars=priority_vars
# don't include indexes in prioritized, because we didn't align
# first and we want indexes to be checked
prioritized = {
k: (v, None) for k, v in self.variables.items() if k not in self.indexes
}
variables, indexes = merge_coordinates_without_align(
[self, other], prioritized
)
yield
self._update_coords(variables)
self._update_coords(variables, indexes)

def merge(self, other: "AbstractCoordinates") -> "Dataset":
def merge(self, other: "Coordinates") -> "Dataset":
"""Merge two sets of coordinates to create a new Dataset

The method implements the logic used for joining coordinates in the
Expand All @@ -173,13 +169,19 @@ def merge(self, other: "AbstractCoordinates") -> "Dataset":

if other is None:
return self.to_dataset()
else:
other_vars = getattr(other, "variables", other)
coords = expand_and_merge_variables([self.variables, other_vars])
return Dataset._from_vars_and_coord_names(coords, set(coords))

if not isinstance(other, Coordinates):
other = Dataset(coords=other).coords

coords, indexes = merge_coordinates_without_align([self, other])
coord_names = set(coords)
merged = Dataset._construct_direct(
variables=coords, coord_names=coord_names, indexes=indexes
)
return merged


class DatasetCoordinates(AbstractCoordinates):
class DatasetCoordinates(Coordinates):
"""Dictionary like container for Dataset coordinates.

Essentially an immutable OrderedDict with keys given by the array's
Expand Down Expand Up @@ -218,7 +220,11 @@ def to_dataset(self) -> "Dataset":
"""
return self._data._copy_listed(self._names)

def _update_coords(self, coords: Mapping[Hashable, Any]) -> None:
def _update_coords(
self,
coords: "OrderedDict[Hashable, Variable]",
indexes: Mapping[Hashable, pd.Index],
) -> None:
from .dataset import calculate_dimensions

variables = self._data._variables.copy()
Expand All @@ -234,7 +240,12 @@ def _update_coords(self, coords: Mapping[Hashable, Any]) -> None:
self._data._variables = variables
self._data._coord_names.update(new_coord_names)
self._data._dims = dims
self._data._indexes = None

# TODO(shoyer): once ._indexes is always populated by a dict, modify
# it to update inplace instead.
original_indexes = OrderedDict(self._data.indexes)
original_indexes.update(indexes)
self._data._indexes = original_indexes

def __delitem__(self, key: Hashable) -> None:
if key in self:
Expand All @@ -251,7 +262,7 @@ def _ipython_key_completions_(self):
]


class DataArrayCoordinates(AbstractCoordinates):
class DataArrayCoordinates(Coordinates):
"""Dictionary like container for DataArray coordinates.

Essentially an OrderedDict with keys given by the array's
Expand All @@ -274,7 +285,11 @@ def _names(self) -> Set[Hashable]:
def __getitem__(self, key: Hashable) -> "DataArray":
return self._data._getitem_coord(key)

def _update_coords(self, coords) -> None:
def _update_coords(
self,
coords: "OrderedDict[Hashable, Variable]",
indexes: Mapping[Hashable, pd.Index],
) -> None:
from .dataset import calculate_dimensions

coords_plus_data = coords.copy()
Expand All @@ -285,7 +300,12 @@ def _update_coords(self, coords) -> None:
"cannot add coordinates with new dimensions to " "a DataArray"
)
self._data._coords = coords
self._data._indexes = None

# TODO(shoyer): once ._indexes is always populated by a dict, modify
# it to update inplace instead.
original_indexes = OrderedDict(self._data.indexes)
original_indexes.update(indexes)
self._data._indexes = original_indexes

@property
def variables(self):
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2519,7 +2519,7 @@ def func(self, other):
if not reflexive
else f(other_variable, self.variable)
)
coords = self.coords._merge_raw(other_coords)
coords, indexes = self.coords._merge_raw(other_coords)
name = self._result_name(other)

return self._replace(variable, coords, name)
Expand Down
Loading