Skip to content

Commit c296801

Browse files
committed
implement idxmax and idxmin
1 parent 70e628d commit c296801

File tree

7 files changed

+1237
-1
lines changed

7 files changed

+1237
-1
lines changed

doc/api.rst

+4
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,8 @@ Computation
180180
:py:attr:`~Dataset.any`
181181
:py:attr:`~Dataset.argmax`
182182
:py:attr:`~Dataset.argmin`
183+
:py:attr:`~Dataset.idxmax`
184+
:py:attr:`~Dataset.idxmin`
183185
:py:attr:`~Dataset.max`
184186
:py:attr:`~Dataset.mean`
185187
:py:attr:`~Dataset.median`
@@ -362,6 +364,8 @@ Computation
362364
:py:attr:`~DataArray.any`
363365
:py:attr:`~DataArray.argmax`
364366
:py:attr:`~DataArray.argmin`
367+
:py:attr:`~DataArray.idxmax`
368+
:py:attr:`~DataArray.idxmin`
365369
:py:attr:`~DataArray.max`
366370
:py:attr:`~DataArray.mean`
367371
:py:attr:`~DataArray.median`

doc/whats-new.rst

+3
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ New Features
3838
- Limited the length of array items with long string reprs to a
3939
reasonable width (:pull:`3900`)
4040
By `Maximilian Roos <https://github.com/max-sixty>`_
41+
- Implement :py:meth:`DataArray.idxmax`, :py:meth:`DataArray.idxmin`,
42+
:py:meth:`Dataset.idxmax`, :py:meth:`Dataset.idxmin`. (:issue:`60`, :pull:`3871`)
43+
By `Todd Jennings <https://github.com/toddrjen>`_
4144

4245

4346
Bug fixes

xarray/core/computation.py

+65-1
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,10 @@
2323

2424
import numpy as np
2525

26-
from . import duck_array_ops, utils
26+
from . import dtypes, duck_array_ops, utils
2727
from .alignment import deep_align
2828
from .merge import merge_coordinates_without_align
29+
from .nanops import dask_array
2930
from .options import OPTIONS
3031
from .pycompat import dask_array_type
3132
from .utils import is_dict_like
@@ -1338,3 +1339,66 @@ def polyval(coord, coeffs, degree_dim="degree"):
13381339
coords={coord.name: coord, degree_dim: np.arange(deg_coord.max() + 1)[::-1]},
13391340
)
13401341
return (lhs * coeffs).sum(degree_dim)
1342+
1343+
1344+
def _calc_idxminmax(
1345+
*,
1346+
array,
1347+
func: Callable,
1348+
dim: Hashable = None,
1349+
skipna: bool = None,
1350+
fill_value: Any = dtypes.NA,
1351+
keep_attrs: bool = None,
1352+
):
1353+
"""Apply common operations for idxmin and idxmax."""
1354+
# This function doesn't make sense for scalars so don't try
1355+
if not array.ndim:
1356+
raise ValueError("This function does not apply for scalars")
1357+
1358+
if dim is not None:
1359+
pass # Use the dim if available
1360+
elif array.ndim == 1:
1361+
# it is okay to guess the dim if there is only 1
1362+
dim = array.dims[0]
1363+
else:
1364+
# The dim is not specified and ambiguous. Don't guess.
1365+
raise ValueError("Must supply 'dim' argument for multidimensional arrays")
1366+
1367+
if dim not in array.dims:
1368+
raise KeyError(f'Dimension "{dim}" not in dimension')
1369+
if dim not in array.coords:
1370+
raise KeyError(f'Dimension "{dim}" does not have coordinates')
1371+
1372+
# These are dtypes with NaN values argmin and argmax can handle
1373+
na_dtypes = "cfO"
1374+
1375+
if skipna or (skipna is None and array.dtype.kind in na_dtypes):
1376+
# Need to skip NaN values since argmin and argmax can't handle them
1377+
allna = array.isnull().all(dim)
1378+
array = array.where(~allna, 0)
1379+
1380+
# This will run argmin or argmax.
1381+
indx = func(array, dim=dim, axis=None, keep_attrs=keep_attrs, skipna=skipna)
1382+
1383+
# Get the coordinate we want.
1384+
coordarray = array[dim]
1385+
1386+
# Handle dask arrays.
1387+
if isinstance(array, dask_array_type):
1388+
res = dask_array.map_blocks(coordarray, indx, dtype=indx.dtype)
1389+
else:
1390+
res = coordarray[
1391+
indx,
1392+
]
1393+
1394+
if skipna or (skipna is None and array.dtype.kind in na_dtypes):
1395+
# Put the NaN values back in after removing them
1396+
res = res.where(~allna, fill_value)
1397+
1398+
# The dim is gone but we need to remove the corresponding coordinate.
1399+
del res.coords[dim]
1400+
1401+
# Copy attributes from argmin/argmax, if any
1402+
res.attrs = indx.attrs
1403+
1404+
return res

xarray/core/dataarray.py

+188
Original file line numberDiff line numberDiff line change
@@ -3508,6 +3508,194 @@ def pad(
35083508
)
35093509
return self._from_temp_dataset(ds)
35103510

3511+
def idxmin(
3512+
self,
3513+
dim: Hashable = None,
3514+
skipna: bool = None,
3515+
fill_value: Any = dtypes.NA,
3516+
keep_attrs: bool = None,
3517+
) -> "DataArray":
3518+
"""Return the coordinate label of the minimum value along a dimension.
3519+
3520+
Returns a new `DataArray` named after the dimension with the values of
3521+
the coordinate labels along that dimension corresponding to minimum
3522+
values along that dimension.
3523+
3524+
In comparison to :py:meth:`~DataArray.argmin`, this returns the
3525+
coordinate label while :py:meth:`~DataArray.argmin` returns the index.
3526+
3527+
Parameters
3528+
----------
3529+
dim : str, optional
3530+
Dimension over which to apply `idxmin`. This is optional for 1D
3531+
arrays, but required for arrays with 2 or more dimensions.
3532+
skipna : bool or None, default None
3533+
If True, skip missing values (as marked by NaN). By default, only
3534+
skips missing values for ``float``, ``complex``, and ``object``
3535+
dtypes; other dtypes either do not have a sentinel missing value
3536+
(``int``) or ``skipna=True`` has not been implemented
3537+
(``datetime64`` or ``timedelta64``).
3538+
fill_value : Any, default NaN
3539+
Value to be filled in case all of the values along a dimension are
3540+
null. By default this is NaN. The fill value and result are
3541+
automatically converted to a compatible dtype if possible.
3542+
Ignored if ``skipna`` is False.
3543+
keep_attrs : bool, default False
3544+
If True, the attributes (``attrs``) will be copied from the
3545+
original object to the new one. If False (default), the new object
3546+
will be returned without attributes.
3547+
3548+
Returns
3549+
-------
3550+
reduced : DataArray
3551+
New `DataArray` object with `idxmin` applied to its data and the
3552+
indicated dimension removed.
3553+
3554+
See also
3555+
--------
3556+
Dataset.idxmin, DataArray.idxmax, DataArray.min, DataArray.argmin
3557+
3558+
Examples
3559+
--------
3560+
3561+
>>> array = xr.DataArray([0, 2, 1, 0, -2], dims="x",
3562+
... coords={"x": ['a', 'b', 'c', 'd', 'e']})
3563+
>>> array.min()
3564+
<xarray.DataArray ()>
3565+
array(-2)
3566+
>>> array.argmin()
3567+
<xarray.DataArray ()>
3568+
array(4)
3569+
>>> array.idxmin()
3570+
<xarray.DataArray 'x' ()>
3571+
array('e', dtype='<U1')
3572+
3573+
>>> array = xr.DataArray([[2.0, 1.0, 2.0, 0.0, -2.0],
3574+
... [-4.0, np.NaN, 2.0, np.NaN, -2.0],
3575+
... [np.NaN, np.NaN, 1., np.NaN, np.NaN]],
3576+
... dims=["y", "x"],
3577+
... coords={"y": [-1, 0, 1],
3578+
... "x": np.arange(5.)**2}
3579+
... )
3580+
>>> array.min(dim="x")
3581+
<xarray.DataArray (y: 3)>
3582+
array([-2., -4., 1.])
3583+
Coordinates:
3584+
* y (y) int64 -1 0 1
3585+
>>> array.argmin(dim="x")
3586+
<xarray.DataArray (y: 3)>
3587+
array([4, 0, 2])
3588+
Coordinates:
3589+
* y (y) int64 -1 0 1
3590+
>>> array.idxmin(dim="x")
3591+
<xarray.DataArray 'x' (y: 3)>
3592+
array([16., 0., 4.])
3593+
Coordinates:
3594+
* y (y) int64 -1 0 1
3595+
"""
3596+
return computation._calc_idxminmax(
3597+
array=self,
3598+
func=lambda x, *args, **kwargs: x.argmin(*args, **kwargs),
3599+
dim=dim,
3600+
skipna=skipna,
3601+
fill_value=fill_value,
3602+
keep_attrs=keep_attrs,
3603+
)
3604+
3605+
def idxmax(
3606+
self,
3607+
dim: Hashable = None,
3608+
skipna: bool = None,
3609+
fill_value: Any = dtypes.NA,
3610+
keep_attrs: bool = None,
3611+
) -> "DataArray":
3612+
"""Return the coordinate label of the maximum value along a dimension.
3613+
3614+
Returns a new `DataArray` named after the dimension with the values of
3615+
the coordinate labels along that dimension corresponding to maximum
3616+
values along that dimension.
3617+
3618+
In comparison to :py:meth:`~DataArray.argmax`, this returns the
3619+
coordinate label while :py:meth:`~DataArray.argmax` returns the index.
3620+
3621+
Parameters
3622+
----------
3623+
dim : str, optional
3624+
Dimension over which to apply `idxmax`. This is optional for 1D
3625+
arrays, but required for arrays with 2 or more dimensions.
3626+
skipna : bool or None, default None
3627+
If True, skip missing values (as marked by NaN). By default, only
3628+
skips missing values for ``float``, ``complex``, and ``object``
3629+
dtypes; other dtypes either do not have a sentinel missing value
3630+
(``int``) or ``skipna=True`` has not been implemented
3631+
(``datetime64`` or ``timedelta64``).
3632+
fill_value : Any, default NaN
3633+
Value to be filled in case all of the values along a dimension are
3634+
null. By default this is NaN. The fill value and result are
3635+
automatically converted to a compatible dtype if possible.
3636+
Ignored if ``skipna`` is False.
3637+
keep_attrs : bool, default False
3638+
If True, the attributes (``attrs``) will be copied from the
3639+
original object to the new one. If False (default), the new object
3640+
will be returned without attributes.
3641+
3642+
Returns
3643+
-------
3644+
reduced : DataArray
3645+
New `DataArray` object with `idxmax` applied to its data and the
3646+
indicated dimension removed.
3647+
3648+
See also
3649+
--------
3650+
Dataset.idxmax, DataArray.idxmin, DataArray.max, DataArray.argmax
3651+
3652+
Examples
3653+
--------
3654+
3655+
>>> array = xr.DataArray([0, 2, 1, 0, -2], dims="x",
3656+
... coords={"x": ['a', 'b', 'c', 'd', 'e']})
3657+
>>> array.max()
3658+
<xarray.DataArray ()>
3659+
array(2)
3660+
>>> array.argmax()
3661+
<xarray.DataArray ()>
3662+
array(1)
3663+
>>> array.idxmax()
3664+
<xarray.DataArray 'x' ()>
3665+
array('b', dtype='<U1')
3666+
3667+
>>> array = xr.DataArray([[2.0, 1.0, 2.0, 0.0, -2.0],
3668+
... [-4.0, np.NaN, 2.0, np.NaN, -2.0],
3669+
... [np.NaN, np.NaN, 1., np.NaN, np.NaN]],
3670+
... dims=["y", "x"],
3671+
... coords={"y": [-1, 0, 1],
3672+
... "x": np.arange(5.)**2}
3673+
... )
3674+
>>> array.max(dim="x")
3675+
<xarray.DataArray (y: 3)>
3676+
array([2., 2., 1.])
3677+
Coordinates:
3678+
* y (y) int64 -1 0 1
3679+
>>> array.argmax(dim="x")
3680+
<xarray.DataArray (y: 3)>
3681+
array([0, 2, 2])
3682+
Coordinates:
3683+
* y (y) int64 -1 0 1
3684+
>>> array.idxmax(dim="x")
3685+
<xarray.DataArray 'x' (y: 3)>
3686+
array([0., 4., 4.])
3687+
Coordinates:
3688+
* y (y) int64 -1 0 1
3689+
"""
3690+
return computation._calc_idxminmax(
3691+
array=self,
3692+
func=lambda x, *args, **kwargs: x.argmax(*args, **kwargs),
3693+
dim=dim,
3694+
skipna=skipna,
3695+
fill_value=fill_value,
3696+
keep_attrs=keep_attrs,
3697+
)
3698+
35113699
# this needs to be at the end, or mypy will confuse with `str`
35123700
# https://mypy.readthedocs.io/en/latest/common_issues.html#dealing-with-conflicting-names
35133701
str = property(StringAccessor)

0 commit comments

Comments
 (0)