From 9f8d828a482ad0eca5ef451e7186205e7bf23e7e Mon Sep 17 00:00:00 2001 From: MarcoGorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Fri, 15 Sep 2023 11:29:29 +0200 Subject: [PATCH 1/7] feat(rust, python): in rolling aggregation functions, sort under-the-hood for user if data is unsorted (with warning) --- crates/polars-plan/src/dsl/mod.rs | 53 ++++++++-- .../src/chunkedarray/rolling_window/mod.rs | 3 + py-polars/polars/expr/expr.py | 51 +++++++-- py-polars/src/expr/rolling.rs | 37 +++++-- .../tests/parametric/test_groupby_rolling.py | 100 +++++++++++++++++- .../unit/operations/rolling/test_rolling.py | 12 +-- 6 files changed, 229 insertions(+), 27 deletions(-) diff --git a/crates/polars-plan/src/dsl/mod.rs b/crates/polars-plan/src/dsl/mod.rs index a16fd25358bb..ddbb4606c96b 100644 --- a/crates/polars-plan/src/dsl/mod.rs +++ b/crates/polars-plan/src/dsl/mod.rs @@ -1,7 +1,5 @@ #![allow(ambiguous_glob_reexports)] //! Domain specific language for the Lazy API. -#[cfg(feature = "rolling_window")] -use polars_core::utils::ensure_sorted_arg; #[cfg(feature = "dtype-categorical")] pub mod cat; #[cfg(feature = "dtype-categorical")] @@ -1230,19 +1228,52 @@ impl Expr { move |s| { let mut by = s[1].clone(); by = by.rechunk(); - let s = &s[0]; + let series: Series; polars_ensure!( options.weights.is_none(), ComputeError: "`weights` is not supported in 'rolling by' expression" ); - let (by, tz) = match by.dtype() { + let (mut by, tz) = match by.dtype() { DataType::Datetime(tu, tz) => { (by.cast(&DataType::Datetime(*tu, None))?, tz) }, _ => (by.clone(), &None), }; - ensure_sorted_arg(&by, expr_name)?; + let sorting_indices; + let original_indices; + let by_flag = by.is_sorted_flag(); + match by_flag { + IsSorted::Ascending => { + series = s[0].clone(); + original_indices = None; + }, + IsSorted::Descending => { + series = s[0].reverse(); + by = by.reverse(); + original_indices = None; + }, + IsSorted::Not => { + eprintln!( + "PolarsPerformanceWarning: Series is not known to be \ + sorted by `by` column, so Polars is temporarily \ + sorting it for you.\n\ + You can silence this warning by:\n\ + - passing `warn_if_unsorted=False`;\n\ + - sorting your data by your `by` column beforehand;\n\ + - setting `.set_sorted()` if you already know your data is sorted\n\ + before passing it to the rolling aggregation function" + ); + sorting_indices = by.arg_sort(Default::default()); + unsafe { by = by.take_unchecked(&sorting_indices)? }; + unsafe { series = s[0].take_unchecked(&sorting_indices)? }; + let int_range = + UInt32Chunked::from_iter_values("", 0..s[0].len() as u32) + .into_series(); + original_indices = + unsafe { int_range.take_unchecked(&sorting_indices) }.ok() + }, + }; let by = by.datetime().unwrap(); let by_values = by.cont_slice().map_err(|_| { polars_err!( @@ -1264,7 +1295,17 @@ impl Expr { fn_params: options.fn_params.clone(), }; - rolling_fn(s, options).map(Some) + match by_flag { + IsSorted::Ascending => rolling_fn(&series, options).map(Some), + IsSorted::Descending => { + Ok(rolling_fn(&series, options)?.reverse()).map(Some) + }, + IsSorted::Not => { + let res = rolling_fn(&series, options)?; + let indices = &original_indices.unwrap().arg_sort(Default::default()); + unsafe { res.take_unchecked(indices) }.map(Some) + }, + } }, &[col(by)], output_type, diff --git a/crates/polars-time/src/chunkedarray/rolling_window/mod.rs b/crates/polars-time/src/chunkedarray/rolling_window/mod.rs index dbb3e07d18e6..15208983ed98 100644 --- a/crates/polars-time/src/chunkedarray/rolling_window/mod.rs +++ b/crates/polars-time/src/chunkedarray/rolling_window/mod.rs @@ -38,6 +38,8 @@ pub struct RollingOptions { pub closed_window: Option, /// Optional parameters for the rolling function pub fn_params: DynArgs, + /// Warn if data is not known to be sorted by `by` column (if passed) + pub warn_if_unsorted: bool, } #[cfg(feature = "rolling_window")] @@ -51,6 +53,7 @@ impl Default for RollingOptions { by: None, closed_window: None, fn_params: None, + warn_if_unsorted: true, } } } diff --git a/py-polars/polars/expr/expr.py b/py-polars/polars/expr/expr.py index c1fc424b91bd..5c7e315d1e87 100644 --- a/py-polars/polars/expr/expr.py +++ b/py-polars/polars/expr/expr.py @@ -5279,6 +5279,7 @@ def rolling_min( center: bool = False, by: str | None = None, closed: ClosedInterval = "left", + warn_if_unsorted: bool = True, ) -> Self: """ Apply a rolling min (moving min) over the values in this array. @@ -5350,6 +5351,8 @@ def rolling_min( closed : {'left', 'right', 'both', 'none'} Define which sides of the temporal interval are closed (inclusive); only applicable if `by` has been set. + warn_if_unsorted + Warn if data is not known to be sorted by `by` column (if passed). Warnings -------- @@ -5475,7 +5478,7 @@ def rolling_min( ) return self._from_pyexpr( self._pyexpr.rolling_min( - window_size, weights, min_periods, center, by, closed + window_size, weights, min_periods, center, by, closed, warn_if_unsorted ) ) @@ -5489,6 +5492,7 @@ def rolling_max( center: bool = False, by: str | None = None, closed: ClosedInterval = "left", + warn_if_unsorted: bool = True, ) -> Self: """ Apply a rolling max (moving max) over the values in this array. @@ -5556,6 +5560,8 @@ def rolling_max( closed : {'left', 'right', 'both', 'none'} Define which sides of the temporal interval are closed (inclusive); only applicable if `by` has been set. + warn_if_unsorted + Warn if data is not known to be sorted by `by` column (if passed). Warnings -------- @@ -5708,7 +5714,7 @@ def rolling_max( ) return self._from_pyexpr( self._pyexpr.rolling_max( - window_size, weights, min_periods, center, by, closed + window_size, weights, min_periods, center, by, closed, warn_if_unsorted ) ) @@ -5722,6 +5728,7 @@ def rolling_mean( center: bool = False, by: str | None = None, closed: ClosedInterval = "left", + warn_if_unsorted: bool = True, ) -> Self: """ Apply a rolling mean (moving mean) over the values in this array. @@ -5793,6 +5800,8 @@ def rolling_mean( closed : {'left', 'right', 'both', 'none'} Define which sides of the temporal interval are closed (inclusive); only applicable if `by` has been set. + warn_if_unsorted + Warn if data is not known to be sorted by `by` column (if passed). Warnings -------- @@ -5945,7 +5954,13 @@ def rolling_mean( ) return self._from_pyexpr( self._pyexpr.rolling_mean( - window_size, weights, min_periods, center, by, closed + window_size, + weights, + min_periods, + center, + by, + closed, + warn_if_unsorted, ) ) @@ -5959,6 +5974,7 @@ def rolling_sum( center: bool = False, by: str | None = None, closed: ClosedInterval = "left", + warn_if_unsorted: bool = True, ) -> Self: """ Apply a rolling sum (moving sum) over the values in this array. @@ -6026,6 +6042,8 @@ def rolling_sum( closed : {'left', 'right', 'both', 'none'} Define which sides of the temporal interval are closed (inclusive); only applicable if `by` has been set. + warn_if_unsorted + Warn if data is not known to be sorted by `by` column (if passed). Warnings -------- @@ -6178,7 +6196,7 @@ def rolling_sum( ) return self._from_pyexpr( self._pyexpr.rolling_sum( - window_size, weights, min_periods, center, by, closed + window_size, weights, min_periods, center, by, closed, warn_if_unsorted ) ) @@ -6193,6 +6211,7 @@ def rolling_std( by: str | None = None, closed: ClosedInterval = "left", ddof: int = 1, + warn_if_unsorted: bool = True, ) -> Self: """ Compute a rolling standard deviation. @@ -6262,6 +6281,8 @@ def rolling_std( applicable if `by` has been set. ddof "Delta Degrees of Freedom": The divisor for a length N window is N - ddof + warn_if_unsorted + Warn if data is not known to be sorted by `by` column (if passed). Warnings -------- @@ -6414,7 +6435,14 @@ def rolling_std( ) return self._from_pyexpr( self._pyexpr.rolling_std( - window_size, weights, min_periods, center, by, closed, ddof + window_size, + weights, + min_periods, + center, + by, + closed, + ddof, + warn_if_unsorted, ) ) @@ -6429,6 +6457,7 @@ def rolling_var( by: str | None = None, closed: ClosedInterval = "left", ddof: int = 1, + warn_if_unsorted: bool = True, ) -> Self: """ Compute a rolling variance. @@ -6498,6 +6527,8 @@ def rolling_var( applicable if `by` has been set. ddof "Delta Degrees of Freedom": The divisor for a length N window is N - ddof + warn_if_unsorted + Warn if data is not known to be sorted by `by` column (if passed). Warnings -------- @@ -6657,6 +6688,7 @@ def rolling_var( by, closed, ddof, + warn_if_unsorted, ) ) @@ -6670,6 +6702,7 @@ def rolling_median( center: bool = False, by: str | None = None, closed: ClosedInterval = "left", + warn_if_unsorted: bool = True, ) -> Self: """ Compute a rolling median. @@ -6737,6 +6770,8 @@ def rolling_median( closed : {'left', 'right', 'both', 'none'} Define which sides of the temporal interval are closed (inclusive); only applicable if `by` has been set. + warn_if_unsorted + Warn if data is not known to be sorted by `by` column (if passed). Warnings -------- @@ -6815,7 +6850,7 @@ def rolling_median( ) return self._from_pyexpr( self._pyexpr.rolling_median( - window_size, weights, min_periods, center, by, closed + window_size, weights, min_periods, center, by, closed, warn_if_unsorted ) ) @@ -6831,6 +6866,7 @@ def rolling_quantile( center: bool = False, by: str | None = None, closed: ClosedInterval = "left", + warn_if_unsorted: bool = True, ) -> Self: """ Compute a rolling quantile. @@ -6902,6 +6938,8 @@ def rolling_quantile( closed : {'left', 'right', 'both', 'none'} Define which sides of the temporal interval are closed (inclusive); only applicable if `by` has been set. + warn_if_unsorted + Warn if data is not known to be sorted by `by` column (if passed). Warnings -------- @@ -7016,6 +7054,7 @@ def rolling_quantile( center, by, closed, + warn_if_unsorted, ) ) diff --git a/py-polars/src/expr/rolling.rs b/py-polars/src/expr/rolling.rs index dbce0d294203..5ee2055ba81d 100644 --- a/py-polars/src/expr/rolling.rs +++ b/py-polars/src/expr/rolling.rs @@ -11,7 +11,8 @@ use crate::{PyExpr, PySeries}; #[pymethods] impl PyExpr { - #[pyo3(signature = (window_size, weights, min_periods, center, by, closed))] + #[pyo3(signature = (window_size, weights, min_periods, center, by, closed, warn_if_unsorted))] + #[allow(clippy::too_many_arguments)] fn rolling_sum( &self, window_size: &str, @@ -20,6 +21,7 @@ impl PyExpr { center: bool, by: Option, closed: Option>, + warn_if_unsorted: bool, ) -> Self { let options = RollingOptions { window_size: Duration::parse(window_size), @@ -28,12 +30,14 @@ impl PyExpr { center, by, closed_window: closed.map(|c| c.0), + warn_if_unsorted, ..Default::default() }; self.inner.clone().rolling_sum(options).into() } - #[pyo3(signature = (window_size, weights, min_periods, center, by, closed))] + #[pyo3(signature = (window_size, weights, min_periods, center, by, closed, warn_if_unsorted))] + #[allow(clippy::too_many_arguments)] fn rolling_min( &self, window_size: &str, @@ -42,6 +46,7 @@ impl PyExpr { center: bool, by: Option, closed: Option>, + warn_if_unsorted: bool, ) -> Self { let options = RollingOptions { window_size: Duration::parse(window_size), @@ -50,12 +55,14 @@ impl PyExpr { center, by, closed_window: closed.map(|c| c.0), + warn_if_unsorted, ..Default::default() }; self.inner.clone().rolling_min(options).into() } - #[pyo3(signature = (window_size, weights, min_periods, center, by, closed))] + #[pyo3(signature = (window_size, weights, min_periods, center, by, closed, warn_if_unsorted))] + #[allow(clippy::too_many_arguments)] fn rolling_max( &self, window_size: &str, @@ -64,6 +71,7 @@ impl PyExpr { center: bool, by: Option, closed: Option>, + warn_if_unsorted: bool, ) -> Self { let options = RollingOptions { window_size: Duration::parse(window_size), @@ -72,12 +80,14 @@ impl PyExpr { center, by, closed_window: closed.map(|c| c.0), + warn_if_unsorted, ..Default::default() }; self.inner.clone().rolling_max(options).into() } - #[pyo3(signature = (window_size, weights, min_periods, center, by, closed))] + #[pyo3(signature = (window_size, weights, min_periods, center, by, closed, warn_if_unsorted))] + #[allow(clippy::too_many_arguments)] fn rolling_mean( &self, window_size: &str, @@ -86,6 +96,7 @@ impl PyExpr { center: bool, by: Option, closed: Option>, + warn_if_unsorted: bool, ) -> Self { let options = RollingOptions { window_size: Duration::parse(window_size), @@ -94,13 +105,14 @@ impl PyExpr { center, by, closed_window: closed.map(|c| c.0), + warn_if_unsorted, ..Default::default() }; self.inner.clone().rolling_mean(options).into() } - #[pyo3(signature = (window_size, weights, min_periods, center, by, closed, ddof))] + #[pyo3(signature = (window_size, weights, min_periods, center, by, closed, ddof, warn_if_unsorted))] #[allow(clippy::too_many_arguments)] fn rolling_std( &self, @@ -111,6 +123,7 @@ impl PyExpr { by: Option, closed: Option>, ddof: u8, + warn_if_unsorted: bool, ) -> Self { let options = RollingOptions { window_size: Duration::parse(window_size), @@ -120,12 +133,13 @@ impl PyExpr { by, closed_window: closed.map(|c| c.0), fn_params: Some(Arc::new(RollingVarParams { ddof }) as Arc), + warn_if_unsorted, }; self.inner.clone().rolling_std(options).into() } - #[pyo3(signature = (window_size, weights, min_periods, center, by, closed, ddof))] + #[pyo3(signature = (window_size, weights, min_periods, center, by, closed, ddof, warn_if_unsorted))] #[allow(clippy::too_many_arguments)] fn rolling_var( &self, @@ -136,6 +150,7 @@ impl PyExpr { by: Option, closed: Option>, ddof: u8, + warn_if_unsorted: bool, ) -> Self { let options = RollingOptions { window_size: Duration::parse(window_size), @@ -145,12 +160,14 @@ impl PyExpr { by, closed_window: closed.map(|c| c.0), fn_params: Some(Arc::new(RollingVarParams { ddof }) as Arc), + warn_if_unsorted, }; self.inner.clone().rolling_var(options).into() } - #[pyo3(signature = (window_size, weights, min_periods, center, by, closed))] + #[pyo3(signature = (window_size, weights, min_periods, center, by, closed, warn_if_unsorted))] + #[allow(clippy::too_many_arguments)] fn rolling_median( &self, window_size: &str, @@ -159,6 +176,7 @@ impl PyExpr { center: bool, by: Option, closed: Option>, + warn_if_unsorted: bool, ) -> Self { let options = RollingOptions { window_size: Duration::parse(window_size), @@ -171,11 +189,12 @@ impl PyExpr { prob: 0.5, interpol: QuantileInterpolOptions::Linear, }) as Arc), + warn_if_unsorted, }; self.inner.clone().rolling_quantile(options).into() } - #[pyo3(signature = (quantile, interpolation, window_size, weights, min_periods, center, by, closed))] + #[pyo3(signature = (quantile, interpolation, window_size, weights, min_periods, center, by, closed, warn_if_unsorted))] #[allow(clippy::too_many_arguments)] fn rolling_quantile( &self, @@ -187,6 +206,7 @@ impl PyExpr { center: bool, by: Option, closed: Option>, + warn_if_unsorted: bool, ) -> Self { let options = RollingOptions { window_size: Duration::parse(window_size), @@ -199,6 +219,7 @@ impl PyExpr { prob: quantile, interpol: interpolation.0, }) as Arc), + warn_if_unsorted, }; self.inner.clone().rolling_quantile(options).into() diff --git a/py-polars/tests/parametric/test_groupby_rolling.py b/py-polars/tests/parametric/test_groupby_rolling.py index c4c62b36a250..38536578ebe9 100644 --- a/py-polars/tests/parametric/test_groupby_rolling.py +++ b/py-polars/tests/parametric/test_groupby_rolling.py @@ -7,7 +7,7 @@ from hypothesis import assume, given, reject import polars as pl -from polars.testing import assert_frame_equal +from polars.testing import assert_frame_equal, assert_series_equal from polars.testing.parametric.primitives import column, dataframes from polars.testing.parametric.strategies import strategy_closed, strategy_time_unit from polars.utils.convert import _timedelta_to_pl_duration @@ -75,3 +75,101 @@ def test_group_by_rolling( pl.col("value").cast(pl.List(pl.Int64)), ) assert_frame_equal(result, expected) + + +@given( + window_size=st.timedeltas(min_value=timedelta(microseconds=0)).map( + _timedelta_to_pl_duration + ), + closed=strategy_closed, + data=st.data(), + time_unit=strategy_time_unit, + aggregation=st.sampled_from( + [ + "min", + "max", + "mean", + "sum", + # "std", blocked by https://github.com/pola-rs/polars/issues/11140 + # "var", blocked by https://github.com/pola-rs/polars/issues/11140 + "median", + ] + ), +) +def test_rolling_aggs( + window_size: str, + closed: ClosedInterval, + data: st.DataObject, + time_unit: TimeUnit, + aggregation: str, +) -> None: + # Check: + # - that we get the same results whether we sort the data beforehand, + # or whether polars sorts it for us under-the-hood + # - that even if polars temporarily sorts the data under-the-hood, the + # order that the user passed the data in is restored + assume(window_size != "") + dataframe = data.draw( + dataframes( + [ + column("ts", dtype=pl.Datetime(time_unit)), + column("value", dtype=pl.Int64), + ], + ) + ) + # take unique because of https://github.com/pola-rs/polars/issues/11150 + df = dataframe.unique("ts") + func = f"rolling_{aggregation}" + try: + result = df.with_columns( + getattr(pl.col("value"), func)( + window_size=window_size, by="ts", closed=closed, warn_if_unsorted=False + ) + ) + except pl.exceptions.PolarsPanicError as exc: + assert any( # noqa: PT017 + msg in str(exc) + for msg in ( + "attempt to multiply with overflow", + "attempt to add with overflow", + ) + ) + reject() + + expected = ( + df.with_row_count("index") + .sort("ts") + .with_columns( + getattr(pl.col("value"), func)( + window_size=window_size, by="ts", closed=closed + ), + "index", + ) + .sort("index") + .drop("index") + ) + assert_frame_equal(result, expected) + assert_series_equal(result["ts"], df["ts"]) + + expected_dict: dict[str, list[object]] = {"ts": [], "value": []} + for ts, _ in df.iter_rows(): + window = df.filter( + pl.col("ts").is_between( + pl.lit(ts, dtype=pl.Datetime(time_unit)).dt.offset_by( + f"-{window_size}" + ), + pl.lit(ts, dtype=pl.Datetime(time_unit)), + closed=closed, + ) + ) + expected_dict["ts"].append(ts) + if window.is_empty(): + expected_dict["value"].append(None) + else: + value = getattr(window["value"], aggregation)() + expected_dict["value"].append(value) + expected = pl.DataFrame(expected_dict).select( + pl.col("ts").cast(pl.Datetime(time_unit)), + pl.col("value").cast(result["value"].dtype), + ) + assert_frame_equal(result, expected) diff --git a/py-polars/tests/unit/operations/rolling/test_rolling.py b/py-polars/tests/unit/operations/rolling/test_rolling.py index 27f57e660a7c..cdf6d845de58 100644 --- a/py-polars/tests/unit/operations/rolling/test_rolling.py +++ b/py-polars/tests/unit/operations/rolling/test_rolling.py @@ -835,18 +835,18 @@ def test_rolling_weighted_quantile_10031() -> None: ) -def test_rolling_aggregations_unsorted_raise_10991() -> None: +def test_rolling_aggregations_unsorted_temporarily_sorts_10991() -> None: df = pl.DataFrame( { "dt": [datetime(2020, 1, 3), datetime(2020, 1, 1), datetime(2020, 1, 2)], "val": [1, 2, 3], } ) - with pytest.raises( - pl.InvalidOperationError, - match="argument in operation 'rolling_sum' is not explicitly sorted", - ): - df.with_columns(roll=pl.col("val").rolling_sum("2d", by="dt", closed="right")) + result = df.with_columns( + roll=pl.col("val").rolling_sum("2d", by="dt", closed="right") + ) + expected = df.with_columns(roll=pl.Series([4, 2, 5])) + assert_frame_equal(result, expected) def test_rolling() -> None: From 261ea5148c2e4f82943fbe0c7e9f9c6e9708c81c Mon Sep 17 00:00:00 2001 From: MarcoGorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sat, 16 Sep 2023 16:51:33 +0200 Subject: [PATCH 2/7] quick fixup --- crates/polars-plan/src/dsl/mod.rs | 22 ++++++++++--------- .../tests/parametric/test_groupby_rolling.py | 12 +++++++++- 2 files changed, 23 insertions(+), 11 deletions(-) diff --git a/crates/polars-plan/src/dsl/mod.rs b/crates/polars-plan/src/dsl/mod.rs index 548d0a4c018a..30c27d6541f3 100644 --- a/crates/polars-plan/src/dsl/mod.rs +++ b/crates/polars-plan/src/dsl/mod.rs @@ -1254,16 +1254,18 @@ impl Expr { original_indices = None; }, IsSorted::Not => { - eprintln!( - "PolarsPerformanceWarning: Series is not known to be \ - sorted by `by` column, so Polars is temporarily \ - sorting it for you.\n\ - You can silence this warning by:\n\ - - passing `warn_if_unsorted=False`;\n\ - - sorting your data by your `by` column beforehand;\n\ - - setting `.set_sorted()` if you already know your data is sorted\n\ - before passing it to the rolling aggregation function" - ); + if options.warn_if_unsorted { + eprintln!( + "PolarsPerformanceWarning: Series is not known to be \ + sorted by `by` column, so Polars is temporarily \ + sorting it for you.\n\ + You can silence this warning by:\n\ + - passing `warn_if_unsorted=False`;\n\ + - sorting your data by your `by` column beforehand;\n\ + - setting `.set_sorted()` if you already know your data is sorted\n\ + before passing it to the rolling aggregation function" + ); + } sorting_indices = by.arg_sort(Default::default()); unsafe { by = by.take_unchecked(&sorting_indices)? }; unsafe { series = s[0].take_unchecked(&sorting_indices)? }; diff --git a/py-polars/tests/parametric/test_groupby_rolling.py b/py-polars/tests/parametric/test_groupby_rolling.py index 38536578ebe9..87d228be442f 100644 --- a/py-polars/tests/parametric/test_groupby_rolling.py +++ b/py-polars/tests/parametric/test_groupby_rolling.py @@ -166,7 +166,17 @@ def test_rolling_aggs( if window.is_empty(): expected_dict["value"].append(None) else: - value = getattr(window["value"], aggregation)() + try: + value = getattr(window["value"], aggregation)() + except pl.exceptions.PolarsPanicError as exc: + assert any( # noqa: PT017 + msg in str(exc) + for msg in ( + "attempt to multiply with overflow", + "attempt to add with overflow", + ) + ) + reject() expected_dict["value"].append(value) expected = pl.DataFrame(expected_dict).select( pl.col("ts").cast(pl.Datetime(time_unit)), From 716007ca327ebd1f2708f6c93d4945a23524373f Mon Sep 17 00:00:00 2001 From: MarcoGorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Mon, 18 Sep 2023 10:34:18 +0200 Subject: [PATCH 3/7] try removing try-except --- py-polars/tests/parametric/test_groupby_rolling.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/py-polars/tests/parametric/test_groupby_rolling.py b/py-polars/tests/parametric/test_groupby_rolling.py index 87d228be442f..38536578ebe9 100644 --- a/py-polars/tests/parametric/test_groupby_rolling.py +++ b/py-polars/tests/parametric/test_groupby_rolling.py @@ -166,17 +166,7 @@ def test_rolling_aggs( if window.is_empty(): expected_dict["value"].append(None) else: - try: - value = getattr(window["value"], aggregation)() - except pl.exceptions.PolarsPanicError as exc: - assert any( # noqa: PT017 - msg in str(exc) - for msg in ( - "attempt to multiply with overflow", - "attempt to add with overflow", - ) - ) - reject() + value = getattr(window["value"], aggregation)() expected_dict["value"].append(value) expected = pl.DataFrame(expected_dict).select( pl.col("ts").cast(pl.Datetime(time_unit)), From 0ff13865c383df8376fd82d70feeacde4f9c0d2b Mon Sep 17 00:00:00 2001 From: MarcoGorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sat, 28 Oct 2023 19:47:35 +0100 Subject: [PATCH 4/7] post-merge fixup --- py-polars/src/expr/rolling.rs | 5 --- .../tests/parametric/test_groupby_rolling.py | 44 +++++++++---------- py-polars/tests/unit/series/test_series.py | 2 +- 3 files changed, 21 insertions(+), 30 deletions(-) diff --git a/py-polars/src/expr/rolling.rs b/py-polars/src/expr/rolling.rs index 3ebb91f2b85c..b0be89a9c8b4 100644 --- a/py-polars/src/expr/rolling.rs +++ b/py-polars/src/expr/rolling.rs @@ -12,7 +12,6 @@ use crate::{PyExpr, PySeries}; #[pymethods] impl PyExpr { #[pyo3(signature = (window_size, weights, min_periods, center, by, closed, warn_if_unsorted))] - #[allow(clippy::too_many_arguments)] fn rolling_sum( &self, window_size: &str, @@ -37,7 +36,6 @@ impl PyExpr { } #[pyo3(signature = (window_size, weights, min_periods, center, by, closed, warn_if_unsorted))] - #[allow(clippy::too_many_arguments)] fn rolling_min( &self, window_size: &str, @@ -62,7 +60,6 @@ impl PyExpr { } #[pyo3(signature = (window_size, weights, min_periods, center, by, closed, warn_if_unsorted))] - #[allow(clippy::too_many_arguments)] fn rolling_max( &self, window_size: &str, @@ -87,7 +84,6 @@ impl PyExpr { } #[pyo3(signature = (window_size, weights, min_periods, center, by, closed, warn_if_unsorted))] - #[allow(clippy::too_many_arguments)] fn rolling_mean( &self, window_size: &str, @@ -165,7 +161,6 @@ impl PyExpr { } #[pyo3(signature = (window_size, weights, min_periods, center, by, closed, warn_if_unsorted))] - #[allow(clippy::too_many_arguments)] fn rolling_median( &self, window_size: &str, diff --git a/py-polars/tests/parametric/test_groupby_rolling.py b/py-polars/tests/parametric/test_groupby_rolling.py index 8f9eabfa3201..4ad028dd26e3 100644 --- a/py-polars/tests/parametric/test_groupby_rolling.py +++ b/py-polars/tests/parametric/test_groupby_rolling.py @@ -83,9 +83,9 @@ def test_rolling( @given( - window_size=st.timedeltas(min_value=timedelta(microseconds=0)).map( - _timedelta_to_pl_duration - ), + window_size=st.timedeltas( + min_value=timedelta(microseconds=0), max_value=timedelta(days=2) + ).map(_timedelta_to_pl_duration), closed=strategy_closed, data=st.data(), time_unit=strategy_time_unit, @@ -114,33 +114,29 @@ def test_rolling_aggs( # - that even if polars temporarily sorts the data under-the-hood, the # order that the user passed the data in is restored assume(window_size != "") - dataframe = data.draw( + df = data.draw( dataframes( [ - column("ts", dtype=pl.Datetime(time_unit)), - column("value", dtype=pl.Int64), + column( + "ts", + strategy=st.datetimes( + min_value=dt.datetime(2000, 1, 1), + max_value=dt.datetime(2001, 1, 1), + ), + dtype=pl.Datetime(time_unit), + ), + column( + "value", + strategy=st.integers(min_value=-100, max_value=100), + dtype=pl.Int64, + ), ], ) ) - # take unique because of https://github.com/pola-rs/polars/issues/11150 - df = dataframe.unique("ts") func = f"rolling_{aggregation}" - try: - result = df.with_columns( - getattr(pl.col("value"), func)( - window_size=window_size, by="ts", closed=closed, warn_if_unsorted=False - ) - ) - except pl.exceptions.PolarsPanicError as exc: - assert any( # noqa: PT017 - msg in str(exc) - for msg in ( - "attempt to multiply with overflow", - "attempt to add with overflow", - ) - ) - reject() - + result = df.with_columns( + getattr(pl.col("value"), func)(window_size=window_size, by="ts", closed=closed) + ) expected = ( df.with_row_count("index") .sort("ts") diff --git a/py-polars/tests/unit/series/test_series.py b/py-polars/tests/unit/series/test_series.py index 4ab816fd5f86..03dd83cd7444 100644 --- a/py-polars/tests/unit/series/test_series.py +++ b/py-polars/tests/unit/series/test_series.py @@ -437,7 +437,7 @@ def test_power() -> None: assert_series_equal(a**a, pl.Series([1.0, 4.0], dtype=Float64)) assert_series_equal(b**b, pl.Series([None, 4.0], dtype=Float64)) assert_series_equal(a**b, pl.Series([None, 4.0], dtype=Float64)) - assert_series_equal(a**None, pl.Series([None] * len(a), dtype=Float64)) + assert_series_equal(a ** None, pl.Series([None] * len(a), dtype=Float64)) with pytest.raises(TypeError): c**2 with pytest.raises(pl.ColumnNotFoundError): From 463ca3af78263eefd37e8ea6db38a6d63db7b6d1 Mon Sep 17 00:00:00 2001 From: MarcoGorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sat, 28 Oct 2023 19:50:58 +0100 Subject: [PATCH 5/7] lint --- py-polars/tests/unit/series/test_series.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py-polars/tests/unit/series/test_series.py b/py-polars/tests/unit/series/test_series.py index 03dd83cd7444..4ab816fd5f86 100644 --- a/py-polars/tests/unit/series/test_series.py +++ b/py-polars/tests/unit/series/test_series.py @@ -437,7 +437,7 @@ def test_power() -> None: assert_series_equal(a**a, pl.Series([1.0, 4.0], dtype=Float64)) assert_series_equal(b**b, pl.Series([None, 4.0], dtype=Float64)) assert_series_equal(a**b, pl.Series([None, 4.0], dtype=Float64)) - assert_series_equal(a ** None, pl.Series([None] * len(a), dtype=Float64)) + assert_series_equal(a**None, pl.Series([None] * len(a), dtype=Float64)) with pytest.raises(TypeError): c**2 with pytest.raises(pl.ColumnNotFoundError): From b1f7bb43f83a509d73fae6487d0b2cede6b77651 Mon Sep 17 00:00:00 2001 From: MarcoGorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Wed, 1 Nov 2023 16:05:06 +0000 Subject: [PATCH 6/7] update --- crates/polars-plan/src/dsl/mod.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/crates/polars-plan/src/dsl/mod.rs b/crates/polars-plan/src/dsl/mod.rs index 1226c6062dbf..28db5a511279 100644 --- a/crates/polars-plan/src/dsl/mod.rs +++ b/crates/polars-plan/src/dsl/mod.rs @@ -1223,7 +1223,7 @@ impl Expr { }, IsSorted::Not => { if options.warn_if_unsorted { - eprintln!( + polars_warn!( "PolarsPerformanceWarning: Series is not known to be \ sorted by `by` column, so Polars is temporarily \ sorting it for you.\n\ @@ -1235,13 +1235,13 @@ impl Expr { ); } sorting_indices = by.arg_sort(Default::default()); - unsafe { by = by.take_unchecked(&sorting_indices)? }; - unsafe { series = s[0].take_unchecked(&sorting_indices)? }; + unsafe { by = by.take_unchecked(&sorting_indices) }; + unsafe { series = s[0].take_unchecked(&sorting_indices) }; let int_range = UInt32Chunked::from_iter_values("", 0..s[0].len() as u32) .into_series(); original_indices = - unsafe { int_range.take_unchecked(&sorting_indices) }.ok() + Some(unsafe { int_range.take_unchecked(&sorting_indices) }) }, }; let by = by.datetime().unwrap(); @@ -1273,7 +1273,7 @@ impl Expr { IsSorted::Not => { let res = rolling_fn(&series, options)?; let indices = &original_indices.unwrap().arg_sort(Default::default()); - unsafe { res.take_unchecked(indices) }.map(Some) + Ok(Some(unsafe { res.take_unchecked(indices) })) }, } }, From 8798d6cbb81cf0f55783ccd860aaf1b54ce8c9b6 Mon Sep 17 00:00:00 2001 From: MarcoGorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Wed, 1 Nov 2023 19:07:25 +0000 Subject: [PATCH 7/7] clippy --- crates/polars-plan/src/dsl/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/polars-plan/src/dsl/mod.rs b/crates/polars-plan/src/dsl/mod.rs index 28db5a511279..1864f9759303 100644 --- a/crates/polars-plan/src/dsl/mod.rs +++ b/crates/polars-plan/src/dsl/mod.rs @@ -1268,7 +1268,7 @@ impl Expr { match by_flag { IsSorted::Ascending => rolling_fn(&series, options).map(Some), IsSorted::Descending => { - Ok(rolling_fn(&series, options)?.reverse()).map(Some) + Ok(Some(rolling_fn(&series, options)?.reverse())) }, IsSorted::Not => { let res = rolling_fn(&series, options)?;