Skip to content

Rename input_type --> input_types on AggregateFunctionExpr / AccumulatorArgs / StateFieldsArgs #11666

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions datafusion/core/tests/user_defined/user_defined_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ use arrow::{
record_batch::RecordBatch,
util::pretty::pretty_format_batches,
};
use async_trait::async_trait;
use futures::{Stream, StreamExt};

use datafusion::execution::session_state::SessionStateBuilder;
use datafusion::{
common::cast::{as_int64_array, as_string_array},
common::{arrow_datafusion_err, internal_err, DFSchemaRef},
Expand All @@ -90,16 +94,12 @@ use datafusion::{
physical_planner::{DefaultPhysicalPlanner, ExtensionPlanner, PhysicalPlanner},
prelude::{SessionConfig, SessionContext},
};

use async_trait::async_trait;
use datafusion::execution::session_state::SessionStateBuilder;
use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_common::ScalarValue;
use datafusion_expr::Projection;
use datafusion_optimizer::optimizer::ApplyOrder;
use datafusion_optimizer::AnalyzerRule;
use futures::{Stream, StreamExt};

/// Execute the specified sql and return the resulting record batches
/// pretty printed as a String.
Expand Down
8 changes: 4 additions & 4 deletions datafusion/expr/src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ pub struct AccumulatorArgs<'a> {
/// ```
pub is_distinct: bool,

/// The input type of the aggregate function.
pub input_type: &'a DataType,
/// The input types of the aggregate function.
pub input_types: &'a [DataType],

/// The logical expression of arguments the aggregate function takes.
pub input_exprs: &'a [Expr],
Expand All @@ -109,8 +109,8 @@ pub struct StateFieldsArgs<'a> {
/// The name of the aggregate function.
pub name: &'a str,

/// The input type of the aggregate function.
pub input_type: &'a DataType,
/// The input types of the aggregate function.
pub input_types: &'a [DataType],

/// The return type of the aggregate function.
pub return_type: &'a DataType,
Expand Down
2 changes: 1 addition & 1 deletion datafusion/functions-aggregate/COMMENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ first argument and the definition looks like this:
// `input_type` : data type of the first argument
let mut fields = vec![Field::new_list(
format_state_name(self.name(), "nth_value"),
Field::new("item", args.input_type.clone(), true /* nullable of list item */ ),
Field::new("item", args.input_types[0].clone(), true /* nullable of list item */ ),
false, // nullable of list itself
)];
```
Expand Down
2 changes: 1 addition & 1 deletion datafusion/functions-aggregate/src/approx_distinct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ impl AggregateUDFImpl for ApproxDistinct {
}

fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
let accumulator: Box<dyn Accumulator> = match acc_args.input_type {
let accumulator: Box<dyn Accumulator> = match &acc_args.input_types[0] {
// TODO u8, i8, u16, i16 shall really be done using bitmap, not HLL
// TODO support for boolean (trivial case)
// https://github.com/apache/datafusion/issues/1109
Expand Down
2 changes: 1 addition & 1 deletion datafusion/functions-aggregate/src/approx_median.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ impl AggregateUDFImpl for ApproxMedian {

Ok(Box::new(ApproxPercentileAccumulator::new(
0.5_f64,
acc_args.input_type.clone(),
acc_args.input_types[0].clone(),
)))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ impl ApproxPercentileCont {
None
};

let accumulator: ApproxPercentileAccumulator = match args.input_type {
let accumulator: ApproxPercentileAccumulator = match &args.input_types[0] {
t @ (DataType::UInt8
| DataType::UInt16
| DataType::UInt32
Expand Down
12 changes: 7 additions & 5 deletions datafusion/functions-aggregate/src/array_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,15 @@ impl AggregateUDFImpl for ArrayAgg {
return Ok(vec![Field::new_list(
format_state_name(args.name, "distinct_array_agg"),
// See COMMENTS.md to understand why nullable is set to true
Field::new("item", args.input_type.clone(), true),
Field::new("item", args.input_types[0].clone(), true),
true,
)]);
}

let mut fields = vec![Field::new_list(
format_state_name(args.name, "array_agg"),
// See COMMENTS.md to understand why nullable is set to true
Field::new("item", args.input_type.clone(), true),
Field::new("item", args.input_types[0].clone(), true),
true,
)];

Expand All @@ -119,12 +119,14 @@ impl AggregateUDFImpl for ArrayAgg {
fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
if acc_args.is_distinct {
return Ok(Box::new(DistinctArrayAggAccumulator::try_new(
acc_args.input_type,
&acc_args.input_types[0],
)?));
}

if acc_args.sort_exprs.is_empty() {
return Ok(Box::new(ArrayAggAccumulator::try_new(acc_args.input_type)?));
return Ok(Box::new(ArrayAggAccumulator::try_new(
&acc_args.input_types[0],
)?));
}

let ordering_req = limited_convert_logical_sort_exprs_to_physical_with_dfschema(
Expand All @@ -138,7 +140,7 @@ impl AggregateUDFImpl for ArrayAgg {
.collect::<Result<Vec<_>>>()?;

OrderSensitiveArrayAggAccumulator::try_new(
acc_args.input_type,
&acc_args.input_types[0],
&ordering_dtypes,
ordering_req,
acc_args.is_reversed,
Expand Down
16 changes: 8 additions & 8 deletions datafusion/functions-aggregate/src/average.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ impl AggregateUDFImpl for Avg {
}
use DataType::*;
// instantiate specialized accumulator based for the type
match (acc_args.input_type, acc_args.data_type) {
match (&acc_args.input_types[0], acc_args.data_type) {
(Float64, Float64) => Ok(Box::<AvgAccumulator>::default()),
(
Decimal128(sum_precision, sum_scale),
Expand All @@ -120,7 +120,7 @@ impl AggregateUDFImpl for Avg {
})),
_ => exec_err!(
"AvgAccumulator for ({} --> {})",
acc_args.input_type,
&acc_args.input_types[0],
acc_args.data_type
),
}
Expand All @@ -135,7 +135,7 @@ impl AggregateUDFImpl for Avg {
),
Field::new(
format_state_name(args.name, "sum"),
args.input_type.clone(),
args.input_types[0].clone(),
true,
),
])
Expand All @@ -154,10 +154,10 @@ impl AggregateUDFImpl for Avg {
) -> Result<Box<dyn GroupsAccumulator>> {
use DataType::*;
// instantiate specialized accumulator based for the type
match (args.input_type, args.data_type) {
match (&args.input_types[0], args.data_type) {
(Float64, Float64) => {
Ok(Box::new(AvgGroupsAccumulator::<Float64Type, _>::new(
args.input_type,
&args.input_types[0],
args.data_type,
|sum: f64, count: u64| Ok(sum / count as f64),
)))
Expand All @@ -176,7 +176,7 @@ impl AggregateUDFImpl for Avg {
move |sum: i128, count: u64| decimal_averager.avg(sum, count as i128);

Ok(Box::new(AvgGroupsAccumulator::<Decimal128Type, _>::new(
args.input_type,
&args.input_types[0],
args.data_type,
avg_fn,
)))
Expand All @@ -197,15 +197,15 @@ impl AggregateUDFImpl for Avg {
};

Ok(Box::new(AvgGroupsAccumulator::<Decimal256Type, _>::new(
args.input_type,
&args.input_types[0],
args.data_type,
avg_fn,
)))
}

_ => not_impl_err!(
"AvgGroupsAccumulator for ({} --> {})",
args.input_type,
&args.input_types[0],
args.data_type
),
}
Expand Down
4 changes: 2 additions & 2 deletions datafusion/functions-aggregate/src/count.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ impl AggregateUDFImpl for Count {
Ok(vec![Field::new_list(
format_state_name(args.name, "count distinct"),
// See COMMENTS.md to understand why nullable is set to true
Field::new("item", args.input_type.clone(), true),
Field::new("item", args.input_types[0].clone(), true),
false,
)])
} else {
Expand All @@ -147,7 +147,7 @@ impl AggregateUDFImpl for Count {
return not_impl_err!("COUNT DISTINCT with multiple arguments");
}

let data_type = acc_args.input_type;
let data_type = &acc_args.input_types[0];
Ok(match data_type {
// try and use a specialized accumulator if possible, otherwise fall back to generic accumulator
DataType::Int8 => Box::new(
Expand Down
4 changes: 2 additions & 2 deletions datafusion/functions-aggregate/src/first_last.rs
Original file line number Diff line number Diff line change
Expand Up @@ -440,14 +440,14 @@ impl AggregateUDFImpl for LastValue {
fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
let StateFieldsArgs {
name,
input_type,
input_types,
return_type: _,
ordering_fields,
is_distinct: _,
} = args;
let mut fields = vec![Field::new(
format_state_name(name, "last_value"),
input_type.clone(),
input_types[0].clone(),
true,
)];
fields.extend(ordering_fields.to_vec());
Expand Down
4 changes: 2 additions & 2 deletions datafusion/functions-aggregate/src/median.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ impl AggregateUDFImpl for Median {

fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
//Intermediate state is a list of the elements we have collected so far
let field = Field::new("item", args.input_type.clone(), true);
let field = Field::new("item", args.input_types[0].clone(), true);
let state_name = if args.is_distinct {
"distinct_median"
} else {
Expand Down Expand Up @@ -133,7 +133,7 @@ impl AggregateUDFImpl for Median {
};
}

let dt = acc_args.input_type;
let dt = &acc_args.input_types[0];
downcast_integer! {
dt => (helper, dt),
DataType::Float16 => helper!(Float16Type, dt),
Expand Down
4 changes: 2 additions & 2 deletions datafusion/functions-aggregate/src/nth_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ impl AggregateUDFImpl for NthValueAgg {

NthValueAccumulator::try_new(
n,
acc_args.input_type,
&acc_args.input_types[0],
&ordering_dtypes,
ordering_req,
)
Expand All @@ -125,7 +125,7 @@ impl AggregateUDFImpl for NthValueAgg {
let mut fields = vec![Field::new_list(
format_state_name(self.name(), "nth_value"),
// See COMMENTS.md to understand why nullable is set to true
Field::new("item", args.input_type.clone(), true),
Field::new("item", args.input_types[0].clone(), true),
false,
)];
let orderings = args.ordering_fields.to_vec();
Expand Down
4 changes: 2 additions & 2 deletions datafusion/functions-aggregate/src/stddev.rs
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ mod tests {
name: "a",
is_distinct: false,
is_reversed: false,
input_type: &DataType::Float64,
input_types: &[DataType::Float64],
input_exprs: &[datafusion_expr::col("a")],
};

Expand All @@ -348,7 +348,7 @@ mod tests {
name: "a",
is_distinct: false,
is_reversed: false,
input_type: &DataType::Float64,
input_types: &[DataType::Float64],
input_exprs: &[datafusion_expr::col("a")],
};

Expand Down
38 changes: 20 additions & 18 deletions datafusion/physical-expr-common/src/aggregate/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,31 +15,33 @@
// specific language governing permissions and limitations
// under the License.

pub mod count_distinct;
pub mod groups_accumulator;
pub mod merge_arrays;
pub mod stats;
pub mod tdigest;
pub mod utils;
use std::fmt::Debug;
use std::{any::Any, sync::Arc};

use arrow::datatypes::{DataType, Field, Schema, SchemaRef};

use datafusion_common::exec_err;
use datafusion_common::{internal_err, not_impl_err, DFSchema, Result};
use datafusion_expr::function::StateFieldsArgs;
use datafusion_expr::type_coercion::aggregates::check_arg_count;
use datafusion_expr::utils::AggregateOrderSensitivity;
use datafusion_expr::ReversedUDAF;
use datafusion_expr::{
function::AccumulatorArgs, Accumulator, AggregateUDF, Expr, GroupsAccumulator,
};
use std::fmt::Debug;
use std::{any::Any, sync::Arc};

use self::utils::down_cast_any_ref;
use crate::physical_expr::PhysicalExpr;
use crate::sort_expr::{LexOrdering, PhysicalSortExpr};
use crate::utils::reverse_order_bys;

use datafusion_common::exec_err;
use datafusion_expr::utils::AggregateOrderSensitivity;
use self::utils::down_cast_any_ref;

pub mod count_distinct;
pub mod groups_accumulator;
pub mod merge_arrays;
pub mod stats;
pub mod tdigest;
pub mod utils;

/// Creates a physical expression of the UDAF, that includes all necessary type coercion.
/// This function errors when `args`' can't be coerced to a valid argument type of the UDAF.
Expand Down Expand Up @@ -225,7 +227,7 @@ impl AggregateExprBuilder {
ignore_nulls,
ordering_fields,
is_distinct,
input_type: input_exprs_types[0].clone(),
input_types: input_exprs_types,
is_reversed,
}))
}
Expand Down Expand Up @@ -466,7 +468,7 @@ pub struct AggregateFunctionExpr {
ordering_fields: Vec<Field>,
is_distinct: bool,
is_reversed: bool,
input_type: DataType,
input_types: Vec<DataType>,
}

impl AggregateFunctionExpr {
Expand Down Expand Up @@ -504,7 +506,7 @@ impl AggregateExpr for AggregateFunctionExpr {
fn state_fields(&self) -> Result<Vec<Field>> {
let args = StateFieldsArgs {
name: &self.name,
input_type: &self.input_type,
input_types: &self.input_types,
return_type: &self.data_type,
ordering_fields: &self.ordering_fields,
is_distinct: self.is_distinct,
Expand All @@ -525,7 +527,7 @@ impl AggregateExpr for AggregateFunctionExpr {
ignore_nulls: self.ignore_nulls,
sort_exprs: &self.sort_exprs,
is_distinct: self.is_distinct,
input_type: &self.input_type,
input_types: &self.input_types,
input_exprs: &self.logical_args,
name: &self.name,
is_reversed: self.is_reversed,
Expand All @@ -542,7 +544,7 @@ impl AggregateExpr for AggregateFunctionExpr {
ignore_nulls: self.ignore_nulls,
sort_exprs: &self.sort_exprs,
is_distinct: self.is_distinct,
input_type: &self.input_type,
input_types: &self.input_types,
input_exprs: &self.logical_args,
name: &self.name,
is_reversed: self.is_reversed,
Expand Down Expand Up @@ -614,7 +616,7 @@ impl AggregateExpr for AggregateFunctionExpr {
ignore_nulls: self.ignore_nulls,
sort_exprs: &self.sort_exprs,
is_distinct: self.is_distinct,
input_type: &self.input_type,
input_types: &self.input_types,
input_exprs: &self.logical_args,
name: &self.name,
is_reversed: self.is_reversed,
Expand All @@ -630,7 +632,7 @@ impl AggregateExpr for AggregateFunctionExpr {
ignore_nulls: self.ignore_nulls,
sort_exprs: &self.sort_exprs,
is_distinct: self.is_distinct,
input_type: &self.input_type,
input_types: &self.input_types,
input_exprs: &self.logical_args,
name: &self.name,
is_reversed: self.is_reversed,
Expand Down