From 165ec6c489eea072c9c10957ad6572bc75da83c6 Mon Sep 17 00:00:00 2001 From: ritchie Date: Sat, 26 Oct 2024 15:31:25 +0200 Subject: [PATCH] perf: Move rolling_corr/cov to an actual implementation on Series --- .../polars-plan/src/dsl/function_expr/mod.rs | 12 ++ .../src/dsl/function_expr/rolling.rs | 115 ++++++++++++++++++ .../src/dsl/function_expr/schema.rs | 2 +- .../src/dsl/functions/correlation.rs | 67 +++------- crates/polars-python/src/lazyframe/visit.rs | 2 +- .../src/lazyframe/visitor/expr_nodes.rs | 3 + 6 files changed, 147 insertions(+), 54 deletions(-) diff --git a/crates/polars-plan/src/dsl/function_expr/mod.rs b/crates/polars-plan/src/dsl/function_expr/mod.rs index 6cebaa301b85..1cacd0649309 100644 --- a/crates/polars-plan/src/dsl/function_expr/mod.rs +++ b/crates/polars-plan/src/dsl/function_expr/mod.rs @@ -954,6 +954,18 @@ impl From for SpecialEq> { Std(options) => map!(rolling::rolling_std, options.clone()), #[cfg(feature = "moment")] Skew(window_size, bias) => map!(rolling::rolling_skew, window_size, bias), + CorrCov { + rolling_options, + corr_cov_options, + is_corr, + } => { + map_as_slice!( + rolling::rolling_corr_cov, + rolling_options.clone(), + corr_cov_options, + is_corr + ) + }, } }, #[cfg(feature = "rolling_window_by")] diff --git a/crates/polars-plan/src/dsl/function_expr/rolling.rs b/crates/polars-plan/src/dsl/function_expr/rolling.rs index c108c92b571a..47700ec64270 100644 --- a/crates/polars-plan/src/dsl/function_expr/rolling.rs +++ b/crates/polars-plan/src/dsl/function_expr/rolling.rs @@ -1,6 +1,10 @@ +use std::ops::BitAnd; + +use polars_core::utils::Container; use polars_time::chunkedarray::*; use super::*; +use crate::dsl::pow::pow; #[derive(Clone, PartialEq, Debug)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] @@ -14,6 +18,12 @@ pub enum RollingFunction { Std(RollingOptionsFixedWindow), #[cfg(feature = "moment")] Skew(usize, bool), + CorrCov { + rolling_options: RollingOptionsFixedWindow, + corr_cov_options: RollingCovOptions, + // Whether is Corr or Cov + is_corr: bool, + }, } impl Display for RollingFunction { @@ -30,6 +40,13 @@ impl Display for RollingFunction { Std(_) => "rolling_std", #[cfg(feature = "moment")] Skew(..) => "rolling_skew", + CorrCov { is_corr, .. } => { + if *is_corr { + "rolling_corr" + } else { + "rolling_cov" + } + }, }; write!(f, "{name}") @@ -47,6 +64,9 @@ impl Hash for RollingFunction { window_size.hash(state); bias.hash(state) }, + CorrCov { is_corr, .. } => { + is_corr.hash(state); + }, _ => {}, } } @@ -111,3 +131,98 @@ pub(super) fn rolling_skew(s: &Column, window_size: usize, bias: bool) -> Polars .rolling_skew(window_size, bias) .map(Column::from) } + +fn det_count_x_y(window_size: usize, len: usize, dtype: &DataType) -> Series { + match dtype { + DataType::Float64 => { + let values = (0..len) + .map(|v| std::cmp::min(window_size, v + 1) as f64) + .collect::>(); + Series::new(PlSmallStr::EMPTY, values) + }, + DataType::Float32 => { + let values = (0..len) + .map(|v| std::cmp::min(window_size, v + 1) as f32) + .collect::>(); + Series::new(PlSmallStr::EMPTY, values) + }, + _ => unreachable!(), + } +} + +pub(super) fn rolling_corr_cov( + s: &[Column], + rolling_options: RollingOptionsFixedWindow, + cov_options: RollingCovOptions, + is_corr: bool, +) -> PolarsResult { + let mut x = s[0].as_materialized_series().rechunk(); + let mut y = s[1].as_materialized_series().rechunk(); + + if !x.dtype().is_float() { + x = x.cast(&DataType::Float64)?; + } + if !y.dtype().is_float() { + y = y.cast(&DataType::Float64)?; + } + let dtype = x.dtype().clone(); + + let mean_x_y = (&x * &y)?.rolling_mean(rolling_options.clone())?; + let rolling_options_count = RollingOptionsFixedWindow { + window_size: rolling_options.window_size, + min_periods: 0, + ..Default::default() + }; + + let count_x_y = if (x.null_count() + y.null_count()) > 0 { + // mask out nulls on both sides before compute mean/var + let valids = x.is_not_null().bitand(y.is_not_null()); + let valids_arr = valids.clone().downcast_into_array(); + let valids_bitmap = valids_arr.values(); + + unsafe { + let xarr = &mut x.chunks_mut()[0]; + *xarr = xarr.with_validity(Some(valids_bitmap.clone())); + let yarr = &mut y.chunks_mut()[0]; + *yarr = yarr.with_validity(Some(valids_bitmap.clone())); + x.compute_len(); + y.compute_len(); + } + valids + .cast(&dtype) + .unwrap() + .rolling_sum(rolling_options_count)? + } else { + det_count_x_y(rolling_options.window_size, x.len(), &dtype) + }; + + let mean_x = x.rolling_mean(rolling_options.clone())?; + let mean_y = y.rolling_mean(rolling_options.clone())?; + let ddof = Series::new( + PlSmallStr::EMPTY, + &[AnyValue::from(cov_options.ddof).cast(&dtype)], + ); + + let numerator = ((mean_x_y - (mean_x * mean_y).unwrap()).unwrap() + * (count_x_y.clone() / (count_x_y - ddof).unwrap()).unwrap()) + .unwrap(); + + if is_corr { + let var_x = x.rolling_var(rolling_options.clone())?; + let var_y = y.rolling_var(rolling_options.clone())?; + + let base = (var_x * var_y).unwrap(); + let sc = Scalar::new( + base.dtype().clone(), + AnyValue::Float64(0.5).cast(&dtype).into_static(), + ); + let denominator = pow(&mut [base.into_column(), sc.into_column("".into())]) + .unwrap() + .unwrap() + .take_materialized_series(); + + Ok((numerator / denominator)?.into_column()) + } else { + Ok(numerator.into_column()) + } +} diff --git a/crates/polars-plan/src/dsl/function_expr/schema.rs b/crates/polars-plan/src/dsl/function_expr/schema.rs index bc09cca94215..dbefe7b3a74c 100644 --- a/crates/polars-plan/src/dsl/function_expr/schema.rs +++ b/crates/polars-plan/src/dsl/function_expr/schema.rs @@ -65,7 +65,7 @@ impl FunctionExpr { use RollingFunction::*; match rolling_func { Min(_) | Max(_) | Sum(_) => mapper.with_same_dtype(), - Mean(_) | Quantile(_) | Var(_) | Std(_) => mapper.map_to_float_dtype(), + Mean(_) | Quantile(_) | Var(_) | Std(_) | CorrCov {..} => mapper.map_to_float_dtype(), #[cfg(feature = "moment")] Skew(..) => mapper.map_to_float_dtype(), } diff --git a/crates/polars-plan/src/dsl/functions/correlation.rs b/crates/polars-plan/src/dsl/functions/correlation.rs index dd7521ad20a9..d3cc1c1a7545 100644 --- a/crates/polars-plan/src/dsl/functions/correlation.rs +++ b/crates/polars-plan/src/dsl/functions/correlation.rs @@ -70,8 +70,7 @@ pub fn spearman_rank_corr(a: Expr, b: Expr, ddof: u8, propagate_nans: bool) -> E } } -#[cfg(feature = "rolling_window")] -pub fn rolling_corr(x: Expr, y: Expr, options: RollingCovOptions) -> Expr { +fn dispatch_corr_cov(x: Expr, y: Expr, options: RollingCovOptions, is_corr: bool) -> Expr { // see: https://github.com/pandas-dev/pandas/blob/v1.5.1/pandas/core/window/rolling.py#L1780-L1804 let rolling_options = RollingOptionsFixedWindow { window_size: options.window_size as usize, @@ -79,59 +78,23 @@ pub fn rolling_corr(x: Expr, y: Expr, options: RollingCovOptions) -> Expr { ..Default::default() }; - let non_null_mask = when(x.clone().is_not_null().and(y.clone().is_not_null())) - .then(lit(1.0)) - .otherwise(lit(Null {})); - - let mean_x_y = (x.clone() * y.clone()).rolling_mean(rolling_options.clone()); - let mean_x = (x.clone() * non_null_mask.clone()).rolling_mean(rolling_options.clone()); - let mean_y = (y.clone() * non_null_mask.clone()).rolling_mean(rolling_options.clone()); - let var_x = (x.clone() * non_null_mask.clone()).rolling_var(rolling_options.clone()); - let var_y = (y.clone() * non_null_mask.clone()).rolling_var(rolling_options); - - let rolling_options_count = RollingOptionsFixedWindow { - window_size: options.window_size as usize, - min_periods: 0, - ..Default::default() - }; - let ddof = options.ddof as f64; - let count_x_y = (x + y) - .is_not_null() - .cast(DataType::Float64) - .rolling_sum(rolling_options_count); - let numerator = (mean_x_y - mean_x * mean_y) * (count_x_y.clone() / (count_x_y - lit(ddof))); - let denominator = (var_x * var_y).pow(lit(0.5)); + Expr::Function { + input: vec![x, y], + function: FunctionExpr::RollingExpr(RollingFunction::CorrCov { + rolling_options, + corr_cov_options: options, + is_corr, + }), + options: Default::default(), + } +} - numerator / denominator +#[cfg(feature = "rolling_window")] +pub fn rolling_corr(x: Expr, y: Expr, options: RollingCovOptions) -> Expr { + dispatch_corr_cov(x, y, options, true) } #[cfg(feature = "rolling_window")] pub fn rolling_cov(x: Expr, y: Expr, options: RollingCovOptions) -> Expr { - // see: https://github.com/pandas-dev/pandas/blob/91111fd99898d9dcaa6bf6bedb662db4108da6e6/pandas/core/window/rolling.py#L1700 - let rolling_options = RollingOptionsFixedWindow { - window_size: options.window_size as usize, - min_periods: options.min_periods as usize, - ..Default::default() - }; - - let non_null_mask = when(x.clone().is_not_null().and(y.clone().is_not_null())) - .then(lit(1.0)) - .otherwise(lit(Null {})); - - let mean_x_y = (x.clone() * y.clone()).rolling_mean(rolling_options.clone()); - let mean_x = (x.clone() * non_null_mask.clone()).rolling_mean(rolling_options.clone()); - let mean_y = (y.clone() * non_null_mask.clone()).rolling_mean(rolling_options); - let rolling_options_count = RollingOptionsFixedWindow { - window_size: options.window_size as usize, - min_periods: 0, - ..Default::default() - }; - let count_x_y = (x + y) - .is_not_null() - .cast(DataType::Float64) - .rolling_sum(rolling_options_count); - - let ddof = options.ddof as f64; - - (mean_x_y - mean_x * mean_y) * (count_x_y.clone() / (count_x_y - lit(ddof))) + dispatch_corr_cov(x, y, options, false) } diff --git a/crates/polars-python/src/lazyframe/visit.rs b/crates/polars-python/src/lazyframe/visit.rs index 5a98398703b9..bc4cebb360a2 100644 --- a/crates/polars-python/src/lazyframe/visit.rs +++ b/crates/polars-python/src/lazyframe/visit.rs @@ -57,7 +57,7 @@ impl NodeTraverser { // Increment major on breaking changes to the IR (e.g. renaming // fields, reordering tuples), minor on backwards compatible // changes (e.g. exposing a new expression node). - const VERSION: Version = (3, 0); + const VERSION: Version = (3, 1); pub fn new(root: Node, lp_arena: Arena, expr_arena: Arena) -> Self { Self { diff --git a/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs b/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs index 41f04b7c1cad..06a98e3fe970 100644 --- a/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs +++ b/crates/polars-python/src/lazyframe/visitor/expr_nodes.rs @@ -1187,6 +1187,9 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult { RollingFunction::Skew(_, _) => { return Err(PyNotImplementedError::new_err("rolling skew")) }, + RollingFunction::CorrCov { .. } => { + return Err(PyNotImplementedError::new_err("rolling cor_cov")) + }, }, FunctionExpr::RollingExprBy(rolling) => match rolling { RollingFunctionBy::MinBy(_) => {