-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(rust, python): in rolling aggregation functions, sort under-the-hood for user if data is unsorted (with warning) #11134
Changes from 5 commits
9f8d828
4cc16b2
261ea51
102ed94
716007c
da6738b
0ff1386
463ca3a
83a08da
b1f7bb4
8798d6c
dee189c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,5 @@ | ||
#![allow(ambiguous_glob_reexports)] | ||
//! Domain specific language for the Lazy API. | ||
#[cfg(feature = "rolling_window")] | ||
use polars_core::utils::ensure_sorted_arg; | ||
#[cfg(feature = "dtype-categorical")] | ||
pub mod cat; | ||
#[cfg(feature = "dtype-categorical")] | ||
|
@@ -1229,19 +1227,54 @@ impl Expr { | |
move |s| { | ||
let mut by = s[1].clone(); | ||
by = by.rechunk(); | ||
let s = &s[0]; | ||
let series: Series; | ||
|
||
polars_ensure!( | ||
options.weights.is_none(), | ||
ComputeError: "`weights` is not supported in 'rolling by' expression" | ||
); | ||
let (by, tz) = match by.dtype() { | ||
let (mut by, tz) = match by.dtype() { | ||
DataType::Datetime(tu, tz) => { | ||
(by.cast(&DataType::Datetime(*tu, None))?, tz) | ||
}, | ||
_ => (by.clone(), &None), | ||
}; | ||
ensure_sorted_arg(&by, expr_name)?; | ||
let sorting_indices; | ||
let original_indices; | ||
let by_flag = by.is_sorted_flag(); | ||
match by_flag { | ||
IsSorted::Ascending => { | ||
series = s[0].clone(); | ||
original_indices = None; | ||
}, | ||
IsSorted::Descending => { | ||
series = s[0].reverse(); | ||
by = by.reverse(); | ||
original_indices = None; | ||
}, | ||
IsSorted::Not => { | ||
if options.warn_if_unsorted { | ||
eprintln!( | ||
"PolarsPerformanceWarning: Series is not known to be \ | ||
sorted by `by` column, so Polars is temporarily \ | ||
sorting it for you.\n\ | ||
You can silence this warning by:\n\ | ||
- passing `warn_if_unsorted=False`;\n\ | ||
- sorting your data by your `by` column beforehand;\n\ | ||
- setting `.set_sorted()` if you already know your data is sorted\n\ | ||
before passing it to the rolling aggregation function" | ||
); | ||
} | ||
sorting_indices = by.arg_sort(Default::default()); | ||
unsafe { by = by.take_unchecked(&sorting_indices)? }; | ||
unsafe { series = s[0].take_unchecked(&sorting_indices)? }; | ||
let int_range = | ||
UInt32Chunked::from_iter_values("", 0..s[0].len() as u32) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be |
||
.into_series(); | ||
original_indices = | ||
unsafe { int_range.take_unchecked(&sorting_indices) }.ok() | ||
}, | ||
}; | ||
let by = by.datetime().unwrap(); | ||
let by_values = by.cont_slice().map_err(|_| { | ||
polars_err!( | ||
|
@@ -1263,7 +1296,17 @@ impl Expr { | |
fn_params: options.fn_params.clone(), | ||
}; | ||
|
||
rolling_fn(s, options).map(Some) | ||
match by_flag { | ||
IsSorted::Ascending => rolling_fn(&series, options).map(Some), | ||
IsSorted::Descending => { | ||
Ok(rolling_fn(&series, options)?.reverse()).map(Some) | ||
}, | ||
IsSorted::Not => { | ||
let res = rolling_fn(&series, options)?; | ||
let indices = &original_indices.unwrap().arg_sort(Default::default()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should not be another sort and take, instead this should be a scatter: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thanks - OK marking as blocked then There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hi @orlp - just checking whether a scatter kernel might be on the horizon, and if not whether you have suggestions on what to do instead There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @MarcoGorelli We don't have a scatter kernel yet, but we do have You can look in let mut out: Vec<T::ZeroablePhysical<'a>> = zeroed_vec(len);
unsafe {
for i in 0..n {
let out_idx = *sorted_indices.get_unchecked(i);
*out.get_unchecked_mut(out_idx) = res.get(i).into();
}
}
let out_arr = T::Array::from_zeroable_vec(out, dtype); There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Depending on the behavior of nulls you may also have to construct a null bitmap. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks Noob question but where would I get |
||
unsafe { res.take_unchecked(indices) }.map(Some) | ||
}, | ||
} | ||
}, | ||
&[col(by)], | ||
output_type, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you use
polars_error
'spolars_warn
function?