Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(rust, python): in rolling aggregation functions, sort under-the-hood for user if data is unsorted (with warning) #11134

Closed
wants to merge 12 commits into from
55 changes: 49 additions & 6 deletions crates/polars-plan/src/dsl/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
#![allow(ambiguous_glob_reexports)]
//! Domain specific language for the Lazy API.
#[cfg(feature = "rolling_window")]
use polars_core::utils::ensure_sorted_arg;
#[cfg(feature = "dtype-categorical")]
pub mod cat;
#[cfg(feature = "dtype-categorical")]
Expand Down Expand Up @@ -1229,19 +1227,54 @@ impl Expr {
move |s| {
let mut by = s[1].clone();
by = by.rechunk();
let s = &s[0];
let series: Series;

polars_ensure!(
options.weights.is_none(),
ComputeError: "`weights` is not supported in 'rolling by' expression"
);
let (by, tz) = match by.dtype() {
let (mut by, tz) = match by.dtype() {
DataType::Datetime(tu, tz) => {
(by.cast(&DataType::Datetime(*tu, None))?, tz)
},
_ => (by.clone(), &None),
};
ensure_sorted_arg(&by, expr_name)?;
let sorting_indices;
let original_indices;
let by_flag = by.is_sorted_flag();
match by_flag {
IsSorted::Ascending => {
series = s[0].clone();
original_indices = None;
},
IsSorted::Descending => {
series = s[0].reverse();
by = by.reverse();
original_indices = None;
},
IsSorted::Not => {
if options.warn_if_unsorted {
eprintln!(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you use polars_error's polars_warn function?

"PolarsPerformanceWarning: Series is not known to be \
sorted by `by` column, so Polars is temporarily \
sorting it for you.\n\
You can silence this warning by:\n\
- passing `warn_if_unsorted=False`;\n\
- sorting your data by your `by` column beforehand;\n\
- setting `.set_sorted()` if you already know your data is sorted\n\
before passing it to the rolling aggregation function"
);
}
sorting_indices = by.arg_sort(Default::default());
unsafe { by = by.take_unchecked(&sorting_indices)? };
unsafe { series = s[0].take_unchecked(&sorting_indices)? };
let int_range =
UInt32Chunked::from_iter_values("", 0..s[0].len() as u32)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be IdxCa and IdxSize instead of u32, no? EDIT: see below, you shouldn't save this at all and instead save sorting_indices.

.into_series();
original_indices =
unsafe { int_range.take_unchecked(&sorting_indices) }.ok()
},
};
let by = by.datetime().unwrap();
let by_values = by.cont_slice().map_err(|_| {
polars_err!(
Expand All @@ -1263,7 +1296,17 @@ impl Expr {
fn_params: options.fn_params.clone(),
};

rolling_fn(s, options).map(Some)
match by_flag {
IsSorted::Ascending => rolling_fn(&series, options).map(Some),
IsSorted::Descending => {
Ok(rolling_fn(&series, options)?.reverse()).map(Some)
},
IsSorted::Not => {
let res = rolling_fn(&series, options)?;
let indices = &original_indices.unwrap().arg_sort(Default::default());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should not be another sort and take, instead this should be a scatter: out[sorting_indices[i]] = res[i]. We currently lack a good scatter kernel though, as far as I'm aware. So this may be blocked until I get to that.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks - OK marking as blocked then

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi @orlp - just checking whether a scatter kernel might be on the horizon, and if not whether you have suggestions on what to do instead

Copy link
Collaborator

@orlp orlp Nov 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MarcoGorelli We don't have a scatter kernel yet, but we do have PolarsDataType::ZeroablePhysical now, which allows you to write a reasonably efficient scatter in a couple lines of code.

You can look in gather_skip_nulls.rs for an example, but basically you do

let mut out: Vec<T::ZeroablePhysical<'a>> = zeroed_vec(len);
unsafe {
    for i in 0..n {
        let out_idx = *sorted_indices.get_unchecked(i);
        *out.get_unchecked_mut(out_idx) = res.get(i).into();
    }
}
let out_arr = T::Array::from_zeroable_vec(out, dtype);

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Depending on the behavior of nulls you may also have to construct a null bitmap.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks

Noob question but where would I get T from in this function? Do I first need to refactor and make finish_rolling generic?

unsafe { res.take_unchecked(indices) }.map(Some)
},
}
},
&[col(by)],
output_type,
Expand Down
3 changes: 3 additions & 0 deletions crates/polars-time/src/chunkedarray/rolling_window/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ pub struct RollingOptions {
pub closed_window: Option<ClosedWindow>,
/// Optional parameters for the rolling function
pub fn_params: DynArgs,
/// Warn if data is not known to be sorted by `by` column (if passed)
pub warn_if_unsorted: bool,
}

#[cfg(feature = "rolling_window")]
Expand All @@ -51,6 +53,7 @@ impl Default for RollingOptions {
by: None,
closed_window: None,
fn_params: None,
warn_if_unsorted: true,
}
}
}
Expand Down
51 changes: 45 additions & 6 deletions py-polars/polars/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -5276,6 +5276,7 @@ def rolling_min(
center: bool = False,
by: str | None = None,
closed: ClosedInterval = "left",
warn_if_unsorted: bool = True,
) -> Self:
"""
Apply a rolling min (moving min) over the values in this array.
Expand Down Expand Up @@ -5347,6 +5348,8 @@ def rolling_min(
closed : {'left', 'right', 'both', 'none'}
Define which sides of the temporal interval are closed (inclusive); only
applicable if `by` has been set.
warn_if_unsorted
Warn if data is not known to be sorted by `by` column (if passed).

Warnings
--------
Expand Down Expand Up @@ -5472,7 +5475,7 @@ def rolling_min(
)
return self._from_pyexpr(
self._pyexpr.rolling_min(
window_size, weights, min_periods, center, by, closed
window_size, weights, min_periods, center, by, closed, warn_if_unsorted
)
)

Expand All @@ -5486,6 +5489,7 @@ def rolling_max(
center: bool = False,
by: str | None = None,
closed: ClosedInterval = "left",
warn_if_unsorted: bool = True,
) -> Self:
"""
Apply a rolling max (moving max) over the values in this array.
Expand Down Expand Up @@ -5553,6 +5557,8 @@ def rolling_max(
closed : {'left', 'right', 'both', 'none'}
Define which sides of the temporal interval are closed (inclusive); only
applicable if `by` has been set.
warn_if_unsorted
Warn if data is not known to be sorted by `by` column (if passed).

Warnings
--------
Expand Down Expand Up @@ -5705,7 +5711,7 @@ def rolling_max(
)
return self._from_pyexpr(
self._pyexpr.rolling_max(
window_size, weights, min_periods, center, by, closed
window_size, weights, min_periods, center, by, closed, warn_if_unsorted
)
)

Expand All @@ -5719,6 +5725,7 @@ def rolling_mean(
center: bool = False,
by: str | None = None,
closed: ClosedInterval = "left",
warn_if_unsorted: bool = True,
) -> Self:
"""
Apply a rolling mean (moving mean) over the values in this array.
Expand Down Expand Up @@ -5790,6 +5797,8 @@ def rolling_mean(
closed : {'left', 'right', 'both', 'none'}
Define which sides of the temporal interval are closed (inclusive); only
applicable if `by` has been set.
warn_if_unsorted
Warn if data is not known to be sorted by `by` column (if passed).

Warnings
--------
Expand Down Expand Up @@ -5942,7 +5951,13 @@ def rolling_mean(
)
return self._from_pyexpr(
self._pyexpr.rolling_mean(
window_size, weights, min_periods, center, by, closed
window_size,
weights,
min_periods,
center,
by,
closed,
warn_if_unsorted,
)
)

Expand All @@ -5956,6 +5971,7 @@ def rolling_sum(
center: bool = False,
by: str | None = None,
closed: ClosedInterval = "left",
warn_if_unsorted: bool = True,
) -> Self:
"""
Apply a rolling sum (moving sum) over the values in this array.
Expand Down Expand Up @@ -6023,6 +6039,8 @@ def rolling_sum(
closed : {'left', 'right', 'both', 'none'}
Define which sides of the temporal interval are closed (inclusive); only
applicable if `by` has been set.
warn_if_unsorted
Warn if data is not known to be sorted by `by` column (if passed).

Warnings
--------
Expand Down Expand Up @@ -6175,7 +6193,7 @@ def rolling_sum(
)
return self._from_pyexpr(
self._pyexpr.rolling_sum(
window_size, weights, min_periods, center, by, closed
window_size, weights, min_periods, center, by, closed, warn_if_unsorted
)
)

Expand All @@ -6190,6 +6208,7 @@ def rolling_std(
by: str | None = None,
closed: ClosedInterval = "left",
ddof: int = 1,
warn_if_unsorted: bool = True,
) -> Self:
"""
Compute a rolling standard deviation.
Expand Down Expand Up @@ -6259,6 +6278,8 @@ def rolling_std(
applicable if `by` has been set.
ddof
"Delta Degrees of Freedom": The divisor for a length N window is N - ddof
warn_if_unsorted
Warn if data is not known to be sorted by `by` column (if passed).

Warnings
--------
Expand Down Expand Up @@ -6411,7 +6432,14 @@ def rolling_std(
)
return self._from_pyexpr(
self._pyexpr.rolling_std(
window_size, weights, min_periods, center, by, closed, ddof
window_size,
weights,
min_periods,
center,
by,
closed,
ddof,
warn_if_unsorted,
)
)

Expand All @@ -6426,6 +6454,7 @@ def rolling_var(
by: str | None = None,
closed: ClosedInterval = "left",
ddof: int = 1,
warn_if_unsorted: bool = True,
) -> Self:
"""
Compute a rolling variance.
Expand Down Expand Up @@ -6495,6 +6524,8 @@ def rolling_var(
applicable if `by` has been set.
ddof
"Delta Degrees of Freedom": The divisor for a length N window is N - ddof
warn_if_unsorted
Warn if data is not known to be sorted by `by` column (if passed).

Warnings
--------
Expand Down Expand Up @@ -6654,6 +6685,7 @@ def rolling_var(
by,
closed,
ddof,
warn_if_unsorted,
)
)

Expand All @@ -6667,6 +6699,7 @@ def rolling_median(
center: bool = False,
by: str | None = None,
closed: ClosedInterval = "left",
warn_if_unsorted: bool = True,
) -> Self:
"""
Compute a rolling median.
Expand Down Expand Up @@ -6734,6 +6767,8 @@ def rolling_median(
closed : {'left', 'right', 'both', 'none'}
Define which sides of the temporal interval are closed (inclusive); only
applicable if `by` has been set.
warn_if_unsorted
Warn if data is not known to be sorted by `by` column (if passed).

Warnings
--------
Expand Down Expand Up @@ -6812,7 +6847,7 @@ def rolling_median(
)
return self._from_pyexpr(
self._pyexpr.rolling_median(
window_size, weights, min_periods, center, by, closed
window_size, weights, min_periods, center, by, closed, warn_if_unsorted
)
)

Expand All @@ -6828,6 +6863,7 @@ def rolling_quantile(
center: bool = False,
by: str | None = None,
closed: ClosedInterval = "left",
warn_if_unsorted: bool = True,
) -> Self:
"""
Compute a rolling quantile.
Expand Down Expand Up @@ -6899,6 +6935,8 @@ def rolling_quantile(
closed : {'left', 'right', 'both', 'none'}
Define which sides of the temporal interval are closed (inclusive); only
applicable if `by` has been set.
warn_if_unsorted
Warn if data is not known to be sorted by `by` column (if passed).

Warnings
--------
Expand Down Expand Up @@ -7013,6 +7051,7 @@ def rolling_quantile(
center,
by,
closed,
warn_if_unsorted,
)
)

Expand Down
Loading