Skip to content

Commit bd40c20

Browse files
Use duck array ops in more places (#8267)
* Use duck array ops for `reshape` * Use duck array ops for `sum` * Use duck array ops for `astype` * Use duck array ops for `ravel` * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update what's new --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent e09609c commit bd40c20

File tree

8 files changed

+35
-17
lines changed

8 files changed

+35
-17
lines changed

doc/whats-new.rst

+4
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ Documentation
5656
Internal Changes
5757
~~~~~~~~~~~~~~~~
5858

59+
- More improvements to support the Python `array API standard <https://data-apis.org/array-api/latest/>`_
60+
by using duck array ops in more places in the codebase. (:pull:`8267`)
61+
By `Tom White <https://github.com/tomwhite>`_.
62+
5963

6064
.. _whats-new.2023.09.0:
6165

xarray/core/accessor_dt.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import pandas as pd
88

99
from xarray.coding.times import infer_calendar_name
10+
from xarray.core import duck_array_ops
1011
from xarray.core.common import (
1112
_contains_datetime_like_objects,
1213
is_np_datetime_like,
@@ -50,7 +51,7 @@ def _access_through_cftimeindex(values, name):
5051
from xarray.coding.cftimeindex import CFTimeIndex
5152

5253
if not isinstance(values, CFTimeIndex):
53-
values_as_cftimeindex = CFTimeIndex(values.ravel())
54+
values_as_cftimeindex = CFTimeIndex(duck_array_ops.ravel(values))
5455
else:
5556
values_as_cftimeindex = values
5657
if name == "season":
@@ -69,7 +70,7 @@ def _access_through_series(values, name):
6970
"""Coerce an array of datetime-like values to a pandas Series and
7071
access requested datetime component
7172
"""
72-
values_as_series = pd.Series(values.ravel(), copy=False)
73+
values_as_series = pd.Series(duck_array_ops.ravel(values), copy=False)
7374
if name == "season":
7475
months = values_as_series.dt.month.values
7576
field_values = _season_from_months(months)
@@ -148,10 +149,10 @@ def _round_through_series_or_index(values, name, freq):
148149
from xarray.coding.cftimeindex import CFTimeIndex
149150

150151
if is_np_datetime_like(values.dtype):
151-
values_as_series = pd.Series(values.ravel(), copy=False)
152+
values_as_series = pd.Series(duck_array_ops.ravel(values), copy=False)
152153
method = getattr(values_as_series.dt, name)
153154
else:
154-
values_as_cftimeindex = CFTimeIndex(values.ravel())
155+
values_as_cftimeindex = CFTimeIndex(duck_array_ops.ravel(values))
155156
method = getattr(values_as_cftimeindex, name)
156157

157158
field_values = method(freq=freq).values
@@ -195,7 +196,7 @@ def _strftime_through_cftimeindex(values, date_format: str):
195196
"""
196197
from xarray.coding.cftimeindex import CFTimeIndex
197198

198-
values_as_cftimeindex = CFTimeIndex(values.ravel())
199+
values_as_cftimeindex = CFTimeIndex(duck_array_ops.ravel(values))
199200

200201
field_values = values_as_cftimeindex.strftime(date_format)
201202
return field_values.values.reshape(values.shape)
@@ -205,7 +206,7 @@ def _strftime_through_series(values, date_format: str):
205206
"""Coerce an array of datetime-like values to a pandas Series and
206207
apply string formatting
207208
"""
208-
values_as_series = pd.Series(values.ravel(), copy=False)
209+
values_as_series = pd.Series(duck_array_ops.ravel(values), copy=False)
209210
strs = values_as_series.dt.strftime(date_format)
210211
return strs.values.reshape(values.shape)
211212

xarray/core/computation.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -2123,7 +2123,8 @@ def _calc_idxminmax(
21232123
chunkmanager = get_chunked_array_type(array.data)
21242124
chunks = dict(zip(array.dims, array.chunks))
21252125
dask_coord = chunkmanager.from_array(array[dim].data, chunks=chunks[dim])
2126-
res = indx.copy(data=dask_coord[indx.data.ravel()].reshape(indx.shape))
2126+
data = dask_coord[duck_array_ops.ravel(indx.data)]
2127+
res = indx.copy(data=duck_array_ops.reshape(data, indx.shape))
21272128
# we need to attach back the dim name
21282129
res.name = dim
21292130
else:

xarray/core/duck_array_ops.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,10 @@ def reshape(array, shape):
337337
return xp.reshape(array, shape)
338338

339339

340+
def ravel(array):
341+
return reshape(array, (-1,))
342+
343+
340344
@contextlib.contextmanager
341345
def _ignore_warnings_if(condition):
342346
if condition:
@@ -363,7 +367,7 @@ def f(values, axis=None, skipna=None, **kwargs):
363367
values = asarray(values)
364368

365369
if coerce_strings and values.dtype.kind in "SU":
366-
values = values.astype(object)
370+
values = astype(values, object)
367371

368372
func = None
369373
if skipna or (skipna is None and values.dtype.kind in "cfO"):

xarray/core/nanops.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import numpy as np
66

7-
from xarray.core import dtypes, nputils, utils
7+
from xarray.core import dtypes, duck_array_ops, nputils, utils
88
from xarray.core.duck_array_ops import (
99
astype,
1010
count,
@@ -21,12 +21,16 @@ def _maybe_null_out(result, axis, mask, min_count=1):
2121
xarray version of pandas.core.nanops._maybe_null_out
2222
"""
2323
if axis is not None and getattr(result, "ndim", False):
24-
null_mask = (np.take(mask.shape, axis).prod() - mask.sum(axis) - min_count) < 0
24+
null_mask = (
25+
np.take(mask.shape, axis).prod()
26+
- duck_array_ops.sum(mask, axis)
27+
- min_count
28+
) < 0
2529
dtype, fill_value = dtypes.maybe_promote(result.dtype)
2630
result = where(null_mask, fill_value, astype(result, dtype))
2731

2832
elif getattr(result, "dtype", None) not in dtypes.NAT_TYPES:
29-
null_mask = mask.size - mask.sum()
33+
null_mask = mask.size - duck_array_ops.sum(mask)
3034
result = where(null_mask < min_count, np.nan, result)
3135

3236
return result

xarray/core/variable.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2324,7 +2324,7 @@ def coarsen_reshape(self, windows, boundary, side):
23242324
else:
23252325
shape.append(variable.shape[i])
23262326

2327-
return variable.data.reshape(shape), tuple(axes)
2327+
return duck_array_ops.reshape(variable.data, shape), tuple(axes)
23282328

23292329
def isnull(self, keep_attrs: bool | None = None):
23302330
"""Test each value in the array for whether it is a missing value.

xarray/tests/test_coarsen.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import xarray as xr
88
from xarray import DataArray, Dataset, set_options
9+
from xarray.core import duck_array_ops
910
from xarray.tests import (
1011
assert_allclose,
1112
assert_equal,
@@ -272,21 +273,24 @@ def test_coarsen_construct(self, dask: bool) -> None:
272273
expected = xr.Dataset(attrs={"foo": "bar"})
273274
expected["vart"] = (
274275
("year", "month"),
275-
ds.vart.data.reshape((-1, 12)),
276+
duck_array_ops.reshape(ds.vart.data, (-1, 12)),
276277
{"a": "b"},
277278
)
278279
expected["varx"] = (
279280
("x", "x_reshaped"),
280-
ds.varx.data.reshape((-1, 5)),
281+
duck_array_ops.reshape(ds.varx.data, (-1, 5)),
281282
{"a": "b"},
282283
)
283284
expected["vartx"] = (
284285
("x", "x_reshaped", "year", "month"),
285-
ds.vartx.data.reshape(2, 5, 4, 12),
286+
duck_array_ops.reshape(ds.vartx.data, (2, 5, 4, 12)),
286287
{"a": "b"},
287288
)
288289
expected["vary"] = ds.vary
289-
expected.coords["time"] = (("year", "month"), ds.time.data.reshape((-1, 12)))
290+
expected.coords["time"] = (
291+
("year", "month"),
292+
duck_array_ops.reshape(ds.time.data, (-1, 12)),
293+
)
290294

291295
with raise_if_dask_computes():
292296
actual = ds.coarsen(time=12, x=5).construct(

xarray/tests/test_variable.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -916,7 +916,7 @@ def test_pad_constant_values(self, xr_arg, np_arg):
916916

917917
actual = v.pad(**xr_arg)
918918
expected = np.pad(
919-
np.array(v.data.astype(float)),
919+
np.array(duck_array_ops.astype(v.data, float)),
920920
np_arg,
921921
mode="constant",
922922
constant_values=np.nan,

0 commit comments

Comments
 (0)