From 86ebde3e21fbd9d963e1f3ae83283fff273780c6 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Thu, 18 Jul 2024 16:27:54 -0400 Subject: [PATCH 1/6] Moving over AggregateExt to ExprFunctionExt and adding in function settings for window functions --- datafusion-examples/examples/expr_api.rs | 4 +- datafusion/core/tests/expr_api/mod.rs | 2 +- datafusion/expr/src/expr.rs | 8 +- datafusion/expr/src/expr_fn.rs | 247 +++++++++++++++++- datafusion/expr/src/lib.rs | 2 +- datafusion/expr/src/udaf.rs | 177 +------------ .../functions-aggregate/src/first_last.rs | 2 +- .../optimizer/src/optimize_projections/mod.rs | 2 +- .../src/replace_distinct_aggregate.rs | 2 +- .../src/single_distinct_to_groupby.rs | 2 +- .../tests/cases/roundtrip_logical_plan.rs | 2 +- datafusion/sql/src/unparser/expr.rs | 2 +- docs/source/user-guide/expressions.md | 2 +- 13 files changed, 260 insertions(+), 194 deletions(-) diff --git a/datafusion-examples/examples/expr_api.rs b/datafusion-examples/examples/expr_api.rs index a5cf7011f811..09ac27df01b1 100644 --- a/datafusion-examples/examples/expr_api.rs +++ b/datafusion-examples/examples/expr_api.rs @@ -33,7 +33,7 @@ use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::expr::BinaryExpr; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::simplify::SimplifyContext; -use datafusion_expr::{AggregateExt, ColumnarValue, ExprSchemable, Operator}; +use datafusion_expr::{ExprFunctionExt, ColumnarValue, ExprSchemable, Operator}; /// This example demonstrates the DataFusion [`Expr`] API. /// @@ -95,7 +95,7 @@ fn expr_fn_demo() -> Result<()> { let agg = first_value.call(vec![col("price")]); assert_eq!(agg.to_string(), "first_value(price)"); - // You can use the AggregateExt trait to create more complex aggregates + // You can use the ExprFunctionExt trait to create more complex aggregates // such as `FIRST_VALUE(price FILTER quantity > 100 ORDER BY ts ) let agg = first_value .call(vec![col("price")]) diff --git a/datafusion/core/tests/expr_api/mod.rs b/datafusion/core/tests/expr_api/mod.rs index f36f2d539845..d76b3c9dc1ec 100644 --- a/datafusion/core/tests/expr_api/mod.rs +++ b/datafusion/core/tests/expr_api/mod.rs @@ -21,7 +21,7 @@ use arrow_array::{ArrayRef, Int64Array, RecordBatch, StringArray, StructArray}; use arrow_schema::{DataType, Field}; use datafusion::prelude::*; use datafusion_common::{assert_contains, DFSchema, ScalarValue}; -use datafusion_expr::AggregateExt; +use datafusion_expr::ExprFunctionExt; use datafusion_functions::core::expr_ext::FieldAccessor; use datafusion_functions_aggregate::first_last::first_value_udaf; use datafusion_functions_aggregate::sum::sum_udaf; diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index a344e621ddb1..4768a8466b3f 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -289,9 +289,9 @@ pub enum Expr { /// Calls an aggregate function with arguments, and optional /// `ORDER BY`, `FILTER`, `DISTINCT` and `NULL TREATMENT`. /// - /// See also [`AggregateExt`] to set these fields. + /// See also [`ExprFunctionExt`] to set these fields. /// - /// [`AggregateExt`]: crate::udaf::AggregateExt + /// [`ExprFunctionExt`]: crate::expr_fn::ExprFunctionExt AggregateFunction(AggregateFunction), /// Represents the call of a window function with arguments. WindowFunction(WindowFunction), @@ -641,9 +641,9 @@ impl AggregateFunctionDefinition { /// Aggregate function /// -/// See also [`AggregateExt`] to set these fields on `Expr` +/// See also [`ExprFunctionExt`] to set these fields on `Expr` /// -/// [`AggregateExt`]: crate::udaf::AggregateExt +/// [`ExprFunctionExt`]: crate::expr_fn::ExprFunctionExt #[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct AggregateFunction { /// Name of the function diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 8b0213fd52fd..6d5eaf140ab6 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -19,7 +19,7 @@ use crate::expr::{ AggregateFunction, BinaryExpr, Cast, Exists, GroupingSet, InList, InSubquery, - Placeholder, TryCast, Unnest, + Placeholder, TryCast, Unnest, WindowFunction, }; use crate::function::{ AccumulatorArgs, AccumulatorFactoryFunction, PartitionEvaluatorFactory, @@ -30,12 +30,13 @@ use crate::{ AggregateUDF, Expr, LogicalPlan, Operator, ScalarFunctionImplementation, ScalarUDF, Signature, Volatility, }; -use crate::{AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, WindowUDF, WindowUDFImpl}; +use crate::{AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, WindowFrame, WindowUDF, WindowUDFImpl}; use arrow::compute::kernels::cast_utils::{ parse_interval_day_time, parse_interval_month_day_nano, parse_interval_year_month, }; use arrow::datatypes::{DataType, Field}; -use datafusion_common::{Column, Result, ScalarValue}; +use datafusion_common::{plan_err, Column, Result, ScalarValue}; +use sqlparser::ast::NullTreatment; use std::any::Any; use std::fmt::Debug; use std::ops::Not; @@ -676,6 +677,246 @@ pub fn interval_month_day_nano_lit(value: &str) -> Expr { Expr::Literal(ScalarValue::IntervalMonthDayNano(interval)) } + +/// Extensions for configuring [`Expr::AggregateFunction`] +/// +/// Adds methods to [`Expr`] that make it easy to set optional aggregate options +/// such as `ORDER BY`, `FILTER` and `DISTINCT` +/// +/// # Example +/// ```no_run +/// # use datafusion_common::Result; +/// # use datafusion_expr::{AggregateUDF, col, Expr, lit}; +/// # use sqlparser::ast::NullTreatment; +/// # fn count(arg: Expr) -> Expr { todo!{} } +/// # fn first_value(arg: Expr) -> Expr { todo!{} } +/// # fn main() -> Result<()> { +/// use datafusion_expr::ExprFunctionExt; +/// +/// // Create COUNT(x FILTER y > 5) +/// let agg = count(col("x")) +/// .filter(col("y").gt(lit(5))) +/// .build()?; +/// // Create FIRST_VALUE(x ORDER BY y IGNORE NULLS) +/// let sort_expr = col("y").sort(true, true); +/// let agg = first_value(col("x")) +/// .order_by(vec![sort_expr]) +/// .null_treatment(NullTreatment::IgnoreNulls) +/// .build()?; +/// # Ok(()) +/// # } +/// ``` +pub trait ExprFunctionExt { + /// Add `ORDER BY ` + /// + /// Note: `order_by` must be [`Expr::Sort`] + fn order_by(self, order_by: Vec) -> ExprFuncBuilder; + /// Add `FILTER ` + fn filter(self, filter: Expr) -> ExprFuncBuilder; + /// Add `DISTINCT` + fn distinct(self) -> ExprFuncBuilder; + /// Add `RESPECT NULLS` or `IGNORE NULLS` + fn null_treatment(self, null_treatment: NullTreatment) -> ExprFuncBuilder; + // Add `PARTITION BY` + fn partition_by(self, partition_by: Vec) -> ExprFuncBuilder; + // Add appropriate window frame conditions + fn window_frame(self, window_frame: WindowFrame) -> ExprFuncBuilder; +} + +#[derive(Debug, Clone)] +pub enum ExprFuncKind { + Aggregate(AggregateFunction), + Window(WindowFunction), +} + +/// Implementation of [`ExprFunctionExt`]. +/// +/// See [`ExprFunctionExt`] for usage and examples +#[derive(Debug, Clone)] +pub struct ExprFuncBuilder { + fun: Option, + order_by: Option>, + filter: Option, + distinct: bool, + null_treatment: Option, + partition_by: Option>, + window_frame: Option, +} + +impl ExprFuncBuilder { + /// Create a new `ExprFuncBuilder`, see [`ExprFunctionExt`] + + fn new(fun: Option) -> Self { + Self { + fun, + order_by: None, + filter: None, + distinct: false, + null_treatment: None, + partition_by: None, + window_frame: None, + } + } + + /// Updates and returns the in progress [`Expr::AggregateFunction`] + /// + /// # Errors: + /// + /// Returns an error of this builder [`ExprFunctionExt`] was used with an + /// `Expr` variant other than [`Expr::AggregateFunction`] + pub fn build(self) -> Result { + let Self { + fun, + order_by, + filter, + distinct, + null_treatment, + partition_by, + window_frame, + } = self; + + let Some(fun) = fun else { + return plan_err!( + "ExprFunctionExt can only be used with Expr::AggregateFunction or Expr::WindowFunction" + ); + }; + + if let Some(order_by) = &order_by { + for expr in order_by.iter() { + if !matches!(expr, Expr::Sort(_)) { + return plan_err!( + "ORDER BY expressions must be Expr::Sort, found {expr:?}" + ); + } + } + } + + let fun_expr = match fun { + ExprFuncKind::Aggregate(mut udaf) => { + udaf.order_by = order_by; + udaf.filter = filter.map(Box::new); + udaf.distinct = distinct; + udaf.null_treatment = null_treatment; + Expr::AggregateFunction(udaf) + } + ExprFuncKind::Window(mut udwf) => { + let has_order_by = order_by.as_ref().map(|o| o.len() > 0); + udwf.order_by = order_by.unwrap_or_default(); + udwf.partition_by = partition_by.unwrap_or_default(); + udwf.window_frame = window_frame.unwrap_or(WindowFrame::new(has_order_by)); + udwf.null_treatment = null_treatment; + Expr::WindowFunction(udwf) + } + }; + + Ok(fun_expr) + } + + /// Add `ORDER BY ` + /// + /// Note: `order_by` must be [`Expr::Sort`] + pub fn order_by(mut self, order_by: Vec) -> ExprFuncBuilder { + self.order_by = Some(order_by); + self + } + + /// Add `FILTER ` + pub fn filter(mut self, filter: Expr) -> ExprFuncBuilder { + self.filter = Some(filter); + self + } + + /// Add `DISTINCT` + pub fn distinct(mut self) -> ExprFuncBuilder { + self.distinct = true; + self + } + + /// Add `RESPECT NULLS` or `IGNORE NULLS` + pub fn null_treatment(mut self, null_treatment: NullTreatment) -> ExprFuncBuilder { + self.null_treatment = Some(null_treatment); + self + } + + pub fn partition_by(mut self, partition_by: Vec) -> ExprFuncBuilder { + self.partition_by = Some(partition_by); + self + } + + pub fn window_frame(mut self, window_frame: WindowFrame) -> ExprFuncBuilder { + self.window_frame = Some(window_frame); + self + } +} + +impl ExprFunctionExt for Expr { + fn order_by(self, order_by: Vec) -> ExprFuncBuilder { + let mut builder = match self { + Expr::AggregateFunction(udaf) => ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf))), + Expr::WindowFunction(udwf) => ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf))), + _ => ExprFuncBuilder::new(None), + }; + if builder.fun.is_some() { + builder.order_by = Some(order_by); + } + builder + } + fn filter(self, filter: Expr) -> ExprFuncBuilder { + match self { + Expr::AggregateFunction(udaf) => { + let mut builder = ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf))); + builder.filter = Some(filter); + builder + } + _ => ExprFuncBuilder::new(None), + } + } + fn distinct(self) -> ExprFuncBuilder { + match self { + Expr::AggregateFunction(udaf) => { + let mut builder = ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf))); + builder.distinct = true; + builder + } + _ => ExprFuncBuilder::new(None), + } + } + fn null_treatment(self, null_treatment: NullTreatment) -> ExprFuncBuilder { + let mut builder = match self { + Expr::AggregateFunction(udaf) => ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf))), + Expr::WindowFunction(udwf) => ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf))), + _ => ExprFuncBuilder::new(None), + }; + if builder.fun.is_some() { + builder.null_treatment = Some(null_treatment); + } + builder + } + + fn partition_by(self, partition_by: Vec) -> ExprFuncBuilder { + match self { + Expr::WindowFunction(udwf) => { + let mut builder = ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf))); + builder.partition_by = Some(partition_by); + builder + } + _ => ExprFuncBuilder::new(None), + } + } + + fn window_frame(self, window_frame: WindowFrame) -> ExprFuncBuilder { + match self { + Expr::WindowFunction(udwf) => { + let mut builder = ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf))); + builder.window_frame = Some(window_frame); + builder + } + _ => ExprFuncBuilder::new(None), + } + } +} + + #[cfg(test)] mod test { use super::*; diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index e1943c890e7c..354e795fe64d 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -86,7 +86,7 @@ pub use signature::{ }; pub use sqlparser; pub use table_source::{TableProviderFilterPushDown, TableSource, TableType}; -pub use udaf::{AggregateExt, AggregateUDF, AggregateUDFImpl, ReversedUDAF}; +pub use udaf::{AggregateUDF, AggregateUDFImpl, ReversedUDAF}; pub use udf::{ScalarUDF, ScalarUDFImpl}; pub use udwf::{WindowUDF, WindowUDFImpl}; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 1657e034fbe2..29267f30100a 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -24,9 +24,8 @@ use std::sync::Arc; use std::vec; use arrow::datatypes::{DataType, Field}; -use sqlparser::ast::NullTreatment; -use datafusion_common::{exec_err, not_impl_err, plan_err, Result}; +use datafusion_common::{exec_err, not_impl_err, Result}; use crate::expr::AggregateFunction; use crate::function::{ @@ -655,177 +654,3 @@ impl AggregateUDFImpl for AggregateUDFLegacyWrapper { (self.accumulator)(acc_args) } } - -/// Extensions for configuring [`Expr::AggregateFunction`] -/// -/// Adds methods to [`Expr`] that make it easy to set optional aggregate options -/// such as `ORDER BY`, `FILTER` and `DISTINCT` -/// -/// # Example -/// ```no_run -/// # use datafusion_common::Result; -/// # use datafusion_expr::{AggregateUDF, col, Expr, lit}; -/// # use sqlparser::ast::NullTreatment; -/// # fn count(arg: Expr) -> Expr { todo!{} } -/// # fn first_value(arg: Expr) -> Expr { todo!{} } -/// # fn main() -> Result<()> { -/// use datafusion_expr::AggregateExt; -/// -/// // Create COUNT(x FILTER y > 5) -/// let agg = count(col("x")) -/// .filter(col("y").gt(lit(5))) -/// .build()?; -/// // Create FIRST_VALUE(x ORDER BY y IGNORE NULLS) -/// let sort_expr = col("y").sort(true, true); -/// let agg = first_value(col("x")) -/// .order_by(vec![sort_expr]) -/// .null_treatment(NullTreatment::IgnoreNulls) -/// .build()?; -/// # Ok(()) -/// # } -/// ``` -pub trait AggregateExt { - /// Add `ORDER BY ` - /// - /// Note: `order_by` must be [`Expr::Sort`] - fn order_by(self, order_by: Vec) -> AggregateBuilder; - /// Add `FILTER ` - fn filter(self, filter: Expr) -> AggregateBuilder; - /// Add `DISTINCT` - fn distinct(self) -> AggregateBuilder; - /// Add `RESPECT NULLS` or `IGNORE NULLS` - fn null_treatment(self, null_treatment: NullTreatment) -> AggregateBuilder; -} - -/// Implementation of [`AggregateExt`]. -/// -/// See [`AggregateExt`] for usage and examples -#[derive(Debug, Clone)] -pub struct AggregateBuilder { - udaf: Option, - order_by: Option>, - filter: Option, - distinct: bool, - null_treatment: Option, -} - -impl AggregateBuilder { - /// Create a new `AggregateBuilder`, see [`AggregateExt`] - - fn new(udaf: Option) -> Self { - Self { - udaf, - order_by: None, - filter: None, - distinct: false, - null_treatment: None, - } - } - - /// Updates and returns the in progress [`Expr::AggregateFunction`] - /// - /// # Errors: - /// - /// Returns an error of this builder [`AggregateExt`] was used with an - /// `Expr` variant other than [`Expr::AggregateFunction`] - pub fn build(self) -> Result { - let Self { - udaf, - order_by, - filter, - distinct, - null_treatment, - } = self; - - let Some(mut udaf) = udaf else { - return plan_err!( - "AggregateExt can only be used with Expr::AggregateFunction" - ); - }; - - if let Some(order_by) = &order_by { - for expr in order_by.iter() { - if !matches!(expr, Expr::Sort(_)) { - return plan_err!( - "ORDER BY expressions must be Expr::Sort, found {expr:?}" - ); - } - } - } - - udaf.order_by = order_by; - udaf.filter = filter.map(Box::new); - udaf.distinct = distinct; - udaf.null_treatment = null_treatment; - Ok(Expr::AggregateFunction(udaf)) - } - - /// Add `ORDER BY ` - /// - /// Note: `order_by` must be [`Expr::Sort`] - pub fn order_by(mut self, order_by: Vec) -> AggregateBuilder { - self.order_by = Some(order_by); - self - } - - /// Add `FILTER ` - pub fn filter(mut self, filter: Expr) -> AggregateBuilder { - self.filter = Some(filter); - self - } - - /// Add `DISTINCT` - pub fn distinct(mut self) -> AggregateBuilder { - self.distinct = true; - self - } - - /// Add `RESPECT NULLS` or `IGNORE NULLS` - pub fn null_treatment(mut self, null_treatment: NullTreatment) -> AggregateBuilder { - self.null_treatment = Some(null_treatment); - self - } -} - -impl AggregateExt for Expr { - fn order_by(self, order_by: Vec) -> AggregateBuilder { - match self { - Expr::AggregateFunction(udaf) => { - let mut builder = AggregateBuilder::new(Some(udaf)); - builder.order_by = Some(order_by); - builder - } - _ => AggregateBuilder::new(None), - } - } - fn filter(self, filter: Expr) -> AggregateBuilder { - match self { - Expr::AggregateFunction(udaf) => { - let mut builder = AggregateBuilder::new(Some(udaf)); - builder.filter = Some(filter); - builder - } - _ => AggregateBuilder::new(None), - } - } - fn distinct(self) -> AggregateBuilder { - match self { - Expr::AggregateFunction(udaf) => { - let mut builder = AggregateBuilder::new(Some(udaf)); - builder.distinct = true; - builder - } - _ => AggregateBuilder::new(None), - } - } - fn null_treatment(self, null_treatment: NullTreatment) -> AggregateBuilder { - match self { - Expr::AggregateFunction(udaf) => { - let mut builder = AggregateBuilder::new(Some(udaf)); - builder.null_treatment = Some(null_treatment); - builder - } - _ => AggregateBuilder::new(None), - } - } -} diff --git a/datafusion/functions-aggregate/src/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs index 0e619bacef82..862bd8c1378a 100644 --- a/datafusion/functions-aggregate/src/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -31,7 +31,7 @@ use datafusion_common::{ use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::{format_state_name, AggregateOrderSensitivity}; use datafusion_expr::{ - Accumulator, AggregateExt, AggregateUDFImpl, ArrayFunctionSignature, Expr, Signature, + Accumulator, ExprFunctionExt, AggregateUDFImpl, ArrayFunctionSignature, Expr, Signature, TypeSignature, Volatility, }; use datafusion_physical_expr_common::aggregate::utils::get_sort_options; diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index 58c1ae297b02..9f04a01a3377 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -806,7 +806,7 @@ mod tests { use datafusion_common::{ Column, DFSchema, DFSchemaRef, JoinType, Result, TableReference, }; - use datafusion_expr::AggregateExt; + use datafusion_expr::ExprFunctionExt; use datafusion_expr::{ binary_expr, build_join_schema, builder::table_scan_with_filters, diff --git a/datafusion/optimizer/src/replace_distinct_aggregate.rs b/datafusion/optimizer/src/replace_distinct_aggregate.rs index fcd33be618f7..430517121f2a 100644 --- a/datafusion/optimizer/src/replace_distinct_aggregate.rs +++ b/datafusion/optimizer/src/replace_distinct_aggregate.rs @@ -23,7 +23,7 @@ use datafusion_common::tree_node::Transformed; use datafusion_common::{Column, Result}; use datafusion_expr::expr_rewriter::normalize_cols; use datafusion_expr::utils::expand_wildcard; -use datafusion_expr::{col, AggregateExt, LogicalPlanBuilder}; +use datafusion_expr::{col, ExprFunctionExt, LogicalPlanBuilder}; use datafusion_expr::{Aggregate, Distinct, DistinctOn, Expr, LogicalPlan}; /// Optimizer that replaces logical [[Distinct]] with a logical [[Aggregate]] diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index f2b4abdd6cbd..d776e6598cbe 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -354,7 +354,7 @@ mod tests { use super::*; use crate::test::*; use datafusion_expr::expr::{self, GroupingSet}; - use datafusion_expr::AggregateExt; + use datafusion_expr::ExprFunctionExt; use datafusion_expr::{ lit, logical_plan::builder::LogicalPlanBuilder, max, min, AggregateFunction, }; diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 0117502f400d..63d2c15dfd40 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -59,7 +59,7 @@ use datafusion_expr::expr::{ }; use datafusion_expr::logical_plan::{Extension, UserDefinedLogicalNodeCore}; use datafusion_expr::{ - Accumulator, AggregateExt, AggregateFunction, AggregateUDF, ColumnarValue, + Accumulator, ExprFunctionExt, AggregateFunction, AggregateUDF, ColumnarValue, ExprSchemable, Literal, LogicalPlan, Operator, PartitionEvaluator, ScalarUDF, Signature, TryCast, Volatility, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, WindowUDF, WindowUDFImpl, diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 950e7e11288a..3bed4540e14f 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -1342,7 +1342,7 @@ mod tests { table_scan, try_cast, when, wildcard, ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature, Volatility, WindowFrame, WindowFunctionDefinition, }; - use datafusion_expr::{interval_month_day_nano_lit, AggregateExt}; + use datafusion_expr::{interval_month_day_nano_lit, ExprFunctionExt}; use datafusion_functions_aggregate::count::count_udaf; use datafusion_functions_aggregate::expr_fn::sum; diff --git a/docs/source/user-guide/expressions.md b/docs/source/user-guide/expressions.md index 6e693a0e7087..60036e440ffb 100644 --- a/docs/source/user-guide/expressions.md +++ b/docs/source/user-guide/expressions.md @@ -308,7 +308,7 @@ select log(-1), log(0), sqrt(-1); ## Aggregate Function Builder -You can also use the `AggregateExt` trait to more easily build Aggregate arguments `Expr`. +You can also use the `ExprFunctionExt` trait to more easily build Aggregate arguments `Expr`. See `datafusion-examples/examples/expr_api.rs` for example usage. From 2dfb0f1d775b0fe3815f908fe72a98dd1116d3e4 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Fri, 19 Jul 2024 08:30:26 -0400 Subject: [PATCH 2/6] Switch WindowFrame to only need the window function definition and arguments. Other parameters will be set via the ExprFuncBuilder --- datafusion/core/src/dataframe/mod.rs | 10 +- datafusion/core/tests/dataframe/mod.rs | 13 +-- datafusion/expr/src/expr.rs | 44 ++++++-- datafusion/expr/src/expr_fn.rs | 4 +- datafusion/expr/src/lib.rs | 1 + datafusion/expr/src/tree_node.rs | 20 ++-- datafusion/expr/src/utils.rs | 80 +++----------- datafusion/expr/src/window_function.rs | 104 ++++++++++++++++++ .../src/analyzer/count_wildcard_rule.rs | 11 +- .../optimizer/src/analyzer/type_coercion.rs | 20 ++-- .../optimizer/src/optimize_projections/mod.rs | 16 +-- .../simplify_expressions/expr_simplifier.rs | 14 +-- .../proto/src/logical_plan/from_proto.rs | 30 +---- .../tests/cases/roundtrip_logical_plan.rs | 49 ++------- datafusion/sql/src/expr/function.rs | 32 +++--- 15 files changed, 219 insertions(+), 229 deletions(-) create mode 100644 datafusion/expr/src/window_function.rs diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index c55b7c752765..e5628c917a57 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -1696,8 +1696,7 @@ mod tests { use datafusion_common::{Constraint, Constraints, ScalarValue}; use datafusion_common_runtime::SpawnedTask; use datafusion_expr::{ - array_agg, cast, create_udf, expr, lit, BuiltInWindowFunction, - ScalarFunctionImplementation, Volatility, WindowFrame, WindowFunctionDefinition, + array_agg, cast, create_udf, expr, lit, BuiltInWindowFunction, ExprFunctionExt, ScalarFunctionImplementation, Volatility, WindowFunctionDefinition }; use datafusion_functions_aggregate::expr_fn::count_distinct; use datafusion_physical_expr::expressions::Column; @@ -1866,12 +1865,7 @@ mod tests { WindowFunctionDefinition::BuiltInWindowFunction( BuiltInWindowFunction::FirstValue, ), - vec![col("aggregate_test_100.c1")], - vec![col("aggregate_test_100.c2")], - vec![], - WindowFrame::new(None), - None, - )); + vec![col("aggregate_test_100.c1")])).partition_by(vec![col("aggregate_test_100.c2")]).build().unwrap(); let t2 = t.select(vec![col("c1"), first_row])?; let plan = t2.plan.clone(); diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 1b2a6770cf01..6eb70f4d26fe 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -54,9 +54,7 @@ use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_expr::expr::{GroupingSet, Sort}; use datafusion_expr::var_provider::{VarProvider, VarType}; use datafusion_expr::{ - array_agg, cast, col, exists, expr, in_subquery, lit, max, out_ref_col, placeholder, - scalar_subquery, when, wildcard, Expr, ExprSchemable, WindowFrame, WindowFrameBound, - WindowFrameUnits, WindowFunctionDefinition, + array_agg, cast, col, exists, expr, in_subquery, lit, max, out_ref_col, placeholder, scalar_subquery, when, wildcard, Expr, ExprFunctionExt, ExprSchemable, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition }; use datafusion_functions_aggregate::expr_fn::{avg, count, sum}; @@ -182,16 +180,13 @@ async fn test_count_wildcard_on_window() -> Result<()> { .await? .select(vec![Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(count_udaf()), - vec![wildcard()], - vec![], - vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))], + vec![wildcard()])).order_by(vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))]).window_frame( WindowFrame::new_bounds( WindowFrameUnits::Range, WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))), WindowFrameBound::Following(ScalarValue::UInt32(Some(2))), - ), - None, - ))])? + )).build().unwrap() + ])? .explain(false, false)? .collect() .await?; diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 4768a8466b3f..81058e402413 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -28,8 +28,7 @@ use crate::expr_fn::binary_expr; use crate::logical_plan::Subquery; use crate::utils::expr_to_columns; use crate::{ - aggregate_function, built_in_window_function, udaf, ExprSchemable, Operator, - Signature, + aggregate_function, built_in_window_function, udaf, BuiltInWindowFunction, ExprSchemable, Operator, Signature, WindowFrame, WindowUDF }; use crate::{window_frame, Volatility}; @@ -769,6 +768,30 @@ impl fmt::Display for WindowFunctionDefinition { } } +impl From for WindowFunctionDefinition { + fn from(value: aggregate_function::AggregateFunction) -> Self { + Self::AggregateFunction(value) + } +} + +impl From for WindowFunctionDefinition { + fn from(value: BuiltInWindowFunction) -> Self { + Self::BuiltInWindowFunction(value) + } +} + +impl From> for WindowFunctionDefinition { + fn from(value: Arc) -> Self { + Self::AggregateUDF(value) + } +} + +impl From> for WindowFunctionDefinition { + fn from(value: Arc) -> Self { + Self::WindowUDF(value) + } +} + /// Window function #[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct WindowFunction { @@ -789,20 +812,17 @@ pub struct WindowFunction { impl WindowFunction { /// Create a new Window expression pub fn new( - fun: WindowFunctionDefinition, + fun: impl Into, args: Vec, - partition_by: Vec, - order_by: Vec, - window_frame: window_frame::WindowFrame, - null_treatment: Option, + ) -> Self { Self { - fun, + fun: fun.into(), args, - partition_by, - order_by, - window_frame, - null_treatment, + partition_by: Vec::default(), + order_by: Vec::default(), + window_frame: WindowFrame::new(None), + null_treatment: None, } } } diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 6d5eaf140ab6..262402a70523 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -758,12 +758,12 @@ impl ExprFuncBuilder { } } - /// Updates and returns the in progress [`Expr::AggregateFunction`] + /// Updates and returns the in progress [`Expr::AggregateFunction`] or [`Expr::WindowFunction`] /// /// # Errors: /// /// Returns an error of this builder [`ExprFunctionExt`] was used with an - /// `Expr` variant other than [`Expr::AggregateFunction`] + /// `Expr` variant other than [`Expr::AggregateFunction`] or [`Expr::WindowFunction`] pub fn build(self) -> Result { let Self { fun, diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 354e795fe64d..0a5cf4653a22 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -60,6 +60,7 @@ pub mod type_coercion; pub mod utils; pub mod var_provider; pub mod window_frame; +pub mod window_function; pub mod window_state; pub use accumulator::Accumulator; diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index f1df8609f903..3d7a72180ca6 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -22,7 +22,7 @@ use crate::expr::{ Cast, GroupingSet, InList, InSubquery, Like, Placeholder, ScalarFunction, Sort, TryCast, Unnest, WindowFunction, }; -use crate::Expr; +use crate::{Expr, ExprFunctionExt}; use datafusion_common::tree_node::{ Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion, @@ -294,14 +294,18 @@ impl TreeNode for Expr { transform_vec(order_by, &mut f) )? .update_data(|(new_args, new_partition_by, new_order_by)| { - Expr::WindowFunction(WindowFunction::new( + let mut builder = Expr::WindowFunction(WindowFunction::new( fun, - new_args, - new_partition_by, - new_order_by, - window_frame, - null_treatment, - )) + new_args)).partition_by(new_partition_by).order_by(new_order_by).window_frame(window_frame); + if let Some(n) = null_treatment { + builder = builder.null_treatment(n) + } + builder.build().unwrap() + // new_partition_by, + // new_order_by, + // window_frame, + // null_treatment, + // )) }), Expr::AggregateFunction(AggregateFunction { args, diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 45155cbd2c27..07392173334d 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -1252,9 +1252,7 @@ impl AggregateOrderSensitivity { mod tests { use super::*; use crate::{ - col, cube, expr, expr_vec_fmt, grouping_set, lit, rollup, - test::function_stub::sum_udaf, AggregateFunction, Cast, WindowFrame, - WindowFunctionDefinition, + col, cube, expr, expr_vec_fmt, grouping_set, lit, rollup, test::function_stub::sum_udaf, AggregateFunction, Cast, ExprFunctionExt, WindowFrame, WindowFunctionDefinition }; #[test] @@ -1269,36 +1267,16 @@ mod tests { fn test_group_window_expr_by_sort_keys_empty_window() -> Result<()> { let max1 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), - vec![col("name")], - vec![], - vec![], - WindowFrame::new(None), - None, - )); + vec![col("name")])); let max2 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), - vec![col("name")], - vec![], - vec![], - WindowFrame::new(None), - None, - )); + vec![col("name")])); let min3 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min), - vec![col("name")], - vec![], - vec![], - WindowFrame::new(None), - None, - )); + vec![col("name")])); let sum4 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(sum_udaf()), - vec![col("age")], - vec![], - vec![], - WindowFrame::new(None), - None, - )); + vec![col("age")])); let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()]; let result = group_window_expr_by_sort_keys(exprs.to_vec())?; let key = vec![]; @@ -1316,36 +1294,16 @@ mod tests { Expr::Sort(expr::Sort::new(Box::new(col("created_at")), false, true)); let max1 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), - vec![col("name")], - vec![], - vec![age_asc.clone(), name_desc.clone()], - WindowFrame::new(Some(false)), - None, - )); + vec![col("name")])).order_by(vec![age_asc.clone(), name_desc.clone()]).build().unwrap(); let max2 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), - vec![col("name")], - vec![], - vec![], - WindowFrame::new(None), - None, - )); + vec![col("name")])); let min3 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min), - vec![col("name")], - vec![], - vec![age_asc.clone(), name_desc.clone()], - WindowFrame::new(Some(false)), - None, - )); + vec![col("name")])).order_by(vec![age_asc.clone(), name_desc.clone()]).build().unwrap(); let sum4 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(sum_udaf()), - vec![col("age")], - vec![], - vec![name_desc.clone(), age_asc.clone(), created_at_desc.clone()], - WindowFrame::new(Some(false)), - None, - )); + vec![col("age")])).order_by(vec![name_desc.clone(), age_asc.clone(), created_at_desc.clone()]).build().unwrap(); // FIXME use as_ref let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()]; let result = group_window_expr_by_sort_keys(exprs.to_vec())?; @@ -1372,27 +1330,19 @@ mod tests { let exprs = &[ Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), - vec![col("name")], - vec![], - vec![ + vec![col("name")])).order_by(vec![ Expr::Sort(expr::Sort::new(Box::new(col("age")), true, true)), Expr::Sort(expr::Sort::new(Box::new(col("name")), false, true)), - ], - WindowFrame::new(Some(false)), - None, - )), + ]).window_frame(WindowFrame::new(Some(false))) + .build().unwrap(), Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(sum_udaf()), - vec![col("age")], - vec![], - vec![ + vec![col("age")])).order_by(vec![ Expr::Sort(expr::Sort::new(Box::new(col("name")), false, true)), Expr::Sort(expr::Sort::new(Box::new(col("age")), true, true)), Expr::Sort(expr::Sort::new(Box::new(col("created_at")), false, true)), - ], - WindowFrame::new(Some(false)), - None, - )), + ]).window_frame(WindowFrame::new(Some(false))) + .build().unwrap(), ]; let expected = vec![ Expr::Sort(expr::Sort::new(Box::new(col("age")), true, true)), diff --git a/datafusion/expr/src/window_function.rs b/datafusion/expr/src/window_function.rs new file mode 100644 index 000000000000..0fa1d4168655 --- /dev/null +++ b/datafusion/expr/src/window_function.rs @@ -0,0 +1,104 @@ +use datafusion_common::ScalarValue; + +use crate::{expr::WindowFunction, BuiltInWindowFunction, Expr, Literal}; + + + +/// Create an expression to represent the `row_number` window function +/// +/// Note: call [`WindowFunction::build]` to create an [`Expr`] +pub fn row_number() -> WindowFunction { + WindowFunction::new(BuiltInWindowFunction::RowNumber, vec![]) +} + +/// Create an expression to represent the `rank` window function +/// +/// Note: call [`WindowFunction::build]` to create an [`Expr`] +pub fn rank() -> WindowFunction { + WindowFunction::new(BuiltInWindowFunction::Rank, vec![]) +} + +/// Create an expression to represent the `dense_rank` window function +/// +/// Note: call [`WindowFunction::build]` to create an [`Expr`] +pub fn dense_rank() -> WindowFunction { + WindowFunction::new(BuiltInWindowFunction::DenseRank, vec![]) +} + +/// Create an expression to represent the `percent_rank` window function +/// +/// Note: call [`WindowFunction::build]` to create an [`Expr`] +pub fn percent_rank() -> WindowFunction { + WindowFunction::new(BuiltInWindowFunction::PercentRank, vec![]) +} + +/// Create an expression to represent the `cume_dist` window function +/// +/// Note: call [`WindowFunction::build]` to create an [`Expr`] +pub fn cume_dist() -> WindowFunction { + WindowFunction::new(BuiltInWindowFunction::CumeDist, vec![]) +} + +/// Create an expression to represent the `ntile` window function +/// +/// Note: call [`WindowFunction::build]` to create an [`Expr`] +pub fn ntile(arg: Expr) -> WindowFunction { + WindowFunction::new(BuiltInWindowFunction::Ntile, vec![arg]) +} + +/// Create an expression to represent the `lag` window function +/// +/// Note: call [`WindowFunction::build]` to create an [`Expr`] +pub fn lag( + arg: Expr, + shift_offset: Option, + default_value: Option, +) -> WindowFunction { + let shift_offset_lit = shift_offset + .map(|v| v.lit()) + .unwrap_or(ScalarValue::Null.lit()); + let default_lit = default_value.unwrap_or(ScalarValue::Null).lit(); + WindowFunction::new( + BuiltInWindowFunction::Lag, + vec![arg, shift_offset_lit, default_lit], + ) +} + +/// Create an expression to represent the `lead` window function +/// +/// Note: call [`WindowFunction::build]` to create an [`Expr`] +pub fn lead( + arg: Expr, + shift_offset: Option, + default_value: Option, +) -> WindowFunction { + let shift_offset_lit = shift_offset + .map(|v| v.lit()) + .unwrap_or(ScalarValue::Null.lit()); + let default_lit = default_value.unwrap_or(ScalarValue::Null).lit(); + WindowFunction::new( + BuiltInWindowFunction::Lead, + vec![arg, shift_offset_lit, default_lit], + ) +} + +/// Create an expression to represent the `first_value` window function +/// +/// Note: call [`WindowFunction::build]` to create an [`Expr`] +pub fn first_value(arg: Expr) -> WindowFunction { + WindowFunction::new(BuiltInWindowFunction::FirstValue, vec![arg]) +} + +/// Create an expression to represent the `last_value` window function +/// +/// Note: call [`WindowFunction::build]` to create an [`Expr`] +pub fn last_value(arg: Expr) -> WindowFunction { + WindowFunction::new(BuiltInWindowFunction::LastValue, vec![arg]) +} + +/// Create an expression to represent the `nth_value` window function +/// +/// Note: call [`WindowFunction::build]` to create an [`Expr`] +pub fn nth_value(arg: Expr, n: i64) -> WindowFunction { + WindowFunction::new(BuiltInWindowFunction::NthValue, vec![arg, n.lit()]) +} diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index 959ffdaaa212..344e07f74345 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -101,6 +101,7 @@ mod tests { use arrow::datatypes::DataType; use datafusion_common::ScalarValue; use datafusion_expr::expr::Sort; + use datafusion_expr::ExprFunctionExt; use datafusion_expr::{ col, exists, expr, in_subquery, logical_plan::LogicalPlanBuilder, max, out_ref_col, scalar_subquery, wildcard, WindowFrame, WindowFrameBound, @@ -222,16 +223,12 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .window(vec![Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(count_udaf()), - vec![wildcard()], - vec![], - vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))], - WindowFrame::new_bounds( + vec![wildcard()])).order_by(vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))]).window_frame(WindowFrame::new_bounds( WindowFrameUnits::Range, WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))), WindowFrameBound::Following(ScalarValue::UInt32(Some(2))), - ), - None, - ))])? + )).build()? + ])? .project(vec![count(wildcard())])? .build()?; diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 80a8c864e431..a0bae6d904d2 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -46,9 +46,7 @@ use datafusion_expr::type_coercion::other::{ use datafusion_expr::type_coercion::{is_datetime, is_utf8_or_large_utf8}; use datafusion_expr::utils::merge_schema; use datafusion_expr::{ - is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, not, - type_coercion, AggregateFunction, AggregateUDF, Expr, ExprSchemable, LogicalPlan, - Operator, ScalarUDF, Signature, WindowFrame, WindowFrameBound, WindowFrameUnits, + is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, not, type_coercion, AggregateFunction, AggregateUDF, Expr, ExprFunctionExt, ExprSchemable, LogicalPlan, Operator, ScalarUDF, Signature, WindowFrame, WindowFrameBound, WindowFrameUnits }; use crate::analyzer::AnalyzerRule; @@ -458,14 +456,16 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { _ => args, }; - Ok(Transformed::yes(Expr::WindowFunction(WindowFunction::new( + Ok(Transformed::yes({ + let mut builder = Expr::WindowFunction(WindowFunction::new( fun, - args, - partition_by, - order_by, - window_frame, - null_treatment, - )))) + args)).partition_by(partition_by).order_by(order_by).window_frame(window_frame); + if let Some(n) = null_treatment { + builder = builder.null_treatment(n); + } + builder.build()? + } + )) } Expr::Alias(_) | Expr::Column(_) diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index 9f04a01a3377..787146d90b00 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -815,7 +815,7 @@ mod tests { lit, logical_plan::{builder::LogicalPlanBuilder, table_scan}, max, min, not, try_cast, when, AggregateFunction, BinaryExpr, Expr, Extension, - Like, LogicalPlan, Operator, Projection, UserDefinedLogicalNodeCore, WindowFrame, + Like, LogicalPlan, Operator, Projection, UserDefinedLogicalNodeCore, WindowFunctionDefinition, }; @@ -1918,21 +1918,11 @@ mod tests { let max1 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), - vec![col("test.a")], - vec![col("test.b")], - vec![], - WindowFrame::new(None), - None, - )); + vec![col("test.a")])).partition_by(vec![col("test.b")]).build().unwrap(); let max2 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), - vec![col("test.b")], - vec![], - vec![], - WindowFrame::new(None), - None, - )); + vec![col("test.b")])); let col1 = col(max1.display_name()?); let col2 = col(max2.display_name()?); diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 8414f39f3060..33d39b07ac05 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -3858,12 +3858,7 @@ mod tests { let window_function_expr = Expr::WindowFunction(datafusion_expr::expr::WindowFunction::new( udwf, - vec![], - vec![], - vec![], - WindowFrame::new(None), - None, - )); + vec![])); let expected = col("result_column"); assert_eq!(simplify(window_function_expr), expected); @@ -3874,12 +3869,7 @@ mod tests { let window_function_expr = Expr::WindowFunction(datafusion_expr::expr::WindowFunction::new( udwf, - vec![], - vec![], - vec![], - WindowFrame::new(None), - None, - )); + vec![])); let expected = window_function_expr.clone(); assert_eq!(simplify(window_function_expr), expected); diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index b6b556a8ed6b..fd6d19d2fd08 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -25,6 +25,7 @@ use datafusion_common::{ use datafusion_expr::expr::Unnest; use datafusion_expr::expr::{Alias, Placeholder}; use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by}; +use datafusion_expr::ExprFunctionExt; use datafusion_expr::{ expr::{self, InList, Sort, WindowFunction}, logical_plan::{PlanType, StringifiedPlan}, @@ -300,7 +301,6 @@ pub fn parse_expr( ) })?; // TODO: support proto for null treatment - let null_treatment = None; regularize_window_order_by(&window_frame, &mut order_by)?; match window_function { @@ -314,12 +314,7 @@ pub fn parse_expr( registry, "expr", codec, - )?], - partition_by, - order_by, - window_frame, - None, - ))) + )?])).partition_by(partition_by).order_by(order_by).window_frame(window_frame).build().unwrap()) } window_expr_node::WindowFunction::BuiltInFunction(i) => { let built_in_function = protobuf::BuiltInWindowFunction::try_from(*i) @@ -335,12 +330,7 @@ pub fn parse_expr( expr::WindowFunctionDefinition::BuiltInWindowFunction( built_in_function, ), - args, - partition_by, - order_by, - window_frame, - null_treatment, - ))) + args)).partition_by(partition_by).order_by(order_by).window_frame(window_frame).build().unwrap()) } window_expr_node::WindowFunction::Udaf(udaf_name) => { let udaf_function = match &expr.fun_definition { @@ -354,12 +344,7 @@ pub fn parse_expr( .unwrap_or_else(Vec::new); Ok(Expr::WindowFunction(WindowFunction::new( expr::WindowFunctionDefinition::AggregateUDF(udaf_function), - args, - partition_by, - order_by, - window_frame, - None, - ))) + args)).partition_by(partition_by).order_by(order_by).window_frame(window_frame).build().unwrap()) } window_expr_node::WindowFunction::Udwf(udwf_name) => { let udwf_function = match &expr.fun_definition { @@ -373,12 +358,7 @@ pub fn parse_expr( .unwrap_or_else(Vec::new); Ok(Expr::WindowFunction(WindowFunction::new( expr::WindowFunctionDefinition::WindowUDF(udwf_function), - args, - partition_by, - order_by, - window_frame, - None, - ))) + args)).partition_by(partition_by).order_by(order_by).window_frame(window_frame).build().unwrap()) } } } diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 63d2c15dfd40..d37aed4c839b 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -2040,24 +2040,14 @@ fn roundtrip_window() { WindowFunctionDefinition::BuiltInWindowFunction( datafusion_expr::BuiltInWindowFunction::Rank, ), - vec![], - vec![col("col1")], - vec![col("col2")], - WindowFrame::new(Some(false)), - None, - )); + vec![])).partition_by(vec![col("col1")]).order_by(vec![col("col2")]).window_frame(WindowFrame::new(Some(false))).build().unwrap(); // 2. with default window_frame let test_expr2 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::BuiltInWindowFunction( datafusion_expr::BuiltInWindowFunction::Rank, ), - vec![], - vec![col("col1")], - vec![col("col2")], - WindowFrame::new(Some(false)), - None, - )); + vec![])).partition_by(vec![col("col1")]).order_by(vec![col("col2")]).window_frame(WindowFrame::new(Some(false))).build().unwrap(); // 3. with window_frame with row numbers let range_number_frame = WindowFrame::new_bounds( @@ -2070,12 +2060,7 @@ fn roundtrip_window() { WindowFunctionDefinition::BuiltInWindowFunction( datafusion_expr::BuiltInWindowFunction::Rank, ), - vec![], - vec![col("col1")], - vec![col("col2")], - range_number_frame, - None, - )); + vec![])).partition_by(vec![col("col1")]).order_by(vec![col("col2")]).window_frame(range_number_frame).build().unwrap(); // 4. test with AggregateFunction let row_number_frame = WindowFrame::new_bounds( @@ -2086,12 +2071,7 @@ fn roundtrip_window() { let test_expr4 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), - vec![col("col1")], - vec![col("col1")], - vec![col("col2")], - row_number_frame.clone(), - None, - )); + vec![col("col1")])).partition_by(vec![col("col1")]).order_by(vec![col("col2")]).window_frame(row_number_frame.clone()).build().unwrap(); // 5. test with AggregateUDF #[derive(Debug)] @@ -2135,12 +2115,7 @@ fn roundtrip_window() { let test_expr5 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(Arc::new(dummy_agg.clone())), - vec![col("col1")], - vec![col("col1")], - vec![col("col2")], - row_number_frame.clone(), - None, - )); + vec![col("col1")])).partition_by(vec![col("col1")]).order_by(vec![col("col2")]).window_frame(row_number_frame.clone()).build().unwrap(); ctx.register_udaf(dummy_agg); // 6. test with WindowUDF @@ -2211,21 +2186,11 @@ fn roundtrip_window() { let test_expr6 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::WindowUDF(Arc::new(dummy_window_udf.clone())), - vec![col("col1")], - vec![col("col1")], - vec![col("col2")], - row_number_frame.clone(), - None, - )); + vec![col("col1")])).partition_by(vec![col("col1")]).order_by(vec![col("col2")]).window_frame(row_number_frame.clone()).build().unwrap(); let text_expr7 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(avg_udaf()), - vec![col("col1")], - vec![], - vec![], - row_number_frame.clone(), - None, - )); + vec![col("col1")])).window_frame(row_number_frame.clone()).build().unwrap(); ctx.register_udwf(dummy_window_udf); diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index dab328cc4908..1d5c10ef2be0 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -23,7 +23,7 @@ use datafusion_common::{ }; use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by}; use datafusion_expr::{ - expr, AggregateFunction, Expr, ExprSchemable, WindowFrame, WindowFunctionDefinition, + expr, AggregateFunction, Expr, ExprFunctionExt, ExprSchemable, WindowFrame, WindowFunctionDefinition }; use datafusion_expr::{ expr::{ScalarFunction, Unnest}, @@ -314,23 +314,23 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let args = self.function_args_to_expr(args, schema, planner_context)?; - Expr::WindowFunction(expr::WindowFunction::new( + let mut builder = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(aggregate_fun), - args, - partition_by, - order_by, - window_frame, - null_treatment, - )) + args)).partition_by(partition_by).order_by(order_by).window_frame(window_frame); + if let Some(n) = null_treatment { + builder = builder.null_treatment(n); + }; + builder.build().unwrap() } - _ => Expr::WindowFunction(expr::WindowFunction::new( - fun, - self.function_args_to_expr(args, schema, planner_context)?, - partition_by, - order_by, - window_frame, - null_treatment, - )), + _ => { + let mut builder = Expr::WindowFunction(expr::WindowFunction::new( + fun, + self.function_args_to_expr(args, schema, planner_context)?)).partition_by(partition_by).order_by(order_by).window_frame(window_frame); + if let Some(n) = null_treatment { + builder = builder.null_treatment(n); + } + builder.build().unwrap() + }, }; return Ok(expr); } From 87ca52a51e6f1c60e4ac2793337c820e0e660a17 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Fri, 19 Jul 2024 08:45:23 -0400 Subject: [PATCH 3/6] Changing null_treatment to take an option, but this is mostly for code cleanliness and not strictly required --- datafusion/expr/src/expr_fn.rs | 10 +++++----- datafusion/expr/src/tree_node.rs | 16 +++------------- .../optimizer/src/analyzer/type_coercion.rs | 10 ++-------- datafusion/sql/src/expr/function.rs | 16 ++++------------ 4 files changed, 14 insertions(+), 38 deletions(-) diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 262402a70523..9fe3ad80a744 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -716,7 +716,7 @@ pub trait ExprFunctionExt { /// Add `DISTINCT` fn distinct(self) -> ExprFuncBuilder; /// Add `RESPECT NULLS` or `IGNORE NULLS` - fn null_treatment(self, null_treatment: NullTreatment) -> ExprFuncBuilder; + fn null_treatment(self, null_treatment: impl Into>) -> ExprFuncBuilder; // Add `PARTITION BY` fn partition_by(self, partition_by: Vec) -> ExprFuncBuilder; // Add appropriate window frame conditions @@ -833,8 +833,8 @@ impl ExprFuncBuilder { } /// Add `RESPECT NULLS` or `IGNORE NULLS` - pub fn null_treatment(mut self, null_treatment: NullTreatment) -> ExprFuncBuilder { - self.null_treatment = Some(null_treatment); + pub fn null_treatment(mut self, null_treatment: impl Into>) -> ExprFuncBuilder { + self.null_treatment = null_treatment.into(); self } @@ -881,14 +881,14 @@ impl ExprFunctionExt for Expr { _ => ExprFuncBuilder::new(None), } } - fn null_treatment(self, null_treatment: NullTreatment) -> ExprFuncBuilder { + fn null_treatment(self, null_treatment: impl Into>) -> ExprFuncBuilder { let mut builder = match self { Expr::AggregateFunction(udaf) => ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf))), Expr::WindowFunction(udwf) => ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf))), _ => ExprFuncBuilder::new(None), }; if builder.fun.is_some() { - builder.null_treatment = Some(null_treatment); + builder.null_treatment = null_treatment.into(); } builder } diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index 3d7a72180ca6..f262613b2295 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -293,20 +293,10 @@ impl TreeNode for Expr { order_by, transform_vec(order_by, &mut f) )? - .update_data(|(new_args, new_partition_by, new_order_by)| { - let mut builder = Expr::WindowFunction(WindowFunction::new( + .update_data(|(new_args, new_partition_by, new_order_by)| Expr::WindowFunction(WindowFunction::new( fun, - new_args)).partition_by(new_partition_by).order_by(new_order_by).window_frame(window_frame); - if let Some(n) = null_treatment { - builder = builder.null_treatment(n) - } - builder.build().unwrap() - // new_partition_by, - // new_order_by, - // window_frame, - // null_treatment, - // )) - }), + new_args)).partition_by(new_partition_by).order_by(new_order_by).window_frame(window_frame).null_treatment(null_treatment).build().unwrap() + ), Expr::AggregateFunction(AggregateFunction { args, func_def, diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index a0bae6d904d2..e9f0d1795027 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -456,15 +456,9 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { _ => args, }; - Ok(Transformed::yes({ - let mut builder = Expr::WindowFunction(WindowFunction::new( + Ok(Transformed::yes(Expr::WindowFunction(WindowFunction::new( fun, - args)).partition_by(partition_by).order_by(order_by).window_frame(window_frame); - if let Some(n) = null_treatment { - builder = builder.null_treatment(n); - } - builder.build()? - } + args)).partition_by(partition_by).order_by(order_by).window_frame(window_frame).null_treatment(null_treatment).build()? )) } Expr::Alias(_) diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 1d5c10ef2be0..43adc8db2d0f 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -314,22 +314,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let args = self.function_args_to_expr(args, schema, planner_context)?; - let mut builder = Expr::WindowFunction(expr::WindowFunction::new( + Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(aggregate_fun), - args)).partition_by(partition_by).order_by(order_by).window_frame(window_frame); - if let Some(n) = null_treatment { - builder = builder.null_treatment(n); - }; - builder.build().unwrap() + args)).partition_by(partition_by).order_by(order_by).window_frame(window_frame).null_treatment(null_treatment).build().unwrap() } _ => { - let mut builder = Expr::WindowFunction(expr::WindowFunction::new( + Expr::WindowFunction(expr::WindowFunction::new( fun, - self.function_args_to_expr(args, schema, planner_context)?)).partition_by(partition_by).order_by(order_by).window_frame(window_frame); - if let Some(n) = null_treatment { - builder = builder.null_treatment(n); - } - builder.build().unwrap() + self.function_args_to_expr(args, schema, planner_context)?)).partition_by(partition_by).order_by(order_by).window_frame(window_frame).null_treatment(null_treatment).build().unwrap() }, }; return Ok(expr); From 2ad0fe25e0c7d1c3656c7f8be36c2bf3da4bdfaa Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Fri, 19 Jul 2024 08:49:46 -0400 Subject: [PATCH 4/6] Moving functions in ExprFuncBuilder over to be explicitly implementing ExprFunctionExt trait so we can guarantee a consistent user experience no matter which they call on the Expr and which on the builder --- datafusion/expr/src/expr_fn.rs | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 9fe3ad80a744..4b7f3fe210f0 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -811,39 +811,42 @@ impl ExprFuncBuilder { Ok(fun_expr) } +} + +impl ExprFunctionExt for ExprFuncBuilder { /// Add `ORDER BY ` /// /// Note: `order_by` must be [`Expr::Sort`] - pub fn order_by(mut self, order_by: Vec) -> ExprFuncBuilder { + fn order_by(mut self, order_by: Vec) -> ExprFuncBuilder { self.order_by = Some(order_by); self } /// Add `FILTER ` - pub fn filter(mut self, filter: Expr) -> ExprFuncBuilder { + fn filter(mut self, filter: Expr) -> ExprFuncBuilder { self.filter = Some(filter); self } /// Add `DISTINCT` - pub fn distinct(mut self) -> ExprFuncBuilder { + fn distinct(mut self) -> ExprFuncBuilder { self.distinct = true; self } /// Add `RESPECT NULLS` or `IGNORE NULLS` - pub fn null_treatment(mut self, null_treatment: impl Into>) -> ExprFuncBuilder { + fn null_treatment(mut self, null_treatment: impl Into>) -> ExprFuncBuilder { self.null_treatment = null_treatment.into(); self } - pub fn partition_by(mut self, partition_by: Vec) -> ExprFuncBuilder { + fn partition_by(mut self, partition_by: Vec) -> ExprFuncBuilder { self.partition_by = Some(partition_by); self } - pub fn window_frame(mut self, window_frame: WindowFrame) -> ExprFuncBuilder { + fn window_frame(mut self, window_frame: WindowFrame) -> ExprFuncBuilder { self.window_frame = Some(window_frame); self } From 26f987f0b01501c99ffafce952e1b57ac9bedec6 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Fri, 19 Jul 2024 08:52:23 -0400 Subject: [PATCH 5/6] Apply cargo fmt --- datafusion-examples/examples/expr_api.rs | 2 +- datafusion/core/src/dataframe/mod.rs | 9 ++- datafusion/core/tests/dataframe/mod.rs | 21 +++--- datafusion/expr/src/expr.rs | 9 +-- datafusion/expr/src/expr_fn.rs | 53 +++++++++----- datafusion/expr/src/tree_node.rs | 13 ++-- datafusion/expr/src/utils.rs | 71 +++++++++++++------ datafusion/expr/src/window_function.rs | 2 - .../functions-aggregate/src/first_last.rs | 4 +- .../src/analyzer/count_wildcard_rule.rs | 15 ++-- .../optimizer/src/analyzer/type_coercion.rs | 15 ++-- .../optimizer/src/optimize_projections/mod.rs | 9 ++- .../simplify_expressions/expr_simplifier.rs | 14 ++-- .../proto/src/logical_plan/from_proto.rs | 32 +++++++-- .../tests/cases/roundtrip_logical_plan.rs | 56 ++++++++++++--- datafusion/sql/src/expr/function.rs | 27 +++++-- 16 files changed, 251 insertions(+), 101 deletions(-) diff --git a/datafusion-examples/examples/expr_api.rs b/datafusion-examples/examples/expr_api.rs index 09ac27df01b1..f7ced0d7e077 100644 --- a/datafusion-examples/examples/expr_api.rs +++ b/datafusion-examples/examples/expr_api.rs @@ -33,7 +33,7 @@ use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::expr::BinaryExpr; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::simplify::SimplifyContext; -use datafusion_expr::{ExprFunctionExt, ColumnarValue, ExprSchemable, Operator}; +use datafusion_expr::{ColumnarValue, ExprFunctionExt, ExprSchemable, Operator}; /// This example demonstrates the DataFusion [`Expr`] API. /// diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index e5628c917a57..0e3e9d79f886 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -1696,7 +1696,8 @@ mod tests { use datafusion_common::{Constraint, Constraints, ScalarValue}; use datafusion_common_runtime::SpawnedTask; use datafusion_expr::{ - array_agg, cast, create_udf, expr, lit, BuiltInWindowFunction, ExprFunctionExt, ScalarFunctionImplementation, Volatility, WindowFunctionDefinition + array_agg, cast, create_udf, expr, lit, BuiltInWindowFunction, ExprFunctionExt, + ScalarFunctionImplementation, Volatility, WindowFunctionDefinition, }; use datafusion_functions_aggregate::expr_fn::count_distinct; use datafusion_physical_expr::expressions::Column; @@ -1865,7 +1866,11 @@ mod tests { WindowFunctionDefinition::BuiltInWindowFunction( BuiltInWindowFunction::FirstValue, ), - vec![col("aggregate_test_100.c1")])).partition_by(vec![col("aggregate_test_100.c2")]).build().unwrap(); + vec![col("aggregate_test_100.c1")], + )) + .partition_by(vec![col("aggregate_test_100.c2")]) + .build() + .unwrap(); let t2 = t.select(vec![col("c1"), first_row])?; let plan = t2.plan.clone(); diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 6eb70f4d26fe..2c3b2bb88b37 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -54,7 +54,9 @@ use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_expr::expr::{GroupingSet, Sort}; use datafusion_expr::var_provider::{VarProvider, VarType}; use datafusion_expr::{ - array_agg, cast, col, exists, expr, in_subquery, lit, max, out_ref_col, placeholder, scalar_subquery, when, wildcard, Expr, ExprFunctionExt, ExprSchemable, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition + array_agg, cast, col, exists, expr, in_subquery, lit, max, out_ref_col, placeholder, + scalar_subquery, when, wildcard, Expr, ExprFunctionExt, ExprSchemable, WindowFrame, + WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; use datafusion_functions_aggregate::expr_fn::{avg, count, sum}; @@ -180,13 +182,16 @@ async fn test_count_wildcard_on_window() -> Result<()> { .await? .select(vec![Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(count_udaf()), - vec![wildcard()])).order_by(vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))]).window_frame( - WindowFrame::new_bounds( - WindowFrameUnits::Range, - WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))), - WindowFrameBound::Following(ScalarValue::UInt32(Some(2))), - )).build().unwrap() - ])? + vec![wildcard()], + )) + .order_by(vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))]) + .window_frame(WindowFrame::new_bounds( + WindowFrameUnits::Range, + WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))), + WindowFrameBound::Following(ScalarValue::UInt32(Some(2))), + )) + .build() + .unwrap()])? .explain(false, false)? .collect() .await?; diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 81058e402413..37acdb957544 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -28,7 +28,8 @@ use crate::expr_fn::binary_expr; use crate::logical_plan::Subquery; use crate::utils::expr_to_columns; use crate::{ - aggregate_function, built_in_window_function, udaf, BuiltInWindowFunction, ExprSchemable, Operator, Signature, WindowFrame, WindowUDF + aggregate_function, built_in_window_function, udaf, BuiltInWindowFunction, + ExprSchemable, Operator, Signature, WindowFrame, WindowUDF, }; use crate::{window_frame, Volatility}; @@ -811,11 +812,7 @@ pub struct WindowFunction { impl WindowFunction { /// Create a new Window expression - pub fn new( - fun: impl Into, - args: Vec, - - ) -> Self { + pub fn new(fun: impl Into, args: Vec) -> Self { Self { fun: fun.into(), args, diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 4b7f3fe210f0..7281a1211aa8 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -30,7 +30,9 @@ use crate::{ AggregateUDF, Expr, LogicalPlan, Operator, ScalarFunctionImplementation, ScalarUDF, Signature, Volatility, }; -use crate::{AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, WindowFrame, WindowUDF, WindowUDFImpl}; +use crate::{ + AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, WindowFrame, WindowUDF, WindowUDFImpl, +}; use arrow::compute::kernels::cast_utils::{ parse_interval_day_time, parse_interval_month_day_nano, parse_interval_year_month, }; @@ -677,7 +679,6 @@ pub fn interval_month_day_nano_lit(value: &str) -> Expr { Expr::Literal(ScalarValue::IntervalMonthDayNano(interval)) } - /// Extensions for configuring [`Expr::AggregateFunction`] /// /// Adds methods to [`Expr`] that make it easy to set optional aggregate options @@ -716,7 +717,10 @@ pub trait ExprFunctionExt { /// Add `DISTINCT` fn distinct(self) -> ExprFuncBuilder; /// Add `RESPECT NULLS` or `IGNORE NULLS` - fn null_treatment(self, null_treatment: impl Into>) -> ExprFuncBuilder; + fn null_treatment( + self, + null_treatment: impl Into>, + ) -> ExprFuncBuilder; // Add `PARTITION BY` fn partition_by(self, partition_by: Vec) -> ExprFuncBuilder; // Add appropriate window frame conditions @@ -803,7 +807,8 @@ impl ExprFuncBuilder { let has_order_by = order_by.as_ref().map(|o| o.len() > 0); udwf.order_by = order_by.unwrap_or_default(); udwf.partition_by = partition_by.unwrap_or_default(); - udwf.window_frame = window_frame.unwrap_or(WindowFrame::new(has_order_by)); + udwf.window_frame = + window_frame.unwrap_or(WindowFrame::new(has_order_by)); udwf.null_treatment = null_treatment; Expr::WindowFunction(udwf) } @@ -814,7 +819,6 @@ impl ExprFuncBuilder { } impl ExprFunctionExt for ExprFuncBuilder { - /// Add `ORDER BY ` /// /// Note: `order_by` must be [`Expr::Sort`] @@ -836,7 +840,10 @@ impl ExprFunctionExt for ExprFuncBuilder { } /// Add `RESPECT NULLS` or `IGNORE NULLS` - fn null_treatment(mut self, null_treatment: impl Into>) -> ExprFuncBuilder { + fn null_treatment( + mut self, + null_treatment: impl Into>, + ) -> ExprFuncBuilder { self.null_treatment = null_treatment.into(); self } @@ -845,7 +852,7 @@ impl ExprFunctionExt for ExprFuncBuilder { self.partition_by = Some(partition_by); self } - + fn window_frame(mut self, window_frame: WindowFrame) -> ExprFuncBuilder { self.window_frame = Some(window_frame); self @@ -855,8 +862,12 @@ impl ExprFunctionExt for ExprFuncBuilder { impl ExprFunctionExt for Expr { fn order_by(self, order_by: Vec) -> ExprFuncBuilder { let mut builder = match self { - Expr::AggregateFunction(udaf) => ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf))), - Expr::WindowFunction(udwf) => ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf))), + Expr::AggregateFunction(udaf) => { + ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf))) + } + Expr::WindowFunction(udwf) => { + ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf))) + } _ => ExprFuncBuilder::new(None), }; if builder.fun.is_some() { @@ -867,7 +878,8 @@ impl ExprFunctionExt for Expr { fn filter(self, filter: Expr) -> ExprFuncBuilder { match self { Expr::AggregateFunction(udaf) => { - let mut builder = ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf))); + let mut builder = + ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf))); builder.filter = Some(filter); builder } @@ -877,17 +889,25 @@ impl ExprFunctionExt for Expr { fn distinct(self) -> ExprFuncBuilder { match self { Expr::AggregateFunction(udaf) => { - let mut builder = ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf))); + let mut builder = + ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf))); builder.distinct = true; builder } _ => ExprFuncBuilder::new(None), } } - fn null_treatment(self, null_treatment: impl Into>) -> ExprFuncBuilder { + fn null_treatment( + self, + null_treatment: impl Into>, + ) -> ExprFuncBuilder { let mut builder = match self { - Expr::AggregateFunction(udaf) => ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf))), - Expr::WindowFunction(udwf) => ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf))), + Expr::AggregateFunction(udaf) => { + ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf))) + } + Expr::WindowFunction(udwf) => { + ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf))) + } _ => ExprFuncBuilder::new(None), }; if builder.fun.is_some() { @@ -895,7 +915,7 @@ impl ExprFunctionExt for Expr { } builder } - + fn partition_by(self, partition_by: Vec) -> ExprFuncBuilder { match self { Expr::WindowFunction(udwf) => { @@ -906,7 +926,7 @@ impl ExprFunctionExt for Expr { _ => ExprFuncBuilder::new(None), } } - + fn window_frame(self, window_frame: WindowFrame) -> ExprFuncBuilder { match self { Expr::WindowFunction(udwf) => { @@ -919,7 +939,6 @@ impl ExprFunctionExt for Expr { } } - #[cfg(test)] mod test { use super::*; diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index f262613b2295..a97b9f010f79 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -293,10 +293,15 @@ impl TreeNode for Expr { order_by, transform_vec(order_by, &mut f) )? - .update_data(|(new_args, new_partition_by, new_order_by)| Expr::WindowFunction(WindowFunction::new( - fun, - new_args)).partition_by(new_partition_by).order_by(new_order_by).window_frame(window_frame).null_treatment(null_treatment).build().unwrap() - ), + .update_data(|(new_args, new_partition_by, new_order_by)| { + Expr::WindowFunction(WindowFunction::new(fun, new_args)) + .partition_by(new_partition_by) + .order_by(new_order_by) + .window_frame(window_frame) + .null_treatment(null_treatment) + .build() + .unwrap() + }), Expr::AggregateFunction(AggregateFunction { args, func_def, diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 07392173334d..b833c1db06a2 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -1252,7 +1252,9 @@ impl AggregateOrderSensitivity { mod tests { use super::*; use crate::{ - col, cube, expr, expr_vec_fmt, grouping_set, lit, rollup, test::function_stub::sum_udaf, AggregateFunction, Cast, ExprFunctionExt, WindowFrame, WindowFunctionDefinition + col, cube, expr, expr_vec_fmt, grouping_set, lit, rollup, + test::function_stub::sum_udaf, AggregateFunction, Cast, ExprFunctionExt, + WindowFrame, WindowFunctionDefinition, }; #[test] @@ -1267,16 +1269,20 @@ mod tests { fn test_group_window_expr_by_sort_keys_empty_window() -> Result<()> { let max1 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), - vec![col("name")])); + vec![col("name")], + )); let max2 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), - vec![col("name")])); + vec![col("name")], + )); let min3 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min), - vec![col("name")])); + vec![col("name")], + )); let sum4 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(sum_udaf()), - vec![col("age")])); + vec![col("age")], + )); let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()]; let result = group_window_expr_by_sort_keys(exprs.to_vec())?; let key = vec![]; @@ -1294,16 +1300,33 @@ mod tests { Expr::Sort(expr::Sort::new(Box::new(col("created_at")), false, true)); let max1 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), - vec![col("name")])).order_by(vec![age_asc.clone(), name_desc.clone()]).build().unwrap(); + vec![col("name")], + )) + .order_by(vec![age_asc.clone(), name_desc.clone()]) + .build() + .unwrap(); let max2 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), - vec![col("name")])); + vec![col("name")], + )); let min3 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min), - vec![col("name")])).order_by(vec![age_asc.clone(), name_desc.clone()]).build().unwrap(); + vec![col("name")], + )) + .order_by(vec![age_asc.clone(), name_desc.clone()]) + .build() + .unwrap(); let sum4 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(sum_udaf()), - vec![col("age")])).order_by(vec![name_desc.clone(), age_asc.clone(), created_at_desc.clone()]).build().unwrap(); + vec![col("age")], + )) + .order_by(vec![ + name_desc.clone(), + age_asc.clone(), + created_at_desc.clone(), + ]) + .build() + .unwrap(); // FIXME use as_ref let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()]; let result = group_window_expr_by_sort_keys(exprs.to_vec())?; @@ -1330,19 +1353,27 @@ mod tests { let exprs = &[ Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), - vec![col("name")])).order_by(vec![ - Expr::Sort(expr::Sort::new(Box::new(col("age")), true, true)), - Expr::Sort(expr::Sort::new(Box::new(col("name")), false, true)), - ]).window_frame(WindowFrame::new(Some(false))) - .build().unwrap(), + vec![col("name")], + )) + .order_by(vec![ + Expr::Sort(expr::Sort::new(Box::new(col("age")), true, true)), + Expr::Sort(expr::Sort::new(Box::new(col("name")), false, true)), + ]) + .window_frame(WindowFrame::new(Some(false))) + .build() + .unwrap(), Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(sum_udaf()), - vec![col("age")])).order_by(vec![ - Expr::Sort(expr::Sort::new(Box::new(col("name")), false, true)), - Expr::Sort(expr::Sort::new(Box::new(col("age")), true, true)), - Expr::Sort(expr::Sort::new(Box::new(col("created_at")), false, true)), - ]).window_frame(WindowFrame::new(Some(false))) - .build().unwrap(), + vec![col("age")], + )) + .order_by(vec![ + Expr::Sort(expr::Sort::new(Box::new(col("name")), false, true)), + Expr::Sort(expr::Sort::new(Box::new(col("age")), true, true)), + Expr::Sort(expr::Sort::new(Box::new(col("created_at")), false, true)), + ]) + .window_frame(WindowFrame::new(Some(false))) + .build() + .unwrap(), ]; let expected = vec![ Expr::Sort(expr::Sort::new(Box::new(col("age")), true, true)), diff --git a/datafusion/expr/src/window_function.rs b/datafusion/expr/src/window_function.rs index 0fa1d4168655..f61c9110ffc9 100644 --- a/datafusion/expr/src/window_function.rs +++ b/datafusion/expr/src/window_function.rs @@ -2,8 +2,6 @@ use datafusion_common::ScalarValue; use crate::{expr::WindowFunction, BuiltInWindowFunction, Expr, Literal}; - - /// Create an expression to represent the `row_number` window function /// /// Note: call [`WindowFunction::build]` to create an [`Expr`] diff --git a/datafusion/functions-aggregate/src/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs index 862bd8c1378a..39f1944452af 100644 --- a/datafusion/functions-aggregate/src/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -31,8 +31,8 @@ use datafusion_common::{ use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::{format_state_name, AggregateOrderSensitivity}; use datafusion_expr::{ - Accumulator, ExprFunctionExt, AggregateUDFImpl, ArrayFunctionSignature, Expr, Signature, - TypeSignature, Volatility, + Accumulator, AggregateUDFImpl, ArrayFunctionSignature, Expr, ExprFunctionExt, + Signature, TypeSignature, Volatility, }; use datafusion_physical_expr_common::aggregate::utils::get_sort_options; use datafusion_physical_expr_common::sort_expr::{ diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index 344e07f74345..a9114ee70a59 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -223,12 +223,15 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .window(vec![Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(count_udaf()), - vec![wildcard()])).order_by(vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))]).window_frame(WindowFrame::new_bounds( - WindowFrameUnits::Range, - WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))), - WindowFrameBound::Following(ScalarValue::UInt32(Some(2))), - )).build()? - ])? + vec![wildcard()], + )) + .order_by(vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))]) + .window_frame(WindowFrame::new_bounds( + WindowFrameUnits::Range, + WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))), + WindowFrameBound::Following(ScalarValue::UInt32(Some(2))), + )) + .build()?])? .project(vec![count(wildcard())])? .build()?; diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index e9f0d1795027..16202f137d06 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -46,7 +46,10 @@ use datafusion_expr::type_coercion::other::{ use datafusion_expr::type_coercion::{is_datetime, is_utf8_or_large_utf8}; use datafusion_expr::utils::merge_schema; use datafusion_expr::{ - is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, not, type_coercion, AggregateFunction, AggregateUDF, Expr, ExprFunctionExt, ExprSchemable, LogicalPlan, Operator, ScalarUDF, Signature, WindowFrame, WindowFrameBound, WindowFrameUnits + is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, not, + type_coercion, AggregateFunction, AggregateUDF, Expr, ExprFunctionExt, ExprSchemable, + LogicalPlan, Operator, ScalarUDF, Signature, WindowFrame, WindowFrameBound, + WindowFrameUnits, }; use crate::analyzer::AnalyzerRule; @@ -456,9 +459,13 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { _ => args, }; - Ok(Transformed::yes(Expr::WindowFunction(WindowFunction::new( - fun, - args)).partition_by(partition_by).order_by(order_by).window_frame(window_frame).null_treatment(null_treatment).build()? + Ok(Transformed::yes( + Expr::WindowFunction(WindowFunction::new(fun, args)) + .partition_by(partition_by) + .order_by(order_by) + .window_frame(window_frame) + .null_treatment(null_treatment) + .build()?, )) } Expr::Alias(_) diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index 787146d90b00..16abf93f3807 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -1918,11 +1918,16 @@ mod tests { let max1 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), - vec![col("test.a")])).partition_by(vec![col("test.b")]).build().unwrap(); + vec![col("test.a")], + )) + .partition_by(vec![col("test.b")]) + .build() + .unwrap(); let max2 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), - vec![col("test.b")])); + vec![col("test.b")], + )); let col1 = col(max1.display_name()?); let col2 = col(max2.display_name()?); diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 33d39b07ac05..f96aa1697a76 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -3855,10 +3855,9 @@ mod tests { let udwf = WindowFunctionDefinition::WindowUDF( WindowUDF::new_from_impl(SimplifyMockUdwf::new_with_simplify()).into(), ); - let window_function_expr = - Expr::WindowFunction(datafusion_expr::expr::WindowFunction::new( - udwf, - vec![])); + let window_function_expr = Expr::WindowFunction( + datafusion_expr::expr::WindowFunction::new(udwf, vec![]), + ); let expected = col("result_column"); assert_eq!(simplify(window_function_expr), expected); @@ -3866,10 +3865,9 @@ mod tests { let udwf = WindowFunctionDefinition::WindowUDF( WindowUDF::new_from_impl(SimplifyMockUdwf::new_without_simplify()).into(), ); - let window_function_expr = - Expr::WindowFunction(datafusion_expr::expr::WindowFunction::new( - udwf, - vec![])); + let window_function_expr = Expr::WindowFunction( + datafusion_expr::expr::WindowFunction::new(udwf, vec![]), + ); let expected = window_function_expr.clone(); assert_eq!(simplify(window_function_expr), expected); diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index fd6d19d2fd08..f3f0e603e060 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -314,7 +314,13 @@ pub fn parse_expr( registry, "expr", codec, - )?])).partition_by(partition_by).order_by(order_by).window_frame(window_frame).build().unwrap()) + )?], + )) + .partition_by(partition_by) + .order_by(order_by) + .window_frame(window_frame) + .build() + .unwrap()) } window_expr_node::WindowFunction::BuiltInFunction(i) => { let built_in_function = protobuf::BuiltInWindowFunction::try_from(*i) @@ -330,7 +336,13 @@ pub fn parse_expr( expr::WindowFunctionDefinition::BuiltInWindowFunction( built_in_function, ), - args)).partition_by(partition_by).order_by(order_by).window_frame(window_frame).build().unwrap()) + args, + )) + .partition_by(partition_by) + .order_by(order_by) + .window_frame(window_frame) + .build() + .unwrap()) } window_expr_node::WindowFunction::Udaf(udaf_name) => { let udaf_function = match &expr.fun_definition { @@ -344,7 +356,13 @@ pub fn parse_expr( .unwrap_or_else(Vec::new); Ok(Expr::WindowFunction(WindowFunction::new( expr::WindowFunctionDefinition::AggregateUDF(udaf_function), - args)).partition_by(partition_by).order_by(order_by).window_frame(window_frame).build().unwrap()) + args, + )) + .partition_by(partition_by) + .order_by(order_by) + .window_frame(window_frame) + .build() + .unwrap()) } window_expr_node::WindowFunction::Udwf(udwf_name) => { let udwf_function = match &expr.fun_definition { @@ -358,7 +376,13 @@ pub fn parse_expr( .unwrap_or_else(Vec::new); Ok(Expr::WindowFunction(WindowFunction::new( expr::WindowFunctionDefinition::WindowUDF(udwf_function), - args)).partition_by(partition_by).order_by(order_by).window_frame(window_frame).build().unwrap()) + args, + )) + .partition_by(partition_by) + .order_by(order_by) + .window_frame(window_frame) + .build() + .unwrap()) } } } diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index d37aed4c839b..30af1e1e202a 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -59,7 +59,7 @@ use datafusion_expr::expr::{ }; use datafusion_expr::logical_plan::{Extension, UserDefinedLogicalNodeCore}; use datafusion_expr::{ - Accumulator, ExprFunctionExt, AggregateFunction, AggregateUDF, ColumnarValue, + Accumulator, AggregateFunction, AggregateUDF, ColumnarValue, ExprFunctionExt, ExprSchemable, Literal, LogicalPlan, Operator, PartitionEvaluator, ScalarUDF, Signature, TryCast, Volatility, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, WindowUDF, WindowUDFImpl, @@ -2040,14 +2040,26 @@ fn roundtrip_window() { WindowFunctionDefinition::BuiltInWindowFunction( datafusion_expr::BuiltInWindowFunction::Rank, ), - vec![])).partition_by(vec![col("col1")]).order_by(vec![col("col2")]).window_frame(WindowFrame::new(Some(false))).build().unwrap(); + vec![], + )) + .partition_by(vec![col("col1")]) + .order_by(vec![col("col2")]) + .window_frame(WindowFrame::new(Some(false))) + .build() + .unwrap(); // 2. with default window_frame let test_expr2 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::BuiltInWindowFunction( datafusion_expr::BuiltInWindowFunction::Rank, ), - vec![])).partition_by(vec![col("col1")]).order_by(vec![col("col2")]).window_frame(WindowFrame::new(Some(false))).build().unwrap(); + vec![], + )) + .partition_by(vec![col("col1")]) + .order_by(vec![col("col2")]) + .window_frame(WindowFrame::new(Some(false))) + .build() + .unwrap(); // 3. with window_frame with row numbers let range_number_frame = WindowFrame::new_bounds( @@ -2060,7 +2072,13 @@ fn roundtrip_window() { WindowFunctionDefinition::BuiltInWindowFunction( datafusion_expr::BuiltInWindowFunction::Rank, ), - vec![])).partition_by(vec![col("col1")]).order_by(vec![col("col2")]).window_frame(range_number_frame).build().unwrap(); + vec![], + )) + .partition_by(vec![col("col1")]) + .order_by(vec![col("col2")]) + .window_frame(range_number_frame) + .build() + .unwrap(); // 4. test with AggregateFunction let row_number_frame = WindowFrame::new_bounds( @@ -2071,7 +2089,13 @@ fn roundtrip_window() { let test_expr4 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), - vec![col("col1")])).partition_by(vec![col("col1")]).order_by(vec![col("col2")]).window_frame(row_number_frame.clone()).build().unwrap(); + vec![col("col1")], + )) + .partition_by(vec![col("col1")]) + .order_by(vec![col("col2")]) + .window_frame(row_number_frame.clone()) + .build() + .unwrap(); // 5. test with AggregateUDF #[derive(Debug)] @@ -2115,7 +2139,13 @@ fn roundtrip_window() { let test_expr5 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(Arc::new(dummy_agg.clone())), - vec![col("col1")])).partition_by(vec![col("col1")]).order_by(vec![col("col2")]).window_frame(row_number_frame.clone()).build().unwrap(); + vec![col("col1")], + )) + .partition_by(vec![col("col1")]) + .order_by(vec![col("col2")]) + .window_frame(row_number_frame.clone()) + .build() + .unwrap(); ctx.register_udaf(dummy_agg); // 6. test with WindowUDF @@ -2186,11 +2216,21 @@ fn roundtrip_window() { let test_expr6 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::WindowUDF(Arc::new(dummy_window_udf.clone())), - vec![col("col1")])).partition_by(vec![col("col1")]).order_by(vec![col("col2")]).window_frame(row_number_frame.clone()).build().unwrap(); + vec![col("col1")], + )) + .partition_by(vec![col("col1")]) + .order_by(vec![col("col2")]) + .window_frame(row_number_frame.clone()) + .build() + .unwrap(); let text_expr7 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(avg_udaf()), - vec![col("col1")])).window_frame(row_number_frame.clone()).build().unwrap(); + vec![col("col1")], + )) + .window_frame(row_number_frame.clone()) + .build() + .unwrap(); ctx.register_udwf(dummy_window_udf); diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 43adc8db2d0f..d5571e0221dc 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -23,7 +23,8 @@ use datafusion_common::{ }; use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by}; use datafusion_expr::{ - expr, AggregateFunction, Expr, ExprFunctionExt, ExprSchemable, WindowFrame, WindowFunctionDefinition + expr, AggregateFunction, Expr, ExprFunctionExt, ExprSchemable, WindowFrame, + WindowFunctionDefinition, }; use datafusion_expr::{ expr::{ScalarFunction, Unnest}, @@ -316,13 +317,25 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(aggregate_fun), - args)).partition_by(partition_by).order_by(order_by).window_frame(window_frame).null_treatment(null_treatment).build().unwrap() + args, + )) + .partition_by(partition_by) + .order_by(order_by) + .window_frame(window_frame) + .null_treatment(null_treatment) + .build() + .unwrap() } - _ => { - Expr::WindowFunction(expr::WindowFunction::new( - fun, - self.function_args_to_expr(args, schema, planner_context)?)).partition_by(partition_by).order_by(order_by).window_frame(window_frame).null_treatment(null_treatment).build().unwrap() - }, + _ => Expr::WindowFunction(expr::WindowFunction::new( + fun, + self.function_args_to_expr(args, schema, planner_context)?, + )) + .partition_by(partition_by) + .order_by(order_by) + .window_frame(window_frame) + .null_treatment(null_treatment) + .build() + .unwrap(), }; return Ok(expr); } From e9234e9f8d1d6420b9ee5cb19f173eddd4db8786 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sun, 21 Jul 2024 09:35:15 -0400 Subject: [PATCH 6/6] Removing builder structure on extension for expr functions and returning Expr instead --- datafusion-examples/examples/expr_api.rs | 7 +- datafusion/core/src/dataframe/mod.rs | 6 +- datafusion/core/src/physical_planner.rs | 31 +- datafusion/core/tests/dataframe/mod.rs | 8 +- datafusion/core/tests/expr_api/mod.rs | 44 +- datafusion/expr/src/expr.rs | 73 ++-- datafusion/expr/src/expr_fn.rs | 402 +++++++++--------- datafusion/expr/src/tree_node.rs | 42 +- datafusion/expr/src/udwf.rs | 4 +- datafusion/expr/src/utils.rs | 38 +- .../functions-aggregate/src/first_last.rs | 11 +- .../src/analyzer/count_wildcard_rule.rs | 6 +- .../optimizer/src/analyzer/type_coercion.rs | 36 +- .../optimizer/src/optimize_projections/mod.rs | 8 +- .../src/replace_distinct_aggregate.rs | 12 +- .../src/single_distinct_to_groupby.rs | 18 +- .../proto/src/logical_plan/from_proto.rs | 33 +- datafusion/proto/src/logical_plan/to_proto.rs | 23 +- .../tests/cases/roundtrip_logical_plan.rs | 69 ++- datafusion/sql/src/expr/function.rs | 22 +- datafusion/sql/src/unparser/expr.rs | 41 +- .../substrait/src/logical_plan/consumer.rs | 6 +- .../substrait/src/logical_plan/producer.rs | 23 +- 23 files changed, 466 insertions(+), 497 deletions(-) diff --git a/datafusion-examples/examples/expr_api.rs b/datafusion-examples/examples/expr_api.rs index f7ced0d7e077..c3f6013512dd 100644 --- a/datafusion-examples/examples/expr_api.rs +++ b/datafusion-examples/examples/expr_api.rs @@ -33,7 +33,7 @@ use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::expr::BinaryExpr; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::simplify::SimplifyContext; -use datafusion_expr::{ColumnarValue, ExprFunctionExt, ExprSchemable, Operator}; +use datafusion_expr::{ColumnarValue, ExprSchemable, Operator}; /// This example demonstrates the DataFusion [`Expr`] API. /// @@ -99,9 +99,8 @@ fn expr_fn_demo() -> Result<()> { // such as `FIRST_VALUE(price FILTER quantity > 100 ORDER BY ts ) let agg = first_value .call(vec![col("price")]) - .order_by(vec![col("ts").sort(false, false)]) - .filter(col("quantity").gt(lit(100))) - .build()?; // build the aggregate + .order_by(vec![col("ts").sort(false, false)])? + .filter(col("quantity").gt(lit(100)))?; assert_eq!( agg.to_string(), "first_value(price) FILTER (WHERE quantity > Int32(100)) ORDER BY [ts DESC NULLS LAST]" diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 0e3e9d79f886..4cb1ec24ae16 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -1696,7 +1696,7 @@ mod tests { use datafusion_common::{Constraint, Constraints, ScalarValue}; use datafusion_common_runtime::SpawnedTask; use datafusion_expr::{ - array_agg, cast, create_udf, expr, lit, BuiltInWindowFunction, ExprFunctionExt, + array_agg, cast, create_udf, expr, lit, BuiltInWindowFunction, ScalarFunctionImplementation, Volatility, WindowFunctionDefinition, }; use datafusion_functions_aggregate::expr_fn::count_distinct; @@ -1868,9 +1868,7 @@ mod tests { ), vec![col("aggregate_test_100.c1")], )) - .partition_by(vec![col("aggregate_test_100.c2")]) - .build() - .unwrap(); + .partition_by(vec![col("aggregate_test_100.c2")])?; let t2 = t.select(vec![col("c1"), first_row])?; let plan = t2.plan.clone(); diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index efc83d8f6b5c..13a9766d9aab 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -101,7 +101,7 @@ fn create_function_physical_name( fun: &str, distinct: bool, args: &[Expr], - order_by: Option<&Vec>, + order_by: &Option>, ) -> Result { let names: Vec = args .iter() @@ -115,7 +115,7 @@ fn create_function_physical_name( let phys_name = format!("{}({}{})", fun, distinct_str, names.join(",")); - Ok(order_by + Ok(order_by.as_ref() .map(|order_by| format!("{} ORDER BY [{}]", phys_name, expr_vec_fmt!(order_by))) .unwrap_or(phys_name)) } @@ -220,7 +220,7 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { order_by, .. }) => { - create_function_physical_name(&fun.to_string(), false, args, Some(order_by)) + create_function_physical_name(&fun.to_string(), false, args, order_by) } Expr::AggregateFunction(AggregateFunction { func_def, @@ -233,7 +233,7 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { func_def.name(), *distinct, args, - order_by.as_ref(), + order_by, ), Expr::GroupingSet(grouping_set) => match grouping_set { GroupingSet::Rollup(exprs) => Ok(format!( @@ -1744,22 +1744,25 @@ pub fn create_window_expr_with_name( let name = name.into(); let physical_schema: &Schema = &logical_schema.into(); match e { - Expr::WindowFunction(WindowFunction { - fun, - args, - partition_by, - order_by, - window_frame, - null_treatment, - }) => { + Expr::WindowFunction(window_fun) => { + let window_frame = window_fun.get_frame_or_default(); + let WindowFunction { + fun, + args, + partition_by, + order_by, + null_treatment, + .. + } = window_fun; + let physical_args = create_physical_exprs(args, logical_schema, execution_props)?; let partition_by = create_physical_exprs(partition_by, logical_schema, execution_props)?; let order_by = - create_physical_sort_exprs(order_by, logical_schema, execution_props)?; + create_physical_sort_exprs(order_by.as_ref().unwrap_or(&vec![]), logical_schema, execution_props)?; - if !is_window_frame_bound_valid(window_frame) { + if !is_window_frame_bound_valid(&window_frame) { return plan_err!( "Invalid window frame: start bound ({}) cannot be larger than end bound ({})", window_frame.start_bound, window_frame.end_bound diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 2c3b2bb88b37..2809a6049a75 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -55,7 +55,7 @@ use datafusion_expr::expr::{GroupingSet, Sort}; use datafusion_expr::var_provider::{VarProvider, VarType}; use datafusion_expr::{ array_agg, cast, col, exists, expr, in_subquery, lit, max, out_ref_col, placeholder, - scalar_subquery, when, wildcard, Expr, ExprFunctionExt, ExprSchemable, WindowFrame, + scalar_subquery, when, wildcard, Expr, ExprSchemable, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; use datafusion_functions_aggregate::expr_fn::{avg, count, sum}; @@ -184,14 +184,12 @@ async fn test_count_wildcard_on_window() -> Result<()> { WindowFunctionDefinition::AggregateUDF(count_udaf()), vec![wildcard()], )) - .order_by(vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))]) + .order_by(vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))])? .window_frame(WindowFrame::new_bounds( WindowFrameUnits::Range, WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))), WindowFrameBound::Following(ScalarValue::UInt32(Some(2))), - )) - .build() - .unwrap()])? + ))?])? .explain(false, false)? .collect() .await?; diff --git a/datafusion/core/tests/expr_api/mod.rs b/datafusion/core/tests/expr_api/mod.rs index d76b3c9dc1ec..62d2b3a95240 100644 --- a/datafusion/core/tests/expr_api/mod.rs +++ b/datafusion/core/tests/expr_api/mod.rs @@ -20,8 +20,7 @@ use arrow_array::builder::{ListBuilder, StringBuilder}; use arrow_array::{ArrayRef, Int64Array, RecordBatch, StringArray, StructArray}; use arrow_schema::{DataType, Field}; use datafusion::prelude::*; -use datafusion_common::{assert_contains, DFSchema, ScalarValue}; -use datafusion_expr::ExprFunctionExt; +use datafusion_common::{assert_contains, DFSchema, Result, ScalarValue}; use datafusion_functions::core::expr_ext::FieldAccessor; use datafusion_functions_aggregate::first_last::first_value_udaf; use datafusion_functions_aggregate::sum::sum_udaf; @@ -173,7 +172,6 @@ async fn test_aggregate_error() { .call(vec![col("props")]) // not a sort column .order_by(vec![col("id")]) - .build() .unwrap_err() .to_string(); assert_contains!( @@ -183,22 +181,18 @@ async fn test_aggregate_error() { } #[tokio::test] -async fn test_aggregate_ext_order_by() { +async fn test_aggregate_ext_order_by() -> Result<()> { let agg = first_value_udaf().call(vec![col("props")]); // ORDER BY id ASC let agg_asc = agg .clone() - .order_by(vec![col("id").sort(true, true)]) - .build() - .unwrap() + .order_by(vec![col("id").sort(true, true)])? .alias("asc"); // ORDER BY id DESC let agg_desc = agg - .order_by(vec![col("id").sort(false, true)]) - .build() - .unwrap() + .order_by(vec![col("id").sort(false, true)])? .alias("desc"); evaluate_agg_test( @@ -224,16 +218,15 @@ async fn test_aggregate_ext_order_by() { ], ) .await; + Ok(()) } #[tokio::test] -async fn test_aggregate_ext_filter() { +async fn test_aggregate_ext_filter() -> Result<()> { let agg = first_value_udaf() .call(vec![col("i")]) - .order_by(vec![col("i").sort(true, true)]) - .filter(col("i").is_not_null()) - .build() - .unwrap() + .order_by(vec![col("i").sort(true, true)])? + .filter(col("i").is_not_null())? .alias("val"); #[rustfmt::skip] @@ -248,16 +241,15 @@ async fn test_aggregate_ext_filter() { ], ) .await; + Ok(()) } #[tokio::test] -async fn test_aggregate_ext_distinct() { +async fn test_aggregate_ext_distinct() -> Result<()> { let agg = sum_udaf() .call(vec![lit(5)]) // distinct sum should be 5, not 15 - .distinct() - .build() - .unwrap() + .distinct()? .alias("distinct"); evaluate_agg_test( @@ -271,25 +263,22 @@ async fn test_aggregate_ext_distinct() { ], ) .await; + Ok(()) } #[tokio::test] -async fn test_aggregate_ext_null_treatment() { +async fn test_aggregate_ext_null_treatment() -> Result<()> { let agg = first_value_udaf() .call(vec![col("i")]) - .order_by(vec![col("i").sort(true, true)]); + .order_by(vec![col("i").sort(true, true)])?; let agg_respect = agg .clone() - .null_treatment(NullTreatment::RespectNulls) - .build() - .unwrap() + .null_treatment(NullTreatment::RespectNulls)? .alias("respect"); let agg_ignore = agg - .null_treatment(NullTreatment::IgnoreNulls) - .build() - .unwrap() + .null_treatment(NullTreatment::IgnoreNulls)? .alias("ignore"); evaluate_agg_test( @@ -315,6 +304,7 @@ async fn test_aggregate_ext_null_treatment() { ], ) .await; + Ok(()) } /// Evaluates the specified expr as an aggregate and compares the result to the diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 37acdb957544..f217a4429769 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -803,9 +803,9 @@ pub struct WindowFunction { /// List of partition by expressions pub partition_by: Vec, /// List of order by expressions - pub order_by: Vec, + pub order_by: Option>, /// Window frame - pub window_frame: window_frame::WindowFrame, + pub window_frame: Option, /// Specifies how NULL value is treated: ignore or respect pub null_treatment: Option, } @@ -817,11 +817,18 @@ impl WindowFunction { fun: fun.into(), args, partition_by: Vec::default(), - order_by: Vec::default(), - window_frame: WindowFrame::new(None), + order_by: None, + window_frame: None, null_treatment: None, } } + + pub fn get_frame_or_default(&self) -> WindowFrame { + match &self.window_frame { + Some(frame) => frame.clone(), + None => WindowFrame::new(self.order_by.as_ref().map(|exprs| !exprs.is_empty())) + } + } } /// Find DataFusion's built-in window function by name. @@ -1833,14 +1840,15 @@ impl fmt::Display for Expr { Expr::ScalarFunction(fun) => { fmt_function(f, fun.name(), false, &fun.args, true) } - Expr::WindowFunction(WindowFunction { - fun, - args, - partition_by, - order_by, - window_frame, - null_treatment, - }) => { + Expr::WindowFunction(window_fun) => { + let WindowFunction { + fun, + args, + partition_by, + order_by, + null_treatment, + .. + } = window_fun; fmt_function(f, &fun.to_string(), false, args, true)?; if let Some(nt) = null_treatment { @@ -1850,13 +1858,16 @@ impl fmt::Display for Expr { if !partition_by.is_empty() { write!(f, " PARTITION BY [{}]", expr_vec_fmt!(partition_by))?; } - if !order_by.is_empty() { - write!(f, " ORDER BY [{}]", expr_vec_fmt!(order_by))?; + if let Some(o) = order_by { + if !o.is_empty() { + write!(f, " ORDER BY [{}]", expr_vec_fmt!(o))?; + } } + let frame = window_fun.get_frame_or_default(); write!( f, " {} BETWEEN {} AND {}", - window_frame.units, window_frame.start_bound, window_frame.end_bound + frame.units, frame.start_bound, frame.end_bound )?; Ok(()) } @@ -2155,14 +2166,15 @@ fn write_name(w: &mut W, e: &Expr) -> Result<()> { Expr::ScalarFunction(fun) => { w.write_str(fun.func.display_name(&fun.args)?.as_str())?; } - Expr::WindowFunction(WindowFunction { - fun, - args, - window_frame, - partition_by, - order_by, - null_treatment, - }) => { + Expr::WindowFunction(window_fun) => { + let WindowFunction { + fun, + args, + partition_by, + order_by, + null_treatment, + .. + } = window_fun; write_function_name(w, &fun.to_string(), false, args)?; if let Some(nt) = null_treatment { @@ -2173,12 +2185,15 @@ fn write_name(w: &mut W, e: &Expr) -> Result<()> { w.write_str(" ")?; write!(w, "PARTITION BY [{}]", expr_vec_fmt!(partition_by))?; } - if !order_by.is_empty() { - w.write_str(" ")?; - write!(w, "ORDER BY [{}]", expr_vec_fmt!(order_by))?; + if let Some(o) = order_by { + if !o.is_empty() { + w.write_str(" ")?; + write!(w, "ORDER BY [{}]", expr_vec_fmt!(o))?; + } } + let frame = window_fun.get_frame_or_default(); w.write_str(" ")?; - write!(w, "{window_frame}")?; + write!(w, "{frame}")?; } Expr::AggregateFunction(AggregateFunction { func_def, @@ -2192,8 +2207,8 @@ fn write_name(w: &mut W, e: &Expr) -> Result<()> { if let Some(fe) = filter { write!(w, " FILTER (WHERE {fe})")?; }; - if let Some(order_by) = order_by { - write!(w, " ORDER BY [{}]", expr_vec_fmt!(order_by))?; + if let Some(o) = order_by { + write!(w, " ORDER BY [{}]", expr_vec_fmt!(o))?; }; if let Some(nt) = null_treatment { write!(w, " {}", nt)?; diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 7281a1211aa8..cbfed26b1e6e 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -19,7 +19,7 @@ use crate::expr::{ AggregateFunction, BinaryExpr, Cast, Exists, GroupingSet, InList, InSubquery, - Placeholder, TryCast, Unnest, WindowFunction, + Placeholder, TryCast, Unnest, }; use crate::function::{ AccumulatorArgs, AccumulatorFactoryFunction, PartitionEvaluatorFactory, @@ -707,235 +707,243 @@ pub fn interval_month_day_nano_lit(value: &str) -> Expr { /// # Ok(()) /// # } /// ``` -pub trait ExprFunctionExt { - /// Add `ORDER BY ` - /// - /// Note: `order_by` must be [`Expr::Sort`] - fn order_by(self, order_by: Vec) -> ExprFuncBuilder; - /// Add `FILTER ` - fn filter(self, filter: Expr) -> ExprFuncBuilder; - /// Add `DISTINCT` - fn distinct(self) -> ExprFuncBuilder; - /// Add `RESPECT NULLS` or `IGNORE NULLS` - fn null_treatment( - self, - null_treatment: impl Into>, - ) -> ExprFuncBuilder; - // Add `PARTITION BY` - fn partition_by(self, partition_by: Vec) -> ExprFuncBuilder; - // Add appropriate window frame conditions - fn window_frame(self, window_frame: WindowFrame) -> ExprFuncBuilder; -} - -#[derive(Debug, Clone)] -pub enum ExprFuncKind { - Aggregate(AggregateFunction), - Window(WindowFunction), -} - -/// Implementation of [`ExprFunctionExt`]. -/// -/// See [`ExprFunctionExt`] for usage and examples -#[derive(Debug, Clone)] -pub struct ExprFuncBuilder { - fun: Option, - order_by: Option>, - filter: Option, - distinct: bool, - null_treatment: Option, - partition_by: Option>, - window_frame: Option, -} - -impl ExprFuncBuilder { - /// Create a new `ExprFuncBuilder`, see [`ExprFunctionExt`] - - fn new(fun: Option) -> Self { - Self { - fun, - order_by: None, - filter: None, - distinct: false, - null_treatment: None, - partition_by: None, - window_frame: None, - } - } - - /// Updates and returns the in progress [`Expr::AggregateFunction`] or [`Expr::WindowFunction`] - /// - /// # Errors: - /// - /// Returns an error of this builder [`ExprFunctionExt`] was used with an - /// `Expr` variant other than [`Expr::AggregateFunction`] or [`Expr::WindowFunction`] - pub fn build(self) -> Result { - let Self { - fun, - order_by, - filter, - distinct, - null_treatment, - partition_by, - window_frame, - } = self; - - let Some(fun) = fun else { - return plan_err!( - "ExprFunctionExt can only be used with Expr::AggregateFunction or Expr::WindowFunction" - ); - }; - - if let Some(order_by) = &order_by { - for expr in order_by.iter() { - if !matches!(expr, Expr::Sort(_)) { - return plan_err!( - "ORDER BY expressions must be Expr::Sort, found {expr:?}" - ); - } +// pub trait ExprFunctionExt { +// /// Add `ORDER BY ` +// /// +// /// Note: `order_by` must be [`Expr::Sort`] +// fn order_by(self, order_by: Vec) -> ExprFuncBuilder; +// /// Add `FILTER ` +// fn filter(self, filter: Expr) -> ExprFuncBuilder; +// /// Add `DISTINCT` +// fn distinct(self) -> ExprFuncBuilder; +// /// Add `RESPECT NULLS` or `IGNORE NULLS` +// fn null_treatment( +// self, +// null_treatment: impl Into>, +// ) -> ExprFuncBuilder; +// // Add `PARTITION BY` +// fn partition_by(self, partition_by: Vec) -> ExprFuncBuilder; +// // Add appropriate window frame conditions +// fn window_frame(self, window_frame: WindowFrame) -> ExprFuncBuilder; +// } + +// #[derive(Debug, Clone)] +// pub enum ExprFuncKind { +// Aggregate(AggregateFunction), +// Window(WindowFunction), +// } + +// /// Implementation of [`ExprFunctionExt`]. +// /// +// /// See [`ExprFunctionExt`] for usage and examples +// #[derive(Debug, Clone)] +// pub struct ExprFuncBuilder { +// fun: Option, +// order_by: Option>, +// filter: Option, +// distinct: bool, +// null_treatment: Option, +// partition_by: Option>, +// window_frame: Option, +// } + +// impl ExprFuncBuilder { +// /// Create a new `ExprFuncBuilder`, see [`ExprFunctionExt`] + +// fn new(fun: Option) -> Self { +// Self { +// fun, +// order_by: None, +// filter: None, +// distinct: false, +// null_treatment: None, +// partition_by: None, +// window_frame: None, +// } +// } + +// /// Updates and returns the in progress [`Expr::AggregateFunction`] or [`Expr::WindowFunction`] +// /// +// /// # Errors: +// /// +// /// Returns an error of this builder [`ExprFunctionExt`] was used with an +// /// `Expr` variant other than [`Expr::AggregateFunction`] or [`Expr::WindowFunction`] +// pub fn build(self) -> Result { +// let Self { +// fun, +// order_by, +// filter, +// distinct, +// null_treatment, +// partition_by, +// window_frame, +// } = self; + +// let Some(fun) = fun else { +// return plan_err!( +// "ExprFunctionExt can only be used with Expr::AggregateFunction or Expr::WindowFunction" +// ); +// }; + +// if let Some(order_by) = &order_by { +// for expr in order_by.iter() { +// if !matches!(expr, Expr::Sort(_)) { +// return plan_err!( +// "ORDER BY expressions must be Expr::Sort, found {expr:?}" +// ); +// } +// } +// } + +// let fun_expr = match fun { +// ExprFuncKind::Aggregate(mut udaf) => { +// udaf.order_by = order_by; +// udaf.filter = filter.map(Box::new); +// udaf.distinct = distinct; +// udaf.null_treatment = null_treatment; +// Expr::AggregateFunction(udaf) +// } +// ExprFuncKind::Window(mut udwf) => { +// let has_order_by = order_by.as_ref().map(|o| o.len() > 0); +// udwf.order_by = order_by.unwrap_or_default(); +// udwf.partition_by = partition_by.unwrap_or_default(); +// udwf.window_frame = +// window_frame.unwrap_or(WindowFrame::new(has_order_by)); +// udwf.null_treatment = null_treatment; +// Expr::WindowFunction(udwf) +// } +// }; + +// Ok(fun_expr) +// } +// } + +// impl ExprFunctionExt for ExprFuncBuilder { +// /// Add `ORDER BY ` +// /// +// /// Note: `order_by` must be [`Expr::Sort`] +// fn order_by(mut self, order_by: Vec) -> ExprFuncBuilder { +// self.order_by = Some(order_by); +// self +// } + +// /// Add `FILTER ` +// fn filter(mut self, filter: Expr) -> ExprFuncBuilder { +// self.filter = Some(filter); +// self +// } + +// /// Add `DISTINCT` +// fn distinct(mut self) -> ExprFuncBuilder { +// self.distinct = true; +// self +// } + +// /// Add `RESPECT NULLS` or `IGNORE NULLS` +// fn null_treatment( +// mut self, +// null_treatment: impl Into>, +// ) -> ExprFuncBuilder { +// self.null_treatment = null_treatment.into(); +// self +// } + +// fn partition_by(mut self, partition_by: Vec) -> ExprFuncBuilder { +// self.partition_by = Some(partition_by); +// self +// } + +// fn window_frame(mut self, window_frame: WindowFrame) -> ExprFuncBuilder { +// self.window_frame = Some(window_frame); +// self +// } +// } + +impl Expr { + pub fn order_by(mut self, order_by: impl Into>>) -> Result { + let order_by = order_by.into(); + if let Some(exprs) = &order_by { + if let Some(expr) = exprs.iter().find(|expr| !matches!(expr, Expr::Sort(_))) { + return plan_err!("ORDER BY expressions must be Expr::Sort, found {expr:?}"); } } - let fun_expr = match fun { - ExprFuncKind::Aggregate(mut udaf) => { - udaf.order_by = order_by; - udaf.filter = filter.map(Box::new); - udaf.distinct = distinct; - udaf.null_treatment = null_treatment; - Expr::AggregateFunction(udaf) - } - ExprFuncKind::Window(mut udwf) => { - let has_order_by = order_by.as_ref().map(|o| o.len() > 0); - udwf.order_by = order_by.unwrap_or_default(); - udwf.partition_by = partition_by.unwrap_or_default(); - udwf.window_frame = - window_frame.unwrap_or(WindowFrame::new(has_order_by)); - udwf.null_treatment = null_treatment; - Expr::WindowFunction(udwf) - } - }; - - Ok(fun_expr) - } -} - -impl ExprFunctionExt for ExprFuncBuilder { - /// Add `ORDER BY ` - /// - /// Note: `order_by` must be [`Expr::Sort`] - fn order_by(mut self, order_by: Vec) -> ExprFuncBuilder { - self.order_by = Some(order_by); - self - } - - /// Add `FILTER ` - fn filter(mut self, filter: Expr) -> ExprFuncBuilder { - self.filter = Some(filter); - self - } - - /// Add `DISTINCT` - fn distinct(mut self) -> ExprFuncBuilder { - self.distinct = true; - self - } - - /// Add `RESPECT NULLS` or `IGNORE NULLS` - fn null_treatment( - mut self, - null_treatment: impl Into>, - ) -> ExprFuncBuilder { - self.null_treatment = null_treatment.into(); - self - } - - fn partition_by(mut self, partition_by: Vec) -> ExprFuncBuilder { - self.partition_by = Some(partition_by); - self - } - - fn window_frame(mut self, window_frame: WindowFrame) -> ExprFuncBuilder { - self.window_frame = Some(window_frame); - self - } -} - -impl ExprFunctionExt for Expr { - fn order_by(self, order_by: Vec) -> ExprFuncBuilder { - let mut builder = match self { + match &mut self { Expr::AggregateFunction(udaf) => { - ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf))) + udaf.order_by = order_by; } Expr::WindowFunction(udwf) => { - ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf))) + udwf.order_by = order_by; + } + _ => { + return plan_err!("order_by can only be used with AggregateFunction or WindowFunction expressions.") } - _ => ExprFuncBuilder::new(None), - }; - if builder.fun.is_some() { - builder.order_by = Some(order_by); } - builder + Ok(self) } - fn filter(self, filter: Expr) -> ExprFuncBuilder { - match self { + pub fn filter(mut self, filter: Expr) -> Result { + match &mut self { Expr::AggregateFunction(udaf) => { - let mut builder = - ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf))); - builder.filter = Some(filter); - builder + udaf.filter = Some(Box::new(filter)); + } + _ => { + return plan_err!("filter can only be used with AggregateFunction expressions.") } - _ => ExprFuncBuilder::new(None), } + Ok(self) } - fn distinct(self) -> ExprFuncBuilder { - match self { + pub fn distinct(mut self) -> Result { + match &mut self { Expr::AggregateFunction(udaf) => { - let mut builder = - ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf))); - builder.distinct = true; - builder + udaf.distinct = true; + } + _ => { + return plan_err!("distinct can only be used with AggregateFunction expressions.") } - _ => ExprFuncBuilder::new(None), } + Ok(self) } - fn null_treatment( - self, + pub fn null_treatment( + mut self, null_treatment: impl Into>, - ) -> ExprFuncBuilder { - let mut builder = match self { + ) -> Result { + match &mut self { Expr::AggregateFunction(udaf) => { - ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf))) + udaf.null_treatment = null_treatment.into(); + // ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf))) } Expr::WindowFunction(udwf) => { - ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf))) + udwf.null_treatment = null_treatment.into(); + } + _ => { + return plan_err!("null_treatment can only be used with AggregateFunction or WindowFunction expressions.") } - _ => ExprFuncBuilder::new(None), - }; - if builder.fun.is_some() { - builder.null_treatment = null_treatment.into(); } - builder + Ok(self) } - fn partition_by(self, partition_by: Vec) -> ExprFuncBuilder { - match self { + pub fn partition_by(mut self, partition_by: Vec) -> Result { + match &mut self { Expr::WindowFunction(udwf) => { - let mut builder = ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf))); - builder.partition_by = Some(partition_by); - builder + udwf.partition_by = partition_by; + } + _ => { + return plan_err!("partition_by can only be used with WindowFunction expressions.") } - _ => ExprFuncBuilder::new(None), } + Ok(self) } - fn window_frame(self, window_frame: WindowFrame) -> ExprFuncBuilder { - match self { + pub fn window_frame(mut self, window_frame: WindowFrame) -> Result { + match &mut self { Expr::WindowFunction(udwf) => { - let mut builder = ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf))); - builder.window_frame = Some(window_frame); - builder + udwf.window_frame = Some(window_frame); + } + _ => { + return plan_err!("window_frame can only be used with WindowFunction expressions.") } - _ => ExprFuncBuilder::new(None), } + Ok(self) } } diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index a97b9f010f79..dfae31c7b9f5 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -22,7 +22,7 @@ use crate::expr::{ Cast, GroupingSet, InList, InSubquery, Like, Placeholder, ScalarFunction, Sort, TryCast, Unnest, WindowFunction, }; -use crate::{Expr, ExprFunctionExt}; +use crate::Expr; use datafusion_common::tree_node::{ Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion, @@ -111,7 +111,9 @@ impl TreeNode for Expr { }) => { let mut expr_vec = args.iter().collect::>(); expr_vec.extend(partition_by); - expr_vec.extend(order_by); + if let Some(o) = order_by { + expr_vec.extend(o); + } expr_vec } Expr::InList(InList { expr, list, .. }) => { @@ -279,29 +281,31 @@ impl TreeNode for Expr { ))) })? } - Expr::WindowFunction(WindowFunction { - args, - fun, - partition_by, - order_by, - window_frame, - null_treatment, - }) => map_until_stop_and_collect!( + Expr::WindowFunction(window_fun) => { + let WindowFunction { + args, + fun, + partition_by, + order_by, + null_treatment, + window_frame, + } = window_fun; + map_until_stop_and_collect!( transform_vec(args, &mut f), partition_by, transform_vec(partition_by, &mut f), order_by, - transform_vec(order_by, &mut f) + transform_option_vec(order_by, &mut f) )? .update_data(|(new_args, new_partition_by, new_order_by)| { - Expr::WindowFunction(WindowFunction::new(fun, new_args)) - .partition_by(new_partition_by) - .order_by(new_order_by) - .window_frame(window_frame) - .null_treatment(null_treatment) - .build() - .unwrap() - }), + Expr::WindowFunction(WindowFunction { + fun, + args: new_args, + partition_by: new_partition_by, + order_by: new_order_by, + window_frame: window_frame, + null_treatment: null_treatment, + })})} Expr::AggregateFunction(AggregateFunction { args, func_def, diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs index 1a6b21e3dd29..2c193c7fa416 100644 --- a/datafusion/expr/src/udwf.rs +++ b/datafusion/expr/src/udwf.rs @@ -141,8 +141,8 @@ impl WindowUDF { fun, args, partition_by, - order_by, - window_frame, + order_by: Some(order_by), + window_frame: Some(window_frame), null_treatment: None, }) } diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index b833c1db06a2..e8059d220422 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -464,8 +464,10 @@ type WindowSortKey = Vec<(Expr, bool)>; /// Generate a sort key for a given window expr's partition_by and order_bu expr pub fn generate_sort_key( partition_by: &[Expr], - order_by: &[Expr], + order_by: &Option>, ) -> Result { + let empty_vec = vec![]; + let order_by = order_by.as_ref().unwrap_or(&empty_vec); let normalized_order_by_keys = order_by .iter() .map(|e| match e { @@ -1252,9 +1254,7 @@ impl AggregateOrderSensitivity { mod tests { use super::*; use crate::{ - col, cube, expr, expr_vec_fmt, grouping_set, lit, rollup, - test::function_stub::sum_udaf, AggregateFunction, Cast, ExprFunctionExt, - WindowFrame, WindowFunctionDefinition, + col, cube, expr, expr_vec_fmt, grouping_set, lit, rollup, test::function_stub::sum_udaf, AggregateFunction, Cast, WindowFrame, WindowFunctionDefinition }; #[test] @@ -1302,9 +1302,7 @@ mod tests { WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![col("name")], )) - .order_by(vec![age_asc.clone(), name_desc.clone()]) - .build() - .unwrap(); + .order_by(vec![age_asc.clone(), name_desc.clone()])?; let max2 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![col("name")], @@ -1313,9 +1311,7 @@ mod tests { WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min), vec![col("name")], )) - .order_by(vec![age_asc.clone(), name_desc.clone()]) - .build() - .unwrap(); + .order_by(vec![age_asc.clone(), name_desc.clone()])?; let sum4 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(sum_udaf()), vec![col("age")], @@ -1324,9 +1320,7 @@ mod tests { name_desc.clone(), age_asc.clone(), created_at_desc.clone(), - ]) - .build() - .unwrap(); + ])?; // FIXME use as_ref let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()]; let result = group_window_expr_by_sort_keys(exprs.to_vec())?; @@ -1358,10 +1352,8 @@ mod tests { .order_by(vec![ Expr::Sort(expr::Sort::new(Box::new(col("age")), true, true)), Expr::Sort(expr::Sort::new(Box::new(col("name")), false, true)), - ]) - .window_frame(WindowFrame::new(Some(false))) - .build() - .unwrap(), + ])? + .window_frame(WindowFrame::new(Some(false)))?, Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(sum_udaf()), vec![col("age")], @@ -1370,10 +1362,8 @@ mod tests { Expr::Sort(expr::Sort::new(Box::new(col("name")), false, true)), Expr::Sort(expr::Sort::new(Box::new(col("age")), true, true)), Expr::Sort(expr::Sort::new(Box::new(col("created_at")), false, true)), - ]) - .window_frame(WindowFrame::new(Some(false))) - .build() - .unwrap(), + ])? + .window_frame(WindowFrame::new(Some(false)))?, ]; let expected = vec![ Expr::Sort(expr::Sort::new(Box::new(col("age")), true, true)), @@ -1392,7 +1382,7 @@ mod tests { let partition_by = &[col("age"), col("name"), col("created_at")]; for asc_ in asc_or_desc { for nulls_first_ in nulls_first_or_last { - let order_by = &[ + let order_by = Some(vec![ Expr::Sort(Sort { expr: Box::new(col("age")), asc: asc_, @@ -1403,7 +1393,7 @@ mod tests { asc: asc_, nulls_first: nulls_first_, }), - ]; + ]); let expected = vec![ ( @@ -1431,7 +1421,7 @@ mod tests { true, ), ]; - let result = generate_sort_key(partition_by, order_by)?; + let result = generate_sort_key(partition_by, &order_by)?; assert_eq!(expected, result); } } diff --git a/datafusion/functions-aggregate/src/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs index 39f1944452af..156abe64a0ed 100644 --- a/datafusion/functions-aggregate/src/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -31,7 +31,7 @@ use datafusion_common::{ use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::utils::{format_state_name, AggregateOrderSensitivity}; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, ArrayFunctionSignature, Expr, ExprFunctionExt, + Accumulator, AggregateUDFImpl, ArrayFunctionSignature, Expr, Signature, TypeSignature, Volatility, }; use datafusion_physical_expr_common::aggregate::utils::get_sort_options; @@ -42,16 +42,13 @@ use datafusion_physical_expr_common::sort_expr::{ create_func!(FirstValue, first_value_udaf); /// Returns the first value in a group of values. -pub fn first_value(expression: Expr, order_by: Option>) -> Expr { +pub fn first_value(expression: Expr, order_by: Option>) -> Result { if let Some(order_by) = order_by { first_value_udaf() .call(vec![expression]) - .order_by(order_by) - .build() - // guaranteed to be `Expr::AggregateFunction` - .unwrap() + .order_by(order_by.clone()) } else { - first_value_udaf().call(vec![expression]) + Ok(first_value_udaf().call(vec![expression])) } } diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index a9114ee70a59..101e8f50138c 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -101,7 +101,6 @@ mod tests { use arrow::datatypes::DataType; use datafusion_common::ScalarValue; use datafusion_expr::expr::Sort; - use datafusion_expr::ExprFunctionExt; use datafusion_expr::{ col, exists, expr, in_subquery, logical_plan::LogicalPlanBuilder, max, out_ref_col, scalar_subquery, wildcard, WindowFrame, WindowFrameBound, @@ -225,13 +224,12 @@ mod tests { WindowFunctionDefinition::AggregateUDF(count_udaf()), vec![wildcard()], )) - .order_by(vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))]) + .order_by(vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))])? .window_frame(WindowFrame::new_bounds( WindowFrameUnits::Range, WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))), WindowFrameBound::Following(ScalarValue::UInt32(Some(2))), - )) - .build()?])? + ))?])? .project(vec![count(wildcard())])? .build()?; diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 16202f137d06..c31cebb5f8a6 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -47,7 +47,7 @@ use datafusion_expr::type_coercion::{is_datetime, is_utf8_or_large_utf8}; use datafusion_expr::utils::merge_schema; use datafusion_expr::{ is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, not, - type_coercion, AggregateFunction, AggregateUDF, Expr, ExprFunctionExt, ExprSchemable, + type_coercion, AggregateFunction, AggregateUDF, Expr, ExprSchemable, LogicalPlan, Operator, ScalarUDF, Signature, WindowFrame, WindowFrameBound, WindowFrameUnits, }; @@ -429,14 +429,16 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { ))) } }, - Expr::WindowFunction(WindowFunction { - fun, - args, - partition_by, - order_by, - window_frame, - null_treatment, - }) => { + Expr::WindowFunction(window_fun) => { + let window_frame = window_fun.get_frame_or_default(); + let WindowFunction { + fun, + args, + partition_by, + order_by, + null_treatment, + .. + } = window_fun; let window_frame = coerce_window_frame(window_frame, self.schema, &order_by)?; @@ -461,11 +463,10 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { Ok(Transformed::yes( Expr::WindowFunction(WindowFunction::new(fun, args)) - .partition_by(partition_by) - .order_by(order_by) - .window_frame(window_frame) - .null_treatment(null_treatment) - .build()?, + .partition_by(partition_by)? + .order_by(order_by)? + .window_frame(window_frame)? + .null_treatment(null_treatment)? )) } Expr::Alias(_) @@ -563,12 +564,13 @@ fn coerce_frame_bound( // Coerces the given `window_frame` to use appropriate natural types. // For example, ROWS and GROUPS frames use `UInt64` during calculations. fn coerce_window_frame( - window_frame: WindowFrame, + mut window_frame: WindowFrame, schema: &DFSchema, - expressions: &[Expr], + expressions: &Option>, ) -> Result { - let mut window_frame = window_frame; let current_types = expressions + .as_ref() + .unwrap_or(&vec![]) .iter() .map(|e| e.get_type(schema)) .collect::>>()?; diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index 16abf93f3807..27346ded95d5 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -806,7 +806,6 @@ mod tests { use datafusion_common::{ Column, DFSchema, DFSchemaRef, JoinType, Result, TableReference, }; - use datafusion_expr::ExprFunctionExt; use datafusion_expr::{ binary_expr, build_join_schema, builder::table_scan_with_filters, @@ -1880,8 +1879,7 @@ mod tests { let table_scan = test_table_scan()?; let aggr_with_filter = count_udaf() .call(vec![col("b")]) - .filter(col("c").gt(lit(42))) - .build()?; + .filter(col("c").gt(lit(42)))?; let plan = LogicalPlanBuilder::from(table_scan) .aggregate( vec![col("a")], @@ -1920,9 +1918,7 @@ mod tests { WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![col("test.a")], )) - .partition_by(vec![col("test.b")]) - .build() - .unwrap(); + .partition_by(vec![col("test.b")])?; let max2 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), diff --git a/datafusion/optimizer/src/replace_distinct_aggregate.rs b/datafusion/optimizer/src/replace_distinct_aggregate.rs index 430517121f2a..612c1c26e94a 100644 --- a/datafusion/optimizer/src/replace_distinct_aggregate.rs +++ b/datafusion/optimizer/src/replace_distinct_aggregate.rs @@ -23,8 +23,9 @@ use datafusion_common::tree_node::Transformed; use datafusion_common::{Column, Result}; use datafusion_expr::expr_rewriter::normalize_cols; use datafusion_expr::utils::expand_wildcard; -use datafusion_expr::{col, ExprFunctionExt, LogicalPlanBuilder}; +use datafusion_expr::{col, LogicalPlanBuilder}; use datafusion_expr::{Aggregate, Distinct, DistinctOn, Expr, LogicalPlan}; +use itertools::Itertools; /// Optimizer that replaces logical [[Distinct]] with a logical [[Aggregate]] /// @@ -96,18 +97,15 @@ impl OptimizerRule for ReplaceDistinctWithAggregate { // Construct the aggregation expression to be used to fetch the selected expressions. let first_value_udaf: std::sync::Arc = config.function_registry().unwrap().udaf("first_value")?; - let aggr_expr = select_expr.into_iter().map(|e| { + let aggr_expr: Vec = select_expr.into_iter().map(|e| { if let Some(order_by) = &sort_expr { first_value_udaf .call(vec![e]) .order_by(order_by.clone()) - .build() - // guaranteed to be `Expr::AggregateFunction` - .unwrap() } else { - first_value_udaf.call(vec![e]) + Ok(first_value_udaf.call(vec![e])) } - }); + }).try_collect()?; let aggr_expr = normalize_cols(aggr_expr, input.as_ref())?; let group_expr = normalize_cols(on_expr, input.as_ref())?; diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index d776e6598cbe..9a69e3caf189 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -354,7 +354,6 @@ mod tests { use super::*; use crate::test::*; use datafusion_expr::expr::{self, GroupingSet}; - use datafusion_expr::ExprFunctionExt; use datafusion_expr::{ lit, logical_plan::builder::LogicalPlanBuilder, max, min, AggregateFunction, }; @@ -675,9 +674,8 @@ mod tests { // count(DISTINCT a) FILTER (WHERE a > 5) let expr = count_udaf() .call(vec![col("a")]) - .distinct() - .filter(col("a").gt(lit(5))) - .build()?; + .distinct()? + .filter(col("a").gt(lit(5)))?; let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![col("c")], vec![sum(col("a")), expr])? .build()?; @@ -718,9 +716,8 @@ mod tests { // count(DISTINCT a ORDER BY a) let expr = count_udaf() .call(vec![col("a")]) - .distinct() - .order_by(vec![col("a").sort(true, false)]) - .build()?; + .distinct()? + .order_by(vec![col("a").sort(true, false)])?; let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![col("c")], vec![sum(col("a")), expr])? .build()?; @@ -738,10 +735,9 @@ mod tests { // count(DISTINCT a ORDER BY a) FILTER (WHERE a > 5) let expr = count_udaf() .call(vec![col("a")]) - .distinct() - .filter(col("a").gt(lit(5))) - .order_by(vec![col("a").sort(true, false)]) - .build()?; + .distinct()? + .filter(col("a").gt(lit(5)))? + .order_by(vec![col("a").sort(true, false)])?; let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![col("c")], vec![sum(col("a")), expr])? .build()?; diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index f3f0e603e060..a27049aa80e8 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -25,7 +25,6 @@ use datafusion_common::{ use datafusion_expr::expr::Unnest; use datafusion_expr::expr::{Alias, Placeholder}; use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by}; -use datafusion_expr::ExprFunctionExt; use datafusion_expr::{ expr::{self, InList, Sort, WindowFunction}, logical_plan::{PlanType, StringifiedPlan}, @@ -316,11 +315,9 @@ pub fn parse_expr( codec, )?], )) - .partition_by(partition_by) - .order_by(order_by) - .window_frame(window_frame) - .build() - .unwrap()) + .partition_by(partition_by)? + .order_by(order_by)? + .window_frame(window_frame)?) } window_expr_node::WindowFunction::BuiltInFunction(i) => { let built_in_function = protobuf::BuiltInWindowFunction::try_from(*i) @@ -338,11 +335,9 @@ pub fn parse_expr( ), args, )) - .partition_by(partition_by) - .order_by(order_by) - .window_frame(window_frame) - .build() - .unwrap()) + .partition_by(partition_by)? + .order_by(order_by)? + .window_frame(window_frame)?) } window_expr_node::WindowFunction::Udaf(udaf_name) => { let udaf_function = match &expr.fun_definition { @@ -358,11 +353,9 @@ pub fn parse_expr( expr::WindowFunctionDefinition::AggregateUDF(udaf_function), args, )) - .partition_by(partition_by) - .order_by(order_by) - .window_frame(window_frame) - .build() - .unwrap()) + .partition_by(partition_by)? + .order_by(order_by)? + .window_frame(window_frame)?) } window_expr_node::WindowFunction::Udwf(udwf_name) => { let udwf_function = match &expr.fun_definition { @@ -378,11 +371,9 @@ pub fn parse_expr( expr::WindowFunctionDefinition::WindowUDF(udwf_function), args, )) - .partition_by(partition_by) - .order_by(order_by) - .window_frame(window_frame) - .build() - .unwrap()) + .partition_by(partition_by)? + .order_by(order_by)? + .window_frame(window_frame)?) } } } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 9607b918eb89..7aafe69f183f 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -310,15 +310,16 @@ pub fn serialize_expr( expr_type: Some(ExprType::SimilarTo(pb)), } } - Expr::WindowFunction(expr::WindowFunction { - ref fun, - ref args, - ref partition_by, - ref order_by, - ref window_frame, - // TODO: support null treatment in proto - null_treatment: _, - }) => { + Expr::WindowFunction(window_fun) => { + let window_frame = window_fun.get_frame_or_default(); + let expr::WindowFunction { + ref fun, + ref args, + ref partition_by, + ref order_by, + // TODO: support null treatment in proto + .. + } = window_fun; let (window_function, fun_definition) = match fun { WindowFunctionDefinition::AggregateFunction(fun) => ( protobuf::window_expr_node::WindowFunction::AggrFunction( @@ -360,10 +361,10 @@ pub fn serialize_expr( None }; let partition_by = serialize_exprs(partition_by, codec)?; - let order_by = serialize_exprs(order_by, codec)?; + let order_by = serialize_exprs(order_by.as_ref().unwrap_or(&vec![]), codec)?; let window_frame: Option = - Some(window_frame.try_into()?); + Some((&window_frame).try_into()?); let window_expr = Box::new(protobuf::WindowExprNode { expr: arg_expr, window_function: Some(window_function), diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 30af1e1e202a..b0174539dbd7 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -59,7 +59,7 @@ use datafusion_expr::expr::{ }; use datafusion_expr::logical_plan::{Extension, UserDefinedLogicalNodeCore}; use datafusion_expr::{ - Accumulator, AggregateFunction, AggregateUDF, ColumnarValue, ExprFunctionExt, + Accumulator, AggregateFunction, AggregateUDF, ColumnarValue, ExprSchemable, Literal, LogicalPlan, Operator, PartitionEvaluator, ScalarUDF, Signature, TryCast, Volatility, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, WindowUDF, WindowUDFImpl, @@ -680,8 +680,8 @@ async fn roundtrip_expr_api() -> Result<()> { array_replace_all(make_array(vec![lit(1), lit(2), lit(3)]), lit(2), lit(4)), count(lit(1)), count_distinct(lit(1)), - first_value(lit(1), None), - first_value(lit(1), Some(vec![lit(2).sort(true, true)])), + first_value(lit(1), None)?, + first_value(lit(1), Some(vec![lit(2).sort(true, true)]))?, avg(lit(1.5)), covar_samp(lit(1.5), lit(2.2)), covar_pop(lit(1.5), lit(2.2)), @@ -1856,14 +1856,13 @@ fn roundtrip_count() { } #[test] -fn roundtrip_count_distinct() { +fn roundtrip_count_distinct() -> Result<()>{ let test_expr = count_udaf() .call(vec![col("bananas")]) - .distinct() - .build() - .unwrap(); + .distinct()?; let ctx = SessionContext::new(); roundtrip_expr_test(test_expr, ctx); + Ok(()) } #[test] @@ -2032,7 +2031,7 @@ fn roundtrip_substr() { roundtrip_expr_test(test_expr_with_count, ctx); } #[test] -fn roundtrip_window() { +fn roundtrip_window() -> Result<()> { let ctx = SessionContext::new(); // 1. without window_frame @@ -2042,11 +2041,9 @@ fn roundtrip_window() { ), vec![], )) - .partition_by(vec![col("col1")]) - .order_by(vec![col("col2")]) - .window_frame(WindowFrame::new(Some(false))) - .build() - .unwrap(); + .partition_by(vec![col("col1")])? + .order_by(vec![col("col2")])? + .window_frame(WindowFrame::new(Some(false)))?; // 2. with default window_frame let test_expr2 = Expr::WindowFunction(expr::WindowFunction::new( @@ -2055,11 +2052,9 @@ fn roundtrip_window() { ), vec![], )) - .partition_by(vec![col("col1")]) - .order_by(vec![col("col2")]) - .window_frame(WindowFrame::new(Some(false))) - .build() - .unwrap(); + .partition_by(vec![col("col1")])? + .order_by(vec![col("col2")])? + .window_frame(WindowFrame::new(Some(false)))?; // 3. with window_frame with row numbers let range_number_frame = WindowFrame::new_bounds( @@ -2074,11 +2069,9 @@ fn roundtrip_window() { ), vec![], )) - .partition_by(vec![col("col1")]) - .order_by(vec![col("col2")]) - .window_frame(range_number_frame) - .build() - .unwrap(); + .partition_by(vec![col("col1")])? + .order_by(vec![col("col2")])? + .window_frame(range_number_frame)?; // 4. test with AggregateFunction let row_number_frame = WindowFrame::new_bounds( @@ -2091,11 +2084,9 @@ fn roundtrip_window() { WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), vec![col("col1")], )) - .partition_by(vec![col("col1")]) - .order_by(vec![col("col2")]) - .window_frame(row_number_frame.clone()) - .build() - .unwrap(); + .partition_by(vec![col("col1")])? + .order_by(vec![col("col2")])? + .window_frame(row_number_frame.clone())?; // 5. test with AggregateUDF #[derive(Debug)] @@ -2141,11 +2132,9 @@ fn roundtrip_window() { WindowFunctionDefinition::AggregateUDF(Arc::new(dummy_agg.clone())), vec![col("col1")], )) - .partition_by(vec![col("col1")]) - .order_by(vec![col("col2")]) - .window_frame(row_number_frame.clone()) - .build() - .unwrap(); + .partition_by(vec![col("col1")])? + .order_by(vec![col("col2")])? + .window_frame(row_number_frame.clone())?; ctx.register_udaf(dummy_agg); // 6. test with WindowUDF @@ -2218,19 +2207,15 @@ fn roundtrip_window() { WindowFunctionDefinition::WindowUDF(Arc::new(dummy_window_udf.clone())), vec![col("col1")], )) - .partition_by(vec![col("col1")]) - .order_by(vec![col("col2")]) - .window_frame(row_number_frame.clone()) - .build() - .unwrap(); + .partition_by(vec![col("col1")])? + .order_by(vec![col("col2")])? + .window_frame(row_number_frame.clone())?; let text_expr7 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(avg_udaf()), vec![col("col1")], )) - .window_frame(row_number_frame.clone()) - .build() - .unwrap(); + .window_frame(row_number_frame.clone())?; ctx.register_udwf(dummy_window_udf); @@ -2241,4 +2226,6 @@ fn roundtrip_window() { roundtrip_expr_test(test_expr5, ctx.clone()); roundtrip_expr_test(test_expr6, ctx.clone()); roundtrip_expr_test(text_expr7, ctx); + + Ok(()) } diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index d5571e0221dc..3ce6d96e74ed 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -23,7 +23,7 @@ use datafusion_common::{ }; use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by}; use datafusion_expr::{ - expr, AggregateFunction, Expr, ExprFunctionExt, ExprSchemable, WindowFrame, + expr, AggregateFunction, Expr, ExprSchemable, WindowFrame, WindowFunctionDefinition, }; use datafusion_expr::{ @@ -319,23 +319,19 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { WindowFunctionDefinition::AggregateFunction(aggregate_fun), args, )) - .partition_by(partition_by) - .order_by(order_by) - .window_frame(window_frame) - .null_treatment(null_treatment) - .build() - .unwrap() + .partition_by(partition_by)? + .order_by(order_by)? + .window_frame(window_frame)? + .null_treatment(null_treatment)? } _ => Expr::WindowFunction(expr::WindowFunction::new( fun, self.function_args_to_expr(args, schema, planner_context)?, )) - .partition_by(partition_by) - .order_by(order_by) - .window_frame(window_frame) - .null_treatment(null_treatment) - .build() - .unwrap(), + .partition_by(partition_by)? + .order_by(order_by)? + .window_frame(window_frame)? + .null_treatment(null_treatment)? }; return Ok(expr); } diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 3bed4540e14f..30a42fd312ff 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -249,14 +249,15 @@ impl Unparser<'_> { } Expr::Literal(value) => Ok(self.scalar_to_sql(value)?), Expr::Alias(Alias { expr, name: _, .. }) => self.expr_to_sql_inner(expr), - Expr::WindowFunction(WindowFunction { - fun, - args, - partition_by, - order_by, - window_frame, - null_treatment: _, - }) => { + Expr::WindowFunction(window_fun) => { + let WindowFunction { + fun, + args, + partition_by, + order_by, + .. + } = window_fun; + let window_frame = window_fun.get_frame_or_default(); let func_name = fun.name(); let args = self.function_args_to_sql(args)?; @@ -273,6 +274,8 @@ impl Unparser<'_> { } }; let order_by: Vec = order_by + .as_ref() + .unwrap_or(&vec![]) .iter() .map(|expr| expr_to_unparsed(expr)?.into_order_by_expr()) .collect::>>()?; @@ -1342,7 +1345,7 @@ mod tests { table_scan, try_cast, when, wildcard, ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature, Volatility, WindowFrame, WindowFunctionDefinition, }; - use datafusion_expr::{interval_month_day_nano_lit, ExprFunctionExt}; + use datafusion_expr::interval_month_day_nano_lit; use datafusion_functions_aggregate::count::count_udaf; use datafusion_functions_aggregate::expr_fn::sum; @@ -1583,17 +1586,13 @@ mod tests { ( count_udaf() .call(vec![Expr::Wildcard { qualifier: None }]) - .distinct() - .build() - .unwrap(), + .distinct()?, "count(DISTINCT *)", ), ( count_udaf() .call(vec![Expr::Wildcard { qualifier: None }]) - .filter(lit(true)) - .build() - .unwrap(), + .filter(lit(true))?, "count(*) FILTER (WHERE true)", ), ( @@ -1603,8 +1602,8 @@ mod tests { ), args: vec![col("col")], partition_by: vec![], - order_by: vec![], - window_frame: WindowFrame::new(None), + order_by: Some(vec![]), + window_frame: Some(WindowFrame::new(None)), null_treatment: None, }), r#"ROW_NUMBER(col) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)"#, @@ -1614,12 +1613,12 @@ mod tests { fun: WindowFunctionDefinition::AggregateUDF(count_udaf()), args: vec![wildcard()], partition_by: vec![], - order_by: vec![Expr::Sort(Sort::new( + order_by: Some(vec![Expr::Sort(Sort::new( Box::new(col("a")), false, true, - ))], - window_frame: WindowFrame::new_bounds( + ))]), + window_frame: Some(WindowFrame::new_bounds( datafusion_expr::WindowFrameUnits::Range, datafusion_expr::WindowFrameBound::Preceding( ScalarValue::UInt32(Some(6)), @@ -1627,7 +1626,7 @@ mod tests { datafusion_expr::WindowFrameBound::Following( ScalarValue::UInt32(Some(2)), ), - ), + )), null_treatment: None, }), r#"count(*) OVER (ORDER BY a DESC NULLS FIRST RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING)"#, diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 1365630d5079..ea216d865de6 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -1234,12 +1234,12 @@ pub async fn from_substrait_rex( extensions, ) .await?, - order_by, - window_frame: datafusion::logical_expr::WindowFrame::new_bounds( + order_by: Some(order_by), + window_frame: Some(datafusion::logical_expr::WindowFrame::new_bounds( bound_units, from_substrait_bound(&window.lower_bound, true)?, from_substrait_bound(&window.upper_bound, false)?, - ), + )), null_treatment: None, })) } diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 7849d0bd431e..2126e38af500 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -1238,14 +1238,15 @@ pub fn to_substrait_rex( Expr::Alias(Alias { expr, .. }) => { to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info) } - Expr::WindowFunction(WindowFunction { - fun, - args, - partition_by, - order_by, - window_frame, - null_treatment: _, - }) => { + Expr::WindowFunction(window_fun) => { + let window_frame = window_fun.get_frame_or_default(); + let WindowFunction { + fun, + args, + partition_by, + order_by, + .. + } = window_fun; // function reference let function_anchor = register_function(fun.to_string(), extension_info); // arguments @@ -1268,12 +1269,14 @@ pub fn to_substrait_rex( .collect::>>()?; // order by expressions let order_by = order_by + .as_ref() + .unwrap_or(&vec![]) .iter() .map(|e| substrait_sort_field(ctx, e, schema, extension_info)) .collect::>>()?; // window frame - let bounds = to_substrait_bounds(window_frame)?; - let bound_type = to_substrait_bound_type(window_frame)?; + let bounds = to_substrait_bounds(&window_frame)?; + let bound_type = to_substrait_bound_type(&window_frame)?; Ok(make_substrait_window_function( function_anchor, arguments,