From 7cf5c1e52e4744d839bf2d846387e8f9f436c672 Mon Sep 17 00:00:00 2001 From: Cameron Riddell Date: Fri, 24 Jan 2025 08:20:16 -0800 Subject: [PATCH] hist as backwards compatible as possible with polars<1.0 --- narwhals/_polars/series.py | 16 ++++++++++++++++ tests/series_only/hist_test.py | 9 +++++++++ 2 files changed, 25 insertions(+) diff --git a/narwhals/_polars/series.py b/narwhals/_polars/series.py index ab8e98ef8..4b2f7a939 100644 --- a/narwhals/_polars/series.py +++ b/narwhals/_polars/series.py @@ -425,6 +425,13 @@ def hist( include_breakpoint: bool = True, ) -> PolarsDataFrame: from narwhals._polars.dataframe import PolarsDataFrame + from narwhals.exceptions import InvalidOperationError + + if bins is not None: + for i in range(1, len(bins)): + if bins[i - 1] >= bins[i]: + msg = "bins must increase monotonically" + raise InvalidOperationError(msg) df = self._native_series.hist( bins=bins, @@ -435,6 +442,15 @@ def hist( if not include_category and not include_breakpoint: df.columns = ["count"] + if self._backend_version < (1, 0): # pragma: no cover + if ( + bins is not None + ): # polars<1.0 implicitly adds -inf and inf to either end of bins + r = pl.int_range(0, len(df)) + df = df.filter((r > 0) & (r < len(df) - 1)) + if include_breakpoint: + df = df.rename({"break_point": "breakpoint"}) + return PolarsDataFrame( df, backend_version=self._backend_version, version=self._version ) diff --git a/tests/series_only/hist_test.py b/tests/series_only/hist_test.py index 5a706aeb5..698906698 100644 --- a/tests/series_only/hist_test.py +++ b/tests/series_only/hist_test.py @@ -8,6 +8,7 @@ import narwhals.stable.v1 as nw from narwhals.exceptions import InvalidOperationError +from tests.utils import POLARS_VERSION from tests.utils import ConstructorEager from tests.utils import assert_equal_data from tests.utils import nwise @@ -109,6 +110,10 @@ def test_hist_bin( assert_equal_data(result, expected) +@pytest.mark.skipif( + POLARS_VERSION < (1, 0), + reason="hist(bin_count=...) behavior significantly changed after 1.0", +) @pytest.mark.parametrize("params", counts_and_expected) @pytest.mark.parametrize("include_breakpoint", [True, False]) @pytest.mark.parametrize("include_category", [True, False]) @@ -229,6 +234,10 @@ def test_hist_bins_hypotheis( ), bin_count=st.integers(min_value=0, max_value=1_000), ) +@pytest.mark.skipif( + POLARS_VERSION < (1, 0), + reason="hist(bin_count=...) behavior significantly changed after 1.0", +) @pytest.mark.filterwarnings( "ignore:`Series.hist` is being called from the stable API although considered an unstable feature." )