Skip to content

Commit b16a104

Browse files
authored
Use map_overlap for rolling reductions with Dask (#9770)
* Use ``map_overlap`` for rolling reducers with Dask * Enable argmin test * Update
1 parent 568dd6f commit b16a104

File tree

3 files changed

+23
-30
lines changed

3 files changed

+23
-30
lines changed

xarray/core/dask_array_ops.py

+8-18
Original file line numberDiff line numberDiff line change
@@ -5,25 +5,15 @@
55

66
def dask_rolling_wrapper(moving_func, a, window, min_count=None, axis=-1):
77
"""Wrapper to apply bottleneck moving window funcs on dask arrays"""
8-
import dask.array as da
9-
10-
dtype, fill_value = dtypes.maybe_promote(a.dtype)
11-
a = a.astype(dtype)
12-
# inputs for overlap
13-
if axis < 0:
14-
axis = a.ndim + axis
15-
depth = {d: 0 for d in range(a.ndim)}
16-
depth[axis] = (window + 1) // 2
17-
boundary = {d: fill_value for d in range(a.ndim)}
18-
# Create overlap array.
19-
ag = da.overlap.overlap(a, depth=depth, boundary=boundary)
20-
# apply rolling func
21-
out = da.map_blocks(
22-
moving_func, ag, window, min_count=min_count, axis=axis, dtype=a.dtype
8+
dtype, _ = dtypes.maybe_promote(a.dtype)
9+
return a.data.map_overlap(
10+
moving_func,
11+
depth={axis: (window - 1, 0)},
12+
axis=axis,
13+
dtype=dtype,
14+
window=window,
15+
min_count=min_count,
2316
)
24-
# trim array
25-
result = da.overlap.trim_internal(out, depth)
26-
return result
2717

2818

2919
def least_squares(lhs, rhs, rcond=None, skipna=False):

xarray/core/rolling.py

+13-11
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import numpy as np
1111
from packaging.version import Version
1212

13-
from xarray.core import dtypes, duck_array_ops, utils
13+
from xarray.core import dask_array_ops, dtypes, duck_array_ops, utils
1414
from xarray.core.arithmetic import CoarsenArithmetic
1515
from xarray.core.options import OPTIONS, _get_keep_attrs
1616
from xarray.core.types import CoarsenBoundaryOptions, SideOptions, T_Xarray
@@ -597,16 +597,18 @@ def _bottleneck_reduce(self, func, keep_attrs, **kwargs):
597597
padded = padded.pad({self.dim[0]: (0, -shift)}, mode="constant")
598598

599599
if is_duck_dask_array(padded.data):
600-
raise AssertionError("should not be reachable")
600+
values = dask_array_ops.dask_rolling_wrapper(
601+
func, padded, axis=axis, window=self.window[0], min_count=min_count
602+
)
601603
else:
602604
values = func(
603605
padded.data, window=self.window[0], min_count=min_count, axis=axis
604606
)
605-
# index 0 is at the rightmost edge of the window
606-
# need to reverse index here
607-
# see GH #8541
608-
if func in [bottleneck.move_argmin, bottleneck.move_argmax]:
609-
values = self.window[0] - 1 - values
607+
# index 0 is at the rightmost edge of the window
608+
# need to reverse index here
609+
# see GH #8541
610+
if func in [bottleneck.move_argmin, bottleneck.move_argmax]:
611+
values = self.window[0] - 1 - values
610612

611613
if self.center[0]:
612614
values = values[valid]
@@ -669,12 +671,12 @@ def _array_reduce(
669671
if (
670672
OPTIONS["use_bottleneck"]
671673
and bottleneck_move_func is not None
672-
and not is_duck_dask_array(self.obj.data)
674+
and (
675+
not is_duck_dask_array(self.obj.data)
676+
or module_available("dask", "2024.11.0")
677+
)
673678
and self.ndim == 1
674679
):
675-
# TODO: re-enable bottleneck with dask after the issues
676-
# underlying https://github.com/pydata/xarray/issues/2940 are
677-
# fixed.
678680
return self._bottleneck_reduce(
679681
bottleneck_move_func, keep_attrs=keep_attrs, **kwargs
680682
)

xarray/tests/test_rolling.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -107,12 +107,13 @@ def test_rolling_properties(self, da) -> None:
107107
):
108108
da.rolling(foo=2)
109109

110+
@requires_dask
110111
@pytest.mark.parametrize(
111112
"name", ("sum", "mean", "std", "min", "max", "median", "argmin", "argmax")
112113
)
113114
@pytest.mark.parametrize("center", (True, False, None))
114115
@pytest.mark.parametrize("min_periods", (1, None))
115-
@pytest.mark.parametrize("backend", ["numpy"], indirect=True)
116+
@pytest.mark.parametrize("backend", ["numpy", "dask"], indirect=True)
116117
def test_rolling_wrapped_bottleneck(
117118
self, da, name, center, min_periods, compute_backend
118119
) -> None:

0 commit comments

Comments
 (0)