From c6dc59f1b13d71377102028250be2f710097b676 Mon Sep 17 00:00:00 2001 From: Weijie Guo Date: Thu, 12 Oct 2023 17:57:52 +0800 Subject: [PATCH] refactor: *_horizontal dependent on reduce_expr to expression architecture --- crates/polars-ops/Cargo.toml | 1 + .../polars-ops/src/series/ops/horizontal.rs | 12 +++++++ crates/polars-plan/Cargo.toml | 2 +- .../src/dsl/function_expr/dispatch.rs | 8 +++++ .../polars-plan/src/dsl/function_expr/mod.rs | 6 ++++ .../src/dsl/function_expr/schema.rs | 2 ++ .../src/dsl/functions/horizontal.rs | 34 +++++++++++++------ crates/polars/Cargo.toml | 2 +- 8 files changed, 55 insertions(+), 12 deletions(-) diff --git a/crates/polars-ops/Cargo.toml b/crates/polars-ops/Cargo.toml index a0f5cf9d1cec..e0e85a31804f 100644 --- a/crates/polars-ops/Cargo.toml +++ b/crates/polars-ops/Cargo.toml @@ -90,6 +90,7 @@ string_from_radix = ["polars-core/strings"] extract_jsonpath = ["serde_json", "jsonpath_lib", "polars-json"] log = [] hash = [] +zip_with = ["polars-core/zip_with"] group_by_list = ["polars-core/group_by_list"] rolling_window = ["polars-core/rolling_window"] moment = ["polars-core/moment"] diff --git a/crates/polars-ops/src/series/ops/horizontal.rs b/crates/polars-ops/src/series/ops/horizontal.rs index 1328f2a2ce77..d7053f03e63d 100644 --- a/crates/polars-ops/src/series/ops/horizontal.rs +++ b/crates/polars-ops/src/series/ops/horizontal.rs @@ -66,3 +66,15 @@ pub fn all_horizontal(s: &[Series]) -> PolarsResult { .with_name("all"); Ok(out.into_series()) } + +#[cfg(feature = "zip_with")] +pub fn max_horizontal(s: &[Series]) -> PolarsResult> { + let df = DataFrame::new_no_checks(Vec::from(s)); + df.hmax().map(|opt_s| opt_s.map(|s| s.with_name("max"))) +} + +#[cfg(feature = "zip_with")] +pub fn min_horizontal(s: &[Series]) -> PolarsResult> { + let df = DataFrame::new_no_checks(Vec::from(s)); + df.hmin().map(|opt_s| opt_s.map(|s| s.with_name("min"))) +} diff --git a/crates/polars-plan/Cargo.toml b/crates/polars-plan/Cargo.toml index de17d1ee0054..3bd23550b514 100644 --- a/crates/polars-plan/Cargo.toml +++ b/crates/polars-plan/Cargo.toml @@ -17,7 +17,7 @@ polars-arrow = { workspace = true } polars-core = { workspace = true, features = ["lazy", "zip_with", "random"], default-features = false } polars-ffi = { workspace = true, optional = true } polars-io = { workspace = true, features = ["lazy"], default-features = false } -polars-ops = { workspace = true, default-features = false } +polars-ops = { workspace = true, features = ["zip_with"], default-features = false } polars-time = { workspace = true, optional = true } polars-utils = { workspace = true } diff --git a/crates/polars-plan/src/dsl/function_expr/dispatch.rs b/crates/polars-plan/src/dsl/function_expr/dispatch.rs index 755594c6cef5..bbe11390fb01 100644 --- a/crates/polars-plan/src/dsl/function_expr/dispatch.rs +++ b/crates/polars-plan/src/dsl/function_expr/dispatch.rs @@ -63,3 +63,11 @@ pub(super) fn forward_fill(s: &Series, limit: FillNullLimit) -> PolarsResult PolarsResult { polars_ops::prelude::sum_horizontal(s) } + +pub(super) fn max_horizontal(s: &mut [Series]) -> PolarsResult> { + polars_ops::prelude::max_horizontal(s) +} + +pub(super) fn min_horizontal(s: &mut [Series]) -> PolarsResult> { + polars_ops::prelude::min_horizontal(s) +} diff --git a/crates/polars-plan/src/dsl/function_expr/mod.rs b/crates/polars-plan/src/dsl/function_expr/mod.rs index 0c6a7f077c50..7f51397f8dfe 100644 --- a/crates/polars-plan/src/dsl/function_expr/mod.rs +++ b/crates/polars-plan/src/dsl/function_expr/mod.rs @@ -257,6 +257,8 @@ pub enum FunctionExpr { limit: FillNullLimit, }, SumHorizontal, + MaxHorizontal, + MinHorizontal, } impl Hash for FunctionExpr { @@ -429,6 +431,8 @@ impl Display for FunctionExpr { BackwardFill { .. } => "backward_fill", ForwardFill { .. } => "forward_fill", SumHorizontal => "sum_horizontal", + MaxHorizontal => "max_horizontal", + MinHorizontal => "min_horizontal", }; write!(f, "{s}") } @@ -749,6 +753,8 @@ impl From for SpecialEq> { BackwardFill { limit } => map!(dispatch::backward_fill, limit), ForwardFill { limit } => map!(dispatch::forward_fill, limit), SumHorizontal => map_as_slice!(dispatch::sum_horizontal), + MaxHorizontal => wrap!(dispatch::max_horizontal), + MinHorizontal => wrap!(dispatch::min_horizontal), } } } diff --git a/crates/polars-plan/src/dsl/function_expr/schema.rs b/crates/polars-plan/src/dsl/function_expr/schema.rs index c27f745a6fd2..d789d0dc9d44 100644 --- a/crates/polars-plan/src/dsl/function_expr/schema.rs +++ b/crates/polars-plan/src/dsl/function_expr/schema.rs @@ -237,6 +237,8 @@ impl FunctionExpr { BackwardFill { .. } => mapper.with_same_dtype(), ForwardFill { .. } => mapper.with_same_dtype(), SumHorizontal => mapper.map_to_supertype(), + MaxHorizontal => mapper.map_to_supertype(), + MinHorizontal => mapper.map_to_supertype(), } } } diff --git a/crates/polars-plan/src/dsl/functions/horizontal.rs b/crates/polars-plan/src/dsl/functions/horizontal.rs index ca4fc074d006..8517ef115387 100644 --- a/crates/polars-plan/src/dsl/functions/horizontal.rs +++ b/crates/polars-plan/src/dsl/functions/horizontal.rs @@ -236,11 +236,18 @@ pub fn max_horizontal>(exprs: E) -> Expr { if exprs.is_empty() { return Expr::Columns(Vec::new()); } - let func = |s1, s2| { - let df = DataFrame::new_no_checks(vec![s1, s2]); - df.hmax() - }; - reduce_exprs(func, exprs).alias("max") + + Expr::Function { + input: exprs, + function: FunctionExpr::MaxHorizontal, + options: FunctionOptions { + collect_groups: ApplyOptions::ApplyFlat, + input_wildcard_expansion: true, + auto_explode: true, + allow_rename: true, + ..Default::default() + }, + } } /// Create a new column with the the minimum value per row. @@ -251,11 +258,18 @@ pub fn min_horizontal>(exprs: E) -> Expr { if exprs.is_empty() { return Expr::Columns(Vec::new()); } - let func = |s1, s2| { - let df = DataFrame::new_no_checks(vec![s1, s2]); - df.hmin() - }; - reduce_exprs(func, exprs).alias("min") + + Expr::Function { + input: exprs, + function: FunctionExpr::MinHorizontal, + options: FunctionOptions { + collect_groups: ApplyOptions::ApplyFlat, + input_wildcard_expansion: true, + auto_explode: true, + allow_rename: true, + ..Default::default() + }, + } } /// Create a new column with the the sum of the values in each row. diff --git a/crates/polars/Cargo.toml b/crates/polars/Cargo.toml index cdae9fbe6c4d..19de558f8177 100644 --- a/crates/polars/Cargo.toml +++ b/crates/polars/Cargo.toml @@ -112,7 +112,7 @@ sort_multiple = ["polars-core/sort_multiple"] # extra operations approx_unique = ["polars-lazy?/approx_unique", "polars-ops/approx_unique"] is_in = ["polars-lazy?/is_in"] -zip_with = ["polars-core/zip_with"] +zip_with = ["polars-core/zip_with", "polars-ops/zip_with"] round_series = ["polars-core/round_series", "polars-lazy?/round_series", "polars-ops/round_series"] checked_arithmetic = ["polars-core/checked_arithmetic"] repeat_by = ["polars-ops/repeat_by", "polars-lazy?/repeat_by"]