Skip to content

Commit

Permalink
ENH(string dtype): Implement cumsum for Python-backed strings (#60938)
Browse files Browse the repository at this point in the history
* ENH(string dtype): Implement cumsum for Python-backed strings

* cleanups

* cleanups

* type-hint fixup

* More type fixes

* Use quotes for cast

* Refinements

* type-ignore
  • Loading branch information
rhshadrach authored Feb 19, 2025
1 parent d4dff29 commit 4e20195
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 20 deletions.
2 changes: 1 addition & 1 deletion doc/source/whatsnew/v2.3.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ Other enhancements
- :meth:`Series.str.decode` result now has ``StringDtype`` when ``future.infer_string`` is True (:issue:`60709`)
- :meth:`~Series.to_hdf` and :meth:`~DataFrame.to_hdf` now round-trip with ``StringDtype`` (:issue:`60663`)
- The :meth:`Series.str.decode` has gained the argument ``dtype`` to control the dtype of the result (:issue:`60940`)
- The :meth:`~Series.cumsum`, :meth:`~Series.cummin`, and :meth:`~Series.cummax` reductions are now implemented for ``StringDtype`` columns when backed by PyArrow (:issue:`60633`)
- The :meth:`~Series.cumsum`, :meth:`~Series.cummin`, and :meth:`~Series.cummax` reductions are now implemented for ``StringDtype`` columns (:issue:`60633`)
- The :meth:`~Series.sum` reduction is now implemented for ``StringDtype`` columns (:issue:`59853`)

.. ---------------------------------------------------------------------------
Expand Down
83 changes: 83 additions & 0 deletions pandas/core/arrays/string_.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
)

from pandas.core import (
missing,
nanops,
ops,
)
Expand Down Expand Up @@ -870,6 +871,88 @@ def _reduce(

raise TypeError(f"Cannot perform reduction '{name}' with string dtype")

def _accumulate(self, name: str, *, skipna: bool = True, **kwargs) -> StringArray:
"""
Return an ExtensionArray performing an accumulation operation.
The underlying data type might change.
Parameters
----------
name : str
Name of the function, supported values are:
- cummin
- cummax
- cumsum
- cumprod
skipna : bool, default True
If True, skip NA values.
**kwargs
Additional keyword arguments passed to the accumulation function.
Currently, there is no supported kwarg.
Returns
-------
array
Raises
------
NotImplementedError : subclass does not define accumulations
"""
if name == "cumprod":
msg = f"operation '{name}' not supported for dtype '{self.dtype}'"
raise TypeError(msg)

# We may need to strip out trailing NA values
tail: np.ndarray | None = None
na_mask: np.ndarray | None = None
ndarray = self._ndarray
np_func = {
"cumsum": np.cumsum,
"cummin": np.minimum.accumulate,
"cummax": np.maximum.accumulate,
}[name]

if self._hasna:
na_mask = cast("npt.NDArray[np.bool_]", isna(ndarray))
if np.all(na_mask):
return type(self)(ndarray)
if skipna:
if name == "cumsum":
ndarray = np.where(na_mask, "", ndarray)
else:
# We can retain the running min/max by forward/backward filling.
ndarray = ndarray.copy()
missing.pad_or_backfill_inplace(
ndarray,
method="pad",
axis=0,
)
missing.pad_or_backfill_inplace(
ndarray,
method="backfill",
axis=0,
)
else:
# When not skipping NA values, the result should be null from
# the first NA value onward.
idx = np.argmax(na_mask)
tail = np.empty(len(ndarray) - idx, dtype="object")
tail[:] = self.dtype.na_value
ndarray = ndarray[:idx]

# mypy: Cannot call function of unknown type
np_result = np_func(ndarray) # type: ignore[operator]

if tail is not None:
np_result = np.hstack((np_result, tail))
elif na_mask is not None:
# Argument 2 to "where" has incompatible type "NAType | float"
np_result = np.where(na_mask, self.dtype.na_value, np_result) # type: ignore[arg-type]

result = type(self)(np_result)
return result

def _wrap_reduction_result(self, axis: AxisInt | None, result) -> Any:
if self.dtype.na_value is np.nan and result is libmissing.NA:
# the masked_reductions use pd.NA -> convert to np.nan
Expand Down
10 changes: 1 addition & 9 deletions pandas/tests/apply/test_str.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import pytest

from pandas.compat import (
HAS_PYARROW,
WASM,
)

Expand Down Expand Up @@ -162,17 +161,10 @@ def test_agg_cython_table_series(series, func, expected):
),
),
)
def test_agg_cython_table_transform_series(request, series, func, expected):
def test_agg_cython_table_transform_series(series, func, expected):
# GH21224
# test transforming functions in
# pandas.core.base.SelectionMixin._cython_table (cumprod, cumsum)
if series.dtype == "string" and func == "cumsum" and not HAS_PYARROW:
request.applymarker(
pytest.mark.xfail(
raises=NotImplementedError,
reason="TODO(infer_string) cumsum not yet implemented for string",
)
)
warn = None if isinstance(func, str) else FutureWarning
with tm.assert_produces_warning(warn, match="is currently using Series.*"):
result = series.agg(func)
Expand Down
6 changes: 1 addition & 5 deletions pandas/tests/extension/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,11 +196,7 @@ def _supports_reduction(self, ser: pd.Series, op_name: str) -> bool:

def _supports_accumulation(self, ser: pd.Series, op_name: str) -> bool:
assert isinstance(ser.dtype, StorageExtensionDtype)
return ser.dtype.storage == "pyarrow" and op_name in [
"cummin",
"cummax",
"cumsum",
]
return op_name in ["cummin", "cummax", "cumsum"]

def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result):
dtype = cast(StringDtype, tm.get_dtype(obj))
Expand Down
11 changes: 6 additions & 5 deletions pandas/tests/series/test_cumulative.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,13 +265,14 @@ def test_cumprod_timedelta(self):
([pd.NA, pd.NA, pd.NA], "cummax", False, [pd.NA, pd.NA, pd.NA]),
],
)
def test_cum_methods_pyarrow_strings(
self, pyarrow_string_dtype, data, op, skipna, expected_data
def test_cum_methods_ea_strings(
self, string_dtype_no_object, data, op, skipna, expected_data
):
# https://github.com/pandas-dev/pandas/pull/60633
ser = pd.Series(data, dtype=pyarrow_string_dtype)
# https://github.com/pandas-dev/pandas/pull/60633 - pyarrow
# https://github.com/pandas-dev/pandas/pull/60938 - Python
ser = pd.Series(data, dtype=string_dtype_no_object)
method = getattr(ser, op)
expected = pd.Series(expected_data, dtype=pyarrow_string_dtype)
expected = pd.Series(expected_data, dtype=string_dtype_no_object)
result = method(skipna=skipna)
tm.assert_series_equal(result, expected)

Expand Down

0 comments on commit 4e20195

Please sign in to comment.