Skip to content

Commit

Permalink
Speed up isel and __getitem__ (#3375)
Browse files Browse the repository at this point in the history
* Variable.isel cleanup/speedup

* Dataset.isel code cleanup

* Speed up isel

* What's New

* Better error checks

* Speedup

* type annotations

* Update doc/whats-new.rst

Co-Authored-By: Maximilian Roos <[email protected]>

* What's New

* What's New

* Always shallow-copy variables
  • Loading branch information
crusaderky authored and Joe Hamman committed Oct 9, 2019
1 parent 132733a commit 3f0049f
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 63 deletions.
10 changes: 8 additions & 2 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,21 +42,27 @@ Breaking changes

(:issue:`3222`, :issue:`3293`, :issue:`3340`, :issue:`3346`, :issue:`3358`).
By `Guido Imperiale <https://github.com/crusaderky>`_.
- Dropped the 'drop=False' optional parameter from :meth:`Variable.isel`.
It was unused and doesn't make sense for a Variable.
(:pull:`3375`) by `Guido Imperiale <https://github.com/crusaderky>`_.

New functions/methods
~~~~~~~~~~~~~~~~~~~~~

Enhancements
~~~~~~~~~~~~

- Add a repr for :py:class:`~xarray.core.GroupBy` objects (:issue:`3344`).
- Add a repr for :py:class:`~xarray.core.GroupBy` objects.
Example::

>>> da.groupby("time.season")
DataArrayGroupBy, grouped over 'season'
4 groups with labels 'DJF', 'JJA', 'MAM', 'SON'

By `Deepak Cherian <https://github.com/dcherian>`_.
(:issue:`3344`) by `Deepak Cherian <https://github.com/dcherian>`_.
- Speed up :meth:`Dataset.isel` up to 33% and :meth:`DataArray.isel` up to 25% for small
arrays (:issue:`2799`, :pull:`3375`) by
`Guido Imperiale <https://github.com/crusaderky>`_.

Bug fixes
~~~~~~~~~
Expand Down
108 changes: 59 additions & 49 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1745,8 +1745,8 @@ def maybe_chunk(name, var, chunks):
return self._replace(variables)

def _validate_indexers(
self, indexers: Mapping
) -> List[Tuple[Any, Union[slice, Variable]]]:
self, indexers: Mapping[Hashable, Any]
) -> Iterator[Tuple[Hashable, Union[int, slice, np.ndarray, Variable]]]:
""" Here we make sure
+ indexer has a valid keys
+ indexer is in a valid data type
Expand All @@ -1755,50 +1755,61 @@ def _validate_indexers(
"""
from .dataarray import DataArray

invalid = [k for k in indexers if k not in self.dims]
invalid = indexers.keys() - self.dims.keys()
if invalid:
raise ValueError("dimensions %r do not exist" % invalid)

# all indexers should be int, slice, np.ndarrays, or Variable
indexers_list: List[Tuple[Any, Union[slice, Variable]]] = []
for k, v in indexers.items():
if isinstance(v, slice):
indexers_list.append((k, v))
continue

if isinstance(v, Variable):
pass
if isinstance(v, (int, slice, Variable)):
yield k, v
elif isinstance(v, DataArray):
v = v.variable
yield k, v.variable
elif isinstance(v, tuple):
v = as_variable(v)
yield k, as_variable(v)
elif isinstance(v, Dataset):
raise TypeError("cannot use a Dataset as an indexer")
elif isinstance(v, Sequence) and len(v) == 0:
v = Variable((k,), np.zeros((0,), dtype="int64"))
yield k, np.empty((0,), dtype="int64")
else:
v = np.asarray(v)

if v.dtype.kind == "U" or v.dtype.kind == "S":
if v.dtype.kind in "US":
index = self.indexes[k]
if isinstance(index, pd.DatetimeIndex):
v = v.astype("datetime64[ns]")
elif isinstance(index, xr.CFTimeIndex):
v = _parse_array_of_cftime_strings(v, index.date_type)

if v.ndim == 0:
v = Variable((), v)
elif v.ndim == 1:
v = Variable((k,), v)
else:
if v.ndim > 1:
raise IndexError(
"Unlabeled multi-dimensional array cannot be "
"used for indexing: {}".format(k)
)
yield k, v

indexers_list.append((k, v))

return indexers_list
def _validate_interp_indexers(
self, indexers: Mapping[Hashable, Any]
) -> Iterator[Tuple[Hashable, Variable]]:
"""Variant of _validate_indexers to be used for interpolation
"""
for k, v in self._validate_indexers(indexers):
if isinstance(v, Variable):
if v.ndim == 1:
yield k, v.to_index_variable()
else:
yield k, v
elif isinstance(v, int):
yield k, Variable((), v)
elif isinstance(v, np.ndarray):
if v.ndim == 0:
yield k, Variable((), v)
elif v.ndim == 1:
yield k, IndexVariable((k,), v)
else:
raise AssertionError() # Already tested by _validate_indexers
else:
raise TypeError(type(v))

def _get_indexers_coords_and_indexes(self, indexers):
"""Extract coordinates and indexes from indexers.
Expand Down Expand Up @@ -1885,10 +1896,10 @@ def isel(
Dataset.sel
DataArray.isel
"""

indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "isel")

indexers_list = self._validate_indexers(indexers)
# Note: we need to preserve the original indexers variable in order to merge the
# coords below
indexers_list = list(self._validate_indexers(indexers))

variables = OrderedDict() # type: OrderedDict[Hashable, Variable]
indexes = OrderedDict() # type: OrderedDict[Hashable, pd.Index]
Expand All @@ -1904,19 +1915,21 @@ def isel(
)
if new_index is not None:
indexes[name] = new_index
else:
elif var_indexers:
new_var = var.isel(indexers=var_indexers)
else:
new_var = var.copy(deep=False)

variables[name] = new_var

coord_names = set(variables).intersection(self._coord_names)
coord_names = self._coord_names & variables.keys()
selected = self._replace_with_new_dims(variables, coord_names, indexes)

# Extract coordinates from indexers
coord_vars, new_indexes = selected._get_indexers_coords_and_indexes(indexers)
variables.update(coord_vars)
indexes.update(new_indexes)
coord_names = set(variables).intersection(self._coord_names).union(coord_vars)
coord_names = self._coord_names & variables.keys() | coord_vars.keys()
return self._replace_with_new_dims(variables, coord_names, indexes=indexes)

def sel(
Expand Down Expand Up @@ -2478,11 +2491,9 @@ def interp(

if kwargs is None:
kwargs = {}

coords = either_dict_or_kwargs(coords, coords_kwargs, "interp")
indexers = OrderedDict(
(k, v.to_index_variable() if isinstance(v, Variable) and v.ndim == 1 else v)
for k, v in self._validate_indexers(coords)
)
indexers = OrderedDict(self._validate_interp_indexers(coords))

obj = self if assume_sorted else self.sortby([k for k in coords])

Expand All @@ -2507,26 +2518,25 @@ def _validate_interp_indexer(x, new_x):
"strings or datetimes. "
"Instead got\n{}".format(new_x)
)
else:
return (x, new_x)
return x, new_x

variables = OrderedDict() # type: OrderedDict[Hashable, Variable]
for name, var in obj._variables.items():
if name not in indexers:
if var.dtype.kind in "uifc":
var_indexers = {
k: _validate_interp_indexer(maybe_variable(obj, k), v)
for k, v in indexers.items()
if k in var.dims
}
variables[name] = missing.interp(
var, var_indexers, method, **kwargs
)
elif all(d not in indexers for d in var.dims):
# keep unrelated object array
variables[name] = var
if name in indexers:
continue

if var.dtype.kind in "uifc":
var_indexers = {
k: _validate_interp_indexer(maybe_variable(obj, k), v)
for k, v in indexers.items()
if k in var.dims
}
variables[name] = missing.interp(var, var_indexers, method, **kwargs)
elif all(d not in indexers for d in var.dims):
# keep unrelated object array
variables[name] = var

coord_names = set(variables).intersection(obj._coord_names)
coord_names = obj._coord_names & variables.keys()
indexes = OrderedDict(
(k, v) for k, v in obj.indexes.items() if k not in indexers
)
Expand All @@ -2546,7 +2556,7 @@ def _validate_interp_indexer(x, new_x):
variables.update(coord_vars)
indexes.update(new_indexes)

coord_names = set(variables).intersection(obj._coord_names).union(coord_vars)
coord_names = obj._coord_names & variables.keys() | coord_vars.keys()
return self._replace_with_new_dims(variables, coord_names, indexes=indexes)

def interp_like(
Expand Down
3 changes: 2 additions & 1 deletion xarray/core/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from collections import OrderedDict
from typing import Any, Hashable, Iterable, Mapping, Optional, Tuple, Union

import numpy as np
import pandas as pd

from . import formatting
Expand Down Expand Up @@ -63,7 +64,7 @@ def isel_variable_and_index(
name: Hashable,
variable: Variable,
index: pd.Index,
indexers: Mapping[Any, Union[slice, Variable]],
indexers: Mapping[Hashable, Union[int, slice, np.ndarray, Variable]],
) -> Tuple[Variable, Optional[pd.Index]]:
"""Index a Variable and pandas.Index together."""
if not indexers:
Expand Down
35 changes: 24 additions & 11 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from collections import OrderedDict, defaultdict
from datetime import timedelta
from distutils.version import LooseVersion
from typing import Any, Hashable, Mapping, Union
from typing import Any, Hashable, Mapping, Union, TypeVar

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -41,6 +41,18 @@
# https://github.com/python/mypy/issues/224
BASIC_INDEXING_TYPES = integer_types + (slice,) # type: ignore

VariableType = TypeVar("VariableType", bound="Variable")
"""Type annotation to be used when methods of Variable return self or a copy of self.
When called from an instance of a subclass, e.g. IndexVariable, mypy identifies the
output as an instance of the subclass.
Usage::
class Variable:
def f(self: VariableType, ...) -> VariableType:
...
"""


class MissingDimensionsError(ValueError):
"""Error class used when we can't safely guess a dimension name.
Expand Down Expand Up @@ -663,8 +675,8 @@ def _broadcast_indexes_vectorized(self, key):

return out_dims, VectorizedIndexer(tuple(out_key)), new_order

def __getitem__(self, key):
"""Return a new Array object whose contents are consistent with
def __getitem__(self: VariableType, key) -> VariableType:
"""Return a new Variable object whose contents are consistent with
getting the provided key from the underlying data.
NB. __getitem__ and __setitem__ implement xarray-style indexing,
Expand All @@ -682,7 +694,7 @@ def __getitem__(self, key):
data = duck_array_ops.moveaxis(data, range(len(new_order)), new_order)
return self._finalize_indexing_result(dims, data)

def _finalize_indexing_result(self, dims, data):
def _finalize_indexing_result(self: VariableType, dims, data) -> VariableType:
"""Used by IndexVariable to return IndexVariable objects when possible.
"""
return type(self)(dims, data, self._attrs, self._encoding, fastpath=True)
Expand Down Expand Up @@ -957,7 +969,11 @@ def chunk(self, chunks=None, name=None, lock=False):

return type(self)(self.dims, data, self._attrs, self._encoding, fastpath=True)

def isel(self, indexers=None, drop=False, **indexers_kwargs):
def isel(
self: VariableType,
indexers: Mapping[Hashable, Any] = None,
**indexers_kwargs: Any
) -> VariableType:
"""Return a new array indexed along the specified dimension(s).
Parameters
Expand All @@ -976,15 +992,12 @@ def isel(self, indexers=None, drop=False, **indexers_kwargs):
"""
indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "isel")

invalid = [k for k in indexers if k not in self.dims]
invalid = indexers.keys() - set(self.dims)
if invalid:
raise ValueError("dimensions %r do not exist" % invalid)

key = [slice(None)] * self.ndim
for i, dim in enumerate(self.dims):
if dim in indexers:
key[i] = indexers[dim]
return self[tuple(key)]
key = tuple(indexers.get(dim, slice(None)) for dim in self.dims)
return self[key]

def squeeze(self, dim=None):
"""Return a new object with squeezed data.
Expand Down

0 comments on commit 3f0049f

Please sign in to comment.