Skip to content

Commit fdf024c

Browse files
dcherianmathause
andauthored
Use numpy & dask sliding_window_view for rolling (#4977)
Co-authored-by: Mathias Hauser <[email protected]>
1 parent 643e89e commit fdf024c

11 files changed

+417
-270
lines changed

doc/user-guide/duckarrays.rst

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,9 @@ the code will still cast to ``numpy`` arrays:
4242
:py:meth:`DataArray.interp` and :py:meth:`DataArray.interp_like` (uses ``scipy``):
4343
duck arrays in data variables and non-dimension coordinates will be casted in
4444
addition to not supporting duck arrays in dimension coordinates
45+
* :py:meth:`Dataset.rolling` and :py:meth:`DataArray.rolling` (requires ``numpy>=1.20``)
4546
* :py:meth:`Dataset.rolling_exp` and :py:meth:`DataArray.rolling_exp` (uses
4647
``numbagg``)
47-
* :py:meth:`Dataset.rolling` and :py:meth:`DataArray.rolling` (uses internal functions
48-
of ``numpy``)
4948
* :py:meth:`Dataset.interpolate_na` and :py:meth:`DataArray.interpolate_na` (uses
5049
:py:class:`numpy.vectorize`)
5150
* :py:func:`apply_ufunc` with ``vectorize=True`` (uses :py:class:`numpy.vectorize`)

xarray/core/dask_array_compat.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,3 +93,132 @@ def nanmedian(a, axis=None, keepdims=False):
9393
)
9494

9595
return result
96+
97+
98+
if LooseVersion(dask_version) > LooseVersion("2.30.0"):
99+
ensure_minimum_chunksize = da.overlap.ensure_minimum_chunksize
100+
else:
101+
102+
# copied from dask
103+
def ensure_minimum_chunksize(size, chunks):
104+
"""Determine new chunks to ensure that every chunk >= size
105+
106+
Parameters
107+
----------
108+
size: int
109+
The maximum size of any chunk.
110+
chunks: tuple
111+
Chunks along one axis, e.g. ``(3, 3, 2)``
112+
113+
Examples
114+
--------
115+
>>> ensure_minimum_chunksize(10, (20, 20, 1))
116+
(20, 11, 10)
117+
>>> ensure_minimum_chunksize(3, (1, 1, 3))
118+
(5,)
119+
120+
See Also
121+
--------
122+
overlap
123+
"""
124+
if size <= min(chunks):
125+
return chunks
126+
127+
# add too-small chunks to chunks before them
128+
output = []
129+
new = 0
130+
for c in chunks:
131+
if c < size:
132+
if new > size + (size - c):
133+
output.append(new - (size - c))
134+
new = size
135+
else:
136+
new += c
137+
if new >= size:
138+
output.append(new)
139+
new = 0
140+
if c >= size:
141+
new += c
142+
if new >= size:
143+
output.append(new)
144+
elif len(output) >= 1:
145+
output[-1] += new
146+
else:
147+
raise ValueError(
148+
f"The overlapping depth {size} is larger than your "
149+
f"array {sum(chunks)}."
150+
)
151+
152+
return tuple(output)
153+
154+
155+
if LooseVersion(dask_version) > LooseVersion("2021.03.0"):
156+
sliding_window_view = da.lib.stride_tricks.sliding_window_view
157+
else:
158+
159+
def sliding_window_view(x, window_shape, axis=None):
160+
from dask.array.overlap import map_overlap
161+
from numpy.core.numeric import normalize_axis_tuple # type: ignore
162+
163+
from .npcompat import sliding_window_view as _np_sliding_window_view
164+
165+
window_shape = (
166+
tuple(window_shape) if np.iterable(window_shape) else (window_shape,)
167+
)
168+
169+
window_shape_array = np.array(window_shape)
170+
if np.any(window_shape_array <= 0):
171+
raise ValueError("`window_shape` must contain positive values")
172+
173+
if axis is None:
174+
axis = tuple(range(x.ndim))
175+
if len(window_shape) != len(axis):
176+
raise ValueError(
177+
f"Since axis is `None`, must provide "
178+
f"window_shape for all dimensions of `x`; "
179+
f"got {len(window_shape)} window_shape elements "
180+
f"and `x.ndim` is {x.ndim}."
181+
)
182+
else:
183+
axis = normalize_axis_tuple(axis, x.ndim, allow_duplicate=True)
184+
if len(window_shape) != len(axis):
185+
raise ValueError(
186+
f"Must provide matching length window_shape and "
187+
f"axis; got {len(window_shape)} window_shape "
188+
f"elements and {len(axis)} axes elements."
189+
)
190+
191+
depths = [0] * x.ndim
192+
for ax, window in zip(axis, window_shape):
193+
depths[ax] += window - 1
194+
195+
# Ensure that each chunk is big enough to leave at least a size-1 chunk
196+
# after windowing (this is only really necessary for the last chunk).
197+
safe_chunks = tuple(
198+
ensure_minimum_chunksize(d + 1, c) for d, c in zip(depths, x.chunks)
199+
)
200+
x = x.rechunk(safe_chunks)
201+
202+
# result.shape = x_shape_trimmed + window_shape,
203+
# where x_shape_trimmed is x.shape with every entry
204+
# reduced by one less than the corresponding window size.
205+
# trim chunks to match x_shape_trimmed
206+
newchunks = tuple(
207+
c[:-1] + (c[-1] - d,) for d, c in zip(depths, x.chunks)
208+
) + tuple((window,) for window in window_shape)
209+
210+
kwargs = dict(
211+
depth=tuple((0, d) for d in depths), # Overlap on +ve side only
212+
boundary="none",
213+
meta=x._meta,
214+
new_axis=range(x.ndim, x.ndim + len(axis)),
215+
chunks=newchunks,
216+
trim=False,
217+
window_shape=window_shape,
218+
axis=axis,
219+
)
220+
# map_overlap's signature changed in https://github.com/dask/dask/pull/6165
221+
if LooseVersion(dask_version) > "2.18.0":
222+
return map_overlap(_np_sliding_window_view, x, align_arrays=False, **kwargs)
223+
else:
224+
return map_overlap(x, _np_sliding_window_view, **kwargs)

xarray/core/dask_array_ops.py

Lines changed: 0 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import numpy as np
2-
31
from . import dtypes, nputils
42

53

@@ -26,92 +24,6 @@ def dask_rolling_wrapper(moving_func, a, window, min_count=None, axis=-1):
2624
return result
2725

2826

29-
def rolling_window(a, axis, window, center, fill_value):
30-
"""Dask's equivalence to np.utils.rolling_window"""
31-
import dask.array as da
32-
33-
if not hasattr(axis, "__len__"):
34-
axis = [axis]
35-
window = [window]
36-
center = [center]
37-
38-
orig_shape = a.shape
39-
depth = {d: 0 for d in range(a.ndim)}
40-
offset = [0] * a.ndim
41-
drop_size = [0] * a.ndim
42-
pad_size = [0] * a.ndim
43-
for ax, win, cent in zip(axis, window, center):
44-
if ax < 0:
45-
ax = a.ndim + ax
46-
depth[ax] = int(win / 2)
47-
# For evenly sized window, we need to crop the first point of each block.
48-
offset[ax] = 1 if win % 2 == 0 else 0
49-
50-
if depth[ax] > min(a.chunks[ax]):
51-
raise ValueError(
52-
"For window size %d, every chunk should be larger than %d, "
53-
"but the smallest chunk size is %d. Rechunk your array\n"
54-
"with a larger chunk size or a chunk size that\n"
55-
"more evenly divides the shape of your array."
56-
% (win, depth[ax], min(a.chunks[ax]))
57-
)
58-
59-
# Although da.overlap pads values to boundaries of the array,
60-
# the size of the generated array is smaller than what we want
61-
# if center == False.
62-
if cent:
63-
start = int(win / 2) # 10 -> 5, 9 -> 4
64-
end = win - 1 - start
65-
else:
66-
start, end = win - 1, 0
67-
pad_size[ax] = max(start, end) + offset[ax] - depth[ax]
68-
drop_size[ax] = 0
69-
# pad_size becomes more than 0 when the overlapped array is smaller than
70-
# needed. In this case, we need to enlarge the original array by padding
71-
# before overlapping.
72-
if pad_size[ax] > 0:
73-
if pad_size[ax] < depth[ax]:
74-
# overlapping requires each chunk larger than depth. If pad_size is
75-
# smaller than the depth, we enlarge this and truncate it later.
76-
drop_size[ax] = depth[ax] - pad_size[ax]
77-
pad_size[ax] = depth[ax]
78-
79-
# TODO maybe following two lines can be summarized.
80-
a = da.pad(
81-
a, [(p, 0) for p in pad_size], mode="constant", constant_values=fill_value
82-
)
83-
boundary = {d: fill_value for d in range(a.ndim)}
84-
85-
# create overlap arrays
86-
ag = da.overlap.overlap(a, depth=depth, boundary=boundary)
87-
88-
def func(x, window, axis):
89-
x = np.asarray(x)
90-
index = [slice(None)] * x.ndim
91-
for ax, win in zip(axis, window):
92-
x = nputils._rolling_window(x, win, ax)
93-
index[ax] = slice(offset[ax], None)
94-
return x[tuple(index)]
95-
96-
chunks = list(a.chunks) + window
97-
new_axis = [a.ndim + i for i in range(len(axis))]
98-
out = da.map_blocks(
99-
func,
100-
ag,
101-
dtype=a.dtype,
102-
new_axis=new_axis,
103-
chunks=chunks,
104-
window=window,
105-
axis=axis,
106-
)
107-
108-
# crop boundary.
109-
index = [slice(None)] * a.ndim
110-
for ax in axis:
111-
index[ax] = slice(drop_size[ax], drop_size[ax] + orig_shape[ax])
112-
return out[tuple(index)]
113-
114-
11527
def least_squares(lhs, rhs, rcond=None, skipna=False):
11628
import dask.array as da
11729

xarray/core/duck_array_ops.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -614,15 +614,15 @@ def last(values, axis, skipna=None):
614614
return take(values, -1, axis=axis)
615615

616616

617-
def rolling_window(array, axis, window, center, fill_value):
617+
def sliding_window_view(array, window_shape, axis):
618618
"""
619619
Make an ndarray with a rolling window of axis-th dimension.
620620
The rolling dimension will be placed at the last dimension.
621621
"""
622622
if is_duck_dask_array(array):
623-
return dask_array_ops.rolling_window(array, axis, window, center, fill_value)
624-
else: # np.ndarray
625-
return nputils.rolling_window(array, axis, window, center, fill_value)
623+
return dask_array_compat.sliding_window_view(array, window_shape, axis)
624+
else:
625+
return npcompat.sliding_window_view(array, window_shape, axis)
626626

627627

628628
def least_squares(lhs, rhs, rcond=None, skipna=False):

xarray/core/npcompat.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
3131
import builtins
3232
import operator
33+
from distutils.version import LooseVersion
3334
from typing import Union
3435

3536
import numpy as np
@@ -96,3 +97,99 @@ def __array_function__(self, *args, **kwargs):
9697

9798

9899
IS_NEP18_ACTIVE = _is_nep18_active()
100+
101+
102+
if LooseVersion(np.__version__) >= "1.20.0":
103+
sliding_window_view = np.lib.stride_tricks.sliding_window_view
104+
else:
105+
from numpy.core.numeric import normalize_axis_tuple # type: ignore
106+
from numpy.lib.stride_tricks import as_strided
107+
108+
# copied from numpy.lib.stride_tricks
109+
def sliding_window_view(
110+
x, window_shape, axis=None, *, subok=False, writeable=False
111+
):
112+
"""
113+
Create a sliding window view into the array with the given window shape.
114+
115+
Also known as rolling or moving window, the window slides across all
116+
dimensions of the array and extracts subsets of the array at all window
117+
positions.
118+
119+
.. versionadded:: 1.20.0
120+
121+
Parameters
122+
----------
123+
x : array_like
124+
Array to create the sliding window view from.
125+
window_shape : int or tuple of int
126+
Size of window over each axis that takes part in the sliding window.
127+
If `axis` is not present, must have same length as the number of input
128+
array dimensions. Single integers `i` are treated as if they were the
129+
tuple `(i,)`.
130+
axis : int or tuple of int, optional
131+
Axis or axes along which the sliding window is applied.
132+
By default, the sliding window is applied to all axes and
133+
`window_shape[i]` will refer to axis `i` of `x`.
134+
If `axis` is given as a `tuple of int`, `window_shape[i]` will refer to
135+
the axis `axis[i]` of `x`.
136+
Single integers `i` are treated as if they were the tuple `(i,)`.
137+
subok : bool, optional
138+
If True, sub-classes will be passed-through, otherwise the returned
139+
array will be forced to be a base-class array (default).
140+
writeable : bool, optional
141+
When true, allow writing to the returned view. The default is false,
142+
as this should be used with caution: the returned view contains the
143+
same memory location multiple times, so writing to one location will
144+
cause others to change.
145+
146+
Returns
147+
-------
148+
view : ndarray
149+
Sliding window view of the array. The sliding window dimensions are
150+
inserted at the end, and the original dimensions are trimmed as
151+
required by the size of the sliding window.
152+
That is, ``view.shape = x_shape_trimmed + window_shape``, where
153+
``x_shape_trimmed`` is ``x.shape`` with every entry reduced by one less
154+
than the corresponding window size.
155+
"""
156+
window_shape = (
157+
tuple(window_shape) if np.iterable(window_shape) else (window_shape,)
158+
)
159+
# first convert input to array, possibly keeping subclass
160+
x = np.array(x, copy=False, subok=subok)
161+
162+
window_shape_array = np.array(window_shape)
163+
if np.any(window_shape_array < 0):
164+
raise ValueError("`window_shape` cannot contain negative values")
165+
166+
if axis is None:
167+
axis = tuple(range(x.ndim))
168+
if len(window_shape) != len(axis):
169+
raise ValueError(
170+
f"Since axis is `None`, must provide "
171+
f"window_shape for all dimensions of `x`; "
172+
f"got {len(window_shape)} window_shape elements "
173+
f"and `x.ndim` is {x.ndim}."
174+
)
175+
else:
176+
axis = normalize_axis_tuple(axis, x.ndim, allow_duplicate=True)
177+
if len(window_shape) != len(axis):
178+
raise ValueError(
179+
f"Must provide matching length window_shape and "
180+
f"axis; got {len(window_shape)} window_shape "
181+
f"elements and {len(axis)} axes elements."
182+
)
183+
184+
out_strides = x.strides + tuple(x.strides[ax] for ax in axis)
185+
186+
# note: same axis can be windowed repeatedly
187+
x_shape_trimmed = list(x.shape)
188+
for ax, dim in zip(axis, window_shape):
189+
if x_shape_trimmed[ax] < dim:
190+
raise ValueError("window shape cannot be larger than input array shape")
191+
x_shape_trimmed[ax] -= dim - 1
192+
out_shape = tuple(x_shape_trimmed) + window_shape
193+
return as_strided(
194+
x, strides=out_strides, shape=out_shape, subok=subok, writeable=writeable
195+
)

0 commit comments

Comments
 (0)