diff --git a/datafusion/core/tests/user_defined/user_defined_plan.rs b/datafusion/core/tests/user_defined/user_defined_plan.rs index a44f522ba95a..47804b927e64 100644 --- a/datafusion/core/tests/user_defined/user_defined_plan.rs +++ b/datafusion/core/tests/user_defined/user_defined_plan.rs @@ -68,6 +68,10 @@ use arrow::{ record_batch::RecordBatch, util::pretty::pretty_format_batches, }; +use async_trait::async_trait; +use futures::{Stream, StreamExt}; + +use datafusion::execution::session_state::SessionStateBuilder; use datafusion::{ common::cast::{as_int64_array, as_string_array}, common::{arrow_datafusion_err, internal_err, DFSchemaRef}, @@ -90,16 +94,12 @@ use datafusion::{ physical_planner::{DefaultPhysicalPlanner, ExtensionPlanner, PhysicalPlanner}, prelude::{SessionConfig, SessionContext}, }; - -use async_trait::async_trait; -use datafusion::execution::session_state::SessionStateBuilder; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::ScalarValue; use datafusion_expr::Projection; use datafusion_optimizer::optimizer::ApplyOrder; use datafusion_optimizer::AnalyzerRule; -use futures::{Stream, StreamExt}; /// Execute the specified sql and return the resulting record batches /// pretty printed as a String. diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index d722e55de487..d8be2b434732 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -94,8 +94,8 @@ pub struct AccumulatorArgs<'a> { /// ``` pub is_distinct: bool, - /// The input type of the aggregate function. - pub input_type: &'a DataType, + /// The input types of the aggregate function. + pub input_types: &'a [DataType], /// The logical expression of arguments the aggregate function takes. pub input_exprs: &'a [Expr], @@ -109,8 +109,8 @@ pub struct StateFieldsArgs<'a> { /// The name of the aggregate function. pub name: &'a str, - /// The input type of the aggregate function. - pub input_type: &'a DataType, + /// The input types of the aggregate function. + pub input_types: &'a [DataType], /// The return type of the aggregate function. pub return_type: &'a DataType, diff --git a/datafusion/functions-aggregate/COMMENTS.md b/datafusion/functions-aggregate/COMMENTS.md index 23a996faf007..e669e1355711 100644 --- a/datafusion/functions-aggregate/COMMENTS.md +++ b/datafusion/functions-aggregate/COMMENTS.md @@ -54,7 +54,7 @@ first argument and the definition looks like this: // `input_type` : data type of the first argument let mut fields = vec![Field::new_list( format_state_name(self.name(), "nth_value"), - Field::new("item", args.input_type.clone(), true /* nullable of list item */ ), + Field::new("item", args.input_types[0].clone(), true /* nullable of list item */ ), false, // nullable of list itself )]; ``` diff --git a/datafusion/functions-aggregate/src/approx_distinct.rs b/datafusion/functions-aggregate/src/approx_distinct.rs index 7c6aef9944f6..56ef32e7ebe0 100644 --- a/datafusion/functions-aggregate/src/approx_distinct.rs +++ b/datafusion/functions-aggregate/src/approx_distinct.rs @@ -277,7 +277,7 @@ impl AggregateUDFImpl for ApproxDistinct { } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { - let accumulator: Box = match acc_args.input_type { + let accumulator: Box = match &acc_args.input_types[0] { // TODO u8, i8, u16, i16 shall really be done using bitmap, not HLL // TODO support for boolean (trivial case) // https://github.com/apache/datafusion/issues/1109 diff --git a/datafusion/functions-aggregate/src/approx_median.rs b/datafusion/functions-aggregate/src/approx_median.rs index bc723c862953..e12e3445a83e 100644 --- a/datafusion/functions-aggregate/src/approx_median.rs +++ b/datafusion/functions-aggregate/src/approx_median.rs @@ -113,7 +113,7 @@ impl AggregateUDFImpl for ApproxMedian { Ok(Box::new(ApproxPercentileAccumulator::new( 0.5_f64, - acc_args.input_type.clone(), + acc_args.input_types[0].clone(), ))) } } diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont.rs b/datafusion/functions-aggregate/src/approx_percentile_cont.rs index dfb94a84cbec..16837dc80748 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont.rs @@ -104,7 +104,7 @@ impl ApproxPercentileCont { None }; - let accumulator: ApproxPercentileAccumulator = match args.input_type { + let accumulator: ApproxPercentileAccumulator = match &args.input_types[0] { t @ (DataType::UInt8 | DataType::UInt16 | DataType::UInt32 diff --git a/datafusion/functions-aggregate/src/array_agg.rs b/datafusion/functions-aggregate/src/array_agg.rs index c25d592428bb..36c9d6a0d7c8 100644 --- a/datafusion/functions-aggregate/src/array_agg.rs +++ b/datafusion/functions-aggregate/src/array_agg.rs @@ -90,7 +90,7 @@ impl AggregateUDFImpl for ArrayAgg { return Ok(vec![Field::new_list( format_state_name(args.name, "distinct_array_agg"), // See COMMENTS.md to understand why nullable is set to true - Field::new("item", args.input_type.clone(), true), + Field::new("item", args.input_types[0].clone(), true), true, )]); } @@ -98,7 +98,7 @@ impl AggregateUDFImpl for ArrayAgg { let mut fields = vec![Field::new_list( format_state_name(args.name, "array_agg"), // See COMMENTS.md to understand why nullable is set to true - Field::new("item", args.input_type.clone(), true), + Field::new("item", args.input_types[0].clone(), true), true, )]; @@ -119,12 +119,14 @@ impl AggregateUDFImpl for ArrayAgg { fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { if acc_args.is_distinct { return Ok(Box::new(DistinctArrayAggAccumulator::try_new( - acc_args.input_type, + &acc_args.input_types[0], )?)); } if acc_args.sort_exprs.is_empty() { - return Ok(Box::new(ArrayAggAccumulator::try_new(acc_args.input_type)?)); + return Ok(Box::new(ArrayAggAccumulator::try_new( + &acc_args.input_types[0], + )?)); } let ordering_req = limited_convert_logical_sort_exprs_to_physical_with_dfschema( @@ -138,7 +140,7 @@ impl AggregateUDFImpl for ArrayAgg { .collect::>>()?; OrderSensitiveArrayAggAccumulator::try_new( - acc_args.input_type, + &acc_args.input_types[0], &ordering_dtypes, ordering_req, acc_args.is_reversed, diff --git a/datafusion/functions-aggregate/src/average.rs b/datafusion/functions-aggregate/src/average.rs index 18642fb84329..228bce1979a3 100644 --- a/datafusion/functions-aggregate/src/average.rs +++ b/datafusion/functions-aggregate/src/average.rs @@ -93,7 +93,7 @@ impl AggregateUDFImpl for Avg { } use DataType::*; // instantiate specialized accumulator based for the type - match (acc_args.input_type, acc_args.data_type) { + match (&acc_args.input_types[0], acc_args.data_type) { (Float64, Float64) => Ok(Box::::default()), ( Decimal128(sum_precision, sum_scale), @@ -120,7 +120,7 @@ impl AggregateUDFImpl for Avg { })), _ => exec_err!( "AvgAccumulator for ({} --> {})", - acc_args.input_type, + &acc_args.input_types[0], acc_args.data_type ), } @@ -135,7 +135,7 @@ impl AggregateUDFImpl for Avg { ), Field::new( format_state_name(args.name, "sum"), - args.input_type.clone(), + args.input_types[0].clone(), true, ), ]) @@ -154,10 +154,10 @@ impl AggregateUDFImpl for Avg { ) -> Result> { use DataType::*; // instantiate specialized accumulator based for the type - match (args.input_type, args.data_type) { + match (&args.input_types[0], args.data_type) { (Float64, Float64) => { Ok(Box::new(AvgGroupsAccumulator::::new( - args.input_type, + &args.input_types[0], args.data_type, |sum: f64, count: u64| Ok(sum / count as f64), ))) @@ -176,7 +176,7 @@ impl AggregateUDFImpl for Avg { move |sum: i128, count: u64| decimal_averager.avg(sum, count as i128); Ok(Box::new(AvgGroupsAccumulator::::new( - args.input_type, + &args.input_types[0], args.data_type, avg_fn, ))) @@ -197,7 +197,7 @@ impl AggregateUDFImpl for Avg { }; Ok(Box::new(AvgGroupsAccumulator::::new( - args.input_type, + &args.input_types[0], args.data_type, avg_fn, ))) @@ -205,7 +205,7 @@ impl AggregateUDFImpl for Avg { _ => not_impl_err!( "AvgGroupsAccumulator for ({} --> {})", - args.input_type, + &args.input_types[0], args.data_type ), } diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index 56850d0e02a1..a6e4450d404c 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -126,7 +126,7 @@ impl AggregateUDFImpl for Count { Ok(vec![Field::new_list( format_state_name(args.name, "count distinct"), // See COMMENTS.md to understand why nullable is set to true - Field::new("item", args.input_type.clone(), true), + Field::new("item", args.input_types[0].clone(), true), false, )]) } else { @@ -147,7 +147,7 @@ impl AggregateUDFImpl for Count { return not_impl_err!("COUNT DISTINCT with multiple arguments"); } - let data_type = acc_args.input_type; + let data_type = &acc_args.input_types[0]; Ok(match data_type { // try and use a specialized accumulator if possible, otherwise fall back to generic accumulator DataType::Int8 => Box::new( diff --git a/datafusion/functions-aggregate/src/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs index 8969937d377c..587767b8e356 100644 --- a/datafusion/functions-aggregate/src/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -440,14 +440,14 @@ impl AggregateUDFImpl for LastValue { fn state_fields(&self, args: StateFieldsArgs) -> Result> { let StateFieldsArgs { name, - input_type, + input_types, return_type: _, ordering_fields, is_distinct: _, } = args; let mut fields = vec![Field::new( format_state_name(name, "last_value"), - input_type.clone(), + input_types[0].clone(), true, )]; fields.extend(ordering_fields.to_vec()); diff --git a/datafusion/functions-aggregate/src/median.rs b/datafusion/functions-aggregate/src/median.rs index bb926b8da271..febf1fcd2fef 100644 --- a/datafusion/functions-aggregate/src/median.rs +++ b/datafusion/functions-aggregate/src/median.rs @@ -102,7 +102,7 @@ impl AggregateUDFImpl for Median { fn state_fields(&self, args: StateFieldsArgs) -> Result> { //Intermediate state is a list of the elements we have collected so far - let field = Field::new("item", args.input_type.clone(), true); + let field = Field::new("item", args.input_types[0].clone(), true); let state_name = if args.is_distinct { "distinct_median" } else { @@ -133,7 +133,7 @@ impl AggregateUDFImpl for Median { }; } - let dt = acc_args.input_type; + let dt = &acc_args.input_types[0]; downcast_integer! { dt => (helper, dt), DataType::Float16 => helper!(Float16Type, dt), diff --git a/datafusion/functions-aggregate/src/nth_value.rs b/datafusion/functions-aggregate/src/nth_value.rs index 74f77f3f4b86..dc7c6c86f213 100644 --- a/datafusion/functions-aggregate/src/nth_value.rs +++ b/datafusion/functions-aggregate/src/nth_value.rs @@ -114,7 +114,7 @@ impl AggregateUDFImpl for NthValueAgg { NthValueAccumulator::try_new( n, - acc_args.input_type, + &acc_args.input_types[0], &ordering_dtypes, ordering_req, ) @@ -125,7 +125,7 @@ impl AggregateUDFImpl for NthValueAgg { let mut fields = vec![Field::new_list( format_state_name(self.name(), "nth_value"), // See COMMENTS.md to understand why nullable is set to true - Field::new("item", args.input_type.clone(), true), + Field::new("item", args.input_types[0].clone(), true), false, )]; let orderings = args.ordering_fields.to_vec(); diff --git a/datafusion/functions-aggregate/src/stddev.rs b/datafusion/functions-aggregate/src/stddev.rs index 247962dc2ce1..df757ddc0422 100644 --- a/datafusion/functions-aggregate/src/stddev.rs +++ b/datafusion/functions-aggregate/src/stddev.rs @@ -335,7 +335,7 @@ mod tests { name: "a", is_distinct: false, is_reversed: false, - input_type: &DataType::Float64, + input_types: &[DataType::Float64], input_exprs: &[datafusion_expr::col("a")], }; @@ -348,7 +348,7 @@ mod tests { name: "a", is_distinct: false, is_reversed: false, - input_type: &DataType::Float64, + input_types: &[DataType::Float64], input_exprs: &[datafusion_expr::col("a")], }; diff --git a/datafusion/physical-expr-common/src/aggregate/mod.rs b/datafusion/physical-expr-common/src/aggregate/mod.rs index 4eede6567504..6febe1464b27 100644 --- a/datafusion/physical-expr-common/src/aggregate/mod.rs +++ b/datafusion/physical-expr-common/src/aggregate/mod.rs @@ -15,31 +15,33 @@ // specific language governing permissions and limitations // under the License. -pub mod count_distinct; -pub mod groups_accumulator; -pub mod merge_arrays; -pub mod stats; -pub mod tdigest; -pub mod utils; +use std::fmt::Debug; +use std::{any::Any, sync::Arc}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; + +use datafusion_common::exec_err; use datafusion_common::{internal_err, not_impl_err, DFSchema, Result}; use datafusion_expr::function::StateFieldsArgs; use datafusion_expr::type_coercion::aggregates::check_arg_count; +use datafusion_expr::utils::AggregateOrderSensitivity; use datafusion_expr::ReversedUDAF; use datafusion_expr::{ function::AccumulatorArgs, Accumulator, AggregateUDF, Expr, GroupsAccumulator, }; -use std::fmt::Debug; -use std::{any::Any, sync::Arc}; -use self::utils::down_cast_any_ref; use crate::physical_expr::PhysicalExpr; use crate::sort_expr::{LexOrdering, PhysicalSortExpr}; use crate::utils::reverse_order_bys; -use datafusion_common::exec_err; -use datafusion_expr::utils::AggregateOrderSensitivity; +use self::utils::down_cast_any_ref; + +pub mod count_distinct; +pub mod groups_accumulator; +pub mod merge_arrays; +pub mod stats; +pub mod tdigest; +pub mod utils; /// Creates a physical expression of the UDAF, that includes all necessary type coercion. /// This function errors when `args`' can't be coerced to a valid argument type of the UDAF. @@ -225,7 +227,7 @@ impl AggregateExprBuilder { ignore_nulls, ordering_fields, is_distinct, - input_type: input_exprs_types[0].clone(), + input_types: input_exprs_types, is_reversed, })) } @@ -466,7 +468,7 @@ pub struct AggregateFunctionExpr { ordering_fields: Vec, is_distinct: bool, is_reversed: bool, - input_type: DataType, + input_types: Vec, } impl AggregateFunctionExpr { @@ -504,7 +506,7 @@ impl AggregateExpr for AggregateFunctionExpr { fn state_fields(&self) -> Result> { let args = StateFieldsArgs { name: &self.name, - input_type: &self.input_type, + input_types: &self.input_types, return_type: &self.data_type, ordering_fields: &self.ordering_fields, is_distinct: self.is_distinct, @@ -525,7 +527,7 @@ impl AggregateExpr for AggregateFunctionExpr { ignore_nulls: self.ignore_nulls, sort_exprs: &self.sort_exprs, is_distinct: self.is_distinct, - input_type: &self.input_type, + input_types: &self.input_types, input_exprs: &self.logical_args, name: &self.name, is_reversed: self.is_reversed, @@ -542,7 +544,7 @@ impl AggregateExpr for AggregateFunctionExpr { ignore_nulls: self.ignore_nulls, sort_exprs: &self.sort_exprs, is_distinct: self.is_distinct, - input_type: &self.input_type, + input_types: &self.input_types, input_exprs: &self.logical_args, name: &self.name, is_reversed: self.is_reversed, @@ -614,7 +616,7 @@ impl AggregateExpr for AggregateFunctionExpr { ignore_nulls: self.ignore_nulls, sort_exprs: &self.sort_exprs, is_distinct: self.is_distinct, - input_type: &self.input_type, + input_types: &self.input_types, input_exprs: &self.logical_args, name: &self.name, is_reversed: self.is_reversed, @@ -630,7 +632,7 @@ impl AggregateExpr for AggregateFunctionExpr { ignore_nulls: self.ignore_nulls, sort_exprs: &self.sort_exprs, is_distinct: self.is_distinct, - input_type: &self.input_type, + input_types: &self.input_types, input_exprs: &self.logical_args, name: &self.name, is_reversed: self.is_reversed,