-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rolling window with as_strided
#1837
Changes from 17 commits
789134c
fa4e857
52915f3
b622007
36a1fe9
71fed0f
3960134
4bd38f3
af8362e
76db6b5
87f53af
c23cedb
9547c57
1f71cff
724776f
73862eb
859bb5c
d5fc24e
05c72f0
d55e498
9393eb2
9c71a50
54975b4
e907fdf
6482536
b8def4f
ff31589
6c011cb
684145a
3a7526e
a0968d6
ac4f00e
fbfc262
c757986
8fd5fa3
ade5ba2
2d6897f
6461f84
aece1c4
d5ad4a0
4189d71
081c928
75c1d7d
452b219
c5490c4
ab91394
9fa0812
0c1d49a
9463937
dce4e37
b3050cb
22f6d4a
19e0fca
d3b1e2b
2d06ec9
734da93
1a000b8
27ff67c
a2c7141
35dee9d
137709f
cc82cdc
b246411
b3a2105
b80fbfd
3c010ae
ab82f75
b9f10cd
cc9c3d6
52cc48d
2954cdf
f19e531
a074df3
f6f78a5
0ec8aba
0261cfe
a91c27f
c83d588
3bb4668
d0d89ce
eaba563
aeabdf5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,6 +27,13 @@ Documentation | |
|
||
Enhancements | ||
~~~~~~~~~~~~ | ||
- Improve :py:func:`~xarray.DataArray.rooling` logic for speed up. | ||
:py:func:`~xarray.DataArrayRolling` object now support ``to_dataarray`` | ||
method that returns a view of the DataArray object with the rolling-window | ||
dimension added to the last position. This enables more flexible operation, | ||
such as strided rolling, windowed rolling, ND-rolling, and convolution. | ||
(:issue:`1831`, :issue:`1142`, :issue:`819`) | ||
By `Keisuke Fujii <https://github.com/fujiisoup>`_. | ||
- Added nodatavals attribute to DataArray when using :py:func:`~xarray.open_rasterio`. (:issue:`1736`). | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add a bug fix note for the aggregations of the last element with |
||
By `Alan Snow <https://github.com/snowman2>`_. | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,6 +34,8 @@ def maybe_promote(dtype): | |
fill_value = np.datetime64('NaT') | ||
elif np.issubdtype(dtype, np.timedelta64): | ||
fill_value = np.timedelta64('NaT') | ||
elif dtype.kind == 'b': | ||
fill_value = False | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is convenient for me, but it is not very clear whether There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Indeed, let's consider other options here. This is used for the default value when reindexing/aligning. |
||
else: | ||
dtype = object | ||
fill_value = np.nan | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,7 @@ | |
import numpy as np | ||
import pandas as pd | ||
import warnings | ||
from . import npcompat | ||
|
||
|
||
def _validate_axis(data, axis): | ||
|
@@ -133,3 +134,52 @@ def __setitem__(self, key, value): | |
mixed_positions, vindex_positions = _advanced_indexer_subspaces(key) | ||
self._array[key] = np.moveaxis(value, vindex_positions, | ||
mixed_positions) | ||
|
||
|
||
def rolling_window(a, axis, window): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a small point, but can you swap the arguments for this function? That would let you set a default axis. Bottleneck uses default arguments like |
||
""" | ||
Make an ndarray with a rolling window along axis. | ||
|
||
Parameters | ||
---------- | ||
a : array_like | ||
Array to add rolling window to | ||
axis: int | ||
axis position along which rolling window will be applied. | ||
window : int | ||
Size of rolling window | ||
|
||
Returns | ||
------- | ||
Array that is a view of the original array with a added dimension | ||
of size w. | ||
|
||
Examples | ||
-------- | ||
>>> x=np.arange(10).reshape((2,5)) | ||
>>> np.rolling_window(x, 3, axis=-1) | ||
array([[[0, 1, 2], [1, 2, 3], [2, 3, 4]], | ||
[[5, 6, 7], [6, 7, 8], [7, 8, 9]]]) | ||
|
||
Calculate rolling mean of last dimension: | ||
>>> np.mean(np.rolling_window(x, 3, axis=-1), -1) | ||
array([[ 1., 2., 3.], | ||
[ 6., 7., 8.]]) | ||
|
||
This function is taken from https://github.com/numpy/numpy/pull/31 | ||
but slightly modified to accept axis option. | ||
""" | ||
axis = _validate_axis(a, axis) | ||
a = np.swapaxes(a, axis, -1) | ||
|
||
if window < 1: | ||
raise ValueError( | ||
"`window` must be at least 1. Given : {}".format(window)) | ||
if window > a.shape[-1]: | ||
raise ValueError("`window` is too long. Given : {}".format(window)) | ||
|
||
shape = a.shape[:-1] + (a.shape[-1] - window + 1, window) | ||
strides = a.strides + (a.strides[-1],) | ||
rolling = npcompat.as_strided(a, shape=shape, strides=strides, | ||
writeable=False) | ||
return np.swapaxes(rolling, -2, axis) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,8 +6,6 @@ | |
from distutils.version import LooseVersion | ||
|
||
from .pycompat import OrderedDict, zip, dask_array_type | ||
from .common import full_like | ||
from .combine import concat | ||
from .ops import (inject_bottleneck_rolling_methods, | ||
inject_datasetrolling_methods, has_bottleneck, bn) | ||
from .dask_array_ops import dask_rolling_wrapper | ||
|
@@ -127,62 +125,75 @@ class DataArrayRolling(Rolling): | |
def __init__(self, obj, min_periods=None, center=False, **windows): | ||
super(DataArrayRolling, self).__init__(obj, min_periods=min_periods, | ||
center=center, **windows) | ||
self._windows = None | ||
self._valid_windows = None | ||
self.window_indices = None | ||
self.window_labels = None | ||
|
||
self._setup_windows() | ||
|
||
@property | ||
def windows(self): | ||
if self._windows is None: | ||
self._windows = OrderedDict(zip(self.window_labels, | ||
self.window_indices)) | ||
return self._windows | ||
|
||
def __iter__(self): | ||
for (label, indices, valid) in zip(self.window_labels, | ||
self.window_indices, | ||
self._valid_windows): | ||
|
||
for (label, indices) in zip(self.window_labels, self.window_indices): | ||
window = self.obj.isel(**{self.dim: indices}) | ||
|
||
if not valid: | ||
window = full_like(window, fill_value=True, dtype=bool) | ||
counts = window.count(dim=self.dim) | ||
window = window.where(counts >= self._min_periods) | ||
|
||
yield (label, window) | ||
|
||
def _setup_windows(self): | ||
""" | ||
Find the indices and labels for each window | ||
""" | ||
from .dataarray import DataArray | ||
|
||
self.window_labels = self.obj[self.dim] | ||
|
||
window = int(self.window) | ||
|
||
dim_size = self.obj[self.dim].size | ||
|
||
stops = np.arange(dim_size) + 1 | ||
starts = np.maximum(stops - window, 0) | ||
|
||
if self._min_periods > 1: | ||
valid_windows = (stops - starts) >= self._min_periods | ||
else: | ||
# No invalid windows | ||
valid_windows = np.ones(dim_size, dtype=bool) | ||
self._valid_windows = DataArray(valid_windows, dims=(self.dim, ), | ||
coords=self.obj[self.dim].coords) | ||
|
||
self.window_indices = [slice(start, stop) | ||
for start, stop in zip(starts, stops)] | ||
|
||
def _center_result(self, result): | ||
"""center result""" | ||
shift = (-self.window // 2) + 1 | ||
return result.shift(**{self.dim: shift}) | ||
def to_dataarray(self, window_dim): | ||
""" | ||
Convert this rolling object to xr.DataArray, | ||
where the window dimension is stacked as a new dimension | ||
|
||
Parameters | ||
---------- | ||
window_dim: str | ||
New name of the window dimension. | ||
|
||
Returns | ||
------- | ||
DataArray that is a view of the original array. | ||
|
||
Note | ||
---- | ||
The return array is not writeable. | ||
|
||
Examples | ||
-------- | ||
>>> da = DataArray(np.arange(8).reshape(2, 4), dims=('a', 'b')) | ||
|
||
>>> da.rolling_window(x, 'b', 4, 'window_dim') | ||
<xarray.DataArray (a: 2, b: 4, window_dim: 3)> | ||
array([[[np.nan, np.nan, 0], [np.nan, 0, 1], [0, 1, 2], [1, 2, 3]], | ||
[[np.nan, np.nan, 4], [np.nan, 4, 5], [4, 5, 6], [5, 6, 7]]]) | ||
Dimensions without coordinates: a, b, window_dim | ||
|
||
>>> da.rolling_window(x, 'b', 4, 'window_dim', center=True) | ||
<xarray.DataArray (a: 2, b: 4, window_dim: 3)> | ||
array([[[np.nan, 0, 1], [0, 1, 2], [1, 2, 3], [2, 3, np.nan]], | ||
[[np.nan, 4, 5], [4, 5, 6], [5, 6, 7], [6, 7, np.nan]]]) | ||
Dimensions without coordinates: a, b, window_dim | ||
""" | ||
|
||
from .dataarray import DataArray | ||
|
||
window = self.obj.variable.rolling_window(self.dim, self.window, | ||
window_dim, self.center) | ||
return DataArray(window, dims=self.obj.dims + (window_dim,), | ||
coords=self.obj.coords) | ||
|
||
def reduce(self, func, **kwargs): | ||
"""Reduce the items in this group by applying `func` along some | ||
|
@@ -203,26 +214,18 @@ def reduce(self, func, **kwargs): | |
Array with summarized data. | ||
""" | ||
|
||
windows = [window.reduce(func, dim=self.dim, **kwargs) | ||
for _, window in self] | ||
|
||
# Find valid windows based on count | ||
if self.dim in self.obj.coords: | ||
concat_dim = self.window_labels | ||
else: | ||
concat_dim = self.dim | ||
counts = concat([window.count(dim=self.dim) for _, window in self], | ||
dim=concat_dim) | ||
result = concat(windows, dim=concat_dim) | ||
# restore dim order | ||
result = result.transpose(*self.obj.dims) | ||
windows = self.to_dataarray('_rolling_window_dim') | ||
result = windows.reduce(func, dim='_rolling_window_dim', **kwargs) | ||
|
||
# Find valid windows based on count. | ||
# We do not use `reduced.count()` because it constructs a larger array | ||
# (notice that `windows` is just a view) | ||
counts = (~self.obj.isnull()).rolling( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For formatting long chains of method calls, I like to add extra parentheses and break every operation at the start of the line, e.g., counts = ((~self.obj.isnull())
.rolling(center=self.center, **{self.dim: self.window})
.to_dataarray('_rolling_window_dim')
.sum(dim='_rolling_window_dim')) I find this makes it easier to read There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we should add a short-cut here that doesn't bother to compute You could add a utility function to determine this based on whether the result of |
||
center=self.center, **{self.dim: self.window}).to_dataarray( | ||
'_rolling_window_dim').sum(dim='_rolling_window_dim') | ||
result = result.where(counts >= self._min_periods) | ||
|
||
if self.center: | ||
result = self._center_result(result) | ||
|
||
return result | ||
# restore dim order | ||
return result.transpose(*self.obj.dims) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think we need to restore dimension order any more. The result should already be calculated correctly. |
||
|
||
@classmethod | ||
def _reduce_method(cls, func): | ||
|
@@ -254,19 +257,24 @@ def wrapped_func(self, **kwargs): | |
|
||
axis = self.obj.get_axis_num(self.dim) | ||
|
||
if isinstance(self.obj.data, dask_array_type): | ||
padded = self.obj.variable | ||
if self.center: | ||
shift = (-self.window // 2) + 1 | ||
padded = padded.pad_with_fill_value(**{self.dim: (0, -shift)}) | ||
valid = (slice(None), ) * axis + (slice(-shift, None), ) | ||
|
||
if isinstance(padded.data, dask_array_type): | ||
values = dask_rolling_wrapper(func, self.obj.data, | ||
window=self.window, | ||
min_count=min_count, | ||
axis=axis) | ||
else: | ||
values = func(self.obj.data, window=self.window, | ||
values = func(padded.data, window=self.window, | ||
min_count=min_count, axis=axis) | ||
|
||
result = DataArray(values, self.obj.coords) | ||
|
||
if self.center: | ||
result = self._center_result(result) | ||
values = values[valid] | ||
result = DataArray(values, self.obj.coords) | ||
|
||
return result | ||
return wrapped_func | ||
|
@@ -373,6 +381,31 @@ def wrapped_func(self, **kwargs): | |
return Dataset(reduced, coords=self.obj.coords) | ||
return wrapped_func | ||
|
||
def to_dataset(self, window_dim): | ||
""" | ||
Convert this rolling object to xr.Dataset, | ||
where the window dimension is stacked as a new dimension | ||
|
||
Parameters | ||
---------- | ||
window_dim: str | ||
New name of the window dimension. | ||
|
||
Returns | ||
------- | ||
Dataset with variables converted from rolling object. | ||
""" | ||
|
||
from .dataset import Dataset | ||
|
||
dataset = OrderedDict() | ||
for key, da in self.obj.data_vars.items(): | ||
if self.dim in da.dims: | ||
dataset[key] = self.rollings[key].to_dataarray(window_dim) | ||
else: | ||
dataset[key] = da | ||
return Dataset(dataset, coords=self.obj.coords) | ||
|
||
|
||
inject_bottleneck_rolling_methods(DataArrayRolling) | ||
inject_datasetrolling_methods(DatasetRolling) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe add: (
to_dataset
for Rolling objects from Dataset)