Skip to content

(fix): disallow NumpyExtensionArray #10334

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
42 changes: 36 additions & 6 deletions properties/test_pandas_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import hypothesis.extra.pandas as pdst # isort:skip
import hypothesis.strategies as st # isort:skip
from hypothesis import given # isort:skip
from xarray.tests import has_pyarrow

numeric_dtypes = st.one_of(
npst.unsigned_integer_dtypes(endianness="="),
Expand Down Expand Up @@ -134,10 +135,39 @@ def test_roundtrip_pandas_dataframe_datetime(df) -> None:
xr.testing.assert_identical(dataset, roundtripped.to_xarray())


def test_roundtrip_1d_pandas_extension_array() -> None:
df = pd.DataFrame({"cat": pd.Categorical(["a", "b", "c"])})
arr = xr.Dataset.from_dataframe(df)["cat"]
@pytest.mark.parametrize(
"extension_array",
[
pd.Categorical(["a", "b", "c"]),
pd.array(["a", "b", "c"], dtype="string"),
pd.arrays.IntervalArray(
[pd.Interval(0, 1), pd.Interval(1, 5), pd.Interval(2, 6)]
),
pd.arrays.TimedeltaArray._from_sequence(pd.TimedeltaIndex(["1h", "2h", "3h"])),
pd.arrays.DatetimeArray._from_sequence(
pd.DatetimeIndex(["2023-01-01", "2023-01-02", "2023-01-03"], freq="D")
),
np.array([1, 2, 3], dtype="int64"),
]
+ ([pd.array([1, 2, 3], dtype="int64[pyarrow]")] if has_pyarrow else []),
ids=["cat", "string", "interval", "timedelta", "datetime", "numpy"]
+ (["pyarrow"] if has_pyarrow else []),
Comment on lines +150 to +154
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
np.array([1, 2, 3], dtype="int64"),
]
+ ([pd.array([1, 2, 3], dtype="int64[pyarrow]")] if has_pyarrow else []),
ids=["cat", "string", "interval", "timedelta", "datetime", "numpy"]
+ (["pyarrow"] if has_pyarrow else []),
np.array([1, 2, 3], dtype="int64"),
pytest.param(pd.array([1, 2, 3], dtype="int64[pyarrow]"), marks=pytest.mark.skipif(not has_pyarrow)),
]

will need to add the ids to each param individually though

Copy link
Contributor Author

@ilan-gold ilan-gold Jun 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried this originall with the requires_pyarrow mark (which uses skipif), but the issue is that (annoyingly), it seems the param is still created without pyarrow:

b959345

and

https://github.com/pydata/xarray/actions/runs/15346716234/job/43184610561

Maybe there is some other way without my ternary I'm missing but this one looks identical to the above linked broken behavior

)
@pytest.mark.parametrize("is_index", [True, False])
def test_roundtrip_1d_pandas_extension_array(extension_array, is_index) -> None:
df = pd.DataFrame({"arr": extension_array})
if is_index:
df = df.set_index("arr")
arr = xr.Dataset.from_dataframe(df)["arr"]
roundtripped = arr.to_pandas()
assert (df["cat"] == roundtripped).all()
assert df["cat"].dtype == roundtripped.dtype
xr.testing.assert_identical(arr, roundtripped.to_xarray())
df_arr_to_test = df.index if is_index else df["arr"]
assert (df_arr_to_test == roundtripped).all()
# `NumpyExtensionArray` types are not roundtripped, including `StringArray` which subtypes.
if isinstance(extension_array, pd.arrays.NumpyExtensionArray): # type: ignore[attr-defined]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's cast the arrow ones for now too and relax it explicitly in a later PR please

Suggested change
if isinstance(extension_array, pd.arrays.NumpyExtensionArray): # type: ignore[attr-defined]
if isinstance(extension_array, pd.arrays.NumpyExtensionArray | pd.arrays.ArrowExtensionArray): # type: ignore[attr-defined]

Copy link
Contributor Author

@ilan-gold ilan-gold Jun 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would the plan be here, to convert to numpy? Why would we do this (previously we were using arrow for testing quite explicitly, so this wish comes as a surprise)? I wouldn't want to do this at first glance as it would represent a regression from our stated goal of supporting as many extension arrays as possible. Would the ultimate goal be to have special handling for certain kinds of arrow arrays?

Furthermore, in the case of as_compatible_data where the incoming object is a series/dataframe with UNSUPPORTED_EXTENSION_ARRAY_TYPES, we call {Series,DataFrame}.values, which actually just gives back the arrow extension array when .array is one (the reason for this is to maintain some datetime compat, which can be seen by changing these lines). So we would need a new branch there for handling the arrow case now there.

Lastly, we would lose all of our test cases that rely on arrow, which is quite a lot of our extension array coverage at the moment (and has been a good way to catch bugs). So we'd have to come up with a new extension array for testing that isn't categorical (I think IntervalArray is the only option since we have special handling for datetimes), which feels limiting.

assert isinstance(arr.data, np.ndarray)
else:
assert (
df_arr_to_test.dtype
== (roundtripped.index if is_index else roundtripped).dtype
)
xr.testing.assert_identical(arr, roundtripped.to_xarray())
3 changes: 2 additions & 1 deletion xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@
parse_dims_as_set,
)
from xarray.core.variable import (
UNSUPPORTED_EXTENSION_ARRAY_TYPES,
IndexVariable,
Variable,
as_variable,
Expand Down Expand Up @@ -7267,7 +7268,7 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Self:
extension_arrays = []
for k, v in dataframe.items():
if not is_extension_array_dtype(v) or isinstance(
v.array, pd.arrays.DatetimeArray | pd.arrays.TimedeltaArray
v.array, UNSUPPORTED_EXTENSION_ARRAY_TYPES
):
arrays.append((k, np.asarray(v)))
else:
Expand Down
7 changes: 7 additions & 0 deletions xarray/core/extension_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,13 @@ class PandasExtensionArray(Generic[T_ExtensionArray], NDArrayMixin):
def __post_init__(self):
if not isinstance(self.array, pd.api.extensions.ExtensionArray):
raise TypeError(f"{self.array} is not an pandas ExtensionArray.")
# This does not use the UNSUPPORTED_EXTENSION_ARRAY_TYPES whitelist because
# we do support extension arrays from datetime, for example, that need
# duck array support internally via this class.
if isinstance(self.array, pd.arrays.NumpyExtensionArray):
raise TypeError(
"`NumpyExtensionArray` should be converted to a numpy array in `xarray` internally."
)

def __array_function__(self, func, types, args, kwargs):
def replace_duck_with_extension_array(args) -> list:
Expand Down
8 changes: 6 additions & 2 deletions xarray/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1802,8 +1802,12 @@ def __array__(

def get_duck_array(self) -> np.ndarray | PandasExtensionArray:
# We return an PandasExtensionArray wrapper type that satisfies
# duck array protocols. This is what's needed for tests to pass.
if pd.api.types.is_extension_array_dtype(self.array):
# duck array protocols.
# `NumpyExtensionArray` is excluded
if pd.api.types.is_extension_array_dtype(self.array) and not isinstance(
self.array.array,
pd.arrays.NumpyExtensionArray, # type: ignore[attr-defined]
):
from xarray.core.extension_array import PandasExtensionArray

return PandasExtensionArray(self.array.array)
Expand Down
16 changes: 15 additions & 1 deletion xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@
)
# https://github.com/python/mypy/issues/224
BASIC_INDEXING_TYPES = integer_types + (slice,)
UNSUPPORTED_EXTENSION_ARRAY_TYPES = (
pd.arrays.DatetimeArray,
pd.arrays.TimedeltaArray,
pd.arrays.NumpyExtensionArray, # type: ignore[attr-defined]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
pd.arrays.NumpyExtensionArray, # type: ignore[attr-defined]
pd.arrays.NumpyExtensionArray, # type: ignore[attr-defined]
pd.arrays.ArrowExtensionArray,

)

if TYPE_CHECKING:
from xarray.core.types import (
Expand Down Expand Up @@ -190,6 +195,8 @@ def _maybe_wrap_data(data):
"""
if isinstance(data, pd.Index):
return PandasIndexingAdapter(data)
if isinstance(data, UNSUPPORTED_EXTENSION_ARRAY_TYPES):
return data.to_numpy()
if isinstance(data, pd.api.extensions.ExtensionArray):
return PandasExtensionArray(data)
return data
Expand Down Expand Up @@ -251,7 +258,14 @@ def convert_non_numpy_type(data):

# we don't want nested self-described arrays
if isinstance(data, pd.Series | pd.DataFrame):
pandas_data = data.values
if (
isinstance(data, pd.Series)
and pd.api.types.is_extension_array_dtype(data)
and not isinstance(data.array, UNSUPPORTED_EXTENSION_ARRAY_TYPES)
):
pandas_data = data.array
else:
pandas_data = data.values # type: ignore[assignment]
if isinstance(pandas_data, NON_NUMPY_SUPPORTED_ARRAY_TYPES):
return convert_non_numpy_type(pandas_data)
else:
Expand Down
Loading