|
6 | 6 | xr = pytest.importorskip("xarray")
|
7 | 7 | # isort: on
|
8 | 8 |
|
| 9 | +from flox import xrdtypes as dtypes |
9 | 10 | from flox.xarray import rechunk_for_blockwise, xarray_reduce
|
10 | 11 |
|
11 | 12 | from . import (
|
@@ -193,13 +194,25 @@ def test_validate_expected_groups(expected_groups):
|
193 | 194 |
|
194 | 195 |
|
195 | 196 | @requires_cftime
|
| 197 | +@pytest.mark.parametrize("indexer", [slice(None), pytest.param(slice(12), id="missing-group")]) |
| 198 | +@pytest.mark.parametrize("expected_groups", [None, [0, 1, 2, 3]]) |
196 | 199 | @pytest.mark.parametrize("func", ["first", "last", "min", "max", "count"])
|
197 |
| -def test_xarray_reduce_cftime_var(engine, func): |
| 200 | +def test_xarray_reduce_cftime_var(engine, indexer, expected_groups, func): |
198 | 201 | times = xr.date_range("1980-09-01 00:00", "1982-09-18 00:00", freq="ME", calendar="noleap")
|
199 | 202 | ds = xr.Dataset({"var": ("time", times)}, coords={"time": np.repeat(np.arange(4), 6)})
|
| 203 | + ds = ds.isel(time=indexer) |
200 | 204 |
|
201 |
| - actual = xarray_reduce(ds, ds.time, func=func) |
| 205 | + actual = xarray_reduce( |
| 206 | + ds, |
| 207 | + ds.time, |
| 208 | + func=func, |
| 209 | + fill_value=dtypes.NA if func in ["first", "last"] else np.nan, |
| 210 | + engine=engine, |
| 211 | + expected_groups=expected_groups, |
| 212 | + ) |
202 | 213 | expected = getattr(ds.groupby("time"), func)()
|
| 214 | + if expected_groups is not None: |
| 215 | + expected = expected.reindex(time=expected_groups) |
203 | 216 | xr.testing.assert_identical(actual, expected)
|
204 | 217 |
|
205 | 218 |
|
|
0 commit comments