Skip to content

Commit f5d22a6

Browse files
authored
Fix unstack method when wrapping array api class (#8668)
* add test for unstacking * fix bug with unstack * some other occurrences of reshape * whatsnew
1 parent 6a6404a commit f5d22a6

File tree

4 files changed

+19
-4
lines changed

4 files changed

+19
-4
lines changed

doc/whats-new.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ Deprecations
3737
Bug fixes
3838
~~~~~~~~~
3939

40+
- Ensure :py:meth:`DataArray.unstack` works when wrapping array API-compliant classes. (:issue:`8666`, :pull:`8668`)
41+
By `Tom Nicholas <https://github.com/TomNicholas>`_.
4042

4143
Documentation
4244
~~~~~~~~~~~~~

xarray/core/missing.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,12 @@
1313
from xarray.core import utils
1414
from xarray.core.common import _contains_datetime_like_objects, ones_like
1515
from xarray.core.computation import apply_ufunc
16-
from xarray.core.duck_array_ops import datetime_to_numeric, push, timedelta_to_numeric
16+
from xarray.core.duck_array_ops import (
17+
datetime_to_numeric,
18+
push,
19+
reshape,
20+
timedelta_to_numeric,
21+
)
1722
from xarray.core.options import _get_keep_attrs
1823
from xarray.core.parallelcompat import get_chunked_array_type, is_chunked_array
1924
from xarray.core.types import Interp1dOptions, InterpOptions
@@ -748,7 +753,7 @@ def _interp1d(var, x, new_x, func, kwargs):
748753
x, new_x = x[0], new_x[0]
749754
rslt = func(x, var, assume_sorted=True, **kwargs)(np.ravel(new_x))
750755
if new_x.ndim > 1:
751-
return rslt.reshape(var.shape[:-1] + new_x.shape)
756+
return reshape(rslt, (var.shape[:-1] + new_x.shape))
752757
if new_x.ndim == 0:
753758
return rslt[..., -1]
754759
return rslt
@@ -767,7 +772,7 @@ def _interpnd(var, x, new_x, func, kwargs):
767772
rslt = func(x, var, xi, **kwargs)
768773
# move back the interpolation axes to the last position
769774
rslt = rslt.transpose(range(-rslt.ndim + 1, 1))
770-
return rslt.reshape(rslt.shape[:-1] + new_x[0].shape)
775+
return reshape(rslt, rslt.shape[:-1] + new_x[0].shape)
771776

772777

773778
def _chunked_aware_interpnd(var, *coords, interp_func, interp_kwargs, localize=True):

xarray/core/variable.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1571,7 +1571,7 @@ def _unstack_once_full(self, dim: Mapping[Any, int], old_dim: Hashable) -> Self:
15711571
reordered = self.transpose(*dim_order)
15721572

15731573
new_shape = reordered.shape[: len(other_dims)] + new_dim_sizes
1574-
new_data = reordered.data.reshape(new_shape)
1574+
new_data = duck_array_ops.reshape(reordered.data, new_shape)
15751575
new_dims = reordered.dims[: len(other_dims)] + new_dim_names
15761576

15771577
return type(self)(

xarray/tests/test_array_api.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,14 @@ def test_stack(arrays: tuple[xr.DataArray, xr.DataArray]) -> None:
115115
assert_equal(actual, expected)
116116

117117

118+
def test_unstack(arrays: tuple[xr.DataArray, xr.DataArray]) -> None:
119+
np_arr, xp_arr = arrays
120+
expected = np_arr.stack(z=("x", "y")).unstack()
121+
actual = xp_arr.stack(z=("x", "y")).unstack()
122+
assert isinstance(actual.data, Array)
123+
assert_equal(actual, expected)
124+
125+
118126
def test_where() -> None:
119127
np_arr = xr.DataArray(np.array([1, 0]), dims="x")
120128
xp_arr = xr.DataArray(xp.asarray([1, 0]), dims="x")

0 commit comments

Comments
 (0)