Skip to content

Commit 23b546f

Browse files
Implement any and all for pyarrow numpy strings (#54591)
Co-authored-by: Joris Van den Bossche <[email protected]>
1 parent 94dcf24 commit 23b546f

File tree

3 files changed

+37
-1
lines changed

3 files changed

+37
-1
lines changed

pandas/core/arrays/string_arrow.py

+13
Original file line numberDiff line numberDiff line change
@@ -554,3 +554,16 @@ def value_counts(self, dropna: bool = True):
554554
return Series(
555555
result._values.to_numpy(), index=result.index, name=result.name, copy=False
556556
)
557+
558+
def _reduce(
559+
self, name: str, *, skipna: bool = True, keepdims: bool = False, **kwargs
560+
):
561+
if name in ["any", "all"]:
562+
arr = pc.and_kleene(
563+
pc.invert(pc.is_null(self._pa_array)), pc.not_equal(self._pa_array, "")
564+
)
565+
return ArrowExtensionArray(arr)._reduce(
566+
name, skipna=skipna, keepdims=keepdims, **kwargs
567+
)
568+
else:
569+
return super()._reduce(name, skipna=skipna, keepdims=keepdims, **kwargs)

pandas/tests/extension/test_string.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,11 @@ def test_fillna_no_op_returns_copy(self, data):
158158

159159
class TestReduce(base.BaseReduceTests):
160160
def _supports_reduction(self, ser: pd.Series, op_name: str) -> bool:
161-
return op_name in ["min", "max"]
161+
return (
162+
op_name in ["min", "max"]
163+
or ser.dtype.storage == "pyarrow_numpy" # type: ignore[union-attr]
164+
and op_name in ("any", "all")
165+
)
162166

163167

164168
class TestMethods(base.BaseMethodsTests):

pandas/tests/reductions/test_reductions.py

+19
Original file line numberDiff line numberDiff line change
@@ -1078,6 +1078,25 @@ def test_any_all_datetimelike(self):
10781078
assert df.any().all()
10791079
assert not df.all().any()
10801080

1081+
def test_any_all_pyarrow_string(self):
1082+
# GH#54591
1083+
pytest.importorskip("pyarrow")
1084+
ser = Series(["", "a"], dtype="string[pyarrow_numpy]")
1085+
assert ser.any()
1086+
assert not ser.all()
1087+
1088+
ser = Series([None, "a"], dtype="string[pyarrow_numpy]")
1089+
assert ser.any()
1090+
assert not ser.all()
1091+
1092+
ser = Series([None, ""], dtype="string[pyarrow_numpy]")
1093+
assert not ser.any()
1094+
assert not ser.all()
1095+
1096+
ser = Series(["a", "b"], dtype="string[pyarrow_numpy]")
1097+
assert ser.any()
1098+
assert ser.all()
1099+
10811100
def test_timedelta64_analytics(self):
10821101
# index min/max
10831102
dti = date_range("2012-1-1", periods=3, freq="D")

0 commit comments

Comments
 (0)