-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
(fix): disallow NumpyExtensionArray
#10334
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
9312d2b
d09c0f5
88e4841
174274d
d9388f0
b87b380
6329964
a29c526
c6ac491
50843ca
b959345
2d33aaa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -15,6 +15,7 @@ | |||||
import hypothesis.extra.pandas as pdst # isort:skip | ||||||
import hypothesis.strategies as st # isort:skip | ||||||
from hypothesis import given # isort:skip | ||||||
from xarray.tests import has_pyarrow | ||||||
|
||||||
numeric_dtypes = st.one_of( | ||||||
npst.unsigned_integer_dtypes(endianness="="), | ||||||
|
@@ -134,10 +135,39 @@ def test_roundtrip_pandas_dataframe_datetime(df) -> None: | |||||
xr.testing.assert_identical(dataset, roundtripped.to_xarray()) | ||||||
|
||||||
|
||||||
def test_roundtrip_1d_pandas_extension_array() -> None: | ||||||
df = pd.DataFrame({"cat": pd.Categorical(["a", "b", "c"])}) | ||||||
arr = xr.Dataset.from_dataframe(df)["cat"] | ||||||
@pytest.mark.parametrize( | ||||||
"extension_array", | ||||||
[ | ||||||
pd.Categorical(["a", "b", "c"]), | ||||||
pd.array(["a", "b", "c"], dtype="string"), | ||||||
pd.arrays.IntervalArray( | ||||||
[pd.Interval(0, 1), pd.Interval(1, 5), pd.Interval(2, 6)] | ||||||
), | ||||||
pd.arrays.TimedeltaArray._from_sequence(pd.TimedeltaIndex(["1h", "2h", "3h"])), | ||||||
pd.arrays.DatetimeArray._from_sequence( | ||||||
pd.DatetimeIndex(["2023-01-01", "2023-01-02", "2023-01-03"], freq="D") | ||||||
), | ||||||
np.array([1, 2, 3], dtype="int64"), | ||||||
] | ||||||
+ ([pd.array([1, 2, 3], dtype="int64[pyarrow]")] if has_pyarrow else []), | ||||||
ids=["cat", "string", "interval", "timedelta", "datetime", "numpy"] | ||||||
+ (["pyarrow"] if has_pyarrow else []), | ||||||
) | ||||||
@pytest.mark.parametrize("is_index", [True, False]) | ||||||
def test_roundtrip_1d_pandas_extension_array(extension_array, is_index) -> None: | ||||||
df = pd.DataFrame({"arr": extension_array}) | ||||||
if is_index: | ||||||
df = df.set_index("arr") | ||||||
arr = xr.Dataset.from_dataframe(df)["arr"] | ||||||
roundtripped = arr.to_pandas() | ||||||
assert (df["cat"] == roundtripped).all() | ||||||
assert df["cat"].dtype == roundtripped.dtype | ||||||
xr.testing.assert_identical(arr, roundtripped.to_xarray()) | ||||||
df_arr_to_test = df.index if is_index else df["arr"] | ||||||
assert (df_arr_to_test == roundtripped).all() | ||||||
# `NumpyExtensionArray` types are not roundtripped, including `StringArray` which subtypes. | ||||||
if isinstance(extension_array, pd.arrays.NumpyExtensionArray): # type: ignore[attr-defined] | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's cast the arrow ones for now too and relax it explicitly in a later PR please
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What would the plan be here, to convert to numpy? Why would we do this (previously we were using arrow for testing quite explicitly, so this wish comes as a surprise)? I wouldn't want to do this at first glance as it would represent a regression from our stated goal of supporting as many extension arrays as possible. Would the ultimate goal be to have special handling for certain kinds of arrow arrays? Furthermore, in the case of Lastly, we would lose all of our test cases that rely on arrow, which is quite a lot of our extension array coverage at the moment (and has been a good way to catch bugs). So we'd have to come up with a new extension array for testing that isn't categorical (I think IntervalArray is the only option since we have special handling for datetimes), which feels limiting. |
||||||
assert isinstance(arr.data, np.ndarray) | ||||||
else: | ||||||
assert ( | ||||||
df_arr_to_test.dtype | ||||||
== (roundtripped.index if is_index else roundtripped).dtype | ||||||
) | ||||||
xr.testing.assert_identical(arr, roundtripped.to_xarray()) |
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -63,6 +63,11 @@ | |||||||
) | ||||||||
# https://github.com/python/mypy/issues/224 | ||||||||
BASIC_INDEXING_TYPES = integer_types + (slice,) | ||||||||
UNSUPPORTED_EXTENSION_ARRAY_TYPES = ( | ||||||||
pd.arrays.DatetimeArray, | ||||||||
pd.arrays.TimedeltaArray, | ||||||||
pd.arrays.NumpyExtensionArray, # type: ignore[attr-defined] | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
) | ||||||||
|
||||||||
if TYPE_CHECKING: | ||||||||
from xarray.core.types import ( | ||||||||
|
@@ -190,6 +195,8 @@ def _maybe_wrap_data(data): | |||||||
""" | ||||||||
if isinstance(data, pd.Index): | ||||||||
return PandasIndexingAdapter(data) | ||||||||
if isinstance(data, UNSUPPORTED_EXTENSION_ARRAY_TYPES): | ||||||||
return data.to_numpy() | ||||||||
if isinstance(data, pd.api.extensions.ExtensionArray): | ||||||||
return PandasExtensionArray(data) | ||||||||
return data | ||||||||
|
@@ -251,7 +258,14 @@ def convert_non_numpy_type(data): | |||||||
|
||||||||
# we don't want nested self-described arrays | ||||||||
if isinstance(data, pd.Series | pd.DataFrame): | ||||||||
pandas_data = data.values | ||||||||
if ( | ||||||||
isinstance(data, pd.Series) | ||||||||
and pd.api.types.is_extension_array_dtype(data) | ||||||||
and not isinstance(data.array, UNSUPPORTED_EXTENSION_ARRAY_TYPES) | ||||||||
): | ||||||||
pandas_data = data.array | ||||||||
else: | ||||||||
pandas_data = data.values # type: ignore[assignment] | ||||||||
if isinstance(pandas_data, NON_NUMPY_SUPPORTED_ARRAY_TYPES): | ||||||||
return convert_non_numpy_type(pandas_data) | ||||||||
else: | ||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will need to add the ids to each param individually though
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried this originall with the
requires_pyarrow
mark (which usesskipif
), but the issue is that (annoyingly), it seems the param is still created without pyarrow:b959345
and
https://github.com/pydata/xarray/actions/runs/15346716234/job/43184610561
Maybe there is some other way without my ternary I'm missing but this one looks identical to the above linked broken behavior