Skip to content

Commit be8b26c

Browse files
committed
Merge behaviour of indices_min/indices_max into argmin/argmax
When argmin or argmax are called with a sequence for 'dim', they now return a dict with the indices for each dimension in dim.
1 parent deee3f8 commit be8b26c

File tree

7 files changed

+108
-81
lines changed

7 files changed

+108
-81
lines changed

doc/api.rst

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -366,8 +366,6 @@ Computation
366366
:py:attr:`~DataArray.argmin`
367367
:py:attr:`~DataArray.idxmax`
368368
:py:attr:`~DataArray.idxmin`
369-
:py:attr:`~DataArray.indices_max`
370-
:py:attr:`~DataArray.indices_min`
371369
:py:attr:`~DataArray.max`
372370
:py:attr:`~DataArray.mean`
373371
:py:attr:`~DataArray.median`

doc/whats-new.rst

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,11 @@ Breaking changes
2929

3030
New Features
3131
~~~~~~~~~~~~
32-
- Added :py:meth:`DataArray.indices_min` and :py:meth:`DataArray.indices_max`
33-
to get a dict of the indices for each dimension of the minimum or maximum of
34-
a DataArray. (:pull:`3936`)
32+
- :py:meth:`DataArray.argmin` and :py:meth:`DataArray.argmax` now support
33+
sequences of 'dim' arguments, and if a sequence is passed return a dict
34+
(which can be passed to :py:meth:`isel` to get the value of the minimum) of
35+
the indices for each dimension of the minimum or maximum of a DataArray.
36+
(:pull:`3936`)
3537
By `John Omotani <https://github.com/johnomotani>`_, thanks to `Keisuke Fujii
3638
<https://github.com/fujiisoup>`_ for work in :pull:`1469`.
3739
- Added :py:meth:`DataArray.polyfit` and :py:func:`xarray.polyval` for fitting polynomials. (:issue:`3349`)

xarray/core/dataarray.py

Lines changed: 45 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3730,16 +3730,35 @@ def _unravel_argminmax(
37303730
self,
37313731
argminmax: Hashable,
37323732
dim: Union[Hashable, Sequence[Hashable], None],
3733-
keep_attrs: bool,
3733+
axis: Union[int, None],
3734+
keep_attrs: Optional[bool],
37343735
skipna: Optional[bool],
37353736
) -> Dict[Hashable, "DataArray"]:
37363737
"""Apply argmin or argmax over one or more dimensions, returning the result as a
37373738
dict of DataArray that can be passed directly to isel.
37383739
"""
3739-
if dim is None:
3740+
if dim is None and axis is None:
3741+
warnings.warn(
3742+
"Behaviour of argmin/argmax with neither dim nor axis argument will "
3743+
"change to return a dict of indices of each dimension. To get a "
3744+
"single, flat index, please use np.argmin(da) or np.argmax(da) instead "
3745+
"of da.argmin() or da.argmax().",
3746+
DeprecationWarning,
3747+
)
3748+
if dim is ...:
3749+
# In future, should do this also when (dim is None and axis is None)
37403750
dim = self.dims
3741-
if not isinstance(dim, Sequence) or isinstance(dim, str):
3742-
dim = (dim,)
3751+
if (
3752+
dim is None
3753+
or axis is not None
3754+
or not isinstance(dim, Sequence)
3755+
or isinstance(dim, str)
3756+
):
3757+
# Return int index if single dimension is passed, and is not part of a
3758+
# sequence
3759+
return getattr(self, str(argminmax))(
3760+
dim=dim, axis=axis, keep_attrs=keep_attrs, skipna=skipna
3761+
)
37433762

37443763
# Get a name for the new dimension that does not conflict with any existing
37453764
# dimension
@@ -3771,9 +3790,10 @@ def _unravel_argminmax(
37713790

37723791
return result
37733792

3774-
def indices_min(
3793+
def argmin(
37753794
self,
37763795
dim: Union[Hashable, Sequence[Hashable]] = None,
3796+
axis: Union[int, None] = None,
37773797
keep_attrs: bool = None,
37783798
skipna: bool = None,
37793799
) -> Dict[Hashable, "DataArray"]:
@@ -3788,6 +3808,9 @@ def indices_min(
37883808
dim : hashable or sequence of hashable, optional
37893809
The dimensions over which to find the minimum. By default, finds minimum over
37903810
all dimensions.
3811+
axis : int, optional
3812+
Axis over which to apply `argmin`. Only one of the 'dim' and 'axis' arguments
3813+
can be supplied.
37913814
keep_attrs : bool, optional
37923815
If True, the attributes (`attrs`) will be copied from the original
37933816
object to the new one. If False (default), the new object will be
@@ -3815,10 +3838,10 @@ def indices_min(
38153838
>>> array.argmin()
38163839
<xarray.DataArray ()>
38173840
array(2)
3818-
>>> array.indices_min()
3841+
>>> array.argmin(...)
38193842
{'x': <xarray.DataArray ()>
38203843
array(2)}
3821-
>>> array.isel(array.indices_min())
3844+
>>> array.isel(array.argmin(...))
38223845
array(-1)
38233846
38243847
>>> array = xr.DataArray([[[3, 2, 1], [3, 1, 2], [2, 1, 3]],
@@ -3836,7 +3859,7 @@ def indices_min(
38363859
[1, 1, 1],
38373860
[0, 0, 1]])
38383861
Dimensions without coordinates: y, z
3839-
>>> array.indices_min(dim="x")
3862+
>>> array.argmin(dim=["x"])
38403863
{'x': <xarray.DataArray (y: 3, z: 3)>
38413864
array([[1, 0, 0],
38423865
[1, 1, 1],
@@ -3846,22 +3869,23 @@ def indices_min(
38463869
<xarray.DataArray (y: 3)>
38473870
array([ 1, -5, 1])
38483871
Dimensions without coordinates: y
3849-
>>> array.indices_min(dim=("x", "z"))
3872+
>>> array.argmin(dim=["x", "z"])
38503873
{'x': <xarray.DataArray (y: 3)>
38513874
array([0, 1, 0])
38523875
Dimensions without coordinates: y, 'z': <xarray.DataArray (y: 3)>
38533876
array([2, 1, 1])
38543877
Dimensions without coordinates: y}
3855-
>>> array.isel(array.indices_min(dim=("x", "z")))
3878+
>>> array.isel(array.argmin(dim=["x", "z"]))
38563879
<xarray.DataArray (y: 3)>
38573880
array([ 1, -5, 1])
38583881
Dimensions without coordinates: y
38593882
"""
3860-
return self._unravel_argminmax("argmin", dim, keep_attrs, skipna)
3883+
return self._unravel_argminmax("_argmin_base", dim, axis, keep_attrs, skipna)
38613884

3862-
def indices_max(
3885+
def argmax(
38633886
self,
38643887
dim: Union[Hashable, Sequence[Hashable]] = None,
3888+
axis: Union[int, None] = None,
38653889
keep_attrs: bool = None,
38663890
skipna: bool = None,
38673891
) -> Dict[Hashable, "DataArray"]:
@@ -3876,6 +3900,9 @@ def indices_max(
38763900
dim : hashable or sequence of hashable, optional
38773901
The dimensions over which to find the maximum. By default, finds maximum over
38783902
all dimensions.
3903+
axis : int, optional
3904+
Axis over which to apply `argmin`. Only one of the 'dim' and 'axis' arguments
3905+
can be supplied.
38793906
keep_attrs : bool, optional
38803907
If True, the attributes (`attrs`) will be copied from the original
38813908
object to the new one. If False (default), the new object will be
@@ -3903,10 +3930,10 @@ def indices_max(
39033930
>>> array.argmax()
39043931
<xarray.DataArray ()>
39053932
array(3)
3906-
>>> array.indices_max()
3933+
>>> array.argmax(...)
39073934
{'x': <xarray.DataArray ()>
39083935
array(3)}
3909-
>>> array.isel(array.indices_max())
3936+
>>> array.isel(array.argmax(...))
39103937
<xarray.DataArray ()>
39113938
array(3)
39123939
@@ -3925,7 +3952,7 @@ def indices_max(
39253952
[0, 1, 0],
39263953
[0, 1, 0]])
39273954
Dimensions without coordinates: y, z
3928-
>>> array.indices_max(dim="x")
3955+
>>> array.argmax(dim=["x"])
39293956
{'x': <xarray.DataArray (y: 3, z: 3)>
39303957
array([[0, 1, 1],
39313958
[0, 1, 0],
@@ -3935,18 +3962,18 @@ def indices_max(
39353962
<xarray.DataArray (y: 3)>
39363963
array([3, 5, 3])
39373964
Dimensions without coordinates: y
3938-
>>> array.indices_max(dim=("x", "z"))
3965+
>>> array.argmax(dim=["x", "z"])
39393966
{'x': <xarray.DataArray (y: 3)>
39403967
array([0, 1, 0])
39413968
Dimensions without coordinates: y, 'z': <xarray.DataArray (y: 3)>
39423969
array([0, 1, 2])
39433970
Dimensions without coordinates: y}
3944-
>>> array.isel(array.indices_max(dim=("x", "z")))
3971+
>>> array.isel(array.argmax(dim=["x", "z"]))
39453972
<xarray.DataArray (y: 3)>
39463973
array([3, 5, 3])
39473974
Dimensions without coordinates: y
39483975
"""
3949-
return self._unravel_argminmax("argmax", dim, keep_attrs, skipna)
3976+
return self._unravel_argminmax("_argmax_base", dim, axis, keep_attrs, skipna)
39503977

39513978

39523979
# priority most be higher than Variable to properly work with binary ufuncs

xarray/core/duck_array_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -319,8 +319,8 @@ def f(values, axis=None, skipna=None, **kwargs):
319319

320320
# Attributes `numeric_only`, `available_min_count` is used for docs.
321321
# See ops.inject_reduce_methods
322-
argmax = _create_nan_agg_method("argmax", coerce_strings=True)
323-
argmin = _create_nan_agg_method("argmin", coerce_strings=True)
322+
_argmax_base = _create_nan_agg_method("argmax", coerce_strings=True)
323+
_argmin_base = _create_nan_agg_method("argmin", coerce_strings=True)
324324
max = _create_nan_agg_method("max", coerce_strings=True)
325325
min = _create_nan_agg_method("min", coerce_strings=True)
326326
sum = _create_nan_agg_method("sum")

xarray/core/ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@
4747
# methods which remove an axis
4848
REDUCE_METHODS = ["all", "any"]
4949
NAN_REDUCE_METHODS = [
50-
"argmax",
51-
"argmin",
50+
"_argmax_base",
51+
"_argmin_base",
5252
"max",
5353
"min",
5454
"mean",

xarray/core/rolling.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,8 @@ def method(self, **kwargs):
130130
method.__doc__ = _ROLLING_REDUCE_DOCSTRING_TEMPLATE.format(name=name)
131131
return method
132132

133-
argmax = _reduce_method("argmax")
134-
argmin = _reduce_method("argmin")
133+
_argmax_base = _reduce_method("_argmax_base")
134+
_argmin_base = _reduce_method("_argmin_base")
135135
max = _reduce_method("max")
136136
min = _reduce_method("min")
137137
mean = _reduce_method("mean")

0 commit comments

Comments
 (0)