Skip to content

Commit 72cb101

Browse files
authored
Backport PR #54591 on branch 2.1.x (Implement any and all for pyarrow numpy strings) (#54796)
1 parent 901b5e6 commit 72cb101

File tree

4 files changed

+39
-1
lines changed

4 files changed

+39
-1
lines changed

pandas/core/arrays/string_arrow.py

+13
Original file line numberDiff line numberDiff line change
@@ -554,3 +554,16 @@ def value_counts(self, dropna: bool = True):
554554
return Series(
555555
result._values.to_numpy(), index=result.index, name=result.name, copy=False
556556
)
557+
558+
def _reduce(
559+
self, name: str, *, skipna: bool = True, keepdims: bool = False, **kwargs
560+
):
561+
if name in ["any", "all"]:
562+
arr = pc.and_kleene(
563+
pc.invert(pc.is_null(self._pa_array)), pc.not_equal(self._pa_array, "")
564+
)
565+
return ArrowExtensionArray(arr)._reduce(
566+
name, skipna=skipna, keepdims=keepdims, **kwargs
567+
)
568+
else:
569+
return super()._reduce(name, skipna=skipna, keepdims=keepdims, **kwargs)

pandas/tests/extension/base/reduce.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def check_reduce(self, s, op_name, skipna):
2525

2626
try:
2727
alt = s.astype("float64")
28-
except TypeError:
28+
except (TypeError, ValueError):
2929
# e.g. Interval can't cast, so let's cast to object and do
3030
# the reduction pointwise
3131
alt = s.astype(object)

pandas/tests/extension/test_string.py

+6
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,12 @@ def test_fillna_no_op_returns_copy(self, data):
157157

158158

159159
class TestReduce(base.BaseReduceTests):
160+
def _supports_reduction(self, ser: pd.Series, op_name: str) -> bool:
161+
return (
162+
ser.dtype.storage == "pyarrow_numpy" # type: ignore[union-attr]
163+
and op_name in ("any", "all")
164+
)
165+
160166
@pytest.mark.parametrize("skipna", [True, False])
161167
def test_reduce_series_numeric(self, data, all_numeric_reductions, skipna):
162168
op_name = all_numeric_reductions

pandas/tests/reductions/test_reductions.py

+19
Original file line numberDiff line numberDiff line change
@@ -1078,6 +1078,25 @@ def test_any_all_datetimelike(self):
10781078
assert df.any().all()
10791079
assert not df.all().any()
10801080

1081+
def test_any_all_pyarrow_string(self):
1082+
# GH#54591
1083+
pytest.importorskip("pyarrow")
1084+
ser = Series(["", "a"], dtype="string[pyarrow_numpy]")
1085+
assert ser.any()
1086+
assert not ser.all()
1087+
1088+
ser = Series([None, "a"], dtype="string[pyarrow_numpy]")
1089+
assert ser.any()
1090+
assert not ser.all()
1091+
1092+
ser = Series([None, ""], dtype="string[pyarrow_numpy]")
1093+
assert not ser.any()
1094+
assert not ser.all()
1095+
1096+
ser = Series(["a", "b"], dtype="string[pyarrow_numpy]")
1097+
assert ser.any()
1098+
assert ser.all()
1099+
10811100
def test_timedelta64_analytics(self):
10821101
# index min/max
10831102
dti = date_range("2012-1-1", periods=3, freq="D")

0 commit comments

Comments
 (0)