Skip to content

Commit 0f637c9

Browse files
committed
allow all np.ndarray subclasses
1 parent 1a38d93 commit 0f637c9

File tree

2 files changed

+22
-11
lines changed

2 files changed

+22
-11
lines changed

xarray/core/variable.py

+9-11
Original file line numberDiff line numberDiff line change
@@ -321,17 +321,15 @@ def convert_non_numpy_type(data):
321321
else:
322322
data = np.asarray(data)
323323

324-
_is_array_like = isinstance(data, np.ndarray | np.generic)
325-
_is_nep18 = hasattr(data, "__array_function__")
326-
_has_array_api = hasattr(data, "__array_namespace__")
327-
_has_unit = hasattr(data, "_unit")
328-
329-
# Allow `astropy.units.Quantity`
330-
if _is_array_like and (_is_nep18 or _has_array_api) and _has_unit:
331-
return cast("T_DuckArray", data)
332-
333-
# immediately return array-like types except `numpy.ndarray` subclasses and `numpy` scalars
334-
if not _is_array_like and (_is_nep18 or _has_array_api):
324+
if isinstance(data, np.matrix):
325+
data = np.asarray(data)
326+
327+
# immediately return array-like types except `numpy.ndarray` and `numpy` scalars
328+
# compare types with `is` instead of `isinstance` to allow `numpy.ndarray` subclasses
329+
is_numpy = type(data) is np.ndarray or isinstance(data, np.generic)
330+
if not is_numpy and (
331+
hasattr(data, "__array_function__") or hasattr(data, "__array_namespace__")
332+
):
335333
return cast("T_DuckArray", data)
336334

337335
# validate whether the data is valid data types. Also, explicitly cast `numpy`

xarray/tests/test_variable.py

+13
Original file line numberDiff line numberDiff line change
@@ -2746,6 +2746,19 @@ def test_ones_like(self) -> None:
27462746
assert_identical(ones_like(orig), full_like(orig, 1))
27472747
assert_identical(ones_like(orig, dtype=int), full_like(orig, 1, dtype=int))
27482748

2749+
def test_numpy_ndarray_subclass(self):
2750+
class SubclassedArray(np.ndarray):
2751+
def __new__(cls, array, foo):
2752+
obj = np.asarray(array).view(cls)
2753+
obj.foo = foo
2754+
return obj
2755+
2756+
data = SubclassedArray([1, 2, 3], foo="bar")
2757+
actual = as_compatible_data(data)
2758+
assert isinstance(actual, SubclassedArray)
2759+
assert actual.foo == "bar"
2760+
assert_array_equal(data, actual)
2761+
27492762
def test_unsupported_type(self):
27502763
# Non indexable type
27512764
class CustomArray(NDArrayMixin):

0 commit comments

Comments
 (0)