Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

String dtype: implement sum reduction #59853

Merged
2 changes: 1 addition & 1 deletion doc/source/whatsnew/v2.3.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ enhancement1
Other enhancements
^^^^^^^^^^^^^^^^^^

-
- The :meth:`~Series.sum` reduction is now implemented for ``StringDtype`` columns (:issue:`59853`)
-

.. ---------------------------------------------------------------------------
Expand Down
4 changes: 4 additions & 0 deletions pandas/core/array_algos/masked_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ def _reductions(
):
return libmissing.NA

if values.dtype == np.dtype(object):
# object dtype does not support `where` without passing an initial
values = values[~mask]
return func(values, axis=axis, **kwargs)
return func(values, where=~mask, axis=axis, **kwargs)


Expand Down
32 changes: 32 additions & 0 deletions pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
unpack_tuple_and_ellipses,
validate_indices,
)
from pandas.core.nanops import check_below_min_count
from pandas.core.strings.base import BaseStringArrayMethods

from pandas.io._util import _arrow_dtype_mapping
Expand Down Expand Up @@ -1705,6 +1706,37 @@ def pyarrow_meth(data, skip_nulls, **kwargs):
denominator = pc.sqrt_checked(pc.count(self._pa_array))
return pc.divide_checked(numerator, denominator)

elif name == "sum" and (
pa.types.is_string(pa_type) or pa.types.is_large_string(pa_type)
):

def pyarrow_meth(data, skip_nulls, min_count=0): # type: ignore[misc]
mask = pc.is_null(data) if data.null_count > 0 else None
if skip_nulls:
if min_count > 0 and check_below_min_count(
(len(data),),
None if mask is None else mask.to_numpy(),
min_count,
):
return pa.scalar(None, type=data.type)
if data.null_count > 0:
# binary_join returns null if there is any null ->
# have to filter out any nulls
data = data.filter(pc.invert(mask))
else:
if mask is not None or check_below_min_count(
(len(data),), None, min_count
):
return pa.scalar(None, type=data.type)

if pa.types.is_large_string(data.type):
# binary_join only supports string, not large_string
data = data.cast(pa.string())
Comment on lines +1732 to +1734
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not too familiar here, can this cause unexpected results? If so, should it be documented?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it can cause overflow error if a single chunk doesn't fit into the string dtype. I suppose this will be very rare, because we are summing here, and that would mean that the single scalar string as result of the sum would be bigger than 2GB (I am not fully sure how well Python will handle such a large str object).

We can indeed document it.
In theory it could also be circumvented by splitting the chunk into multiple chunks (although I have to verify that pyarrow then does not actually concatenate that again under the hood in the binary_join implementation).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a note about this in the issue listing the behavioral changes -> #59328

data_list = pa.ListArray.from_arrays(
[0, len(data)], data.combine_chunks()
)[0]
return pc.binary_join(data_list, "")

else:
pyarrow_name = {
"median": "quantile",
Expand Down
18 changes: 16 additions & 2 deletions pandas/core/arrays/string_.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,8 +812,8 @@ def _reduce(
else:
return nanops.nanall(self._ndarray, skipna=skipna)

if name in ["min", "max"]:
result = getattr(self, name)(skipna=skipna, axis=axis)
if name in ["min", "max", "sum"]:
result = getattr(self, name)(skipna=skipna, axis=axis, **kwargs)
if keepdims:
return self._from_sequence([result], dtype=self.dtype)
return result
Expand All @@ -840,6 +840,20 @@ def max(self, axis=None, skipna: bool = True, **kwargs) -> Scalar:
)
return self._wrap_reduction_result(axis, result)

def sum(
self,
*,
axis: AxisInt | None = None,
skipna: bool = True,
min_count: int = 0,
**kwargs,
) -> Scalar:
nv.validate_sum((), kwargs)
result = masked_reductions.sum(
values=self._ndarray, mask=self.isna(), skipna=skipna
)
return self._wrap_reduction_result(axis, result)

def value_counts(self, dropna: bool = True) -> Series:
from pandas.core.algorithms import value_counts_internal as value_counts

Expand Down
6 changes: 5 additions & 1 deletion pandas/core/arrays/string_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,11 @@ def _reduce(
return result.astype(np.bool_)
return result

result = self._reduce_calc(name, skipna=skipna, keepdims=keepdims, **kwargs)
if name in ("min", "max", "sum", "argmin", "argmax"):
result = self._reduce_calc(name, skipna=skipna, keepdims=keepdims, **kwargs)
else:
raise TypeError(f"Cannot perform reduction '{name}' with string dtype")

if name in ("argmin", "argmax") and isinstance(result, pa.Array):
return self._convert_int_result(result)
elif isinstance(result, pa.Array):
Expand Down
10 changes: 0 additions & 10 deletions pandas/tests/apply/test_frame_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,6 @@
import numpy as np
import pytest

from pandas._config import using_string_dtype

from pandas.compat import HAS_PYARROW

from pandas.core.dtypes.dtypes import CategoricalDtype

import pandas as pd
Expand Down Expand Up @@ -1218,7 +1214,6 @@ def test_agg_with_name_as_column_name():
tm.assert_series_equal(result, expected)


@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
def test_agg_multiple_mixed():
# GH 20909
mdf = DataFrame(
Expand Down Expand Up @@ -1247,9 +1242,6 @@ def test_agg_multiple_mixed():
tm.assert_frame_equal(result, expected)


@pytest.mark.xfail(
using_string_dtype() and not HAS_PYARROW, reason="TODO(infer_string)"
)
def test_agg_multiple_mixed_raises():
# GH 20909
mdf = DataFrame(
Expand Down Expand Up @@ -1347,7 +1339,6 @@ def test_named_agg_reduce_axis1_raises(float_frame):
float_frame.agg(row1=(name1, "sum"), row2=(name2, "max"), axis=axis)


@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
def test_nuiscance_columns():
# GH 15015
df = DataFrame(
Expand Down Expand Up @@ -1524,7 +1515,6 @@ def test_apply_datetime_tz_issue(engine, request):
tm.assert_series_equal(result, expected)


@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False)
@pytest.mark.parametrize("df", [DataFrame({"A": ["a", None], "B": ["c", "d"]})])
@pytest.mark.parametrize("method", ["min", "max", "sum"])
def test_mixed_column_raises(df, method, using_infer_string):
Expand Down
39 changes: 20 additions & 19 deletions pandas/tests/apply/test_invalid_arg.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@
import numpy as np
import pytest

from pandas._config import using_string_dtype

from pandas.compat import HAS_PYARROW
from pandas.errors import SpecificationError

from pandas import (
Expand Down Expand Up @@ -212,10 +209,6 @@ def transform(row):
data.apply(transform, axis=1)


# we should raise a proper TypeError instead of propagating the pyarrow error
@pytest.mark.xfail(
using_string_dtype() and not HAS_PYARROW, reason="TODO(infer_string)"
)
@pytest.mark.parametrize(
"df, func, expected",
tm.get_cython_table_params(
Expand All @@ -225,21 +218,25 @@ def transform(row):
def test_agg_cython_table_raises_frame(df, func, expected, axis, using_infer_string):
# GH 21224
if using_infer_string:
import pyarrow as pa
if df.dtypes.iloc[0].storage == "pyarrow":
import pyarrow as pa

expected = (expected, pa.lib.ArrowNotImplementedError)
# TODO(infer_string)
# should raise a proper TypeError instead of propagating the pyarrow error

msg = "can't multiply sequence by non-int of type 'str'|has no kernel"
expected = (expected, pa.lib.ArrowNotImplementedError)
else:
expected = (expected, NotImplementedError)

msg = (
"can't multiply sequence by non-int of type 'str'|has no kernel|cannot perform"
)
warn = None if isinstance(func, str) else FutureWarning
with pytest.raises(expected, match=msg):
with tm.assert_produces_warning(warn, match="using DataFrame.cumprod"):
df.agg(func, axis=axis)


# we should raise a proper TypeError instead of propagating the pyarrow error
@pytest.mark.xfail(
using_string_dtype() and not HAS_PYARROW, reason="TODO(infer_string)"
)
@pytest.mark.parametrize(
"series, func, expected",
chain(
Expand All @@ -263,11 +260,15 @@ def test_agg_cython_table_raises_series(series, func, expected, using_infer_stri
msg = r"Cannot convert \['a' 'b' 'c'\] to numeric"

if using_infer_string:
import pyarrow as pa

expected = (expected, pa.lib.ArrowNotImplementedError)

msg = msg + "|does not support|has no kernel"
if series.dtype.storage == "pyarrow":
import pyarrow as pa

# TODO(infer_string)
# should raise a proper TypeError instead of propagating the pyarrow error
expected = (expected, pa.lib.ArrowNotImplementedError)
else:
expected = (expected, NotImplementedError)
msg = msg + "|does not support|has no kernel|Cannot perform|cannot perform"
warn = None if isinstance(func, str) else FutureWarning

with pytest.raises(expected, match=msg):
Expand Down
2 changes: 0 additions & 2 deletions pandas/tests/arrays/string_/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,14 +444,12 @@ def test_astype_float(dtype, any_float_dtype):
tm.assert_series_equal(result, expected)


@pytest.mark.xfail(reason="Not implemented StringArray.sum")
def test_reduce(skipna, dtype):
arr = pd.Series(["a", "b", "c"], dtype=dtype)
result = arr.sum(skipna=skipna)
assert result == "abc"


@pytest.mark.xfail(reason="Not implemented StringArray.sum")
def test_reduce_missing(skipna, dtype):
arr = pd.Series([None, "a", None, "b", "c", None], dtype=dtype)
result = arr.sum(skipna=skipna)
Expand Down
25 changes: 4 additions & 21 deletions pandas/tests/extension/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,10 +461,11 @@ def _supports_reduction(self, ser: pd.Series, op_name: str) -> bool:
pass
else:
return False
elif pa.types.is_binary(pa_dtype) and op_name == "sum":
return False
elif (
pa.types.is_string(pa_dtype) or pa.types.is_binary(pa_dtype)
) and op_name in [
"sum",
"mean",
"median",
"prod",
Expand Down Expand Up @@ -563,6 +564,8 @@ def _get_expected_reduction_dtype(self, arr, op_name: str, skipna: bool):
cmp_dtype = "float64[pyarrow]"
elif op_name in ["sum", "prod"] and pa.types.is_boolean(pa_type):
cmp_dtype = "uint64[pyarrow]"
elif op_name == "sum" and pa.types.is_string(pa_type):
cmp_dtype = arr.dtype
else:
cmp_dtype = {
"i": "int64[pyarrow]",
Expand Down Expand Up @@ -594,26 +597,6 @@ def test_median_not_approximate(self, typ):
result = pd.Series([1, 2], dtype=f"{typ}[pyarrow]").median()
assert result == 1.5

def test_in_numeric_groupby(self, data_for_grouping):
dtype = data_for_grouping.dtype
if is_string_dtype(dtype):
df = pd.DataFrame(
{
"A": [1, 1, 2, 2, 3, 3, 1, 4],
"B": data_for_grouping,
"C": [1, 1, 1, 1, 1, 1, 1, 1],
}
)

expected = pd.Index(["C"])
msg = re.escape(f"agg function failed [how->sum,dtype->{dtype}")
with pytest.raises(TypeError, match=msg):
df.groupby("A").sum()
result = df.groupby("A").sum(numeric_only=True).columns
tm.assert_index_equal(result, expected)
else:
super().test_in_numeric_groupby(data_for_grouping)

def test_construct_from_string_own_name(self, dtype, request):
pa_dtype = dtype.pyarrow_dtype
if pa.types.is_decimal(pa_dtype):
Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/extension/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def _get_expected_exception(

def _supports_reduction(self, ser: pd.Series, op_name: str) -> bool:
return (
op_name in ["min", "max"]
op_name in ["min", "max", "sum"]
or ser.dtype.na_value is np.nan # type: ignore[union-attr]
and op_name in ("any", "all")
)
Expand Down
Loading