Skip to content

Commit bf06e12

Browse files
committed
Merge branch 'main' into map-blocks-indexes-fix
* main: Adapt map_blocks to use new Coordinates API (pydata#8560) add xeofs to ecosystem.rst (pydata#8561) Offer a fixture for unifying DataArray & Dataset tests (pydata#8533) Generalize cumulative reduction (scan) to non-dask types (pydata#8019)
2 parents 84ba745 + b444438 commit bf06e12

File tree

7 files changed

+170
-68
lines changed

7 files changed

+170
-68
lines changed

doc/ecosystem.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ Extend xarray capabilities
7878
- `xarray-dataclasses <https://github.com/astropenguin/xarray-dataclasses>`_: xarray extension for typed DataArray and Dataset creation.
7979
- `xarray_einstats <https://xarray-einstats.readthedocs.io>`_: Statistics, linear algebra and einops for xarray
8080
- `xarray_extras <https://github.com/crusaderky/xarray_extras>`_: Advanced algorithms for xarray objects (e.g. integrations/interpolations).
81+
- `xeofs <https://github.com/nicrie/xeofs>`_: PCA/EOF analysis and related techniques, integrated with xarray and Dask for efficient handling of large-scale data.
8182
- `xpublish <https://xpublish.readthedocs.io/>`_: Publish Xarray Datasets via a Zarr compatible REST API.
8283
- `xrft <https://github.com/rabernat/xrft>`_: Fourier transforms for xarray data.
8384
- `xr-scipy <https://xr-scipy.readthedocs.io>`_: A lightweight scipy wrapper for xarray.

doc/whats-new.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -589,6 +589,10 @@ Internal Changes
589589

590590
- :py:func:`as_variable` now consistently includes the variable name in any exceptions
591591
raised. (:pull:`7995`). By `Peter Hill <https://github.com/ZedThree>`_
592+
- Redirect cumulative reduction functions internally through the :py:class:`ChunkManagerEntryPoint`,
593+
potentially allowing :py:meth:`~xarray.DataArray.ffill` and :py:meth:`~xarray.DataArray.bfill` to
594+
use non-dask chunked array types.
595+
(:pull:`8019`) By `Tom Nicholas <https://github.com/TomNicholas>`_.
592596
- :py:func:`encode_dataset_coordinates` now sorts coordinates automatically assigned to
593597
`coordinates` attributes during serialization (:issue:`8026`, :pull:`8034`).
594598
`By Ian Carroll <https://github.com/itcarroll>`_.

xarray/core/daskmanager.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,28 @@ def reduction(
9797
keepdims=keepdims,
9898
)
9999

100+
def scan(
101+
self,
102+
func: Callable,
103+
binop: Callable,
104+
ident: float,
105+
arr: T_ChunkedArray,
106+
axis: int | None = None,
107+
dtype: np.dtype | None = None,
108+
**kwargs,
109+
) -> DaskArray:
110+
from dask.array.reductions import cumreduction
111+
112+
return cumreduction(
113+
func,
114+
binop,
115+
ident,
116+
arr,
117+
axis=axis,
118+
dtype=dtype,
119+
**kwargs,
120+
)
121+
100122
def apply_gufunc(
101123
self,
102124
func: Callable,

xarray/core/parallel.py

Lines changed: 38 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,15 @@
44
import itertools
55
import operator
66
from collections.abc import Hashable, Iterable, Mapping, Sequence
7-
from typing import TYPE_CHECKING, Any, Callable
7+
from typing import TYPE_CHECKING, Any, Callable, Literal, TypedDict
88

99
import numpy as np
1010

1111
from xarray.core.alignment import align
1212
from xarray.core.coordinates import Coordinates
1313
from xarray.core.dataarray import DataArray
1414
from xarray.core.dataset import Dataset
15+
from xarray.core.indexes import Index
1516
from xarray.core.merge import merge
1617
from xarray.core.pycompat import is_dask_collection
1718
from xarray.core.variable import Variable
@@ -20,6 +21,13 @@
2021
from xarray.core.types import T_Xarray
2122

2223

24+
class ExpectedDict(TypedDict):
25+
shapes: dict[Hashable, int]
26+
coords: set[Hashable]
27+
data_vars: set[Hashable]
28+
indexes: dict[Hashable, Index]
29+
30+
2331
def unzip(iterable):
2432
return zip(*iterable)
2533

@@ -34,7 +42,9 @@ def assert_chunks_compatible(a: Dataset, b: Dataset):
3442

3543

3644
def check_result_variables(
37-
result: DataArray | Dataset, expected: Mapping[str, Any], kind: str
45+
result: DataArray | Dataset,
46+
expected: ExpectedDict,
47+
kind: Literal["coords", "data_vars"],
3848
):
3949
if kind == "coords":
4050
nice_str = "coordinate"
@@ -326,7 +336,7 @@ def _wrapper(
326336
args: list,
327337
kwargs: dict,
328338
arg_is_array: Iterable[bool],
329-
expected: dict,
339+
expected: ExpectedDict,
330340
):
331341
"""
332342
Wrapper function that receives datasets in args; converts to dataarrays when necessary;
@@ -429,6 +439,8 @@ def _wrapper(
429439

430440
merged_coordinates = merge([arg.coords for arg in aligned]).coords
431441

442+
merged_coordinates = merge([arg.coords for arg in aligned]).coords
443+
432444
_, npargs = unzip(
433445
sorted(list(zip(xarray_indices, xarray_objs)) + others, key=lambda x: x[0])
434446
)
@@ -444,11 +456,11 @@ def _wrapper(
444456
# infer template by providing zero-shaped arrays
445457
template = infer_template(func, aligned[0], *args, **kwargs)
446458
template_coords = set(template.coords)
447-
preserved_coord_names = template_coords & set(merged_coordinates)
448-
new_indexes = set(template.xindexes) - set(merged_coordinates)
459+
preserved_coord_vars = template_coords & set(merged_coordinates)
460+
new_coord_vars = template_coords - set(merged_coordinates)
449461

450-
preserved_coords = merged_coordinates.to_dataset()[preserved_coord_names]
451-
# preserved_coords contains all coordinate variables that share a dimension
462+
preserved_coords = merged_coordinates.to_dataset()[preserved_coord_vars]
463+
# preserved_coords contains all coordinates bariables that share a dimension
452464
# with any index variable in preserved_indexes
453465
# Drop any unneeded vars in a second pass, this is required for e.g.
454466
# if the mapped function were to drop a non-dimension coordinate variable.
@@ -457,7 +469,7 @@ def _wrapper(
457469
)
458470

459471
coordinates = merge(
460-
(preserved_coords, template.coords.to_dataset()[new_indexes])
472+
(preserved_coords, template.coords.to_dataset()[new_coord_vars])
461473
).coords
462474
output_chunks: Mapping[Hashable, tuple[int, ...]] = {
463475
dim: input_chunks[dim] for dim in template.dims if dim in input_chunks
@@ -520,7 +532,7 @@ def _wrapper(
520532
dim: np.cumsum((0,) + chunks_v) for dim, chunks_v in output_chunks.items()
521533
}
522534

523-
include_variables = set(template.variables) - set(coordinates.indexes)
535+
computed_variables = set(template.variables) - set(coordinates.indexes)
524536
# iterate over all possible chunk combinations
525537
for chunk_tuple in itertools.product(*ichunk.values()):
526538
# mapping from dimension name to chunk index
@@ -533,31 +545,31 @@ def _wrapper(
533545
for isxr, arg in zip(is_xarray, npargs)
534546
]
535547

536-
# expected["shapes", "coords", "data_vars", "indexes"] are used to
537548
# raise nice error messages in _wrapper
538-
expected: dict[Hashable, dict] = {}
539-
# input chunk 0 along a dimension maps to output chunk 0 along the same dimension
540-
# even if length of dimension is changed by the applied function
541-
expected["shapes"] = {
542-
k: output_chunks[k][v] for k, v in chunk_index.items() if k in output_chunks
543-
}
544-
expected["data_vars"] = set(template.data_vars.keys()) # type: ignore[assignment]
545-
expected["coords"] = set(template.coords.keys()) # type: ignore[assignment]
546-
# Minimize duplication due to broadcasting by only including any new or modified indexes
547-
# Others can be inferred by inputs to wrapper (GH8412)
548-
expected["indexes"] = {
549-
name: coordinates.xindexes[name][
550-
_get_chunk_slicer(name, chunk_index, output_chunk_bounds)
551-
]
552-
for name in (new_indexes | modified_indexes)
549+
expected: ExpectedDict = {
550+
# input chunk 0 along a dimension maps to output chunk 0 along the same dimension
551+
# even if length of dimension is changed by the applied function
552+
"shapes": {
553+
k: output_chunks[k][v]
554+
for k, v in chunk_index.items()
555+
if k in output_chunks
556+
},
557+
"data_vars": set(template.data_vars.keys()),
558+
"coords": set(template.coords.keys()),
559+
"indexes": {
560+
dim: coordinates.xindexes[dim][
561+
_get_chunk_slicer(dim, chunk_index, output_chunk_bounds)
562+
]
563+
for dim in (new_indexes | modified_indexes)
564+
},
553565
}
554566

555567
from_wrapper = (gname,) + chunk_tuple
556568
graph[from_wrapper] = (_wrapper, func, blocked_args, kwargs, is_array, expected)
557569

558570
# mapping from variable name to dask graph key
559571
var_key_map: dict[Hashable, str] = {}
560-
for name in include_variables:
572+
for name in computed_variables:
561573
variable = template.variables[name]
562574
gname_l = f"{name}-{gname}"
563575
var_key_map[name] = gname_l

xarray/core/parallelcompat.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,43 @@ def reduction(
403403
"""
404404
raise NotImplementedError()
405405

406+
def scan(
407+
self,
408+
func: Callable,
409+
binop: Callable,
410+
ident: float,
411+
arr: T_ChunkedArray,
412+
axis: int | None = None,
413+
dtype: np.dtype | None = None,
414+
**kwargs,
415+
) -> T_ChunkedArray:
416+
"""
417+
General version of a 1D scan, also known as a cumulative array reduction.
418+
419+
Used in ``ffill`` and ``bfill`` in xarray.
420+
421+
Parameters
422+
----------
423+
func: callable
424+
Cumulative function like np.cumsum or np.cumprod
425+
binop: callable
426+
Associated binary operator like ``np.cumsum->add`` or ``np.cumprod->mul``
427+
ident: Number
428+
Associated identity like ``np.cumsum->0`` or ``np.cumprod->1``
429+
arr: dask Array
430+
axis: int, optional
431+
dtype: dtype
432+
433+
Returns
434+
-------
435+
Chunked array
436+
437+
See also
438+
--------
439+
dask.array.cumreduction
440+
"""
441+
raise NotImplementedError()
442+
406443
@abstractmethod
407444
def apply_gufunc(
408445
self,

xarray/tests/conftest.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import numpy as np
24
import pandas as pd
35
import pytest
@@ -77,3 +79,44 @@ def da(request, backend):
7779
return da
7880
else:
7981
raise ValueError
82+
83+
84+
@pytest.fixture(params=[Dataset, DataArray])
85+
def type(request):
86+
return request.param
87+
88+
89+
@pytest.fixture(params=[1])
90+
def d(request, backend, type) -> DataArray | Dataset:
91+
"""
92+
For tests which can test either a DataArray or a Dataset.
93+
"""
94+
result: DataArray | Dataset
95+
if request.param == 1:
96+
ds = Dataset(
97+
dict(
98+
a=(["x", "z"], np.arange(24).reshape(2, 12)),
99+
b=(["y", "z"], np.arange(100, 136).reshape(3, 12).astype(np.float64)),
100+
),
101+
dict(
102+
x=("x", np.linspace(0, 1.0, 2)),
103+
y=range(3),
104+
z=("z", pd.date_range("2000-01-01", periods=12)),
105+
w=("x", ["a", "b"]),
106+
),
107+
)
108+
if type == DataArray:
109+
result = ds["a"].assign_coords(w=ds.coords["w"])
110+
elif type == Dataset:
111+
result = ds
112+
else:
113+
raise ValueError
114+
else:
115+
raise ValueError
116+
117+
if backend == "dask":
118+
return result.chunk()
119+
elif backend == "numpy":
120+
return result
121+
else:
122+
raise ValueError

xarray/tests/test_rolling.py

Lines changed: 25 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,31 @@ def compute_backend(request):
3636
yield request.param
3737

3838

39+
@pytest.mark.parametrize("func", ["mean", "sum"])
40+
@pytest.mark.parametrize("min_periods", [1, 10])
41+
def test_cumulative(d, func, min_periods) -> None:
42+
# One dim
43+
result = getattr(d.cumulative("z", min_periods=min_periods), func)()
44+
expected = getattr(d.rolling(z=d["z"].size, min_periods=min_periods), func)()
45+
assert_identical(result, expected)
46+
47+
# Multiple dim
48+
result = getattr(d.cumulative(["z", "x"], min_periods=min_periods), func)()
49+
expected = getattr(
50+
d.rolling(z=d["z"].size, x=d["x"].size, min_periods=min_periods),
51+
func,
52+
)()
53+
assert_identical(result, expected)
54+
55+
56+
def test_cumulative_vs_cum(d) -> None:
57+
result = d.cumulative("z").sum()
58+
expected = d.cumsum("z")
59+
# cumsum drops the coord of the dimension; cumulative doesn't
60+
expected = expected.assign_coords(z=result["z"])
61+
assert_identical(result, expected)
62+
63+
3964
class TestDataArrayRolling:
4065
@pytest.mark.parametrize("da", (1, 2), indirect=True)
4166
@pytest.mark.parametrize("center", [True, False])
@@ -485,29 +510,6 @@ def test_rolling_exp_keep_attrs(self, da, func) -> None:
485510
):
486511
da.rolling_exp(time=10, keep_attrs=True)
487512

488-
@pytest.mark.parametrize("func", ["mean", "sum"])
489-
@pytest.mark.parametrize("min_periods", [1, 20])
490-
def test_cumulative(self, da, func, min_periods) -> None:
491-
# One dim
492-
result = getattr(da.cumulative("time", min_periods=min_periods), func)()
493-
expected = getattr(
494-
da.rolling(time=da.time.size, min_periods=min_periods), func
495-
)()
496-
assert_identical(result, expected)
497-
498-
# Multiple dim
499-
result = getattr(da.cumulative(["time", "a"], min_periods=min_periods), func)()
500-
expected = getattr(
501-
da.rolling(time=da.time.size, a=da.a.size, min_periods=min_periods),
502-
func,
503-
)()
504-
assert_identical(result, expected)
505-
506-
def test_cumulative_vs_cum(self, da) -> None:
507-
result = da.cumulative("time").sum()
508-
expected = da.cumsum("time")
509-
assert_identical(result, expected)
510-
511513

512514
class TestDatasetRolling:
513515
@pytest.mark.parametrize(
@@ -832,25 +834,6 @@ def test_raise_no_warning_dask_rolling_assert_close(self, ds, name) -> None:
832834
expected = getattr(getattr(ds.rolling(time=4), name)().rolling(x=3), name)()
833835
assert_allclose(actual, expected)
834836

835-
@pytest.mark.parametrize("func", ["mean", "sum"])
836-
@pytest.mark.parametrize("ds", (2,), indirect=True)
837-
@pytest.mark.parametrize("min_periods", [1, 10])
838-
def test_cumulative(self, ds, func, min_periods) -> None:
839-
# One dim
840-
result = getattr(ds.cumulative("time", min_periods=min_periods), func)()
841-
expected = getattr(
842-
ds.rolling(time=ds.time.size, min_periods=min_periods), func
843-
)()
844-
assert_identical(result, expected)
845-
846-
# Multiple dim
847-
result = getattr(ds.cumulative(["time", "x"], min_periods=min_periods), func)()
848-
expected = getattr(
849-
ds.rolling(time=ds.time.size, x=ds.x.size, min_periods=min_periods),
850-
func,
851-
)()
852-
assert_identical(result, expected)
853-
854837

855838
@requires_numbagg
856839
class TestDatasetRollingExp:

0 commit comments

Comments
 (0)