Skip to content

Commit c73e958

Browse files
FIX: correct dask array handling in _calc_idxminmax (#3922)
* FIX: correct dask array handling in _calc_idxminmax * FIX: remove unneeded import, reformat via black * fix idxmax, idxmin with dask arrays * FIX: use array[dim].data in `_calc_idxminmax` as per @keewis suggestion, attach dim name to result * ADD: add dask tests to `idxmin`/`idxmax` dataarray tests * FIX: add back fixture line removed by accident * ADD: complete dask handling in `idxmin`/`idxmax` tests in test_dataarray, xfail dask tests for dtype dateime64 (M) * ADD: add "support dask handling for idxmin/idxmax" in whats-new.rst * MIN: reintroduce changes added by #3953 * MIN: change if-clause to use `and` instead of `&` as per review-comment * MIN: change if-clause to use `and` instead of `&` as per review-comment * WIP: remove dask handling entirely for debugging purposes * Test for dask computes * WIP: re-add dask handling (map_blocks-approach), add `with raise_if_dask_computes()` context to idxmin-tests * Use dask indexing instead of map_blocks. * Better chunk choice. * Return -1 for _nan_argminmax_object if all NaNs along dim * Revert "Return -1 for _nan_argminmax_object if all NaNs along dim" This reverts commit 58901b9. * Raise error for object arrays * No error for object arrays. Instead expect 1 compute in tests. Co-authored-by: dcherian <[email protected]>
1 parent bd84186 commit c73e958

File tree

3 files changed

+110
-36
lines changed

3 files changed

+110
-36
lines changed

doc/whats-new.rst

+3
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@ New Features
5353
- Implement :py:meth:`DataArray.idxmax`, :py:meth:`DataArray.idxmin`,
5454
:py:meth:`Dataset.idxmax`, :py:meth:`Dataset.idxmin`. (:issue:`60`, :pull:`3871`)
5555
By `Todd Jennings <https://github.com/toddrjen>`_
56+
- Support dask handling for :py:meth:`DataArray.idxmax`, :py:meth:`DataArray.idxmin`,
57+
:py:meth:`Dataset.idxmax`, :py:meth:`Dataset.idxmin`. (:pull:`3922`)
58+
By `Kai Mühlbauer <https://github.com/kmuehlbauer>`_.
5659
- More support for unit aware arrays with pint (:pull:`3643`)
5760
By `Justus Magin <https://github.com/keewis>`_.
5861
- Support overriding existing variables in ``to_zarr()`` with ``mode='a'`` even

xarray/core/computation.py

+11-12
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
from . import dtypes, duck_array_ops, utils
2727
from .alignment import deep_align
2828
from .merge import merge_coordinates_without_align
29-
from .nanops import dask_array
3029
from .options import OPTIONS
3130
from .pycompat import dask_array_type
3231
from .utils import is_dict_like
@@ -1380,24 +1379,24 @@ def _calc_idxminmax(
13801379
# This will run argmin or argmax.
13811380
indx = func(array, dim=dim, axis=None, keep_attrs=keep_attrs, skipna=skipna)
13821381

1383-
# Get the coordinate we want.
1384-
coordarray = array[dim]
1385-
13861382
# Handle dask arrays.
1387-
if isinstance(array, dask_array_type):
1388-
res = dask_array.map_blocks(coordarray, indx, dtype=indx.dtype)
1383+
if isinstance(array.data, dask_array_type):
1384+
import dask.array
1385+
1386+
chunks = dict(zip(array.dims, array.chunks))
1387+
dask_coord = dask.array.from_array(array[dim].data, chunks=chunks[dim])
1388+
res = indx.copy(data=dask_coord[(indx.data,)])
1389+
# we need to attach back the dim name
1390+
res.name = dim
13891391
else:
1390-
res = coordarray[
1391-
indx,
1392-
]
1392+
res = array[dim][(indx,)]
1393+
# The dim is gone but we need to remove the corresponding coordinate.
1394+
del res.coords[dim]
13931395

13941396
if skipna or (skipna is None and array.dtype.kind in na_dtypes):
13951397
# Put the NaN values back in after removing them
13961398
res = res.where(~allna, fill_value)
13971399

1398-
# The dim is gone but we need to remove the corresponding coordinate.
1399-
del res.coords[dim]
1400-
14011400
# Copy attributes from argmin/argmax, if any
14021401
res.attrs = indx.attrs
14031402

xarray/tests/test_dataarray.py

+96-24
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
source_ndarray,
3535
)
3636

37+
from .test_dask import raise_if_dask_computes
38+
3739

3840
class TestDataArray:
3941
@pytest.fixture(autouse=True)
@@ -4524,11 +4526,21 @@ def test_argmax(self, x, minindex, maxindex, nanindex):
45244526

45254527
assert_identical(result2, expected2)
45264528

4527-
def test_idxmin(self, x, minindex, maxindex, nanindex):
4528-
ar0 = xr.DataArray(
4529+
@pytest.mark.parametrize("use_dask", [True, False])
4530+
def test_idxmin(self, x, minindex, maxindex, nanindex, use_dask):
4531+
if use_dask and not has_dask:
4532+
pytest.skip("requires dask")
4533+
if use_dask and x.dtype.kind == "M":
4534+
pytest.xfail("dask operation 'argmin' breaks when dtype is datetime64 (M)")
4535+
ar0_raw = xr.DataArray(
45294536
x, dims=["x"], coords={"x": np.arange(x.size) * 4}, attrs=self.attrs
45304537
)
45314538

4539+
if use_dask:
4540+
ar0 = ar0_raw.chunk({})
4541+
else:
4542+
ar0 = ar0_raw
4543+
45324544
# dim doesn't exist
45334545
with pytest.raises(KeyError):
45344546
ar0.idxmin(dim="spam")
@@ -4620,11 +4632,21 @@ def test_idxmin(self, x, minindex, maxindex, nanindex):
46204632
result7 = ar0.idxmin(fill_value=-1j)
46214633
assert_identical(result7, expected7)
46224634

4623-
def test_idxmax(self, x, minindex, maxindex, nanindex):
4624-
ar0 = xr.DataArray(
4635+
@pytest.mark.parametrize("use_dask", [True, False])
4636+
def test_idxmax(self, x, minindex, maxindex, nanindex, use_dask):
4637+
if use_dask and not has_dask:
4638+
pytest.skip("requires dask")
4639+
if use_dask and x.dtype.kind == "M":
4640+
pytest.xfail("dask operation 'argmax' breaks when dtype is datetime64 (M)")
4641+
ar0_raw = xr.DataArray(
46254642
x, dims=["x"], coords={"x": np.arange(x.size) * 4}, attrs=self.attrs
46264643
)
46274644

4645+
if use_dask:
4646+
ar0 = ar0_raw.chunk({})
4647+
else:
4648+
ar0 = ar0_raw
4649+
46284650
# dim doesn't exist
46294651
with pytest.raises(KeyError):
46304652
ar0.idxmax(dim="spam")
@@ -4944,14 +4966,31 @@ def test_argmax(self, x, minindex, maxindex, nanindex):
49444966

49454967
assert_identical(result3, expected2)
49464968

4947-
def test_idxmin(self, x, minindex, maxindex, nanindex):
4948-
ar0 = xr.DataArray(
4969+
@pytest.mark.parametrize("use_dask", [True, False])
4970+
def test_idxmin(self, x, minindex, maxindex, nanindex, use_dask):
4971+
if use_dask and not has_dask:
4972+
pytest.skip("requires dask")
4973+
if use_dask and x.dtype.kind == "M":
4974+
pytest.xfail("dask operation 'argmin' breaks when dtype is datetime64 (M)")
4975+
4976+
if x.dtype.kind == "O":
4977+
# TODO: nanops._nan_argminmax_object computes once to check for all-NaN slices.
4978+
max_computes = 1
4979+
else:
4980+
max_computes = 0
4981+
4982+
ar0_raw = xr.DataArray(
49494983
x,
49504984
dims=["y", "x"],
49514985
coords={"x": np.arange(x.shape[1]) * 4, "y": 1 - np.arange(x.shape[0])},
49524986
attrs=self.attrs,
49534987
)
49544988

4989+
if use_dask:
4990+
ar0 = ar0_raw.chunk({})
4991+
else:
4992+
ar0 = ar0_raw
4993+
49554994
assert_identical(ar0, ar0)
49564995

49574996
# No dimension specified
@@ -4982,15 +5021,18 @@ def test_idxmin(self, x, minindex, maxindex, nanindex):
49825021
expected0.name = "x"
49835022

49845023
# Default fill value (NaN)
4985-
result0 = ar0.idxmin(dim="x")
5024+
with raise_if_dask_computes(max_computes=max_computes):
5025+
result0 = ar0.idxmin(dim="x")
49865026
assert_identical(result0, expected0)
49875027

49885028
# Manually specify NaN fill_value
4989-
result1 = ar0.idxmin(dim="x", fill_value=np.NaN)
5029+
with raise_if_dask_computes(max_computes=max_computes):
5030+
result1 = ar0.idxmin(dim="x", fill_value=np.NaN)
49905031
assert_identical(result1, expected0)
49915032

49925033
# keep_attrs
4993-
result2 = ar0.idxmin(dim="x", keep_attrs=True)
5034+
with raise_if_dask_computes(max_computes=max_computes):
5035+
result2 = ar0.idxmin(dim="x", keep_attrs=True)
49945036
expected2 = expected0.copy()
49955037
expected2.attrs = self.attrs
49965038
assert_identical(result2, expected2)
@@ -5008,11 +5050,13 @@ def test_idxmin(self, x, minindex, maxindex, nanindex):
50085050
expected3.name = "x"
50095051
expected3.attrs = {}
50105052

5011-
result3 = ar0.idxmin(dim="x", skipna=False)
5053+
with raise_if_dask_computes(max_computes=max_computes):
5054+
result3 = ar0.idxmin(dim="x", skipna=False)
50125055
assert_identical(result3, expected3)
50135056

50145057
# fill_value should be ignored with skipna=False
5015-
result4 = ar0.idxmin(dim="x", skipna=False, fill_value=-100j)
5058+
with raise_if_dask_computes(max_computes=max_computes):
5059+
result4 = ar0.idxmin(dim="x", skipna=False, fill_value=-100j)
50165060
assert_identical(result4, expected3)
50175061

50185062
# Float fill_value
@@ -5024,7 +5068,8 @@ def test_idxmin(self, x, minindex, maxindex, nanindex):
50245068
expected5 = xr.concat(expected5, dim="y")
50255069
expected5.name = "x"
50265070

5027-
result5 = ar0.idxmin(dim="x", fill_value=-1.1)
5071+
with raise_if_dask_computes(max_computes=max_computes):
5072+
result5 = ar0.idxmin(dim="x", fill_value=-1.1)
50285073
assert_identical(result5, expected5)
50295074

50305075
# Integer fill_value
@@ -5036,7 +5081,8 @@ def test_idxmin(self, x, minindex, maxindex, nanindex):
50365081
expected6 = xr.concat(expected6, dim="y")
50375082
expected6.name = "x"
50385083

5039-
result6 = ar0.idxmin(dim="x", fill_value=-1)
5084+
with raise_if_dask_computes(max_computes=max_computes):
5085+
result6 = ar0.idxmin(dim="x", fill_value=-1)
50405086
assert_identical(result6, expected6)
50415087

50425088
# Complex fill_value
@@ -5048,17 +5094,35 @@ def test_idxmin(self, x, minindex, maxindex, nanindex):
50485094
expected7 = xr.concat(expected7, dim="y")
50495095
expected7.name = "x"
50505096

5051-
result7 = ar0.idxmin(dim="x", fill_value=-5j)
5097+
with raise_if_dask_computes(max_computes=max_computes):
5098+
result7 = ar0.idxmin(dim="x", fill_value=-5j)
50525099
assert_identical(result7, expected7)
50535100

5054-
def test_idxmax(self, x, minindex, maxindex, nanindex):
5055-
ar0 = xr.DataArray(
5101+
@pytest.mark.parametrize("use_dask", [True, False])
5102+
def test_idxmax(self, x, minindex, maxindex, nanindex, use_dask):
5103+
if use_dask and not has_dask:
5104+
pytest.skip("requires dask")
5105+
if use_dask and x.dtype.kind == "M":
5106+
pytest.xfail("dask operation 'argmax' breaks when dtype is datetime64 (M)")
5107+
5108+
if x.dtype.kind == "O":
5109+
# TODO: nanops._nan_argminmax_object computes once to check for all-NaN slices.
5110+
max_computes = 1
5111+
else:
5112+
max_computes = 0
5113+
5114+
ar0_raw = xr.DataArray(
50565115
x,
50575116
dims=["y", "x"],
50585117
coords={"x": np.arange(x.shape[1]) * 4, "y": 1 - np.arange(x.shape[0])},
50595118
attrs=self.attrs,
50605119
)
50615120

5121+
if use_dask:
5122+
ar0 = ar0_raw.chunk({})
5123+
else:
5124+
ar0 = ar0_raw
5125+
50625126
# No dimension specified
50635127
with pytest.raises(ValueError):
50645128
ar0.idxmax()
@@ -5090,15 +5154,18 @@ def test_idxmax(self, x, minindex, maxindex, nanindex):
50905154
expected0.name = "x"
50915155

50925156
# Default fill value (NaN)
5093-
result0 = ar0.idxmax(dim="x")
5157+
with raise_if_dask_computes(max_computes=max_computes):
5158+
result0 = ar0.idxmax(dim="x")
50945159
assert_identical(result0, expected0)
50955160

50965161
# Manually specify NaN fill_value
5097-
result1 = ar0.idxmax(dim="x", fill_value=np.NaN)
5162+
with raise_if_dask_computes(max_computes=max_computes):
5163+
result1 = ar0.idxmax(dim="x", fill_value=np.NaN)
50985164
assert_identical(result1, expected0)
50995165

51005166
# keep_attrs
5101-
result2 = ar0.idxmax(dim="x", keep_attrs=True)
5167+
with raise_if_dask_computes(max_computes=max_computes):
5168+
result2 = ar0.idxmax(dim="x", keep_attrs=True)
51025169
expected2 = expected0.copy()
51035170
expected2.attrs = self.attrs
51045171
assert_identical(result2, expected2)
@@ -5116,11 +5183,13 @@ def test_idxmax(self, x, minindex, maxindex, nanindex):
51165183
expected3.name = "x"
51175184
expected3.attrs = {}
51185185

5119-
result3 = ar0.idxmax(dim="x", skipna=False)
5186+
with raise_if_dask_computes(max_computes=max_computes):
5187+
result3 = ar0.idxmax(dim="x", skipna=False)
51205188
assert_identical(result3, expected3)
51215189

51225190
# fill_value should be ignored with skipna=False
5123-
result4 = ar0.idxmax(dim="x", skipna=False, fill_value=-100j)
5191+
with raise_if_dask_computes(max_computes=max_computes):
5192+
result4 = ar0.idxmax(dim="x", skipna=False, fill_value=-100j)
51245193
assert_identical(result4, expected3)
51255194

51265195
# Float fill_value
@@ -5132,7 +5201,8 @@ def test_idxmax(self, x, minindex, maxindex, nanindex):
51325201
expected5 = xr.concat(expected5, dim="y")
51335202
expected5.name = "x"
51345203

5135-
result5 = ar0.idxmax(dim="x", fill_value=-1.1)
5204+
with raise_if_dask_computes(max_computes=max_computes):
5205+
result5 = ar0.idxmax(dim="x", fill_value=-1.1)
51365206
assert_identical(result5, expected5)
51375207

51385208
# Integer fill_value
@@ -5144,7 +5214,8 @@ def test_idxmax(self, x, minindex, maxindex, nanindex):
51445214
expected6 = xr.concat(expected6, dim="y")
51455215
expected6.name = "x"
51465216

5147-
result6 = ar0.idxmax(dim="x", fill_value=-1)
5217+
with raise_if_dask_computes(max_computes=max_computes):
5218+
result6 = ar0.idxmax(dim="x", fill_value=-1)
51485219
assert_identical(result6, expected6)
51495220

51505221
# Complex fill_value
@@ -5156,7 +5227,8 @@ def test_idxmax(self, x, minindex, maxindex, nanindex):
51565227
expected7 = xr.concat(expected7, dim="y")
51575228
expected7.name = "x"
51585229

5159-
result7 = ar0.idxmax(dim="x", fill_value=-5j)
5230+
with raise_if_dask_computes(max_computes=max_computes):
5231+
result7 = ar0.idxmax(dim="x", fill_value=-5j)
51605232
assert_identical(result7, expected7)
51615233

51625234

0 commit comments

Comments
 (0)