Skip to content

Commit 6d9d509

Browse files
committed
Basic overload of argmin() and argmax() for Dataset
If single dim is passed to Dataset.argmin() or Dataset.argmax(), then pass through to _argmin_base or _argmax_base. If a sequence is passed for dim, raise an exception, because the result for each DataArray would be a dict, which cannot be stored in a Dataset.
1 parent be8b26c commit 6d9d509

File tree

1 file changed

+50
-0
lines changed

1 file changed

+50
-0
lines changed

xarray/core/dataset.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6294,5 +6294,55 @@ def idxmax(
62946294
),
62956295
)
62966296

6297+
def argmin(self, dim=None, axis=None, **kwargs):
6298+
if dim is None and axis is None:
6299+
warnings.warn(
6300+
"Behaviour of DataArray.argmin() with neither dim nor axis argument "
6301+
"will change to return a dict of indices of each dimension, and then it "
6302+
"will be an error to call Dataset.argmin() with no argument. To get a "
6303+
"single, flat index, please use np.argmin(ds) instead of ds.argmin().",
6304+
DeprecationWarning,
6305+
)
6306+
if (
6307+
dim is None
6308+
or axis is not None
6309+
or not isinstance(dim, Sequence)
6310+
or isinstance(dim, str)
6311+
):
6312+
# Return int index if single dimension is passed, and is not part of a
6313+
# sequence
6314+
return getattr(self, "_argmin_base")(dim=dim, axis=axis, **kwargs)
6315+
else:
6316+
raise ValueError(
6317+
"When dim is a sequence, DataArray.argmin() returns a "
6318+
"dict. dicts cannot be contained in a Dataset, so cannot "
6319+
"call Dataset.argmin() with a sequence for dim"
6320+
)
6321+
6322+
def argmax(self, dim=None, axis=None, **kwargs):
6323+
if dim is None and axis is None:
6324+
warnings.warn(
6325+
"Behaviour of DataArray.argmin() with neither dim nor axis argument "
6326+
"will change to return a dict of indices of each dimension, and then it "
6327+
"will be an error to call Dataset.argmin() with no argument. To get a "
6328+
"single, flat index, please use np.argmin(ds) instead of ds.argmin().",
6329+
DeprecationWarning,
6330+
)
6331+
if (
6332+
dim is None
6333+
or axis is not None
6334+
or not isinstance(dim, Sequence)
6335+
or isinstance(dim, str)
6336+
):
6337+
# Return int index if single dimension is passed, and is not part of a
6338+
# sequence
6339+
return getattr(self, "_argmax_base")(dim=dim, axis=axis, **kwargs)
6340+
else:
6341+
raise ValueError(
6342+
"When dim is a sequence, DataArray.argmax() returns a "
6343+
"dict. dicts cannot be contained in a Dataset, so cannot "
6344+
"call Dataset.argmax() with a sequence for dim"
6345+
)
6346+
62976347

62986348
ops.inject_all_ops_and_reduce_methods(Dataset, array_only=False)

0 commit comments

Comments
 (0)