diff --git a/crates/polars-lazy/Cargo.toml b/crates/polars-lazy/Cargo.toml index 4aca12a5eeb0..4c696453ed15 100644 --- a/crates/polars-lazy/Cargo.toml +++ b/crates/polars-lazy/Cargo.toml @@ -114,6 +114,7 @@ abs = ["polars-plan/abs"] random = ["polars-plan/random"] dynamic_group_by = ["polars-plan/dynamic_group_by", "polars-time", "temporal"] ewma = ["polars-plan/ewma"] +ewma_by = ["polars-plan/ewma_by"] dot_diagram = ["polars-plan/dot_diagram"] diagonal_concat = [] unique_counts = ["polars-plan/unique_counts"] diff --git a/crates/polars-ops/Cargo.toml b/crates/polars-ops/Cargo.toml index f132a2800c0a..168998e8c330 100644 --- a/crates/polars-ops/Cargo.toml +++ b/crates/polars-ops/Cargo.toml @@ -129,6 +129,7 @@ repeat_by = [] peaks = [] cum_agg = [] ewma = [] +ewma_by = [] abs = [] cov = [] gather = [] diff --git a/crates/polars-ops/src/series/ops/ewm_by.rs b/crates/polars-ops/src/series/ops/ewm_by.rs new file mode 100644 index 000000000000..f47e03239ae3 --- /dev/null +++ b/crates/polars-ops/src/series/ops/ewm_by.rs @@ -0,0 +1,176 @@ +use num_traits::{Float, FromPrimitive, One, Zero}; +use polars_core::prelude::*; + +pub fn ewm_mean_by( + s: &Series, + times: &Series, + half_life: i64, + assume_sorted: bool, +) -> PolarsResult { + let func = match assume_sorted { + true => ewm_mean_by_impl_sorted, + false => ewm_mean_by_impl, + }; + match (s.dtype(), times.dtype()) { + (DataType::Float64, DataType::Int64) => { + Ok(func(s.f64().unwrap(), times.i64().unwrap(), half_life).into_series()) + }, + (DataType::Float32, DataType::Int64) => { + Ok(ewm_mean_by_impl(s.f32().unwrap(), times.i64().unwrap(), half_life).into_series()) + }, + #[cfg(feature = "dtype-datetime")] + (_, DataType::Datetime(time_unit, _)) => { + let half_life = adjust_half_life_to_time_unit(half_life, time_unit); + ewm_mean_by(s, ×.cast(&DataType::Int64)?, half_life, assume_sorted) + }, + #[cfg(feature = "dtype-date")] + (_, DataType::Date) => ewm_mean_by( + s, + ×.cast(&DataType::Datetime(TimeUnit::Milliseconds, None))?, + half_life, + assume_sorted, + ), + (_, DataType::UInt64 | DataType::UInt32 | DataType::Int32) => { + ewm_mean_by(s, ×.cast(&DataType::Int64)?, half_life, assume_sorted) + }, + (DataType::UInt64 | DataType::UInt32 | DataType::Int64 | DataType::Int32, _) => { + ewm_mean_by( + &s.cast(&DataType::Float64)?, + times, + half_life, + assume_sorted, + ) + }, + _ => { + polars_bail!(InvalidOperation: "expected series to be Float64, Float32, \ + Int64, Int32, UInt64, UInt32, and `by` to be Date, Datetime, Int64, Int32, \ + UInt64, or UInt32") + }, + } +} + +/// Sort on behalf of user +fn ewm_mean_by_impl( + values: &ChunkedArray, + times: &Int64Chunked, + half_life: i64, +) -> ChunkedArray +where + T: PolarsFloatType, + T::Native: Float + Zero + One, + ChunkedArray: ChunkTakeUnchecked, +{ + let sorting_indices = times.arg_sort(Default::default()); + let values = unsafe { values.take_unchecked(&sorting_indices) }; + let times = unsafe { times.take_unchecked(&sorting_indices) }; + let sorting_indices = sorting_indices + .cont_slice() + .expect("`arg_sort` should have returned a single chunk"); + + let mut out = vec![None; times.len()]; + + let mut skip_rows: usize = 0; + let mut prev_time: i64 = 0; + let mut prev_result = T::Native::zero(); + for (idx, (value, time)) in values.iter().zip(times.iter()).enumerate() { + if let (Some(time), Some(value)) = (time, value) { + prev_time = time; + prev_result = value; + unsafe { + let out_idx = sorting_indices.get_unchecked(idx); + *out.get_unchecked_mut(*out_idx as usize) = Some(prev_result); + } + skip_rows = idx + 1; + break; + }; + } + values + .iter() + .zip(times.iter()) + .enumerate() + .skip(skip_rows) + .for_each(|(idx, (value, time))| { + let result_opt = match (time, value) { + (Some(time), Some(value)) => { + let result = update(value, prev_result, time, prev_time, half_life); + prev_time = time; + prev_result = result; + Some(result) + }, + _ => None, + }; + unsafe { + let out_idx = sorting_indices.get_unchecked(idx); + *out.get_unchecked_mut(*out_idx as usize) = result_opt; + } + }); + ChunkedArray::::from_iter_options(values.name(), out.into_iter()) +} + +/// Fastpath if `times` is known to already be sorted. +fn ewm_mean_by_impl_sorted( + values: &ChunkedArray, + times: &Int64Chunked, + half_life: i64, +) -> ChunkedArray +where + T: PolarsFloatType, + T::Native: Float + Zero + One, +{ + let mut out = Vec::with_capacity(times.len()); + + let mut skip_rows: usize = 0; + let mut prev_time: i64 = 0; + let mut prev_result = T::Native::zero(); + for (idx, (value, time)) in values.iter().zip(times.iter()).enumerate() { + if let (Some(time), Some(value)) = (time, value) { + prev_time = time; + prev_result = value; + out.push(Some(prev_result)); + skip_rows = idx + 1; + break; + }; + } + values + .iter() + .zip(times.iter()) + .skip(skip_rows) + .for_each(|(value, time)| { + let result_opt = match (time, value) { + (Some(time), Some(value)) => { + let result = update(value, prev_result, time, prev_time, half_life); + prev_time = time; + prev_result = result; + Some(result) + }, + _ => None, + }; + out.push(result_opt); + }); + ChunkedArray::::from_iter_options(values.name(), out.into_iter()) +} + +fn adjust_half_life_to_time_unit(half_life: i64, time_unit: &TimeUnit) -> i64 { + match time_unit { + TimeUnit::Milliseconds => half_life / 1_000_000, + TimeUnit::Microseconds => half_life / 1_000, + TimeUnit::Nanoseconds => half_life, + } +} + +fn update(value: T, prev_result: T, time: i64, prev_time: i64, half_life: i64) -> T +where + T: Float + Zero + One + FromPrimitive, +{ + if value != prev_result { + let delta_time = time - prev_time; + // equivalent to: alpha = 1 - exp(-delta_time*ln(2) / half_life) + let one_minus_alpha = T::from_f64(0.5) + .unwrap() + .powf(T::from_i64(delta_time).unwrap() / T::from_i64(half_life).unwrap()); + let alpha = T::one() - one_minus_alpha; + alpha * value + one_minus_alpha * prev_result + } else { + value + } +} diff --git a/crates/polars-ops/src/series/ops/mod.rs b/crates/polars-ops/src/series/ops/mod.rs index 0ebcff5daace..a87a9ef9a29d 100644 --- a/crates/polars-ops/src/series/ops/mod.rs +++ b/crates/polars-ops/src/series/ops/mod.rs @@ -16,6 +16,8 @@ mod cut; mod diff; #[cfg(feature = "ewma")] mod ewm; +#[cfg(feature = "ewma_by")] +mod ewm_by; #[cfg(feature = "round_series")] mod floor_divide; #[cfg(feature = "fused")] @@ -78,6 +80,8 @@ pub use cut::*; pub use diff::*; #[cfg(feature = "ewma")] pub use ewm::*; +#[cfg(feature = "ewma_by")] +pub use ewm_by::*; #[cfg(feature = "round_series")] pub use floor_divide::*; #[cfg(feature = "fused")] diff --git a/crates/polars-plan/Cargo.toml b/crates/polars-plan/Cargo.toml index de5774feef07..18279f58fe00 100644 --- a/crates/polars-plan/Cargo.toml +++ b/crates/polars-plan/Cargo.toml @@ -130,6 +130,7 @@ abs = ["polars-ops/abs"] random = ["polars-core/random"] dynamic_group_by = ["polars-core/dynamic_group_by"] ewma = ["polars-ops/ewma"] +ewma_by = ["polars-ops/ewma_by"] dot_diagram = [] unique_counts = ["polars-ops/unique_counts"] log = ["polars-ops/log"] @@ -205,6 +206,7 @@ features = [ "cutqcut", "async", "ewma", + "ewma_by", "random", "chunked_ids", "repeat_by", diff --git a/crates/polars-plan/src/dsl/function_expr/ewm.rs b/crates/polars-plan/src/dsl/function_expr/ewm.rs index b824ca3013e9..ef7faafcce55 100644 --- a/crates/polars-plan/src/dsl/function_expr/ewm.rs +++ b/crates/polars-plan/src/dsl/function_expr/ewm.rs @@ -4,6 +4,31 @@ pub(super) fn ewm_mean(s: &Series, options: EWMOptions) -> PolarsResult polars_ops::prelude::ewm_mean(s, options) } +pub(super) fn ewm_mean_by( + s: &[Series], + half_life: Duration, + check_sorted: bool, +) -> PolarsResult { + let time_zone = match s[1].dtype() { + DataType::Datetime(_, Some(time_zone)) => Some(time_zone.as_str()), + _ => None, + }; + polars_ensure!(!half_life.negative(), InvalidOperation: "half_life cannot be negative"); + polars_ensure!(half_life.is_constant_duration(time_zone), + InvalidOperation: "expected `half_life` to be a constant duration \ + (i.e. one independent of differing month durations or of daylight savings time), got {}.\n\ + \n\ + You may want to try:\n\ + - using `'730h'` instead of `'1mo'`\n\ + - using `'24h'` instead of `'1d'` if your series is time-zone-aware", half_life); + // `half_life` is a constant duration so we can safely use `duration_ns()`. + let half_life = half_life.duration_ns(); + let values = &s[0]; + let times = &s[1]; + let assume_sorted = !check_sorted || times.is_sorted_flag() == IsSorted::Ascending; + polars_ops::prelude::ewm_mean_by(values, times, half_life, assume_sorted) +} + pub(super) fn ewm_std(s: &Series, options: EWMOptions) -> PolarsResult { polars_ops::prelude::ewm_std(s, options) } diff --git a/crates/polars-plan/src/dsl/function_expr/mod.rs b/crates/polars-plan/src/dsl/function_expr/mod.rs index 82d04e7da55f..ca371f5e8efd 100644 --- a/crates/polars-plan/src/dsl/function_expr/mod.rs +++ b/crates/polars-plan/src/dsl/function_expr/mod.rs @@ -313,6 +313,11 @@ pub enum FunctionExpr { EwmMean { options: EWMOptions, }, + #[cfg(feature = "ewma_by")] + EwmMeanBy { + half_life: Duration, + check_sorted: bool, + }, #[cfg(feature = "ewma")] EwmStd { options: EWMOptions, @@ -520,6 +525,11 @@ impl Hash for FunctionExpr { BackwardFill { limit } | ForwardFill { limit } => limit.hash(state), #[cfg(feature = "ewma")] EwmMean { options } => options.hash(state), + #[cfg(feature = "ewma_by")] + EwmMeanBy { + half_life, + check_sorted, + } => (half_life, check_sorted).hash(state), #[cfg(feature = "ewma")] EwmStd { options } => options.hash(state), #[cfg(feature = "ewma")] @@ -705,6 +715,8 @@ impl Display for FunctionExpr { MeanHorizontal => "mean_horizontal", #[cfg(feature = "ewma")] EwmMean { .. } => "ewm_mean", + #[cfg(feature = "ewma_by")] + EwmMeanBy { .. } => "ewm_mean_by", #[cfg(feature = "ewma")] EwmStd { .. } => "ewm_std", #[cfg(feature = "ewma")] @@ -1073,6 +1085,11 @@ impl From for SpecialEq> { MeanHorizontal => wrap!(dispatch::mean_horizontal), #[cfg(feature = "ewma")] EwmMean { options } => map!(ewm::ewm_mean, options), + #[cfg(feature = "ewma_by")] + EwmMeanBy { + half_life, + check_sorted, + } => map_as_slice!(ewm::ewm_mean_by, half_life, check_sorted), #[cfg(feature = "ewma")] EwmStd { options } => map!(ewm::ewm_std, options), #[cfg(feature = "ewma")] diff --git a/crates/polars-plan/src/dsl/function_expr/schema.rs b/crates/polars-plan/src/dsl/function_expr/schema.rs index 98ce87676eb0..597b2003f955 100644 --- a/crates/polars-plan/src/dsl/function_expr/schema.rs +++ b/crates/polars-plan/src/dsl/function_expr/schema.rs @@ -286,6 +286,8 @@ impl FunctionExpr { MeanHorizontal => mapper.map_to_float_dtype(), #[cfg(feature = "ewma")] EwmMean { .. } => mapper.map_to_float_dtype(), + #[cfg(feature = "ewma_by")] + EwmMeanBy { .. } => mapper.map_to_float_dtype(), #[cfg(feature = "ewma")] EwmStd { .. } => mapper.map_to_float_dtype(), #[cfg(feature = "ewma")] diff --git a/crates/polars-plan/src/dsl/mod.rs b/crates/polars-plan/src/dsl/mod.rs index 2bd5d5d9fa30..671bdec3117a 100644 --- a/crates/polars-plan/src/dsl/mod.rs +++ b/crates/polars-plan/src/dsl/mod.rs @@ -1534,6 +1534,20 @@ impl Expr { self.apply_private(FunctionExpr::EwmMean { options }) } + #[cfg(feature = "ewma_by")] + /// Calculate the exponentially-weighted moving average by a time column. + pub fn ewm_mean_by(self, times: Expr, half_life: Duration, check_sorted: bool) -> Self { + self.apply_many_private( + FunctionExpr::EwmMeanBy { + half_life, + check_sorted, + }, + &[times], + false, + false, + ) + } + #[cfg(feature = "ewma")] /// Calculate the exponentially-weighted moving standard deviation. pub fn ewm_std(self, options: EWMOptions) -> Self { diff --git a/crates/polars-time/src/windows/duration.rs b/crates/polars-time/src/windows/duration.rs index c4b70c3372a1..60b9c235c8b1 100644 --- a/crates/polars-time/src/windows/duration.rs +++ b/crates/polars-time/src/windows/duration.rs @@ -7,6 +7,7 @@ use arrow::legacy::kernels::{Ambiguous, NonExistent}; use arrow::legacy::time_zone::Tz; use arrow::temporal_conversions::{ timestamp_ms_to_datetime, timestamp_ns_to_datetime, timestamp_us_to_datetime, MILLISECONDS, + NANOSECONDS, }; use chrono::{Datelike, NaiveDate, NaiveDateTime, NaiveTime, Timelike}; use polars_core::export::arrow::temporal_conversions::MICROSECONDS; @@ -14,7 +15,6 @@ use polars_core::prelude::{ datetime_to_timestamp_ms, datetime_to_timestamp_ns, datetime_to_timestamp_us, polars_bail, PolarsResult, }; -use polars_core::utils::arrow::temporal_conversions::NANOSECONDS; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -378,6 +378,11 @@ impl Duration { self.nsecs } + /// Returns whether duration is negative. + pub fn negative(&self) -> bool { + self.negative + } + /// Estimated duration of the window duration. Not a very good one if not a constant duration. #[doc(hidden)] pub const fn duration_ns(&self) -> i64 { diff --git a/crates/polars/Cargo.toml b/crates/polars/Cargo.toml index 60c8429201d7..9056f42abfaa 100644 --- a/crates/polars/Cargo.toml +++ b/crates/polars/Cargo.toml @@ -151,6 +151,7 @@ dot_diagram = ["polars-lazy?/dot_diagram"] dot_product = ["polars-core/dot_product"] dynamic_group_by = ["polars-core/dynamic_group_by", "polars-lazy?/dynamic_group_by"] ewma = ["polars-ops/ewma", "polars-lazy?/ewma"] +ewma_by = ["polars-ops/ewma_by", "polars-lazy?/ewma_by"] extract_groups = ["polars-lazy?/extract_groups"] extract_jsonpath = [ "polars-core/strings", diff --git a/py-polars/Cargo.toml b/py-polars/Cargo.toml index b10889e0f108..7e9fe9f71d5c 100644 --- a/py-polars/Cargo.toml +++ b/py-polars/Cargo.toml @@ -53,6 +53,7 @@ features = [ "dtype-full", "dynamic_group_by", "ewma", + "ewma_by", "fmt", "interpolate", "is_first_distinct", diff --git a/py-polars/docs/source/reference/expressions/computation.rst b/py-polars/docs/source/reference/expressions/computation.rst index 1060fe7a8907..663f04df0185 100644 --- a/py-polars/docs/source/reference/expressions/computation.rst +++ b/py-polars/docs/source/reference/expressions/computation.rst @@ -35,6 +35,7 @@ Computation Expr.dot Expr.entropy Expr.ewm_mean + Expr.ewm_mean_by Expr.ewm_std Expr.ewm_var Expr.exp diff --git a/py-polars/docs/source/reference/series/computation.rst b/py-polars/docs/source/reference/series/computation.rst index 3e38b447c8a4..90335b9aad23 100644 --- a/py-polars/docs/source/reference/series/computation.rst +++ b/py-polars/docs/source/reference/series/computation.rst @@ -32,6 +32,7 @@ Computation Series.dot Series.entropy Series.ewm_mean + Series.ewm_mean_by Series.ewm_std Series.ewm_var Series.exp diff --git a/py-polars/polars/expr/expr.py b/py-polars/polars/expr/expr.py index 15eae731fbfd..5436a6b87e5e 100644 --- a/py-polars/polars/expr/expr.py +++ b/py-polars/polars/expr/expr.py @@ -8753,6 +8753,101 @@ def ewm_mean( self._pyexpr.ewm_mean(alpha, adjust, min_periods, ignore_nulls) ) + def ewm_mean_by( + self, + by: str | IntoExpr, + *, + half_life: str | timedelta, + check_sorted: bool = True, + ) -> Self: + r""" + Calculate time-based exponentially weighted moving average. + + Given observations :math:`x_1, x_2, \ldots, x_n` at times + :math:`t_1, t_2, \ldots, t_n`, the EWMA is calculated as + + .. math:: + + y_0 &= x_0 + + \alpha_i &= \exp(-\lambda(t_i - t_{i-1})) + + y_i &= \alpha_i x_i + (1 - \alpha_i) y_{i-1}; \quad i > 0 + + where :math:`\lambda` equals :math:`\ln(2) / \text{half_life}`. + + Parameters + ---------- + by + Times to calculate average by. Should be ``DateTime``, ``Date``, ``UInt64``, + ``UInt32``, ``Int64``, or ``Int32`` data type. + half_life + Unit over which observation decays to half its value. + + Can be created either from a timedelta, or + by using the following string language: + + - 1ns (1 nanosecond) + - 1us (1 microsecond) + - 1ms (1 millisecond) + - 1s (1 second) + - 1m (1 minute) + - 1h (1 hour) + - 1d (1 day) + - 1w (1 week) + - 1i (1 index count) + + Or combine them: + "3d12h4m25s" # 3 days, 12 hours, 4 minutes, and 25 seconds + + Note that `half_life` is treated as a constant duration - calendar + durations such as months (or even days in the time-zone-aware case) + are not supported, please express your duration in an approximately + equivalent number of hours (e.g. '370h' instead of '1mo'). + check_sorted + Check whether `by` column is sorted. + Incorrectly setting this to `False` will lead to incorrect output. + + Returns + ------- + Expr + Float32 if input is Float32, otherwise Float64. + + Examples + -------- + >>> from datetime import date, timedelta + >>> df = pl.DataFrame( + ... { + ... "values": [0, 1, 2, None, 4], + ... "times": [ + ... date(2020, 1, 1), + ... date(2020, 1, 3), + ... date(2020, 1, 10), + ... date(2020, 1, 15), + ... date(2020, 1, 17), + ... ], + ... } + ... ).sort("times") + >>> df.with_columns( + ... result=pl.col("values").ewm_mean_by("times", half_life="4d"), + ... ) + shape: (5, 3) + ┌────────┬────────────┬──────────┐ + │ values ┆ times ┆ result │ + │ --- ┆ --- ┆ --- │ + │ i64 ┆ date ┆ f64 │ + ╞════════╪════════════╪══════════╡ + │ 0 ┆ 2020-01-01 ┆ 0.0 │ + │ 1 ┆ 2020-01-03 ┆ 0.292893 │ + │ 2 ┆ 2020-01-10 ┆ 1.492474 │ + │ null ┆ 2020-01-15 ┆ null │ + │ 4 ┆ 2020-01-17 ┆ 3.254508 │ + └────────┴────────────┴──────────┘ + """ + by = parse_as_expression(by) + half_life = parse_as_duration_string(half_life) + return self._from_pyexpr(self._pyexpr.ewm_mean_by(by, half_life, check_sorted)) + @deprecate_nonkeyword_arguments(version="0.19.10") def ewm_std( self, diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index 91b6c8159cdf..9e26692d288e 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -6881,6 +6881,92 @@ def ewm_mean( ] """ + def ewm_mean_by( + self, + by: str | IntoExpr, + *, + half_life: str | timedelta, + ) -> Series: + r""" + Calculate time-based exponentially weighted moving average. + + Given observations :math:`x_1, x_2, \ldots, x_n` at times + :math:`t_1, t_2, \ldots, t_n`, the EWMA is calculated as + + .. math:: + + y_0 &= x_0 + + \alpha_i &= \exp(-\lambda(t_i - t_{i-1})) + + y_i &= \alpha_i x_i + (1 - \alpha_i) y_{i-1}; \quad i > 0 + + where :math:`\lambda` equals :math:`\ln(2) / \text{half_life}`. + + Parameters + ---------- + by + Times to calculate average by. Should be ``DateTime``, ``Date``, ``UInt64``, + ``UInt32``, ``Int64``, or ``Int32`` data type. + half_life + Unit over which observation decays to half its value. + + Can be created either from a timedelta, or + by using the following string language: + + - 1ns (1 nanosecond) + - 1us (1 microsecond) + - 1ms (1 millisecond) + - 1s (1 second) + - 1m (1 minute) + - 1h (1 hour) + - 1d (1 day) + - 1w (1 week) + - 1i (1 index count) + + Or combine them: + "3d12h4m25s" # 3 days, 12 hours, 4 minutes, and 25 seconds + + Note that `half_life` is treated as a constant duration - calendar + durations such as months (or even days in the time-zone-aware case) + are not supported, please express your duration in an approximately + equivalent number of hours (e.g. '370h' instead of '1mo'). + check_sorted + Check whether `by` column is sorted. + Incorrectly setting this to `False` will lead to incorrect output. + + Returns + ------- + Expr + Float32 if input is Float32, otherwise Float64. + + Examples + -------- + >>> from datetime import date, timedelta + >>> df = pl.DataFrame( + ... { + ... "values": [0, 1, 2, None, 4], + ... "times": [ + ... date(2020, 1, 1), + ... date(2020, 1, 3), + ... date(2020, 1, 10), + ... date(2020, 1, 15), + ... date(2020, 1, 17), + ... ], + ... } + ... ).sort("times") + >>> df["values"].ewm_mean_by(df["times"], half_life="4d") + shape: (5,) + Series: 'values' [f64] + [ + 0.0 + 0.292893 + 1.492474 + null + 3.254508 + ] + """ + @deprecate_nonkeyword_arguments(version="0.19.10") def ewm_std( self, diff --git a/py-polars/src/expr/general.rs b/py-polars/src/expr/general.rs index e39d935fa569..62a4d48bea3d 100644 --- a/py-polars/src/expr/general.rs +++ b/py-polars/src/expr/general.rs @@ -787,6 +787,14 @@ impl PyExpr { }; self.inner.clone().ewm_mean(options).into() } + fn ewm_mean_by(&self, times: PyExpr, half_life: &str, check_sorted: bool) -> Self { + let half_life = Duration::parse(half_life); + self.inner + .clone() + .ewm_mean_by(times.inner, half_life, check_sorted) + .into() + } + fn ewm_std( &self, alpha: f64, diff --git a/py-polars/tests/parametric/time_series/test_ewm_by.py b/py-polars/tests/parametric/time_series/test_ewm_by.py new file mode 100644 index 000000000000..e283a1f14c0f --- /dev/null +++ b/py-polars/tests/parametric/time_series/test_ewm_by.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +import hypothesis.strategies as st +from hypothesis import given + +import polars as pl +from polars.testing import assert_frame_equal +from polars.testing.parametric.primitives import column, dataframes + + +@given( + data=st.data(), + half_life=st.integers(min_value=1, max_value=1000), +) +def test_ewm_by(data: st.DataObject, half_life: int) -> None: + # For evenly spaced times, ewm_by and ewm should be equivalent + df = data.draw( + dataframes( + [ + column( + "values", + strategy=st.floats(min_value=-100, max_value=100), + dtype=pl.Float64, + ), + ], + min_size=1, + ) + ) + result = df.with_row_index().select( + pl.col("values").ewm_mean_by( + by="index", half_life=f"{half_life}i", check_sorted=False + ) + ) + expected = df.select( + pl.col("values").ewm_mean(half_life=half_life, ignore_nulls=False, adjust=False) + ) + assert_frame_equal(result, expected) + result = ( + df.with_row_index() + .sort("values") + .with_columns( + pl.col("values").ewm_mean_by(by="index", half_life=f"{half_life}i") + ) + .sort("index") + .select("values") + ) + assert_frame_equal(result, expected) diff --git a/py-polars/tests/unit/operations/test_ewm_by.py b/py-polars/tests/unit/operations/test_ewm_by.py new file mode 100644 index 000000000000..fcd87fd83f5d --- /dev/null +++ b/py-polars/tests/unit/operations/test_ewm_by.py @@ -0,0 +1,222 @@ +from __future__ import annotations + +from datetime import date, datetime, timedelta +from typing import TYPE_CHECKING + +import pytest + +import polars as pl +from polars.testing import assert_frame_equal + +if TYPE_CHECKING: + from polars.type_aliases import PolarsIntegerType, TimeUnit + + +def test_ewma_by_date() -> None: + df = pl.LazyFrame( + { + "values": [3.0, 1.0, 2.0, None, 4.0], + "times": [ + None, + date(2020, 1, 4), + date(2020, 1, 11), + date(2020, 1, 16), + date(2020, 1, 18), + ], + } + ) + result = df.select( + pl.col("values").ewm_mean_by("times", half_life=timedelta(days=2)), + ) + expected = pl.DataFrame( + {"values": [None, 1.0, 1.9116116523516815, None, 3.815410804703363]} + ) + assert_frame_equal(result.collect(), expected) + assert result.schema["values"] == pl.Float64 + assert result.collect().schema["values"] == pl.Float64 + + +def test_ewma_by_date_constant() -> None: + df = pl.DataFrame( + { + "values": [1, 1, 1], + "times": [ + date(2020, 1, 4), + date(2020, 1, 11), + date(2020, 1, 16), + ], + } + ) + result = df.select( + pl.col("values").ewm_mean_by("times", half_life=timedelta(days=2)), + ) + expected = pl.DataFrame({"values": [1.0, 1, 1]}) + assert_frame_equal(result, expected) + + +def test_ewma_f32() -> None: + df = pl.LazyFrame( + { + "values": [3.0, 1.0, 2.0, None, 4.0], + "times": [ + None, + date(2020, 1, 4), + date(2020, 1, 11), + date(2020, 1, 16), + date(2020, 1, 18), + ], + }, + schema_overrides={"values": pl.Float32}, + ) + result = df.select( + pl.col("values").ewm_mean_by("times", half_life=timedelta(days=2)), + ) + expected = pl.DataFrame( + {"values": [None, 1.0, 1.9116116523516815, None, 3.815410804703363]}, + schema_overrides={"values": pl.Float32}, + ) + assert_frame_equal(result.collect(), expected) + assert result.schema["values"] == pl.Float32 + assert result.collect().schema["values"] == pl.Float32 + + +@pytest.mark.parametrize("time_unit", ["ms", "us", "ns"]) +@pytest.mark.parametrize("time_zone", [None, "UTC"]) +def test_ewma_by_datetime(time_unit: TimeUnit, time_zone: str | None) -> None: + df = pl.DataFrame( + { + "values": [3.0, 1.0, 2.0, None, 4.0], + "times": [ + None, + datetime(2020, 1, 4), + datetime(2020, 1, 11), + datetime(2020, 1, 16), + datetime(2020, 1, 18), + ], + }, + schema_overrides={"times": pl.Datetime(time_unit, time_zone)}, + ) + result = df.select( + pl.col("values").ewm_mean_by("times", half_life=timedelta(days=2)), + ) + expected = pl.DataFrame( + {"values": [None, 1.0, 1.9116116523516815, None, 3.815410804703363]} + ) + assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("time_unit", ["ms", "us", "ns"]) +def test_ewma_by_datetime_tz_aware(time_unit: TimeUnit) -> None: + df = pl.DataFrame( + { + "values": [3.0, 1.0, 2.0, None, 4.0], + "times": [ + None, + datetime(2020, 1, 4), + datetime(2020, 1, 11), + datetime(2020, 1, 16), + datetime(2020, 1, 18), + ], + }, + schema_overrides={"times": pl.Datetime(time_unit, "Asia/Kathmandu")}, + ) + msg = "expected `half_life` to be a constant duration" + with pytest.raises(pl.InvalidOperationError, match=msg): + df.select( + pl.col("values").ewm_mean_by("times", half_life="2d"), + ) + + result = df.select( + pl.col("values").ewm_mean_by("times", half_life="48h0ns"), + ) + expected = pl.DataFrame( + {"values": [None, 1.0, 1.9116116523516815, None, 3.815410804703363]} + ) + assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("data_type", [pl.Int64, pl.Int32, pl.UInt64, pl.UInt32]) +def test_ewma_by_index(data_type: PolarsIntegerType) -> None: + df = pl.LazyFrame( + { + "values": [3.0, 1.0, 2.0, None, 4.0], + "times": [ + None, + 4, + 11, + 16, + 18, + ], + }, + schema_overrides={"times": data_type}, + ) + result = df.select( + pl.col("values").ewm_mean_by("times", half_life="2i"), + ) + expected = pl.DataFrame( + {"values": [None, 1.0, 1.9116116523516815, None, 3.815410804703363]} + ) + assert_frame_equal(result.collect(), expected) + assert result.schema["values"] == pl.Float64 + assert result.collect().schema["values"] == pl.Float64 + + +def test_ewma_by_empty() -> None: + df = pl.DataFrame({"values": []}, schema_overrides={"values": pl.Float64}) + result = df.with_row_index().select( + pl.col("values").ewm_mean_by("index", half_life="2i"), + ) + expected = pl.DataFrame({"values": []}, schema_overrides={"values": pl.Float64}) + assert_frame_equal(result, expected) + + +def test_ewma_by_warn_if_unsorted() -> None: + df = pl.DataFrame({"values": [3.0, 2.0], "by": [3, 1]}) + + # Check that with `check_sorted=False`, the user can get incorrect results + # if they really want to. + result = df.select( + pl.col("values").ewm_mean_by("by", half_life="2i", check_sorted=False), + ) + expected = pl.DataFrame({"values": [3.0, 4.0]}) + assert_frame_equal(result, expected) + + result = df.with_columns( + pl.col("values").ewm_mean_by("by", half_life="2i"), + ) + expected = pl.DataFrame({"values": [2.5, 2.0], "by": [3, 1]}) + assert_frame_equal(result, expected) + result = df.sort("by").with_columns( + pl.col("values").ewm_mean_by("by", half_life="2i"), + ) + assert_frame_equal(result, expected.sort("by")) + + +def test_ewma_by_invalid() -> None: + df = pl.DataFrame({"values": [1, 2]}) + with pytest.raises(pl.InvalidOperationError, match="half_life cannot be negative"): + df.with_row_index().select( + pl.col("values").ewm_mean_by("index", half_life="-2i"), + ) + df = pl.DataFrame({"values": [[1, 2], [3, 4]]}) + with pytest.raises( + pl.InvalidOperationError, match=r"expected series to be Float64, Float32, .*" + ): + df.with_row_index().select( + pl.col("values").ewm_mean_by("index", half_life="2i"), + ) + + +def test_ewma_by_warn_two_chunks() -> None: + df = pl.DataFrame({"values": [3.0, 2.0], "by": [3, 1]}) + df = pl.concat([df, df], rechunk=False) + + result = df.with_columns( + pl.col("values").ewm_mean_by("by", half_life="2i"), + ) + expected = pl.DataFrame({"values": [2.5, 2.0, 2.5, 2], "by": [3, 1, 3, 1]}) + assert_frame_equal(result, expected) + result = df.sort("by").with_columns( + pl.col("values").ewm_mean_by("by", half_life="2i"), + ) + assert_frame_equal(result, expected.sort("by"))