From 99fc615593b7b737b79f949a146e398dc1ff3206 Mon Sep 17 00:00:00 2001 From: Weijie Guo Date: Mon, 9 Oct 2023 12:40:00 +0800 Subject: [PATCH] refactor: Make arg_min(max), diff in list namespace non-anonymous --- .../polars-plan/src/dsl/function_expr/list.rs | 24 +++++++++++++++++++ .../polars-plan/src/dsl/function_expr/mod.rs | 4 ++++ .../src/dsl/function_expr/schema.rs | 4 ++++ crates/polars-plan/src/dsl/list.rs | 21 +++++----------- 4 files changed, 38 insertions(+), 15 deletions(-) diff --git a/crates/polars-plan/src/dsl/function_expr/list.rs b/crates/polars-plan/src/dsl/function_expr/list.rs index 27cadf0a0d79..8cf674fd2dc5 100644 --- a/crates/polars-plan/src/dsl/function_expr/list.rs +++ b/crates/polars-plan/src/dsl/function_expr/list.rs @@ -23,6 +23,13 @@ pub enum ListFunction { Max, Min, Mean, + ArgMin, + ArgMax, + #[cfg(feature = "diff")] + Diff { + n: i64, + null_behavior: NullBehavior, + }, Sort(SortOptions), Reverse, Unique(bool), @@ -56,6 +63,10 @@ impl Display for ListFunction { Min => "min", Max => "max", Mean => "mean", + ArgMin => "arg_min", + ArgMax => "arg_max", + #[cfg(feature = "diff")] + Diff { .. } => "diff", Length => "length", Sort(_) => "sort", Reverse => "reverse", @@ -321,6 +332,19 @@ pub(super) fn mean(s: &Series) -> PolarsResult { Ok(s.list()?.lst_mean()) } +pub(super) fn arg_min(s: &Series) -> PolarsResult { + Ok(s.list()?.lst_arg_min().into_series()) +} + +pub(super) fn arg_max(s: &Series) -> PolarsResult { + Ok(s.list()?.lst_arg_max().into_series()) +} + +#[cfg(feature = "diff")] +pub(super) fn diff(s: &Series, n: i64, null_behavior: NullBehavior) -> PolarsResult { + Ok(s.list()?.lst_diff(n, null_behavior)?.into_series()) +} + pub(super) fn sort(s: &Series, options: SortOptions) -> PolarsResult { Ok(s.list()?.lst_sort(options).into_series()) } diff --git a/crates/polars-plan/src/dsl/function_expr/mod.rs b/crates/polars-plan/src/dsl/function_expr/mod.rs index bffa79acb106..1fdcba2d0111 100644 --- a/crates/polars-plan/src/dsl/function_expr/mod.rs +++ b/crates/polars-plan/src/dsl/function_expr/mod.rs @@ -579,6 +579,10 @@ impl From for SpecialEq> { Max => map!(list::max), Min => map!(list::min), Mean => map!(list::mean), + ArgMin => map!(list::arg_min), + ArgMax => map!(list::arg_max), + #[cfg(feature = "diff")] + Diff { n, null_behavior } => map!(list::diff, n, null_behavior), Sort(options) => map!(list::sort, options), Reverse => map!(list::reverse), Unique(is_stable) => map!(list::unique, is_stable), diff --git a/crates/polars-plan/src/dsl/function_expr/schema.rs b/crates/polars-plan/src/dsl/function_expr/schema.rs index 44b224e2a988..bc5c435b99cd 100644 --- a/crates/polars-plan/src/dsl/function_expr/schema.rs +++ b/crates/polars-plan/src/dsl/function_expr/schema.rs @@ -113,6 +113,10 @@ impl FunctionExpr { Min => mapper.map_to_list_inner_dtype(), Max => mapper.map_to_list_inner_dtype(), Mean => mapper.with_dtype(DataType::Float64), + ArgMin => mapper.with_dtype(IDX_DTYPE), + ArgMax => mapper.with_dtype(IDX_DTYPE), + #[cfg(feature = "diff")] + Diff { .. } => mapper.with_same_dtype(), Sort(_) => mapper.with_same_dtype(), Reverse => mapper.with_same_dtype(), Unique(_) => mapper.with_same_dtype(), diff --git a/crates/polars-plan/src/dsl/list.rs b/crates/polars-plan/src/dsl/list.rs index 86261e43a82d..0e19de499679 100644 --- a/crates/polars-plan/src/dsl/list.rs +++ b/crates/polars-plan/src/dsl/list.rs @@ -138,32 +138,23 @@ impl ListNameSpace { /// Return the index of the minimal value of every sublist pub fn arg_min(self) -> Expr { self.0 - .map( - |s| Ok(Some(s.list()?.lst_arg_min().into_series())), - GetOutput::from_type(IDX_DTYPE), - ) - .with_fmt("list.arg_min") + .map_private(FunctionExpr::ListExpr(ListFunction::ArgMin)) } /// Return the index of the maximum value of every sublist pub fn arg_max(self) -> Expr { self.0 - .map( - |s| Ok(Some(s.list()?.lst_arg_max().into_series())), - GetOutput::from_type(IDX_DTYPE), - ) - .with_fmt("list.arg_max") + .map_private(FunctionExpr::ListExpr(ListFunction::ArgMax)) } /// Diff every sublist. #[cfg(feature = "diff")] pub fn diff(self, n: i64, null_behavior: NullBehavior) -> Expr { self.0 - .map( - move |s| Ok(Some(s.list()?.lst_diff(n, null_behavior)?.into_series())), - GetOutput::same_type(), - ) - .with_fmt("list.diff") + .map_private(FunctionExpr::ListExpr(ListFunction::Diff { + n, + null_behavior, + })) } /// Shift every sublist.