Skip to content

Commit

Permalink
sparse option to reindex and unstack (#3542)
Browse files Browse the repository at this point in the history
* Added fill_value for unstack

* remove sparse option and fix unintended changes

* a bug fix

* Added sparse option to unstack and reindex

* black

* More tests

* black

* Remove sparse option from reindex

* try __array_function__ where

* flake8
  • Loading branch information
fujiisoup authored and shoyer committed Nov 19, 2019
1 parent dc559ea commit 220adbc
Show file tree
Hide file tree
Showing 7 changed files with 113 additions and 4 deletions.
4 changes: 4 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ Breaking changes

New Features
~~~~~~~~~~~~
- Added the ``sparse`` option to :py:meth:`~xarray.DataArray.unstack`,
:py:meth:`~xarray.Dataset.unstack`, :py:meth:`~xarray.DataArray.reindex`,
:py:meth:`~xarray.Dataset.reindex` (:issue:`3518`).
By `Keisuke Fujii <https://github.com/fujiisoup>`_.

- Added the ``max_gap`` kwarg to :py:meth:`DataArray.interpolate_na` and
:py:meth:`Dataset.interpolate_na`. This controls the maximum size of the data
Expand Down
5 changes: 5 additions & 0 deletions xarray/core/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,7 @@ def reindex_variables(
tolerance: Any = None,
copy: bool = True,
fill_value: Optional[Any] = dtypes.NA,
sparse: bool = False,
) -> Tuple[Dict[Hashable, Variable], Dict[Hashable, pd.Index]]:
"""Conform a dictionary of aligned variables onto a new set of variables,
filling in missing values with NaN.
Expand Down Expand Up @@ -503,6 +504,8 @@ def reindex_variables(
the input. In either case, new xarray objects are always returned.
fill_value : scalar, optional
Value to use for newly missing values
sparse: bool, optional
Use an sparse-array
Returns
-------
Expand Down Expand Up @@ -571,6 +574,8 @@ def reindex_variables(

for name, var in variables.items():
if name not in indexers:
if sparse:
var = var._as_sparse(fill_value=fill_value)
key = tuple(
slice(None) if d in unchanged_dims else int_indexers.get(d, slice(None))
for d in var.dims
Expand Down
4 changes: 3 additions & 1 deletion xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1729,6 +1729,7 @@ def unstack(
self,
dim: Union[Hashable, Sequence[Hashable], None] = None,
fill_value: Any = dtypes.NA,
sparse: bool = False,
) -> "DataArray":
"""
Unstack existing dimensions corresponding to MultiIndexes into
Expand All @@ -1742,6 +1743,7 @@ def unstack(
Dimension(s) over which to unstack. By default unstacks all
MultiIndexes.
fill_value: value to be filled. By default, np.nan
sparse: use sparse-array if True
Returns
-------
Expand Down Expand Up @@ -1773,7 +1775,7 @@ def unstack(
--------
DataArray.stack
"""
ds = self._to_temp_dataset().unstack(dim, fill_value)
ds = self._to_temp_dataset().unstack(dim, fill_value, sparse)
return self._from_temp_dataset(ds)

def to_unstacked_dataset(self, dim, level=0):
Expand Down
35 changes: 32 additions & 3 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2286,6 +2286,7 @@ def reindex(
the input. In either case, a new xarray object is always returned.
fill_value : scalar, optional
Value to use for newly missing values
sparse: use sparse-array. By default, False
**indexers_kwarg : {dim: indexer, ...}, optional
Keyword arguments in the same form as ``indexers``.
One of indexers or indexers_kwargs must be provided.
Expand Down Expand Up @@ -2428,6 +2429,29 @@ def reindex(
the original and desired indexes. If you do want to fill in the `NaN` values present in the
original dataset, use the :py:meth:`~Dataset.fillna()` method.
"""
return self._reindex(
indexers,
method,
tolerance,
copy,
fill_value,
sparse=False,
**indexers_kwargs,
)

def _reindex(
self,
indexers: Mapping[Hashable, Any] = None,
method: str = None,
tolerance: Number = None,
copy: bool = True,
fill_value: Any = dtypes.NA,
sparse: bool = False,
**indexers_kwargs: Any,
) -> "Dataset":
"""
same to _reindex but support sparse option
"""
indexers = utils.either_dict_or_kwargs(indexers, indexers_kwargs, "reindex")

Expand All @@ -2444,6 +2468,7 @@ def reindex(
tolerance,
copy=copy,
fill_value=fill_value,
sparse=sparse,
)
coord_names = set(self._coord_names)
coord_names.update(indexers)
Expand Down Expand Up @@ -3327,7 +3352,7 @@ def ensure_stackable(val):

return data_array

def _unstack_once(self, dim: Hashable, fill_value) -> "Dataset":
def _unstack_once(self, dim: Hashable, fill_value, sparse) -> "Dataset":
index = self.get_index(dim)
index = index.remove_unused_levels()
full_idx = pd.MultiIndex.from_product(index.levels, names=index.names)
Expand All @@ -3336,7 +3361,9 @@ def _unstack_once(self, dim: Hashable, fill_value) -> "Dataset":
if index.equals(full_idx):
obj = self
else:
obj = self.reindex({dim: full_idx}, copy=False, fill_value=fill_value)
obj = self._reindex(
{dim: full_idx}, copy=False, fill_value=fill_value, sparse=sparse
)

new_dim_names = index.names
new_dim_sizes = [lev.size for lev in index.levels]
Expand Down Expand Up @@ -3366,6 +3393,7 @@ def unstack(
self,
dim: Union[Hashable, Iterable[Hashable]] = None,
fill_value: Any = dtypes.NA,
sparse: bool = False,
) -> "Dataset":
"""
Unstack existing dimensions corresponding to MultiIndexes into
Expand All @@ -3379,6 +3407,7 @@ def unstack(
Dimension(s) over which to unstack. By default unstacks all
MultiIndexes.
fill_value: value to be filled. By default, np.nan
sparse: use sparse-array if True
Returns
-------
Expand Down Expand Up @@ -3416,7 +3445,7 @@ def unstack(

result = self.copy(deep=False)
for dim in dims:
result = result._unstack_once(dim, fill_value)
result = result._unstack_once(dim, fill_value, sparse)
return result

def update(self, other: "CoercibleMapping", inplace: bool = None) -> "Dataset":
Expand Down
38 changes: 38 additions & 0 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -993,6 +993,36 @@ def chunk(self, chunks=None, name=None, lock=False):

return type(self)(self.dims, data, self._attrs, self._encoding, fastpath=True)

def _as_sparse(self, sparse_format=_default, fill_value=dtypes.NA):
"""
use sparse-array as backend.
"""
import sparse

# TODO what to do if dask-backended?
if fill_value is dtypes.NA:
dtype, fill_value = dtypes.maybe_promote(self.dtype)
else:
dtype = dtypes.result_type(self.dtype, fill_value)

if sparse_format is _default:
sparse_format = "coo"
try:
as_sparse = getattr(sparse, "as_{}".format(sparse_format.lower()))
except AttributeError:
raise ValueError("{} is not a valid sparse format".format(sparse_format))

data = as_sparse(self.data.astype(dtype), fill_value=fill_value)
return self._replace(data=data)

def _to_dense(self):
"""
Change backend from sparse to np.array
"""
if hasattr(self._data, "todense"):
return self._replace(data=self._data.todense())
return self.copy(deep=False)

def isel(
self: VariableType,
indexers: Mapping[Hashable, Any] = None,
Expand Down Expand Up @@ -2021,6 +2051,14 @@ def chunk(self, chunks=None, name=None, lock=False):
# Dummy - do not chunk. This method is invoked e.g. by Dataset.chunk()
return self.copy(deep=False)

def _as_sparse(self, sparse_format=_default, fill_value=_default):
# Dummy
return self.copy(deep=False)

def _to_dense(self):
# Dummy
return self.copy(deep=False)

def _finalize_indexing_result(self, dims, data):
if getattr(data, "ndim", 0) != 1:
# returns Variable rather than IndexVariable if multi-dimensional
Expand Down
19 changes: 19 additions & 0 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2811,6 +2811,25 @@ def test_unstack_fill_value(self):
expected = ds["var"].unstack("index").fillna(-1).astype(np.int)
assert actual.equals(expected)

@requires_sparse
def test_unstack_sparse(self):
ds = xr.Dataset(
{"var": (("x",), np.arange(6))},
coords={"x": [0, 1, 2] * 2, "y": (("x",), ["a"] * 3 + ["b"] * 3)},
)
# make ds incomplete
ds = ds.isel(x=[0, 2, 3, 4]).set_index(index=["x", "y"])
# test fill_value
actual = ds.unstack("index", sparse=True)
expected = ds.unstack("index")
assert actual["var"].variable._to_dense().equals(expected["var"].variable)
assert actual["var"].data.density < 1.0

actual = ds["var"].unstack("index", sparse=True)
expected = ds["var"].unstack("index")
assert actual.variable._to_dense().equals(expected.variable)
assert actual.data.density < 1.0

def test_stack_unstack_fast(self):
ds = Dataset(
{
Expand Down
12 changes: 12 additions & 0 deletions xarray/tests/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
assert_identical,
raises_regex,
requires_dask,
requires_sparse,
source_ndarray,
)

Expand Down Expand Up @@ -1862,6 +1863,17 @@ def test_getitem_with_mask_nd_indexer(self):
)


@requires_sparse
class TestVariableWithSparse:
# TODO inherit VariableSubclassobjects to cover more tests

def test_as_sparse(self):
data = np.arange(12).reshape(3, 4)
var = Variable(("x", "y"), data)._as_sparse(fill_value=-1)
actual = var._to_dense()
assert_identical(var, actual)


class TestIndexVariable(VariableSubclassobjects):
cls = staticmethod(IndexVariable)

Expand Down

0 comments on commit 220adbc

Please sign in to comment.