Skip to content

Commit

Permalink
refactor: Make arg_min(max), diff in list namespace non-anonymous
Browse files Browse the repository at this point in the history
  • Loading branch information
reswqa committed Oct 9, 2023
1 parent 3b3c4a0 commit 99fc615
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 15 deletions.
24 changes: 24 additions & 0 deletions crates/polars-plan/src/dsl/function_expr/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@ pub enum ListFunction {
Max,
Min,
Mean,
ArgMin,
ArgMax,
#[cfg(feature = "diff")]
Diff {
n: i64,
null_behavior: NullBehavior,
},
Sort(SortOptions),
Reverse,
Unique(bool),
Expand Down Expand Up @@ -56,6 +63,10 @@ impl Display for ListFunction {
Min => "min",
Max => "max",
Mean => "mean",
ArgMin => "arg_min",
ArgMax => "arg_max",
#[cfg(feature = "diff")]
Diff { .. } => "diff",
Length => "length",
Sort(_) => "sort",
Reverse => "reverse",
Expand Down Expand Up @@ -321,6 +332,19 @@ pub(super) fn mean(s: &Series) -> PolarsResult<Series> {
Ok(s.list()?.lst_mean())
}

pub(super) fn arg_min(s: &Series) -> PolarsResult<Series> {
Ok(s.list()?.lst_arg_min().into_series())
}

pub(super) fn arg_max(s: &Series) -> PolarsResult<Series> {
Ok(s.list()?.lst_arg_max().into_series())
}

#[cfg(feature = "diff")]
pub(super) fn diff(s: &Series, n: i64, null_behavior: NullBehavior) -> PolarsResult<Series> {
Ok(s.list()?.lst_diff(n, null_behavior)?.into_series())
}

pub(super) fn sort(s: &Series, options: SortOptions) -> PolarsResult<Series> {
Ok(s.list()?.lst_sort(options).into_series())
}
Expand Down
4 changes: 4 additions & 0 deletions crates/polars-plan/src/dsl/function_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,10 @@ impl From<FunctionExpr> for SpecialEq<Arc<dyn SeriesUdf>> {
Max => map!(list::max),
Min => map!(list::min),
Mean => map!(list::mean),
ArgMin => map!(list::arg_min),
ArgMax => map!(list::arg_max),
#[cfg(feature = "diff")]
Diff { n, null_behavior } => map!(list::diff, n, null_behavior),
Sort(options) => map!(list::sort, options),
Reverse => map!(list::reverse),
Unique(is_stable) => map!(list::unique, is_stable),
Expand Down
4 changes: 4 additions & 0 deletions crates/polars-plan/src/dsl/function_expr/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,10 @@ impl FunctionExpr {
Min => mapper.map_to_list_inner_dtype(),
Max => mapper.map_to_list_inner_dtype(),
Mean => mapper.with_dtype(DataType::Float64),
ArgMin => mapper.with_dtype(IDX_DTYPE),
ArgMax => mapper.with_dtype(IDX_DTYPE),
#[cfg(feature = "diff")]
Diff { .. } => mapper.with_same_dtype(),
Sort(_) => mapper.with_same_dtype(),
Reverse => mapper.with_same_dtype(),
Unique(_) => mapper.with_same_dtype(),
Expand Down
21 changes: 6 additions & 15 deletions crates/polars-plan/src/dsl/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,32 +138,23 @@ impl ListNameSpace {
/// Return the index of the minimal value of every sublist
pub fn arg_min(self) -> Expr {
self.0
.map(
|s| Ok(Some(s.list()?.lst_arg_min().into_series())),
GetOutput::from_type(IDX_DTYPE),
)
.with_fmt("list.arg_min")
.map_private(FunctionExpr::ListExpr(ListFunction::ArgMin))
}

/// Return the index of the maximum value of every sublist
pub fn arg_max(self) -> Expr {
self.0
.map(
|s| Ok(Some(s.list()?.lst_arg_max().into_series())),
GetOutput::from_type(IDX_DTYPE),
)
.with_fmt("list.arg_max")
.map_private(FunctionExpr::ListExpr(ListFunction::ArgMax))
}

/// Diff every sublist.
#[cfg(feature = "diff")]
pub fn diff(self, n: i64, null_behavior: NullBehavior) -> Expr {
self.0
.map(
move |s| Ok(Some(s.list()?.lst_diff(n, null_behavior)?.into_series())),
GetOutput::same_type(),
)
.with_fmt("list.diff")
.map_private(FunctionExpr::ListExpr(ListFunction::Diff {
n,
null_behavior,
}))
}

/// Shift every sublist.
Expand Down

0 comments on commit 99fc615

Please sign in to comment.