Skip to content

Commit 2cd8f96

Browse files
authored
Allow a function in .sortby method (#8273)
1 parent 938579d commit 2cd8f96

File tree

4 files changed

+49
-16
lines changed

4 files changed

+49
-16
lines changed

doc/whats-new.rst

+4-1
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,12 @@ New Features
2323
~~~~~~~~~~~~
2424

2525
- :py:meth:`DataArray.where` & :py:meth:`Dataset.where` accept a callable for
26-
the ``other`` parameter, passing the object as the first argument. Previously,
26+
the ``other`` parameter, passing the object as the only argument. Previously,
2727
this was only valid for the ``cond`` parameter. (:issue:`8255`)
2828
By `Maximilian Roos <https://github.com/max-sixty>`_.
29+
- :py:meth:`DataArray.sortby` & :py:meth:`Dataset.sortby` accept a callable for
30+
the ``variables`` parameter, passing the object as the only argument.
31+
By `Maximilian Roos <https://github.com/max-sixty>`_.
2932

3033
Breaking changes
3134
~~~~~~~~~~~~~~~~

xarray/core/common.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1073,7 +1073,8 @@ def where(self, cond: Any, other: Any = dtypes.NA, drop: bool = False) -> Self:
10731073
----------
10741074
cond : DataArray, Dataset, or callable
10751075
Locations at which to preserve this object's values. dtype must be `bool`.
1076-
If a callable, it must expect this object as its only parameter.
1076+
If a callable, the callable is passed this object, and the result is used as
1077+
the value for cond.
10771078
other : scalar, DataArray, Dataset, or callable, optional
10781079
Value to use for locations in this object where ``cond`` is False.
10791080
By default, these locations are filled with NA. If a callable, it must

xarray/core/dataarray.py

+23-8
Original file line numberDiff line numberDiff line change
@@ -4941,7 +4941,10 @@ def dot(
49414941

49424942
def sortby(
49434943
self,
4944-
variables: Hashable | DataArray | Sequence[Hashable | DataArray],
4944+
variables: Hashable
4945+
| DataArray
4946+
| Sequence[Hashable | DataArray]
4947+
| Callable[[Self], Hashable | DataArray | Sequence[Hashable | DataArray]],
49454948
ascending: bool = True,
49464949
) -> Self:
49474950
"""Sort object by labels or values (along an axis).
@@ -4962,9 +4965,10 @@ def sortby(
49624965
49634966
Parameters
49644967
----------
4965-
variables : Hashable, DataArray, or sequence of Hashable or DataArray
4966-
1D DataArray objects or name(s) of 1D variable(s) in
4967-
coords whose values are used to sort this array.
4968+
variables : Hashable, DataArray, sequence of Hashable or DataArray, or Callable
4969+
1D DataArray objects or name(s) of 1D variable(s) in coords whose values are
4970+
used to sort this array. If a callable, the callable is passed this object,
4971+
and the result is used as the value for cond.
49684972
ascending : bool, default: True
49694973
Whether to sort by ascending or descending order.
49704974
@@ -4984,22 +4988,33 @@ def sortby(
49844988
Examples
49854989
--------
49864990
>>> da = xr.DataArray(
4987-
... np.random.rand(5),
4991+
... np.arange(5, 0, -1),
49884992
... coords=[pd.date_range("1/1/2000", periods=5)],
49894993
... dims="time",
49904994
... )
49914995
>>> da
49924996
<xarray.DataArray (time: 5)>
4993-
array([0.5488135 , 0.71518937, 0.60276338, 0.54488318, 0.4236548 ])
4997+
array([5, 4, 3, 2, 1])
49944998
Coordinates:
49954999
* time (time) datetime64[ns] 2000-01-01 2000-01-02 ... 2000-01-05
49965000
49975001
>>> da.sortby(da)
49985002
<xarray.DataArray (time: 5)>
4999-
array([0.4236548 , 0.54488318, 0.5488135 , 0.60276338, 0.71518937])
5003+
array([1, 2, 3, 4, 5])
50005004
Coordinates:
5001-
* time (time) datetime64[ns] 2000-01-05 2000-01-04 ... 2000-01-02
5005+
* time (time) datetime64[ns] 2000-01-05 2000-01-04 ... 2000-01-01
5006+
5007+
>>> da.sortby(lambda x: x)
5008+
<xarray.DataArray (time: 5)>
5009+
array([1, 2, 3, 4, 5])
5010+
Coordinates:
5011+
* time (time) datetime64[ns] 2000-01-05 2000-01-04 ... 2000-01-01
50025012
"""
5013+
# We need to convert the callable here rather than pass it through to the
5014+
# dataset method, since otherwise the dataset method would try to call the
5015+
# callable with the dataset as the object
5016+
if callable(variables):
5017+
variables = variables(self)
50035018
ds = self._to_temp_dataset().sortby(variables, ascending=ascending)
50045019
return self._from_temp_dataset(ds)
50055020

xarray/core/dataset.py

+20-6
Original file line numberDiff line numberDiff line change
@@ -7838,7 +7838,10 @@ def roll(
78387838

78397839
def sortby(
78407840
self,
7841-
variables: Hashable | DataArray | list[Hashable | DataArray],
7841+
variables: Hashable
7842+
| DataArray
7843+
| Sequence[Hashable | DataArray]
7844+
| Callable[[Self], Hashable | DataArray | list[Hashable | DataArray]],
78427845
ascending: bool = True,
78437846
) -> Self:
78447847
"""
@@ -7860,9 +7863,10 @@ def sortby(
78607863
78617864
Parameters
78627865
----------
7863-
variables : Hashable, DataArray, or list of hashable or DataArray
7864-
1D DataArray objects or name(s) of 1D variable(s) in
7865-
coords/data_vars whose values are used to sort the dataset.
7866+
kariables : Hashable, DataArray, sequence of Hashable or DataArray, or Callable
7867+
1D DataArray objects or name(s) of 1D variable(s) in coords whose values are
7868+
used to sort this array. If a callable, the callable is passed this object,
7869+
and the result is used as the value for cond.
78667870
ascending : bool, default: True
78677871
Whether to sort by ascending or descending order.
78687872
@@ -7888,8 +7892,7 @@ def sortby(
78887892
... },
78897893
... coords={"x": ["b", "a"], "y": [1, 0]},
78907894
... )
7891-
>>> ds = ds.sortby("x")
7892-
>>> ds
7895+
>>> ds.sortby("x")
78937896
<xarray.Dataset>
78947897
Dimensions: (x: 2, y: 2)
78957898
Coordinates:
@@ -7898,9 +7901,20 @@ def sortby(
78987901
Data variables:
78997902
A (x, y) int64 3 4 1 2
79007903
B (x, y) int64 7 8 5 6
7904+
>>> ds.sortby(lambda x: -x["y"])
7905+
<xarray.Dataset>
7906+
Dimensions: (x: 2, y: 2)
7907+
Coordinates:
7908+
* x (x) <U1 'b' 'a'
7909+
* y (y) int64 1 0
7910+
Data variables:
7911+
A (x, y) int64 1 2 3 4
7912+
B (x, y) int64 5 6 7 8
79017913
"""
79027914
from xarray.core.dataarray import DataArray
79037915

7916+
if callable(variables):
7917+
variables = variables(self)
79047918
if not isinstance(variables, list):
79057919
variables = [variables]
79067920
else:

0 commit comments

Comments
 (0)