Skip to content

Commit 3cbf960

Browse files
authored
(fix): pandas extension array repr for int64[pyarrow] (#10317)
1 parent d589df1 commit 3cbf960

File tree

6 files changed

+55
-22
lines changed

6 files changed

+55
-22
lines changed

xarray/core/extension_array.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,17 @@ def __extension_duck_array__where(
6666
return cast(T_ExtensionArray, pd.Series(x).where(condition, pd.Series(y)).array)
6767

6868

69+
@implements(np.reshape)
70+
def __extension_duck_array__reshape(
71+
arr: T_ExtensionArray, shape: tuple
72+
) -> T_ExtensionArray:
73+
if (shape[0] == len(arr) and len(shape) == 1) or shape == (-1,):
74+
return arr
75+
raise NotImplementedError(
76+
f"Cannot reshape 1d-only pandas extension array to: {shape}"
77+
)
78+
79+
6980
@dataclass(frozen=True)
7081
class PandasExtensionArray(Generic[T_ExtensionArray], NDArrayMixin):
7182
"""NEP-18 compliant wrapper for pandas extension arrays.
@@ -101,10 +112,10 @@ def replace_duck_with_extension_array(args) -> list:
101112

102113
args = tuple(replace_duck_with_extension_array(args))
103114
if func not in HANDLED_EXTENSION_ARRAY_FUNCTIONS:
104-
return func(*args, **kwargs)
115+
raise KeyError("Function not registered for pandas extension arrays.")
105116
res = HANDLED_EXTENSION_ARRAY_FUNCTIONS[func](*args, **kwargs)
106117
if is_extension_array_dtype(res):
107-
return type(self)[type(res)](res)
118+
return PandasExtensionArray(res)
108119
return res
109120

110121
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):

xarray/core/formatting.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from xarray.core.datatree_render import RenderDataTree
2121
from xarray.core.duck_array_ops import array_all, array_any, array_equiv, astype, ravel
22+
from xarray.core.extension_array import PandasExtensionArray
2223
from xarray.core.indexing import MemoryCachedArray
2324
from xarray.core.options import OPTIONS, _get_boolean_with_default
2425
from xarray.core.treenode import group_subtrees
@@ -176,6 +177,11 @@ def format_timedelta(t, timedelta_format=None):
176177

177178
def format_item(x, timedelta_format=None, quote_strings=True):
178179
"""Returns a succinct summary of an object as a string"""
180+
if isinstance(x, PandasExtensionArray):
181+
# We want to bypass PandasExtensionArray's repr here
182+
# because its __repr__ is PandasExtensionArray(array=[...])
183+
# and this function is only for single elements.
184+
return str(x.array[0])
179185
if isinstance(x, np.datetime64 | datetime):
180186
return format_timestamp(x)
181187
if isinstance(x, np.timedelta64 | timedelta):
@@ -194,7 +200,9 @@ def format_items(x):
194200
"""Returns a succinct summaries of all items in a sequence as strings"""
195201
x = to_duck_array(x)
196202
timedelta_format = "datetime"
197-
if np.issubdtype(x.dtype, np.timedelta64):
203+
if not isinstance(x, PandasExtensionArray) and np.issubdtype(
204+
x.dtype, np.timedelta64
205+
):
198206
x = astype(x, dtype="timedelta64[ns]")
199207
day_part = x[~pd.isnull(x)].astype("timedelta64[D]").astype("timedelta64[ns]")
200208
time_needed = x[~pd.isnull(x)] != day_part

xarray/tests/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,14 @@ def create_test_data(
363363
)
364364
),
365365
)
366+
if has_pyarrow:
367+
obj["var5"] = (
368+
"dim1",
369+
pd.array(
370+
rs.integers(1, 10, size=dim_sizes[0]).tolist(),
371+
dtype="int64[pyarrow]",
372+
),
373+
)
366374
if dim_sizes == _DEFAULT_TEST_DIM_SIZES:
367375
numbers_values = np.array([0, 1, 2, 0, 0, 1, 1, 2, 2, 3], dtype="int64")
368376
else:

xarray/tests/test_concat.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
assert_equal,
2222
assert_identical,
2323
requires_dask,
24+
requires_pyarrow,
2425
)
2526
from xarray.tests.test_dataset import create_test_data
2627

@@ -154,19 +155,20 @@ def test_concat_missing_var() -> None:
154155
assert_identical(actual, expected)
155156

156157

157-
def test_concat_categorical() -> None:
158+
@pytest.mark.parametrize("var", ["var4", pytest.param("var5", marks=requires_pyarrow)])
159+
def test_concat_extension_array(var) -> None:
158160
data1 = create_test_data(use_extension_array=True)
159161
data2 = create_test_data(use_extension_array=True)
160162
concatenated = concat([data1, data2], dim="dim1")
161-
assert (
162-
concatenated["var4"]
163-
== type(data2["var4"].variable.data)._concat_same_type(
163+
assert pd.Series(
164+
concatenated[var]
165+
== type(data2[var].variable.data)._concat_same_type(
164166
[
165-
data1["var4"].variable.data,
166-
data2["var4"].variable.data,
167+
data1[var].variable.data,
168+
data2[var].variable.data,
167169
]
168170
)
169-
).all()
171+
).all() # need to wrap in series because pyarrow bool does not support `all`
170172

171173

172174
def test_concat_missing_multiple_consecutive_var() -> None:

xarray/tests/test_dataarray.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3637,7 +3637,7 @@ def test_series_categorical_index(self) -> None:
36373637

36383638
s = pd.Series(np.arange(5), index=pd.CategoricalIndex(list("aabbc")))
36393639
arr = DataArray(s)
3640-
assert "'a'" in repr(arr) # should not error
3640+
assert "a a b b" in repr(arr) # should not error
36413641

36423642
@pytest.mark.parametrize("use_dask", [True, False])
36433643
@pytest.mark.parametrize("data", ["list", "array", True])

xarray/tests/test_dataset.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
create_test_data,
6161
has_cftime,
6262
has_dask,
63+
has_pyarrow,
6364
raise_if_dask_computes,
6465
requires_bottleneck,
6566
requires_cftime,
@@ -283,26 +284,28 @@ def test_repr(self) -> None:
283284
data = create_test_data(seed=123, use_extension_array=True)
284285
data.attrs["foo"] = "bar"
285286
# need to insert str dtype at runtime to handle different endianness
287+
var5 = (
288+
"\n var5 (dim1) int64[pyarrow] 64B 5 9 7 2 6 2 8 1"
289+
if has_pyarrow
290+
else ""
291+
)
286292
expected = dedent(
287-
"""\
293+
f"""\
288294
<xarray.Dataset> Size: 2kB
289295
Dimensions: (dim2: 9, dim3: 10, time: 20, dim1: 8)
290296
Coordinates:
291297
* dim2 (dim2) float64 72B 0.0 0.5 1.0 1.5 2.0 2.5 3.0 3.5 4.0
292-
* dim3 (dim3) {} 40B 'a' 'b' 'c' 'd' 'e' 'f' 'g' 'h' 'i' 'j'
293-
* time (time) datetime64[{}] 160B 2000-01-01 2000-01-02 ... 2000-01-20
298+
* dim3 (dim3) {data["dim3"].dtype} 40B 'a' 'b' 'c' 'd' 'e' 'f' 'g' 'h' 'i' 'j'
299+
* time (time) datetime64[ns] 160B 2000-01-01 2000-01-02 ... 2000-01-20
294300
numbers (dim3) int64 80B 0 1 2 0 0 1 1 2 2 3
295301
Dimensions without coordinates: dim1
296302
Data variables:
297303
var1 (dim1, dim2) float64 576B -0.9891 -0.3678 1.288 ... -0.2116 0.364
298304
var2 (dim1, dim2) float64 576B 0.953 1.52 1.704 ... 0.1347 -0.6423
299305
var3 (dim3, dim1) float64 640B 0.4107 0.9941 0.1665 ... 0.716 1.555
300-
var4 (dim1) category 32B 'b' 'c' 'b' 'a' 'c' 'a' 'c' 'a'
306+
var4 (dim1) category 32B b c b a c a c a{var5}
301307
Attributes:
302-
foo: bar""".format(
303-
data["dim3"].dtype,
304-
"ns",
305-
)
308+
foo: bar"""
306309
)
307310
actual = "\n".join(x.rstrip() for x in repr(data).split("\n"))
308311

@@ -5884,20 +5887,21 @@ def test_reduce_cumsum_test_dims(self, reduct, expected, func) -> None:
58845887
def test_reduce_non_numeric(self) -> None:
58855888
data1 = create_test_data(seed=44, use_extension_array=True)
58865889
data2 = create_test_data(seed=44)
5887-
add_vars = {"var5": ["dim1", "dim2"], "var6": ["dim1"]}
5890+
add_vars = {"var6": ["dim1", "dim2"], "var7": ["dim1"]}
58885891
for v, dims in sorted(add_vars.items()):
58895892
size = tuple(data1.sizes[d] for d in dims)
58905893
data = np.random.randint(0, 100, size=size).astype(np.str_)
58915894
data1[v] = (dims, data, {"foo": "variable"})
5892-
# var4 is extension array categorical and should be dropped
5895+
# var4 and var5 are extension arrays and should be dropped
58935896
assert (
58945897
"var4" not in data1.mean()
58955898
and "var5" not in data1.mean()
58965899
and "var6" not in data1.mean()
5900+
and "var7" not in data1.mean()
58975901
)
58985902
assert_equal(data1.mean(), data2.mean())
58995903
assert_equal(data1.mean(dim="dim1"), data2.mean(dim="dim1"))
5900-
assert "var5" not in data1.mean(dim="dim2") and "var6" in data1.mean(dim="dim2")
5904+
assert "var6" not in data1.mean(dim="dim2") and "var7" in data1.mean(dim="dim2")
59015905

59025906
@pytest.mark.filterwarnings(
59035907
"ignore:Once the behaviour of DataArray:DeprecationWarning"

0 commit comments

Comments
 (0)