Skip to content

Commit bdcfab5

Browse files
johnomotanishoyerkeewisdcherian
authored
Support multiple dimensions in DataArray.argmin() and DataArray.argmax() methods (#3936)
* DataArray.indices_min() and DataArray.indices_max() methods These return dicts of the indices of the minimum or maximum of a DataArray over several dimensions. * Update whats-new.rst and api.rst with indices_min(), indices_max() * Fix type checking in DataArray._unravel_argminmax() * Fix expected results for TestReduce3D.test_indices_max() * Respect global default for keep_attrs * Merge behaviour of indices_min/indices_max into argmin/argmax When argmin or argmax are called with a sequence for 'dim', they now return a dict with the indices for each dimension in dim. * Basic overload of argmin() and argmax() for Dataset If single dim is passed to Dataset.argmin() or Dataset.argmax(), then pass through to _argmin_base or _argmax_base. If a sequence is passed for dim, raise an exception, because the result for each DataArray would be a dict, which cannot be stored in a Dataset. * Update Variable and dask tests with _argmin_base, _argmax_base The basic numpy-style argmin() and argmax() methods were renamed when adding support for handling multiple dimensions in DataArray.argmin() and DataArray.argmax(). Variable.argmin() and Variable.argmax() are therefore renamed as Variable._argmin_base() and Variable._argmax_base(). * Update api-hidden.rst with _argmin_base and _argmax_base * Explicitly defined class methods override injected methods If a method (such as 'argmin') has been explicitly defined on a class (so that hasattr(cls, "argmin")==True), then do not inject that method, as it would override the explicitly defined one. Instead inject a private method, prefixed by "_injected_" (such as '_injected_argmin'), so that the injected method is available to the explicitly defined one. Do not perform the hasattr check on binary ops, because this breaks some operations (e.g. addition between DataArray and int in test_dask.py). * Move StringAccessor back to bottom of DataArray class definition * Revert use of _argmin_base and _argmax_base Now not needed because of change to injection in ops.py. * Move implementation of argmin, argmax from DataArray to Variable Makes use of argmin and argmax more general (they are available for Variable) and is straightforward for DataArray to wrap the Variable version. * Update tests for change to coordinates on result of argmin, argmax * Add 'out' keyword to argmin/argmax methods - allow numpy call signature When np.argmin(da) is called, numpy passes an 'out' keyword argument to argmin/argmax. Need to allow this argument to avoid errors (but an exception is thrown if out is not None). * Update and correct docstrings for argmin and argmax * Correct suggested replacement for da.argmin() and da.argmax() * Remove use of _injected_ methods in argmin/argmax * Fix typo in name of argminmax_func Co-Authored-By: keewis <[email protected]> * Mark argminmax argument to _unravel_argminmax as a string Co-Authored-By: keewis <[email protected]> * Hidden internal methods don't need to appear in docs * Basic docstrings for Dataset.argmin() and Dataset.argmax() * Set stacklevel for DeprecationWarning in argmin/argmax methods * Revert "Explicitly defined class methods override injected methods" This reverts commit 8caf2b8. * Revert "Add 'out' keyword to argmin/argmax methods - allow numpy call signature" This reverts commit ab480b5. * Remove argmin and argmax from ops.py * Use self.reduce() in Dataset.argmin() and Dataset.argmax() Replaces need for "_injected_argmin" and "_injected_argmax". * Whitespace after 'title' lines in docstrings * Remove tests of np.argmax() and np.argmin() functions from test_units.py Applying numpy functions to xarray objects is not necessarily expected to work, and the wrapping of argmin() and argmax() is broken by xarray-specific interface of argmin() and argmax() methods of Variable, DataArray and Dataset. * Clearer deprecation warnings in Dataset.argmin() and Dataset.argmax() Also, previously suggested workaround was not correct. Remove suggestion as there is no workaround (but the removed behaviour is unlikely to be useful). * Add unravel_index to duck_array_ops, use in Variable._unravel_argminmax * Filter argmin/argmax DeprecationWarnings in tests * Correct test for exception for nan in test_argmax * Remove injected argmin and argmax methods from api-hidden.rst * flake8 fixes * Tidy up argmin/argmax following code review Co-authored-by: Deepak Cherian <[email protected]> * Remove filters for warnings from argmin/argmax from tests Pass an explicit axis or dim argument instead to avoid the warning. * Swap order of reduce_dims checks in Dataset.reduce() Prefer to pass reduce_dims=None when possible, including for variables with only one dimension. Avoids an error if an 'axis' keyword was passed. * revert the changes to Dataset.reduce * use dim instead of axis * use dimension instead of Ellipsis * Make passing 'dim=...' to Dataset.argmin() or Dataset.argmax() an error * Better docstrings for Dataset.argmin() and Dataset.argmax() * Update doc/whats-new.rst Co-authored-by: keewis <[email protected]> Co-authored-by: Stephan Hoyer <[email protected]> Co-authored-by: keewis <[email protected]> Co-authored-by: Deepak Cherian <[email protected]> Co-authored-by: Keewis <[email protected]>
1 parent a64cf2d commit bdcfab5

11 files changed

+1415
-44
lines changed

doc/api-hidden.rst

-20
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,6 @@
4141

4242
core.rolling.DatasetCoarsen.all
4343
core.rolling.DatasetCoarsen.any
44-
core.rolling.DatasetCoarsen.argmax
45-
core.rolling.DatasetCoarsen.argmin
4644
core.rolling.DatasetCoarsen.count
4745
core.rolling.DatasetCoarsen.max
4846
core.rolling.DatasetCoarsen.mean
@@ -68,8 +66,6 @@
6866
core.groupby.DatasetGroupBy.where
6967
core.groupby.DatasetGroupBy.all
7068
core.groupby.DatasetGroupBy.any
71-
core.groupby.DatasetGroupBy.argmax
72-
core.groupby.DatasetGroupBy.argmin
7369
core.groupby.DatasetGroupBy.count
7470
core.groupby.DatasetGroupBy.max
7571
core.groupby.DatasetGroupBy.mean
@@ -85,8 +81,6 @@
8581
core.resample.DatasetResample.all
8682
core.resample.DatasetResample.any
8783
core.resample.DatasetResample.apply
88-
core.resample.DatasetResample.argmax
89-
core.resample.DatasetResample.argmin
9084
core.resample.DatasetResample.assign
9185
core.resample.DatasetResample.assign_coords
9286
core.resample.DatasetResample.bfill
@@ -110,8 +104,6 @@
110104
core.resample.DatasetResample.dims
111105
core.resample.DatasetResample.groups
112106

113-
core.rolling.DatasetRolling.argmax
114-
core.rolling.DatasetRolling.argmin
115107
core.rolling.DatasetRolling.count
116108
core.rolling.DatasetRolling.max
117109
core.rolling.DatasetRolling.mean
@@ -185,8 +177,6 @@
185177

186178
core.rolling.DataArrayCoarsen.all
187179
core.rolling.DataArrayCoarsen.any
188-
core.rolling.DataArrayCoarsen.argmax
189-
core.rolling.DataArrayCoarsen.argmin
190180
core.rolling.DataArrayCoarsen.count
191181
core.rolling.DataArrayCoarsen.max
192182
core.rolling.DataArrayCoarsen.mean
@@ -211,8 +201,6 @@
211201
core.groupby.DataArrayGroupBy.where
212202
core.groupby.DataArrayGroupBy.all
213203
core.groupby.DataArrayGroupBy.any
214-
core.groupby.DataArrayGroupBy.argmax
215-
core.groupby.DataArrayGroupBy.argmin
216204
core.groupby.DataArrayGroupBy.count
217205
core.groupby.DataArrayGroupBy.max
218206
core.groupby.DataArrayGroupBy.mean
@@ -228,8 +216,6 @@
228216
core.resample.DataArrayResample.all
229217
core.resample.DataArrayResample.any
230218
core.resample.DataArrayResample.apply
231-
core.resample.DataArrayResample.argmax
232-
core.resample.DataArrayResample.argmin
233219
core.resample.DataArrayResample.assign_coords
234220
core.resample.DataArrayResample.bfill
235221
core.resample.DataArrayResample.count
@@ -252,8 +238,6 @@
252238
core.resample.DataArrayResample.dims
253239
core.resample.DataArrayResample.groups
254240

255-
core.rolling.DataArrayRolling.argmax
256-
core.rolling.DataArrayRolling.argmin
257241
core.rolling.DataArrayRolling.count
258242
core.rolling.DataArrayRolling.max
259243
core.rolling.DataArrayRolling.mean
@@ -423,8 +407,6 @@
423407

424408
IndexVariable.all
425409
IndexVariable.any
426-
IndexVariable.argmax
427-
IndexVariable.argmin
428410
IndexVariable.argsort
429411
IndexVariable.astype
430412
IndexVariable.broadcast_equals
@@ -564,8 +546,6 @@
564546
CFTimeIndex.all
565547
CFTimeIndex.any
566548
CFTimeIndex.append
567-
CFTimeIndex.argmax
568-
CFTimeIndex.argmin
569549
CFTimeIndex.argsort
570550
CFTimeIndex.asof
571551
CFTimeIndex.asof_locs

doc/whats-new.rst

+7
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,13 @@ Enhancements
5454

5555
New Features
5656
~~~~~~~~~~~~
57+
- :py:meth:`DataArray.argmin` and :py:meth:`DataArray.argmax` now support
58+
sequences of 'dim' arguments, and if a sequence is passed return a dict
59+
(which can be passed to :py:meth:`isel` to get the value of the minimum) of
60+
the indices for each dimension of the minimum or maximum of a DataArray.
61+
(:pull:`3936`)
62+
By `John Omotani <https://github.com/johnomotani>`_, thanks to `Keisuke Fujii
63+
<https://github.com/fujiisoup>`_ for work in :pull:`1469`.
5764
- Added :py:meth:`xarray.infer_freq` for extending frequency inferring to CFTime indexes and data (:pull:`4033`).
5865
By `Pascal Bourgault <https://github.com/aulemahal>`_.
5966
- ``chunks='auto'`` is now supported in the ``chunks`` argument of

xarray/core/dataarray.py

+203
Original file line numberDiff line numberDiff line change
@@ -3819,6 +3819,209 @@ def idxmax(
38193819
keep_attrs=keep_attrs,
38203820
)
38213821

3822+
def argmin(
3823+
self,
3824+
dim: Union[Hashable, Sequence[Hashable]] = None,
3825+
axis: int = None,
3826+
keep_attrs: bool = None,
3827+
skipna: bool = None,
3828+
) -> Union["DataArray", Dict[Hashable, "DataArray"]]:
3829+
"""Index or indices of the minimum of the DataArray over one or more dimensions.
3830+
3831+
If a sequence is passed to 'dim', then result returned as dict of DataArrays,
3832+
which can be passed directly to isel(). If a single str is passed to 'dim' then
3833+
returns a DataArray with dtype int.
3834+
3835+
If there are multiple minima, the indices of the first one found will be
3836+
returned.
3837+
3838+
Parameters
3839+
----------
3840+
dim : hashable, sequence of hashable or ..., optional
3841+
The dimensions over which to find the minimum. By default, finds minimum over
3842+
all dimensions - for now returning an int for backward compatibility, but
3843+
this is deprecated, in future will return a dict with indices for all
3844+
dimensions; to return a dict with all dimensions now, pass '...'.
3845+
axis : int, optional
3846+
Axis over which to apply `argmin`. Only one of the 'dim' and 'axis' arguments
3847+
can be supplied.
3848+
keep_attrs : bool, optional
3849+
If True, the attributes (`attrs`) will be copied from the original
3850+
object to the new one. If False (default), the new object will be
3851+
returned without attributes.
3852+
skipna : bool, optional
3853+
If True, skip missing values (as marked by NaN). By default, only
3854+
skips missing values for float dtypes; other dtypes either do not
3855+
have a sentinel missing value (int) or skipna=True has not been
3856+
implemented (object, datetime64 or timedelta64).
3857+
3858+
Returns
3859+
-------
3860+
result : DataArray or dict of DataArray
3861+
3862+
See also
3863+
--------
3864+
Variable.argmin, DataArray.idxmin
3865+
3866+
Examples
3867+
--------
3868+
>>> array = xr.DataArray([0, 2, -1, 3], dims="x")
3869+
>>> array.min()
3870+
<xarray.DataArray ()>
3871+
array(-1)
3872+
>>> array.argmin()
3873+
<xarray.DataArray ()>
3874+
array(2)
3875+
>>> array.argmin(...)
3876+
{'x': <xarray.DataArray ()>
3877+
array(2)}
3878+
>>> array.isel(array.argmin(...))
3879+
array(-1)
3880+
3881+
>>> array = xr.DataArray([[[3, 2, 1], [3, 1, 2], [2, 1, 3]],
3882+
... [[1, 3, 2], [2, -5, 1], [2, 3, 1]]],
3883+
... dims=("x", "y", "z"))
3884+
>>> array.min(dim="x")
3885+
<xarray.DataArray (y: 3, z: 3)>
3886+
array([[ 1, 2, 1],
3887+
[ 2, -5, 1],
3888+
[ 2, 1, 1]])
3889+
Dimensions without coordinates: y, z
3890+
>>> array.argmin(dim="x")
3891+
<xarray.DataArray (y: 3, z: 3)>
3892+
array([[1, 0, 0],
3893+
[1, 1, 1],
3894+
[0, 0, 1]])
3895+
Dimensions without coordinates: y, z
3896+
>>> array.argmin(dim=["x"])
3897+
{'x': <xarray.DataArray (y: 3, z: 3)>
3898+
array([[1, 0, 0],
3899+
[1, 1, 1],
3900+
[0, 0, 1]])
3901+
Dimensions without coordinates: y, z}
3902+
>>> array.min(dim=("x", "z"))
3903+
<xarray.DataArray (y: 3)>
3904+
array([ 1, -5, 1])
3905+
Dimensions without coordinates: y
3906+
>>> array.argmin(dim=["x", "z"])
3907+
{'x': <xarray.DataArray (y: 3)>
3908+
array([0, 1, 0])
3909+
Dimensions without coordinates: y, 'z': <xarray.DataArray (y: 3)>
3910+
array([2, 1, 1])
3911+
Dimensions without coordinates: y}
3912+
>>> array.isel(array.argmin(dim=["x", "z"]))
3913+
<xarray.DataArray (y: 3)>
3914+
array([ 1, -5, 1])
3915+
Dimensions without coordinates: y
3916+
"""
3917+
result = self.variable.argmin(dim, axis, keep_attrs, skipna)
3918+
if isinstance(result, dict):
3919+
return {k: self._replace_maybe_drop_dims(v) for k, v in result.items()}
3920+
else:
3921+
return self._replace_maybe_drop_dims(result)
3922+
3923+
def argmax(
3924+
self,
3925+
dim: Union[Hashable, Sequence[Hashable]] = None,
3926+
axis: int = None,
3927+
keep_attrs: bool = None,
3928+
skipna: bool = None,
3929+
) -> Union["DataArray", Dict[Hashable, "DataArray"]]:
3930+
"""Index or indices of the maximum of the DataArray over one or more dimensions.
3931+
3932+
If a sequence is passed to 'dim', then result returned as dict of DataArrays,
3933+
which can be passed directly to isel(). If a single str is passed to 'dim' then
3934+
returns a DataArray with dtype int.
3935+
3936+
If there are multiple maxima, the indices of the first one found will be
3937+
returned.
3938+
3939+
Parameters
3940+
----------
3941+
dim : hashable, sequence of hashable or ..., optional
3942+
The dimensions over which to find the maximum. By default, finds maximum over
3943+
all dimensions - for now returning an int for backward compatibility, but
3944+
this is deprecated, in future will return a dict with indices for all
3945+
dimensions; to return a dict with all dimensions now, pass '...'.
3946+
axis : int, optional
3947+
Axis over which to apply `argmin`. Only one of the 'dim' and 'axis' arguments
3948+
can be supplied.
3949+
keep_attrs : bool, optional
3950+
If True, the attributes (`attrs`) will be copied from the original
3951+
object to the new one. If False (default), the new object will be
3952+
returned without attributes.
3953+
skipna : bool, optional
3954+
If True, skip missing values (as marked by NaN). By default, only
3955+
skips missing values for float dtypes; other dtypes either do not
3956+
have a sentinel missing value (int) or skipna=True has not been
3957+
implemented (object, datetime64 or timedelta64).
3958+
3959+
Returns
3960+
-------
3961+
result : DataArray or dict of DataArray
3962+
3963+
See also
3964+
--------
3965+
Variable.argmax, DataArray.idxmax
3966+
3967+
Examples
3968+
--------
3969+
>>> array = xr.DataArray([0, 2, -1, 3], dims="x")
3970+
>>> array.max()
3971+
<xarray.DataArray ()>
3972+
array(3)
3973+
>>> array.argmax()
3974+
<xarray.DataArray ()>
3975+
array(3)
3976+
>>> array.argmax(...)
3977+
{'x': <xarray.DataArray ()>
3978+
array(3)}
3979+
>>> array.isel(array.argmax(...))
3980+
<xarray.DataArray ()>
3981+
array(3)
3982+
3983+
>>> array = xr.DataArray([[[3, 2, 1], [3, 1, 2], [2, 1, 3]],
3984+
... [[1, 3, 2], [2, 5, 1], [2, 3, 1]]],
3985+
... dims=("x", "y", "z"))
3986+
>>> array.max(dim="x")
3987+
<xarray.DataArray (y: 3, z: 3)>
3988+
array([[3, 3, 2],
3989+
[3, 5, 2],
3990+
[2, 3, 3]])
3991+
Dimensions without coordinates: y, z
3992+
>>> array.argmax(dim="x")
3993+
<xarray.DataArray (y: 3, z: 3)>
3994+
array([[0, 1, 1],
3995+
[0, 1, 0],
3996+
[0, 1, 0]])
3997+
Dimensions without coordinates: y, z
3998+
>>> array.argmax(dim=["x"])
3999+
{'x': <xarray.DataArray (y: 3, z: 3)>
4000+
array([[0, 1, 1],
4001+
[0, 1, 0],
4002+
[0, 1, 0]])
4003+
Dimensions without coordinates: y, z}
4004+
>>> array.max(dim=("x", "z"))
4005+
<xarray.DataArray (y: 3)>
4006+
array([3, 5, 3])
4007+
Dimensions without coordinates: y
4008+
>>> array.argmax(dim=["x", "z"])
4009+
{'x': <xarray.DataArray (y: 3)>
4010+
array([0, 1, 0])
4011+
Dimensions without coordinates: y, 'z': <xarray.DataArray (y: 3)>
4012+
array([0, 1, 2])
4013+
Dimensions without coordinates: y}
4014+
>>> array.isel(array.argmax(dim=["x", "z"]))
4015+
<xarray.DataArray (y: 3)>
4016+
array([3, 5, 3])
4017+
Dimensions without coordinates: y
4018+
"""
4019+
result = self.variable.argmax(dim, axis, keep_attrs, skipna)
4020+
if isinstance(result, dict):
4021+
return {k: self._replace_maybe_drop_dims(v) for k, v in result.items()}
4022+
else:
4023+
return self._replace_maybe_drop_dims(result)
4024+
38224025
# this needs to be at the end, or mypy will confuse with `str`
38234026
# https://mypy.readthedocs.io/en/latest/common_issues.html#dealing-with-conflicting-names
38244027
str = utils.UncachedAccessor(StringAccessor)

0 commit comments

Comments
 (0)