From 72efc1c48ae6c6530e6c1a646a5a8664889af162 Mon Sep 17 00:00:00 2001 From: MarcoGorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sun, 12 Nov 2023 19:16:52 +0000 Subject: [PATCH] chore(rust): use unwrap_or_else and get_unchecked_release in rolling kernels --- Cargo.lock | 1 + crates/polars-arrow/Cargo.toml | 1 + .../kernels/rolling/no_nulls/quantile.rs | 19 +++++++++++++------ .../src/legacy/kernels/rolling/nulls/mod.rs | 8 +++----- .../legacy/kernels/rolling/nulls/quantile.rs | 15 ++++++++++----- 5 files changed, 28 insertions(+), 16 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f6310d9c1744..3f37a725f26e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2658,6 +2658,7 @@ dependencies = [ "multiversion", "num-traits", "polars-error", + "polars-utils", "proptest", "rand 0.8.5", "regex", diff --git a/crates/polars-arrow/Cargo.toml b/crates/polars-arrow/Cargo.toml index 540acf8a177a..56cc496fccbc 100644 --- a/crates/polars-arrow/Cargo.toml +++ b/crates/polars-arrow/Cargo.toml @@ -26,6 +26,7 @@ foreign_vec = { version = "0.1" } hashbrown = { workspace = true } num-traits = { workspace = true } polars-error = { workspace = true } +polars-utils = { workspace = true } serde = { workspace = true, features = ["derive"], optional = true } simdutf8 = { workspace = true } diff --git a/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/quantile.rs b/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/quantile.rs index 5b1f78f63669..546fbc0bf9c2 100644 --- a/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/quantile.rs +++ b/crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/quantile.rs @@ -2,6 +2,7 @@ use std::fmt::Debug; use num_traits::ToPrimitive; use polars_error::polars_ensure; +use polars_utils::slice::GetSaferUnchecked; use super::QuantileInterpolOptions::*; use super::*; @@ -60,12 +61,16 @@ impl< if top_idx == idx { // safety // we are in bounds - unsafe { *vals.get_unchecked(idx) } + unsafe { *vals.get_unchecked_release(idx) } } else { // safety // we are in bounds - let (mid, mid_plus_1) = - unsafe { (*vals.get_unchecked(idx), *vals.get_unchecked(idx + 1)) }; + let (mid, mid_plus_1) = unsafe { + ( + *vals.get_unchecked_release(idx), + *vals.get_unchecked_release(idx + 1), + ) + }; (mid + mid_plus_1) / T::from::(2.0f64).unwrap() } @@ -77,16 +82,18 @@ impl< if top_idx == idx { // safety // we are in bounds - unsafe { *vals.get_unchecked(idx) } + unsafe { *vals.get_unchecked_release(idx) } } else { let proportion = T::from(float_idx - idx as f64).unwrap(); - proportion * (vals[top_idx] - vals[idx]) + vals[idx] + proportion + * (*vals.get_unchecked_release(top_idx) - *vals.get_unchecked_release(idx)) + + *vals.get_unchecked_release(idx) } }, _ => { // safety // we are in bounds - unsafe { *vals.get_unchecked(idx) } + unsafe { *vals.get_unchecked_release(idx) } }, } } diff --git a/crates/polars-arrow/src/legacy/kernels/rolling/nulls/mod.rs b/crates/polars-arrow/src/legacy/kernels/rolling/nulls/mod.rs index 6942a7e8abea..d9ea26de2744 100644 --- a/crates/polars-arrow/src/legacy/kernels/rolling/nulls/mod.rs +++ b/crates/polars-arrow/src/legacy/kernels/rolling/nulls/mod.rs @@ -49,14 +49,12 @@ where // Safety; we are in bounds let mut agg_window = unsafe { Agg::new(values, validity, start, end, params) }; - let mut validity = match create_validity(min_periods, len, window_size, det_offsets_fn) { - Some(v) => v, - None => { + let mut validity = create_validity(min_periods, len, window_size, det_offsets_fn) + .unwrap_or_else(|| { let mut validity = MutableBitmap::with_capacity(len); validity.extend_constant(len, true); validity - }, - }; + }); let out = (0..len) .map(|idx| { diff --git a/crates/polars-arrow/src/legacy/kernels/rolling/nulls/quantile.rs b/crates/polars-arrow/src/legacy/kernels/rolling/nulls/quantile.rs index 501f7c8e1f9a..f4cfd4da9e51 100644 --- a/crates/polars-arrow/src/legacy/kernels/rolling/nulls/quantile.rs +++ b/crates/polars-arrow/src/legacy/kernels/rolling/nulls/quantile.rs @@ -1,3 +1,5 @@ +use polars_utils::slice::GetSaferUnchecked; + use super::*; pub struct QuantileWindow<'a, T: NativeType + IsFloat + PartialOrd> { @@ -65,7 +67,8 @@ impl< QuantileInterpolOptions::Midpoint => { let top_idx = ((length as f64 - 1.0) * self.prob).ceil() as usize; Some( - (values[idx].unwrap() + values[top_idx].unwrap()) + (values.get_unchecked_release(idx).unwrap() + + values.get_unchecked_release(top_idx).unwrap()) / T::from::(2.0f64).unwrap(), ) }, @@ -74,16 +77,18 @@ impl< let top_idx = f64::ceil(float_idx) as usize; if top_idx == idx { - Some(values[idx].unwrap()) + Some(values.get_unchecked_release(idx).unwrap()) } else { let proportion = T::from(float_idx - idx as f64).unwrap(); Some( - proportion * (values[top_idx].unwrap() - values[idx].unwrap()) - + values[idx].unwrap(), + proportion + * (values.get_unchecked_release(top_idx).unwrap() + - values.get_unchecked_release(idx).unwrap()) + + values.get_unchecked_release(idx).unwrap(), ) } }, - _ => Some(values[idx].unwrap()), + _ => Some(values.get_unchecked_release(idx).unwrap()), } }