Skip to content

Commit

Permalink
hist as backwards compatible as possible with polars<1.0
Browse files Browse the repository at this point in the history
  • Loading branch information
camriddell committed Jan 24, 2025
1 parent 7c3bacb commit 7cf5c1e
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 0 deletions.
16 changes: 16 additions & 0 deletions narwhals/_polars/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,13 @@ def hist(
include_breakpoint: bool = True,
) -> PolarsDataFrame:
from narwhals._polars.dataframe import PolarsDataFrame
from narwhals.exceptions import InvalidOperationError

if bins is not None:
for i in range(1, len(bins)):
if bins[i - 1] >= bins[i]:
msg = "bins must increase monotonically"
raise InvalidOperationError(msg)

df = self._native_series.hist(
bins=bins,
Expand All @@ -435,6 +442,15 @@ def hist(
if not include_category and not include_breakpoint:
df.columns = ["count"]

if self._backend_version < (1, 0): # pragma: no cover
if (
bins is not None
): # polars<1.0 implicitly adds -inf and inf to either end of bins
r = pl.int_range(0, len(df))
df = df.filter((r > 0) & (r < len(df) - 1))
if include_breakpoint:
df = df.rename({"break_point": "breakpoint"})

return PolarsDataFrame(
df, backend_version=self._backend_version, version=self._version
)
Expand Down
9 changes: 9 additions & 0 deletions tests/series_only/hist_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import narwhals.stable.v1 as nw
from narwhals.exceptions import InvalidOperationError
from tests.utils import POLARS_VERSION
from tests.utils import ConstructorEager
from tests.utils import assert_equal_data
from tests.utils import nwise
Expand Down Expand Up @@ -109,6 +110,10 @@ def test_hist_bin(
assert_equal_data(result, expected)


@pytest.mark.skipif(
POLARS_VERSION < (1, 0),
reason="hist(bin_count=...) behavior significantly changed after 1.0",
)
@pytest.mark.parametrize("params", counts_and_expected)
@pytest.mark.parametrize("include_breakpoint", [True, False])
@pytest.mark.parametrize("include_category", [True, False])
Expand Down Expand Up @@ -229,6 +234,10 @@ def test_hist_bins_hypotheis(
),
bin_count=st.integers(min_value=0, max_value=1_000),
)
@pytest.mark.skipif(
POLARS_VERSION < (1, 0),
reason="hist(bin_count=...) behavior significantly changed after 1.0",
)
@pytest.mark.filterwarnings(
"ignore:`Series.hist` is being called from the stable API although considered an unstable feature."
)
Expand Down

0 comments on commit 7cf5c1e

Please sign in to comment.