Skip to content

Commit 766da34

Browse files
TomNicholasdcherianIllviljan
authored
Generalize cumulative reduction (scan) to non-dask types (#8019)
* add scan to ChunkManager ABC * implement scan for dask using cumreduction * generalize push to work for non-dask chunked arrays * whatsnew * fix importerror * Allow arbitrary kwargs Co-authored-by: Deepak Cherian <[email protected]> * Type hint return value of T_ChunkedArray Co-authored-by: Illviljan <[email protected]> * Type hint return value of Dask array * ffill -> bfill in doc/whats-new.rst Co-authored-by: Deepak Cherian <[email protected]> * hopefully fix docs warning --------- Co-authored-by: Deepak Cherian <[email protected]> Co-authored-by: Illviljan <[email protected]>
1 parent 2971994 commit 766da34

File tree

3 files changed

+63
-0
lines changed

3 files changed

+63
-0
lines changed

doc/whats-new.rst

+4
Original file line numberDiff line numberDiff line change
@@ -589,6 +589,10 @@ Internal Changes
589589

590590
- :py:func:`as_variable` now consistently includes the variable name in any exceptions
591591
raised. (:pull:`7995`). By `Peter Hill <https://github.com/ZedThree>`_
592+
- Redirect cumulative reduction functions internally through the :py:class:`ChunkManagerEntryPoint`,
593+
potentially allowing :py:meth:`~xarray.DataArray.ffill` and :py:meth:`~xarray.DataArray.bfill` to
594+
use non-dask chunked array types.
595+
(:pull:`8019`) By `Tom Nicholas <https://github.com/TomNicholas>`_.
592596
- :py:func:`encode_dataset_coordinates` now sorts coordinates automatically assigned to
593597
`coordinates` attributes during serialization (:issue:`8026`, :pull:`8034`).
594598
`By Ian Carroll <https://github.com/itcarroll>`_.

xarray/core/daskmanager.py

+22
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,28 @@ def reduction(
9797
keepdims=keepdims,
9898
)
9999

100+
def scan(
101+
self,
102+
func: Callable,
103+
binop: Callable,
104+
ident: float,
105+
arr: T_ChunkedArray,
106+
axis: int | None = None,
107+
dtype: np.dtype | None = None,
108+
**kwargs,
109+
) -> DaskArray:
110+
from dask.array.reductions import cumreduction
111+
112+
return cumreduction(
113+
func,
114+
binop,
115+
ident,
116+
arr,
117+
axis=axis,
118+
dtype=dtype,
119+
**kwargs,
120+
)
121+
100122
def apply_gufunc(
101123
self,
102124
func: Callable,

xarray/core/parallelcompat.py

+37
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,43 @@ def reduction(
403403
"""
404404
raise NotImplementedError()
405405

406+
def scan(
407+
self,
408+
func: Callable,
409+
binop: Callable,
410+
ident: float,
411+
arr: T_ChunkedArray,
412+
axis: int | None = None,
413+
dtype: np.dtype | None = None,
414+
**kwargs,
415+
) -> T_ChunkedArray:
416+
"""
417+
General version of a 1D scan, also known as a cumulative array reduction.
418+
419+
Used in ``ffill`` and ``bfill`` in xarray.
420+
421+
Parameters
422+
----------
423+
func: callable
424+
Cumulative function like np.cumsum or np.cumprod
425+
binop: callable
426+
Associated binary operator like ``np.cumsum->add`` or ``np.cumprod->mul``
427+
ident: Number
428+
Associated identity like ``np.cumsum->0`` or ``np.cumprod->1``
429+
arr: dask Array
430+
axis: int, optional
431+
dtype: dtype
432+
433+
Returns
434+
-------
435+
Chunked array
436+
437+
See also
438+
--------
439+
dask.array.cumreduction
440+
"""
441+
raise NotImplementedError()
442+
406443
@abstractmethod
407444
def apply_gufunc(
408445
self,

0 commit comments

Comments
 (0)