diff --git a/crates/polars-arrow/src/legacy/kernels/ewm/mod.rs b/crates/polars-arrow/src/legacy/kernels/ewm/mod.rs index 8f45bbbef2fb..5984106f1521 100644 --- a/crates/polars-arrow/src/legacy/kernels/ewm/mod.rs +++ b/crates/polars-arrow/src/legacy/kernels/ewm/mod.rs @@ -2,9 +2,12 @@ mod average; mod variance; pub use average::*; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; pub use variance::*; -#[derive(Debug, Copy, Clone)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug, Copy, Clone, PartialEq)] #[must_use] pub struct EWMOptions { pub alpha: f64, diff --git a/crates/polars-plan/src/dsl/function_expr/ewm.rs b/crates/polars-plan/src/dsl/function_expr/ewm.rs new file mode 100644 index 000000000000..a26285eef33a --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/ewm.rs @@ -0,0 +1,13 @@ +use super::*; + +pub(super) fn ewm_mean(s: &Series, options: EWMOptions) -> PolarsResult { + s.ewm_mean(options) +} + +pub(super) fn ewm_std(s: &Series, options: EWMOptions) -> PolarsResult { + s.ewm_std(options) +} + +pub(super) fn ewm_var(s: &Series, options: EWMOptions) -> PolarsResult { + s.ewm_var(options) +} diff --git a/crates/polars-plan/src/dsl/function_expr/mod.rs b/crates/polars-plan/src/dsl/function_expr/mod.rs index 7f51397f8dfe..24a61df540f0 100644 --- a/crates/polars-plan/src/dsl/function_expr/mod.rs +++ b/crates/polars-plan/src/dsl/function_expr/mod.rs @@ -18,6 +18,8 @@ mod cum; #[cfg(feature = "temporal")] mod datetime; mod dispatch; +#[cfg(feature = "ewma")] +mod ewm; mod fill_null; #[cfg(feature = "fused")] mod fused; @@ -259,6 +261,18 @@ pub enum FunctionExpr { SumHorizontal, MaxHorizontal, MinHorizontal, + #[cfg(feature = "ewma")] + EwmMean { + options: EWMOptions, + }, + #[cfg(feature = "ewma")] + EwmStd { + options: EWMOptions, + }, + #[cfg(feature = "ewma")] + EwmVar { + options: EWMOptions, + }, } impl Hash for FunctionExpr { @@ -433,6 +447,12 @@ impl Display for FunctionExpr { SumHorizontal => "sum_horizontal", MaxHorizontal => "max_horizontal", MinHorizontal => "min_horizontal", + #[cfg(feature = "ewma")] + EwmMean { .. } => "ewm_mean", + #[cfg(feature = "ewma")] + EwmStd { .. } => "ewm_std", + #[cfg(feature = "ewma")] + EwmVar { .. } => "ewm_var", }; write!(f, "{s}") } @@ -755,6 +775,12 @@ impl From for SpecialEq> { SumHorizontal => map_as_slice!(dispatch::sum_horizontal), MaxHorizontal => wrap!(dispatch::max_horizontal), MinHorizontal => wrap!(dispatch::min_horizontal), + #[cfg(feature = "ewma")] + EwmMean { options } => map!(ewm::ewm_mean, options), + #[cfg(feature = "ewma")] + EwmStd { options } => map!(ewm::ewm_std, options), + #[cfg(feature = "ewma")] + EwmVar { options } => map!(ewm::ewm_var, options), } } } diff --git a/crates/polars-plan/src/dsl/function_expr/schema.rs b/crates/polars-plan/src/dsl/function_expr/schema.rs index 3d8996e74431..3cb2e2a852f7 100644 --- a/crates/polars-plan/src/dsl/function_expr/schema.rs +++ b/crates/polars-plan/src/dsl/function_expr/schema.rs @@ -242,6 +242,12 @@ impl FunctionExpr { SumHorizontal => mapper.map_to_supertype(), MaxHorizontal => mapper.map_to_supertype(), MinHorizontal => mapper.map_to_supertype(), + #[cfg(feature = "ewma")] + EwmMean { .. } => mapper.map_to_float_dtype(), + #[cfg(feature = "ewma")] + EwmStd { .. } => mapper.map_to_float_dtype(), + #[cfg(feature = "ewma")] + EwmVar { .. } => mapper.map_to_float_dtype(), } } } diff --git a/crates/polars-plan/src/dsl/mod.rs b/crates/polars-plan/src/dsl/mod.rs index 3b335e6d5475..aff4c2378a12 100644 --- a/crates/polars-plan/src/dsl/mod.rs +++ b/crates/polars-plan/src/dsl/mod.rs @@ -1666,43 +1666,19 @@ impl Expr { #[cfg(feature = "ewma")] /// Calculate the exponentially-weighted moving average. pub fn ewm_mean(self, options: EWMOptions) -> Self { - use DataType::*; - self.apply( - move |s| s.ewm_mean(options).map(Some), - GetOutput::map_dtype(|dt| match dt { - Float64 | Float32 => dt.clone(), - _ => Float64, - }), - ) - .with_fmt("ewm_mean") + self.apply_private(FunctionExpr::EwmMean { options }) } #[cfg(feature = "ewma")] /// Calculate the exponentially-weighted moving standard deviation. pub fn ewm_std(self, options: EWMOptions) -> Self { - use DataType::*; - self.apply( - move |s| s.ewm_std(options).map(Some), - GetOutput::map_dtype(|dt| match dt { - Float64 | Float32 => dt.clone(), - _ => Float64, - }), - ) - .with_fmt("ewm_std") + self.apply_private(FunctionExpr::EwmStd { options }) } #[cfg(feature = "ewma")] /// Calculate the exponentially-weighted moving variance. pub fn ewm_var(self, options: EWMOptions) -> Self { - use DataType::*; - self.apply( - move |s| s.ewm_var(options).map(Some), - GetOutput::map_dtype(|dt| match dt { - Float64 | Float32 => dt.clone(), - _ => Float64, - }), - ) - .with_fmt("ewm_var") + self.apply_private(FunctionExpr::EwmVar { options }) } /// Returns whether any of the values in the column are `true`.