Skip to content

Commit 69950a4

Browse files
dcherianmathausemax-sixty
authored
Support ffill and bfill along chunked dimensions (#5187)
Co-authored-by: Mathias Hauser <[email protected]> Co-authored-by: Maximilian Roos <[email protected]>
1 parent 6bfbaed commit 69950a4

8 files changed

+96
-36
lines changed

doc/whats-new.rst

+2
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,8 @@ Deprecations
100100

101101
Bug fixes
102102
~~~~~~~~~
103+
- Properly support :py:meth:`DataArray.ffill`, :py:meth:`DataArray.bfill`, :py:meth:`Dataset.ffill`, :py:meth:`Dataset.bfill` along chunked dimensions.
104+
(:issue:`2699`).By `Deepak Cherian <https://github.com/dcherian>`_.
103105
- Fix 2d plot failure for certain combinations of dimensions when `x` is 1d and `y` is
104106
2d (:issue:`5097`, :pull:`5099`). By `John Omotani <https://github.com/johnomotani>`_.
105107
- Ensure standard calendar times encoded with large values (i.e. greater than approximately 292 years), can be decoded correctly without silently overflowing (:pull:`5050`). This was a regression in xarray 0.17.0. By `Zeb Nicholls <https://github.com/znicholls>`_.

xarray/core/dask_array_ops.py

+21
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,24 @@ def least_squares(lhs, rhs, rcond=None, skipna=False):
5151
# See issue dask/dask#6516
5252
coeffs, residuals, _, _ = da.linalg.lstsq(lhs_da, rhs)
5353
return coeffs, residuals
54+
55+
56+
def push(array, n, axis):
57+
"""
58+
Dask-aware bottleneck.push
59+
"""
60+
from bottleneck import push
61+
62+
if len(array.chunks[axis]) > 1 and n is not None and n < array.shape[axis]:
63+
raise NotImplementedError(
64+
"Cannot fill along a chunked axis when limit is not None."
65+
"Either rechunk to a single chunk along this axis or call .compute() or .load() first."
66+
)
67+
if all(c == 1 for c in array.chunks[axis]):
68+
array = array.rechunk({axis: 2})
69+
pushed = array.map_blocks(push, axis=axis, n=n)
70+
if len(array.chunks[axis]) > 1:
71+
pushed = pushed.map_overlap(
72+
push, axis=axis, n=n, depth={axis: (1, 0)}, boundary="none"
73+
)
74+
return pushed

xarray/core/dataarray.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -2515,7 +2515,8 @@ def ffill(self, dim: Hashable, limit: int = None) -> "DataArray":
25152515
The maximum number of consecutive NaN values to forward fill. In
25162516
other words, if there is a gap with more than this number of
25172517
consecutive NaNs, it will only be partially filled. Must be greater
2518-
than 0 or None for no limit.
2518+
than 0 or None for no limit. Must be None or greater than or equal
2519+
to axis length if filling along chunked axes (dimensions).
25192520
25202521
Returns
25212522
-------
@@ -2539,7 +2540,8 @@ def bfill(self, dim: Hashable, limit: int = None) -> "DataArray":
25392540
The maximum number of consecutive NaN values to backward fill. In
25402541
other words, if there is a gap with more than this number of
25412542
consecutive NaNs, it will only be partially filled. Must be greater
2542-
than 0 or None for no limit.
2543+
than 0 or None for no limit. Must be None or greater than or equal
2544+
to axis length if filling along chunked axes (dimensions).
25432545
25442546
Returns
25452547
-------

xarray/core/dataset.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -4654,7 +4654,8 @@ def ffill(self, dim: Hashable, limit: int = None) -> "Dataset":
46544654
The maximum number of consecutive NaN values to forward fill. In
46554655
other words, if there is a gap with more than this number of
46564656
consecutive NaNs, it will only be partially filled. Must be greater
4657-
than 0 or None for no limit.
4657+
than 0 or None for no limit. Must be None or greater than or equal
4658+
to axis length if filling along chunked axes (dimensions).
46584659
46594660
Returns
46604661
-------
@@ -4679,7 +4680,8 @@ def bfill(self, dim: Hashable, limit: int = None) -> "Dataset":
46794680
The maximum number of consecutive NaN values to backward fill. In
46804681
other words, if there is a gap with more than this number of
46814682
consecutive NaNs, it will only be partially filled. Must be greater
4682-
than 0 or None for no limit.
4683+
than 0 or None for no limit. Must be None or greater than or equal
4684+
to axis length if filling along chunked axes (dimensions).
46834685
46844686
Returns
46854687
-------

xarray/core/duck_array_ops.py

+9
Original file line numberDiff line numberDiff line change
@@ -641,3 +641,12 @@ def least_squares(lhs, rhs, rcond=None, skipna=False):
641641
return dask_array_ops.least_squares(lhs, rhs, rcond=rcond, skipna=skipna)
642642
else:
643643
return nputils.least_squares(lhs, rhs, rcond=rcond, skipna=skipna)
644+
645+
646+
def push(array, n, axis):
647+
from bottleneck import push
648+
649+
if is_duck_dask_array(array):
650+
return dask_array_ops.push(array, n, axis)
651+
else:
652+
return push(array, n, axis)

xarray/core/missing.py

+5-9
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from . import utils
1212
from .common import _contains_datetime_like_objects, ones_like
1313
from .computation import apply_ufunc
14-
from .duck_array_ops import datetime_to_numeric, timedelta_to_numeric
14+
from .duck_array_ops import datetime_to_numeric, push, timedelta_to_numeric
1515
from .options import _get_keep_attrs
1616
from .pycompat import is_duck_dask_array
1717
from .utils import OrderedSet, is_scalar
@@ -390,30 +390,26 @@ def func_interpolate_na(interpolator, y, x, **kwargs):
390390

391391
def _bfill(arr, n=None, axis=-1):
392392
"""inverse of ffill"""
393-
import bottleneck as bn
394-
395393
arr = np.flip(arr, axis=axis)
396394

397395
# fill
398-
arr = bn.push(arr, axis=axis, n=n)
396+
arr = push(arr, axis=axis, n=n)
399397

400398
# reverse back to original
401399
return np.flip(arr, axis=axis)
402400

403401

404402
def ffill(arr, dim=None, limit=None):
405403
"""forward fill missing values"""
406-
import bottleneck as bn
407-
408404
axis = arr.get_axis_num(dim)
409405

410406
# work around for bottleneck 178
411407
_limit = limit if limit is not None else arr.shape[axis]
412408

413409
return apply_ufunc(
414-
bn.push,
410+
push,
415411
arr,
416-
dask="parallelized",
412+
dask="allowed",
417413
keep_attrs=True,
418414
output_dtypes=[arr.dtype],
419415
kwargs=dict(n=_limit, axis=axis),
@@ -430,7 +426,7 @@ def bfill(arr, dim=None, limit=None):
430426
return apply_ufunc(
431427
_bfill,
432428
arr,
433-
dask="parallelized",
429+
dask="allowed",
434430
keep_attrs=True,
435431
output_dtypes=[arr.dtype],
436432
kwargs=dict(n=_limit, axis=axis),

xarray/tests/test_duck_array_ops.py

+25
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
mean,
2121
np_timedelta64_to_float,
2222
pd_timedelta_to_float,
23+
push,
2324
py_timedelta_to_float,
2425
stack,
2526
timedelta_to_numeric,
@@ -34,6 +35,7 @@
3435
has_dask,
3536
has_scipy,
3637
raise_if_dask_computes,
38+
requires_bottleneck,
3739
requires_cftime,
3840
requires_dask,
3941
)
@@ -869,3 +871,26 @@ def test_least_squares(use_dask, skipna):
869871

870872
np.testing.assert_allclose(coeffs, [1.5, 1.25])
871873
np.testing.assert_allclose(residuals, [2.0])
874+
875+
876+
@requires_dask
877+
@requires_bottleneck
878+
def test_push_dask():
879+
import bottleneck
880+
import dask.array
881+
882+
array = np.array([np.nan, np.nan, np.nan, 1, 2, 3, np.nan, np.nan, 4, 5, np.nan, 6])
883+
expected = bottleneck.push(array, axis=0)
884+
for c in range(1, 11):
885+
with raise_if_dask_computes():
886+
actual = push(dask.array.from_array(array, chunks=c), axis=0, n=None)
887+
np.testing.assert_equal(actual, expected)
888+
889+
# some chunks of size-1 with NaN
890+
with raise_if_dask_computes():
891+
actual = push(
892+
dask.array.from_array(array, chunks=(1, 2, 3, 2, 2, 1, 1)),
893+
axis=0,
894+
n=None,
895+
)
896+
np.testing.assert_equal(actual, expected)

xarray/tests/test_missing.py

+26-23
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
assert_allclose,
1818
assert_array_equal,
1919
assert_equal,
20+
raise_if_dask_computes,
2021
requires_bottleneck,
2122
requires_cftime,
2223
requires_dask,
@@ -393,37 +394,39 @@ def test_ffill():
393394

394395
@requires_bottleneck
395396
@requires_dask
396-
def test_ffill_dask():
397+
@pytest.mark.parametrize("method", ["ffill", "bfill"])
398+
def test_ffill_bfill_dask(method):
397399
da, _ = make_interpolate_example_data((40, 40), 0.5)
398400
da = da.chunk({"x": 5})
399-
actual = da.ffill("time")
400-
expected = da.load().ffill("time")
401-
assert isinstance(actual.data, dask_array_type)
402-
assert_equal(actual, expected)
403401

404-
# with limit
405-
da = da.chunk({"x": 5})
406-
actual = da.ffill("time", limit=3)
407-
expected = da.load().ffill("time", limit=3)
408-
assert isinstance(actual.data, dask_array_type)
402+
dask_method = getattr(da, method)
403+
numpy_method = getattr(da.compute(), method)
404+
# unchunked axis
405+
with raise_if_dask_computes():
406+
actual = dask_method("time")
407+
expected = numpy_method("time")
409408
assert_equal(actual, expected)
410409

411-
412-
@requires_bottleneck
413-
@requires_dask
414-
def test_bfill_dask():
415-
da, _ = make_interpolate_example_data((40, 40), 0.5)
416-
da = da.chunk({"x": 5})
417-
actual = da.bfill("time")
418-
expected = da.load().bfill("time")
419-
assert isinstance(actual.data, dask_array_type)
410+
# chunked axis
411+
with raise_if_dask_computes():
412+
actual = dask_method("x")
413+
expected = numpy_method("x")
420414
assert_equal(actual, expected)
421415

422416
# with limit
423-
da = da.chunk({"x": 5})
424-
actual = da.bfill("time", limit=3)
425-
expected = da.load().bfill("time", limit=3)
426-
assert isinstance(actual.data, dask_array_type)
417+
with raise_if_dask_computes():
418+
actual = dask_method("time", limit=3)
419+
expected = numpy_method("time", limit=3)
420+
assert_equal(actual, expected)
421+
422+
# limit < axis size
423+
with pytest.raises(NotImplementedError):
424+
actual = dask_method("x", limit=2)
425+
426+
# limit > axis size
427+
with raise_if_dask_computes():
428+
actual = dask_method("x", limit=41)
429+
expected = numpy_method("x", limit=41)
427430
assert_equal(actual, expected)
428431

429432

0 commit comments

Comments
 (0)