Skip to content

Commit

Permalink
feat: Series.hist (#1859)
Browse files Browse the repository at this point in the history
  • Loading branch information
camriddell authored Feb 10, 2025
1 parent df2b077 commit 4b1d778
Show file tree
Hide file tree
Showing 10 changed files with 858 additions and 2 deletions.
1 change: 1 addition & 0 deletions docs/api-reference/series.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
- filter
- gather_every
- head
- hist
- implementation
- is_between
- is_duplicated
Expand Down
124 changes: 124 additions & 0 deletions narwhals/_arrow/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from narwhals._arrow.utils import narwhals_to_native_dtype
from narwhals._arrow.utils import native_to_narwhals_dtype
from narwhals._arrow.utils import pad_series
from narwhals.exceptions import InvalidOperationError
from narwhals.typing import CompliantSeries
from narwhals.utils import Implementation
from narwhals.utils import generate_temporary_column_name
Expand Down Expand Up @@ -1005,6 +1006,129 @@ def rank(
result = pc.if_else(null_mask, pa.scalar(None), rank)
return self._from_native_series(result)

def hist( # noqa: PLR0915
self: Self,
bins: list[float | int] | None,
*,
bin_count: int | None,
include_breakpoint: bool,
) -> ArrowDataFrame:
if self._backend_version < (13,):
msg = f"`Series.hist` requires PyArrow>=13.0.0, found PyArrow version: {self._backend_version}"
raise NotImplementedError(msg)
import numpy as np # ignore-banned-import

from narwhals._arrow.dataframe import ArrowDataFrame

def _hist_from_bin_count(
bin_count: int,
) -> tuple[Sequence[int], Sequence[int | float], Sequence[int | float]]:
d = pc.min_max(self._native_series)
lower, upper = d["min"], d["max"]
pad_lowest_bin = False
if lower == upper:
range_ = pa.scalar(1.0)
width = pc.divide(range_, bin_count)
lower = pc.subtract(lower, 0.5)
upper = pc.add(upper, 0.5)
else:
pad_lowest_bin = True
range_ = pc.subtract(upper, lower)
width = pc.divide(range_.cast("float"), float(bin_count))

bin_proportions = pc.divide(pc.subtract(self._native_series, lower), width)
bin_indices = pc.floor(bin_proportions)

bin_indices = pc.if_else( # shift bins so they are right-closed
pc.and_(
pc.equal(bin_indices, bin_proportions),
pc.greater(bin_indices, 0),
),
pc.subtract(bin_indices, 1),
bin_indices,
)
counts = ( # count bin id occurrences
pa.Table.from_arrays(
pc.value_counts(bin_indices).flatten(),
names=["values", "counts"],
)
# nan values are implicitly dropped in value_counts
.filter(~pc.field("values").is_nan())
.cast(pa.schema([("values", pa.int64()), ("counts", pa.int64())]))
.join( # align bin ids to all possible bin ids (populate in missing bins)
pa.Table.from_arrays(
[np.arange(bin_count, dtype="int64")], ["values"]
),
keys="values",
join_type="right outer",
)
.sort_by("values")
)
counts = counts.set_column( # empty bin intervals should have a 0 count
0, "counts", pc.coalesce(counts.column("counts"), 0)
)

# extract left/right side of the intervals
bin_left = pc.add(lower, pc.multiply(counts.column("values"), width))
bin_right = pc.add(bin_left, width)
if pad_lowest_bin:
bin_left = pa.chunked_array(
[ # pad lowest bin by 1% of range
[
pc.subtract(
bin_left[0], pc.multiply(range_.cast("float"), 0.001)
)
],
bin_left[1:], # pyarrow==11.0 needs to infer
]
)
return counts.column("counts"), bin_left, bin_right

def _hist_from_bins(
bins: Sequence[int | float],
) -> tuple[Sequence[int], Sequence[int | float], Sequence[int | float]]:
bin_indices = np.searchsorted(bins, self._native_series, side="left")
obs_cats, obs_counts = np.unique(bin_indices, return_counts=True)
obj_cats = np.arange(1, len(bins))
counts = np.zeros_like(obj_cats)
counts[np.isin(obj_cats, obs_cats)] = obs_counts[np.isin(obs_cats, obj_cats)]

bin_right = bins[1:]
bin_left = bins[:-1]
return counts, bin_left, bin_right

counts: Sequence[int]
bin_left: Sequence[int | float]
bin_right: Sequence[int | float]
if bins is not None:
if len(bins) < 2:
counts, bin_left, bin_right = [], [], []
else:
counts, bin_left, bin_right = _hist_from_bins(bins)

elif bin_count is not None:
if bin_count == 0:
counts, bin_left, bin_right = [], [], []
else:
counts, bin_left, bin_right = _hist_from_bin_count(bin_count)

else: # pragma: no cover
# caller guarantees that either bins or bin_count is specified
msg = "must provide one of `bin_count` or `bins`"
raise InvalidOperationError(msg)

data: dict[str, Sequence[int | float | str]] = {}
if include_breakpoint:
data["breakpoint"] = bin_right
data["count"] = counts

return ArrowDataFrame(
pa.Table.from_pydict(data),
backend_version=self._backend_version,
version=self._version,
validate_column_names=True,
)

def __iter__(self: Self) -> Iterator[Any]:
yield from (
maybe_extract_py_scalar(x, return_py_scalar=True)
Expand Down
89 changes: 89 additions & 0 deletions narwhals/_pandas_like/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1026,6 +1026,95 @@ def rank(

return self._from_native_series(ranked_series)

def hist(
self: Self,
bins: list[float | int] | None,
*,
bin_count: int | None,
include_breakpoint: bool,
) -> PandasLikeDataFrame:
from numpy import linspace
from numpy import zeros

from narwhals._pandas_like.dataframe import PandasLikeDataFrame

ns = self.__native_namespace__()
data: dict[str, Sequence[int | float | str]]

if bin_count == 0 or (bins is not None and len(bins) <= 1):
data = {}
if include_breakpoint:
data["breakpoint"] = []
data["count"] = []

return PandasLikeDataFrame(
ns.DataFrame(data),
implementation=self._implementation,
backend_version=self._backend_version,
version=self._version,
validate_column_names=True,
)
elif self._native_series.count() < 1:
if bins is not None:
data = {
"breakpoint": bins[1:],
"count": zeros(shape=len(bins) - 1),
}
else:
data = {
"breakpoint": linspace(0, 1, bin_count),
"count": zeros(shape=bin_count),
}

if not include_breakpoint:
del data["breakpoint"]

return PandasLikeDataFrame(
ns.DataFrame(data),
implementation=self._implementation,
backend_version=self._backend_version,
version=self._version,
validate_column_names=True,
)

elif bin_count is not None: # use Polars binning behavior
lower, upper = self._native_series.min(), self._native_series.max()
pad_lowest_bin = False
if lower == upper:
lower -= 0.5
upper += 0.5
else:
pad_lowest_bin = True

bins = linspace(lower, upper, bin_count + 1)
if pad_lowest_bin and bins is not None:
bins[0] -= 0.001 * abs(bins[0]) if bins[0] != 0 else 0.001
bin_count = None

# pandas (2.2.*) .value_counts(bins=int) adjusts the lowest bin twice, result in improper counts.
# pandas (2.2.*) .value_counts(bins=[...]) adjusts the lowest bin which should not happen since
# the bins were explicitly passed in.
categories = ns.cut(
self._native_series, bins=bins if bin_count is None else bin_count
)
# modin (0.32.0) .value_counts(...) silently drops bins with empty observations, .reindex
# is necessary to restore these bins.
result = categories.value_counts(dropna=True, sort=False).reindex(
categories.cat.categories, fill_value=0
)
data = {}
if include_breakpoint:
data["breakpoint"] = bins[1:] if bins is not None else result.index.right
data["count"] = result.reset_index(drop=True)

return PandasLikeDataFrame(
ns.DataFrame(data),
implementation=self._implementation,
backend_version=self._backend_version,
version=self._version,
validate_column_names=True,
)

@property
def str(self: Self) -> PandasLikeSeriesStringNamespace:
return PandasLikeSeriesStringNamespace(self)
Expand Down
95 changes: 95 additions & 0 deletions narwhals/_polars/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from typing import TYPE_CHECKING
from typing import Any
from typing import Sequence
from typing import Union
from typing import cast
from typing import overload

import polars as pl
Expand Down Expand Up @@ -459,6 +461,99 @@ def __contains__(self: Self, other: Any) -> bool:
except Exception as e: # noqa: BLE001
raise catch_polars_exception(e, self._backend_version) from None

def hist(
self: Self,
bins: list[float | int] | None,
*,
bin_count: int | None,
include_breakpoint: bool,
) -> PolarsDataFrame:
from narwhals._polars.dataframe import PolarsDataFrame

if (bins is not None and len(bins) <= 1) or (bin_count == 0): # pragma: no cover
data: list[pl.Series] = []
if include_breakpoint:
data.append(pl.Series("breakpoint", [], dtype=pl.Float64))
data.append(pl.Series("count", [], dtype=pl.UInt32))
return PolarsDataFrame(
pl.DataFrame(data),
backend_version=self._backend_version,
version=self._version,
)
elif (self._backend_version < (1, 15)) and self._native_series.count() < 1:
data_dict: dict[str, list[int | float] | pl.Series | pl.Expr]
if bins is not None:
data_dict = {
"breakpoint": bins[1:],
"count": pl.zeros(n=len(bins) - 1, dtype=pl.Int64, eager=True),
}
elif bin_count is not None:
data_dict = {
"breakpoint": pl.int_range(0, bin_count, eager=True) / bin_count,
"count": pl.zeros(n=bin_count, dtype=pl.Int64, eager=True),
}

if not include_breakpoint:
del data_dict["breakpoint"]

return PolarsDataFrame(
pl.DataFrame(data_dict),
backend_version=self._backend_version,
version=self._version,
)

# polars <1.15 does not adjust the bins when they have equivalent min/max
# polars <1.5 with bin_count=...
# returns bins that range from -inf to +inf and has bin_count + 1 bins.
# for compat: convert `bin_count=` call to `bins=`
if (
(self._backend_version < (1, 15))
and (bin_count is not None)
and (self._native_series.count() > 0)
): # pragma: no cover
lower = cast(Union[int, float], self._native_series.min())
upper = cast(Union[int, float], self._native_series.max())
pad_lowest_bin = False
if lower == upper:
width = 1 / bin_count
lower -= 0.5
upper += 0.5
else:
pad_lowest_bin = True
width = (upper - lower) / bin_count

bins = (pl.int_range(0, bin_count + 1, eager=True) * width + lower).to_list()
if pad_lowest_bin:
bins[0] -= 0.001 * abs(bins[0]) if bins[0] != 0 else 0.001
bin_count = None

# Polars inconsistently handles NaN values when computing histograms
# against predefined bins: https://github.com/pola-rs/polars/issues/21082
series = self._native_series
if self._backend_version < (1, 15) or bins is not None:
series = series.set(series.is_nan(), None)

df = series.hist(
bins=bins,
bin_count=bin_count,
include_category=False,
include_breakpoint=include_breakpoint,
)
if not include_breakpoint:
df.columns = ["count"]

# polars<1.15 implicitly adds -inf and inf to either end of bins
if self._backend_version < (1, 15) and bins is not None: # pragma: no cover
r = pl.int_range(0, len(df))
df = df.filter((r > 0) & (r < len(df) - 1))

if self._backend_version < (1, 0) and include_breakpoint:
df = df.rename({"break_point": "breakpoint"})

return PolarsDataFrame(
df, backend_version=self._backend_version, version=self._version
)

def to_polars(self: Self) -> pl.Series:
return self._native_series

Expand Down
4 changes: 2 additions & 2 deletions narwhals/_polars/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import polars as pl

from narwhals.exceptions import ColumnNotFoundError
from narwhals.exceptions import ComputeError
from narwhals.exceptions import InvalidOperationError
from narwhals.exceptions import NarwhalsError
from narwhals.exceptions import ShapeError
Expand Down Expand Up @@ -237,8 +238,7 @@ def catch_polars_exception(
elif isinstance(exception, pl.exceptions.InvalidOperationError):
return InvalidOperationError(str(exception))
elif isinstance(exception, pl.exceptions.ComputeError):
# We don't (yet?) have a Narwhals ComputeError.
return NarwhalsError(str(exception))
return ComputeError(str(exception))
if backend_version >= (1,) and isinstance(exception, pl.exceptions.PolarsError):
# Old versions of Polars didn't have PolarsError.
return NarwhalsError(str(exception))
Expand Down
4 changes: 4 additions & 0 deletions narwhals/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ def from_missing_and_available_column_names(
return ColumnNotFoundError(message)


class ComputeError(NarwhalsError):
"""Exception raised when the underlying computation could not be evaluated."""


class ShapeError(NarwhalsError):
"""Exception raised when trying to perform operations on data structures with incompatible shapes."""

Expand Down
Loading

0 comments on commit 4b1d778

Please sign in to comment.