Skip to content

Commit

Permalink
perf: Move rolling_corr/cov to an actual implementation on Series
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Oct 26, 2024
1 parent 7e9e784 commit 165ec6c
Show file tree
Hide file tree
Showing 6 changed files with 147 additions and 54 deletions.
12 changes: 12 additions & 0 deletions crates/polars-plan/src/dsl/function_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -954,6 +954,18 @@ impl From<FunctionExpr> for SpecialEq<Arc<dyn ColumnsUdf>> {
Std(options) => map!(rolling::rolling_std, options.clone()),
#[cfg(feature = "moment")]
Skew(window_size, bias) => map!(rolling::rolling_skew, window_size, bias),
CorrCov {
rolling_options,
corr_cov_options,
is_corr,
} => {
map_as_slice!(
rolling::rolling_corr_cov,
rolling_options.clone(),
corr_cov_options,
is_corr
)
},
}
},
#[cfg(feature = "rolling_window_by")]
Expand Down
115 changes: 115 additions & 0 deletions crates/polars-plan/src/dsl/function_expr/rolling.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
use std::ops::BitAnd;

use polars_core::utils::Container;
use polars_time::chunkedarray::*;

use super::*;
use crate::dsl::pow::pow;

#[derive(Clone, PartialEq, Debug)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
Expand All @@ -14,6 +18,12 @@ pub enum RollingFunction {
Std(RollingOptionsFixedWindow),
#[cfg(feature = "moment")]
Skew(usize, bool),
CorrCov {
rolling_options: RollingOptionsFixedWindow,
corr_cov_options: RollingCovOptions,
// Whether is Corr or Cov
is_corr: bool,
},
}

impl Display for RollingFunction {
Expand All @@ -30,6 +40,13 @@ impl Display for RollingFunction {
Std(_) => "rolling_std",
#[cfg(feature = "moment")]
Skew(..) => "rolling_skew",
CorrCov { is_corr, .. } => {
if *is_corr {
"rolling_corr"
} else {
"rolling_cov"
}
},
};

write!(f, "{name}")
Expand All @@ -47,6 +64,9 @@ impl Hash for RollingFunction {
window_size.hash(state);
bias.hash(state)
},
CorrCov { is_corr, .. } => {
is_corr.hash(state);
},
_ => {},
}
}
Expand Down Expand Up @@ -111,3 +131,98 @@ pub(super) fn rolling_skew(s: &Column, window_size: usize, bias: bool) -> Polars
.rolling_skew(window_size, bias)
.map(Column::from)
}

fn det_count_x_y(window_size: usize, len: usize, dtype: &DataType) -> Series {
match dtype {
DataType::Float64 => {
let values = (0..len)
.map(|v| std::cmp::min(window_size, v + 1) as f64)
.collect::<Vec<_>>();
Series::new(PlSmallStr::EMPTY, values)
},
DataType::Float32 => {
let values = (0..len)
.map(|v| std::cmp::min(window_size, v + 1) as f32)
.collect::<Vec<_>>();
Series::new(PlSmallStr::EMPTY, values)
},
_ => unreachable!(),
}
}

pub(super) fn rolling_corr_cov(
s: &[Column],
rolling_options: RollingOptionsFixedWindow,
cov_options: RollingCovOptions,
is_corr: bool,
) -> PolarsResult<Column> {
let mut x = s[0].as_materialized_series().rechunk();
let mut y = s[1].as_materialized_series().rechunk();

if !x.dtype().is_float() {
x = x.cast(&DataType::Float64)?;
}
if !y.dtype().is_float() {
y = y.cast(&DataType::Float64)?;
}
let dtype = x.dtype().clone();

let mean_x_y = (&x * &y)?.rolling_mean(rolling_options.clone())?;
let rolling_options_count = RollingOptionsFixedWindow {
window_size: rolling_options.window_size,
min_periods: 0,
..Default::default()
};

let count_x_y = if (x.null_count() + y.null_count()) > 0 {
// mask out nulls on both sides before compute mean/var
let valids = x.is_not_null().bitand(y.is_not_null());
let valids_arr = valids.clone().downcast_into_array();
let valids_bitmap = valids_arr.values();

unsafe {
let xarr = &mut x.chunks_mut()[0];
*xarr = xarr.with_validity(Some(valids_bitmap.clone()));
let yarr = &mut y.chunks_mut()[0];
*yarr = yarr.with_validity(Some(valids_bitmap.clone()));
x.compute_len();
y.compute_len();
}
valids
.cast(&dtype)
.unwrap()
.rolling_sum(rolling_options_count)?
} else {
det_count_x_y(rolling_options.window_size, x.len(), &dtype)
};

let mean_x = x.rolling_mean(rolling_options.clone())?;
let mean_y = y.rolling_mean(rolling_options.clone())?;
let ddof = Series::new(
PlSmallStr::EMPTY,
&[AnyValue::from(cov_options.ddof).cast(&dtype)],
);

let numerator = ((mean_x_y - (mean_x * mean_y).unwrap()).unwrap()
* (count_x_y.clone() / (count_x_y - ddof).unwrap()).unwrap())
.unwrap();

if is_corr {
let var_x = x.rolling_var(rolling_options.clone())?;
let var_y = y.rolling_var(rolling_options.clone())?;

let base = (var_x * var_y).unwrap();
let sc = Scalar::new(
base.dtype().clone(),
AnyValue::Float64(0.5).cast(&dtype).into_static(),
);
let denominator = pow(&mut [base.into_column(), sc.into_column("".into())])
.unwrap()
.unwrap()
.take_materialized_series();

Ok((numerator / denominator)?.into_column())
} else {
Ok(numerator.into_column())
}
}
2 changes: 1 addition & 1 deletion crates/polars-plan/src/dsl/function_expr/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ impl FunctionExpr {
use RollingFunction::*;
match rolling_func {
Min(_) | Max(_) | Sum(_) => mapper.with_same_dtype(),
Mean(_) | Quantile(_) | Var(_) | Std(_) => mapper.map_to_float_dtype(),
Mean(_) | Quantile(_) | Var(_) | Std(_) | CorrCov {..} => mapper.map_to_float_dtype(),
#[cfg(feature = "moment")]
Skew(..) => mapper.map_to_float_dtype(),
}
Expand Down
67 changes: 15 additions & 52 deletions crates/polars-plan/src/dsl/functions/correlation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,68 +70,31 @@ pub fn spearman_rank_corr(a: Expr, b: Expr, ddof: u8, propagate_nans: bool) -> E
}
}

#[cfg(feature = "rolling_window")]
pub fn rolling_corr(x: Expr, y: Expr, options: RollingCovOptions) -> Expr {
fn dispatch_corr_cov(x: Expr, y: Expr, options: RollingCovOptions, is_corr: bool) -> Expr {
// see: https://github.com/pandas-dev/pandas/blob/v1.5.1/pandas/core/window/rolling.py#L1780-L1804
let rolling_options = RollingOptionsFixedWindow {
window_size: options.window_size as usize,
min_periods: options.min_periods as usize,
..Default::default()
};

let non_null_mask = when(x.clone().is_not_null().and(y.clone().is_not_null()))
.then(lit(1.0))
.otherwise(lit(Null {}));

let mean_x_y = (x.clone() * y.clone()).rolling_mean(rolling_options.clone());
let mean_x = (x.clone() * non_null_mask.clone()).rolling_mean(rolling_options.clone());
let mean_y = (y.clone() * non_null_mask.clone()).rolling_mean(rolling_options.clone());
let var_x = (x.clone() * non_null_mask.clone()).rolling_var(rolling_options.clone());
let var_y = (y.clone() * non_null_mask.clone()).rolling_var(rolling_options);

let rolling_options_count = RollingOptionsFixedWindow {
window_size: options.window_size as usize,
min_periods: 0,
..Default::default()
};
let ddof = options.ddof as f64;
let count_x_y = (x + y)
.is_not_null()
.cast(DataType::Float64)
.rolling_sum(rolling_options_count);
let numerator = (mean_x_y - mean_x * mean_y) * (count_x_y.clone() / (count_x_y - lit(ddof)));
let denominator = (var_x * var_y).pow(lit(0.5));
Expr::Function {
input: vec![x, y],
function: FunctionExpr::RollingExpr(RollingFunction::CorrCov {
rolling_options,
corr_cov_options: options,
is_corr,
}),
options: Default::default(),
}
}

numerator / denominator
#[cfg(feature = "rolling_window")]
pub fn rolling_corr(x: Expr, y: Expr, options: RollingCovOptions) -> Expr {
dispatch_corr_cov(x, y, options, true)
}

#[cfg(feature = "rolling_window")]
pub fn rolling_cov(x: Expr, y: Expr, options: RollingCovOptions) -> Expr {
// see: https://github.com/pandas-dev/pandas/blob/91111fd99898d9dcaa6bf6bedb662db4108da6e6/pandas/core/window/rolling.py#L1700
let rolling_options = RollingOptionsFixedWindow {
window_size: options.window_size as usize,
min_periods: options.min_periods as usize,
..Default::default()
};

let non_null_mask = when(x.clone().is_not_null().and(y.clone().is_not_null()))
.then(lit(1.0))
.otherwise(lit(Null {}));

let mean_x_y = (x.clone() * y.clone()).rolling_mean(rolling_options.clone());
let mean_x = (x.clone() * non_null_mask.clone()).rolling_mean(rolling_options.clone());
let mean_y = (y.clone() * non_null_mask.clone()).rolling_mean(rolling_options);
let rolling_options_count = RollingOptionsFixedWindow {
window_size: options.window_size as usize,
min_periods: 0,
..Default::default()
};
let count_x_y = (x + y)
.is_not_null()
.cast(DataType::Float64)
.rolling_sum(rolling_options_count);

let ddof = options.ddof as f64;

(mean_x_y - mean_x * mean_y) * (count_x_y.clone() / (count_x_y - lit(ddof)))
dispatch_corr_cov(x, y, options, false)
}
2 changes: 1 addition & 1 deletion crates/polars-python/src/lazyframe/visit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ impl NodeTraverser {
// Increment major on breaking changes to the IR (e.g. renaming
// fields, reordering tuples), minor on backwards compatible
// changes (e.g. exposing a new expression node).
const VERSION: Version = (3, 0);
const VERSION: Version = (3, 1);

pub fn new(root: Node, lp_arena: Arena<IR>, expr_arena: Arena<AExpr>) -> Self {
Self {
Expand Down
3 changes: 3 additions & 0 deletions crates/polars-python/src/lazyframe/visitor/expr_nodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1187,6 +1187,9 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult<PyObject> {
RollingFunction::Skew(_, _) => {
return Err(PyNotImplementedError::new_err("rolling skew"))
},
RollingFunction::CorrCov { .. } => {
return Err(PyNotImplementedError::new_err("rolling cor_cov"))
},
},
FunctionExpr::RollingExprBy(rolling) => match rolling {
RollingFunctionBy::MinBy(_) => {
Expand Down

0 comments on commit 165ec6c

Please sign in to comment.