Skip to content

Commit 66a8570

Browse files
lewiszlwalamb
andauthored
Rename input_type --> input_types on AggregateFunctionExpr / AccumulatorArgs / StateFieldsArgs (#11666)
* UDAF input types * Rename * Update COMMENTS.md * Update datafusion/functions-aggregate/COMMENTS.md --------- Co-authored-by: Andrew Lamb <[email protected]>
1 parent 35c2e7e commit 66a8570

File tree

14 files changed

+57
-53
lines changed

14 files changed

+57
-53
lines changed

datafusion/core/tests/user_defined/user_defined_plan.rs

+4-4
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,10 @@ use arrow::{
6868
record_batch::RecordBatch,
6969
util::pretty::pretty_format_batches,
7070
};
71+
use async_trait::async_trait;
72+
use futures::{Stream, StreamExt};
73+
74+
use datafusion::execution::session_state::SessionStateBuilder;
7175
use datafusion::{
7276
common::cast::{as_int64_array, as_string_array},
7377
common::{arrow_datafusion_err, internal_err, DFSchemaRef},
@@ -90,16 +94,12 @@ use datafusion::{
9094
physical_planner::{DefaultPhysicalPlanner, ExtensionPlanner, PhysicalPlanner},
9195
prelude::{SessionConfig, SessionContext},
9296
};
93-
94-
use async_trait::async_trait;
95-
use datafusion::execution::session_state::SessionStateBuilder;
9697
use datafusion_common::config::ConfigOptions;
9798
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
9899
use datafusion_common::ScalarValue;
99100
use datafusion_expr::Projection;
100101
use datafusion_optimizer::optimizer::ApplyOrder;
101102
use datafusion_optimizer::AnalyzerRule;
102-
use futures::{Stream, StreamExt};
103103

104104
/// Execute the specified sql and return the resulting record batches
105105
/// pretty printed as a String.

datafusion/expr/src/function.rs

+4-4
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,8 @@ pub struct AccumulatorArgs<'a> {
9494
/// ```
9595
pub is_distinct: bool,
9696

97-
/// The input type of the aggregate function.
98-
pub input_type: &'a DataType,
97+
/// The input types of the aggregate function.
98+
pub input_types: &'a [DataType],
9999

100100
/// The logical expression of arguments the aggregate function takes.
101101
pub input_exprs: &'a [Expr],
@@ -109,8 +109,8 @@ pub struct StateFieldsArgs<'a> {
109109
/// The name of the aggregate function.
110110
pub name: &'a str,
111111

112-
/// The input type of the aggregate function.
113-
pub input_type: &'a DataType,
112+
/// The input types of the aggregate function.
113+
pub input_types: &'a [DataType],
114114

115115
/// The return type of the aggregate function.
116116
pub return_type: &'a DataType,

datafusion/functions-aggregate/COMMENTS.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ first argument and the definition looks like this:
5454
// `input_type` : data type of the first argument
5555
let mut fields = vec![Field::new_list(
5656
format_state_name(self.name(), "nth_value"),
57-
Field::new("item", args.input_type.clone(), true /* nullable of list item */ ),
57+
Field::new("item", args.input_types[0].clone(), true /* nullable of list item */ ),
5858
false, // nullable of list itself
5959
)];
6060
```

datafusion/functions-aggregate/src/approx_distinct.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ impl AggregateUDFImpl for ApproxDistinct {
277277
}
278278

279279
fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
280-
let accumulator: Box<dyn Accumulator> = match acc_args.input_type {
280+
let accumulator: Box<dyn Accumulator> = match &acc_args.input_types[0] {
281281
// TODO u8, i8, u16, i16 shall really be done using bitmap, not HLL
282282
// TODO support for boolean (trivial case)
283283
// https://github.com/apache/datafusion/issues/1109

datafusion/functions-aggregate/src/approx_median.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ impl AggregateUDFImpl for ApproxMedian {
113113

114114
Ok(Box::new(ApproxPercentileAccumulator::new(
115115
0.5_f64,
116-
acc_args.input_type.clone(),
116+
acc_args.input_types[0].clone(),
117117
)))
118118
}
119119
}

datafusion/functions-aggregate/src/approx_percentile_cont.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ impl ApproxPercentileCont {
104104
None
105105
};
106106

107-
let accumulator: ApproxPercentileAccumulator = match args.input_type {
107+
let accumulator: ApproxPercentileAccumulator = match &args.input_types[0] {
108108
t @ (DataType::UInt8
109109
| DataType::UInt16
110110
| DataType::UInt32

datafusion/functions-aggregate/src/array_agg.rs

+7-5
Original file line numberDiff line numberDiff line change
@@ -90,15 +90,15 @@ impl AggregateUDFImpl for ArrayAgg {
9090
return Ok(vec![Field::new_list(
9191
format_state_name(args.name, "distinct_array_agg"),
9292
// See COMMENTS.md to understand why nullable is set to true
93-
Field::new("item", args.input_type.clone(), true),
93+
Field::new("item", args.input_types[0].clone(), true),
9494
true,
9595
)]);
9696
}
9797

9898
let mut fields = vec![Field::new_list(
9999
format_state_name(args.name, "array_agg"),
100100
// See COMMENTS.md to understand why nullable is set to true
101-
Field::new("item", args.input_type.clone(), true),
101+
Field::new("item", args.input_types[0].clone(), true),
102102
true,
103103
)];
104104

@@ -119,12 +119,14 @@ impl AggregateUDFImpl for ArrayAgg {
119119
fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
120120
if acc_args.is_distinct {
121121
return Ok(Box::new(DistinctArrayAggAccumulator::try_new(
122-
acc_args.input_type,
122+
&acc_args.input_types[0],
123123
)?));
124124
}
125125

126126
if acc_args.sort_exprs.is_empty() {
127-
return Ok(Box::new(ArrayAggAccumulator::try_new(acc_args.input_type)?));
127+
return Ok(Box::new(ArrayAggAccumulator::try_new(
128+
&acc_args.input_types[0],
129+
)?));
128130
}
129131

130132
let ordering_req = limited_convert_logical_sort_exprs_to_physical_with_dfschema(
@@ -138,7 +140,7 @@ impl AggregateUDFImpl for ArrayAgg {
138140
.collect::<Result<Vec<_>>>()?;
139141

140142
OrderSensitiveArrayAggAccumulator::try_new(
141-
acc_args.input_type,
143+
&acc_args.input_types[0],
142144
&ordering_dtypes,
143145
ordering_req,
144146
acc_args.is_reversed,

datafusion/functions-aggregate/src/average.rs

+8-8
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ impl AggregateUDFImpl for Avg {
9393
}
9494
use DataType::*;
9595
// instantiate specialized accumulator based for the type
96-
match (acc_args.input_type, acc_args.data_type) {
96+
match (&acc_args.input_types[0], acc_args.data_type) {
9797
(Float64, Float64) => Ok(Box::<AvgAccumulator>::default()),
9898
(
9999
Decimal128(sum_precision, sum_scale),
@@ -120,7 +120,7 @@ impl AggregateUDFImpl for Avg {
120120
})),
121121
_ => exec_err!(
122122
"AvgAccumulator for ({} --> {})",
123-
acc_args.input_type,
123+
&acc_args.input_types[0],
124124
acc_args.data_type
125125
),
126126
}
@@ -135,7 +135,7 @@ impl AggregateUDFImpl for Avg {
135135
),
136136
Field::new(
137137
format_state_name(args.name, "sum"),
138-
args.input_type.clone(),
138+
args.input_types[0].clone(),
139139
true,
140140
),
141141
])
@@ -154,10 +154,10 @@ impl AggregateUDFImpl for Avg {
154154
) -> Result<Box<dyn GroupsAccumulator>> {
155155
use DataType::*;
156156
// instantiate specialized accumulator based for the type
157-
match (args.input_type, args.data_type) {
157+
match (&args.input_types[0], args.data_type) {
158158
(Float64, Float64) => {
159159
Ok(Box::new(AvgGroupsAccumulator::<Float64Type, _>::new(
160-
args.input_type,
160+
&args.input_types[0],
161161
args.data_type,
162162
|sum: f64, count: u64| Ok(sum / count as f64),
163163
)))
@@ -176,7 +176,7 @@ impl AggregateUDFImpl for Avg {
176176
move |sum: i128, count: u64| decimal_averager.avg(sum, count as i128);
177177

178178
Ok(Box::new(AvgGroupsAccumulator::<Decimal128Type, _>::new(
179-
args.input_type,
179+
&args.input_types[0],
180180
args.data_type,
181181
avg_fn,
182182
)))
@@ -197,15 +197,15 @@ impl AggregateUDFImpl for Avg {
197197
};
198198

199199
Ok(Box::new(AvgGroupsAccumulator::<Decimal256Type, _>::new(
200-
args.input_type,
200+
&args.input_types[0],
201201
args.data_type,
202202
avg_fn,
203203
)))
204204
}
205205

206206
_ => not_impl_err!(
207207
"AvgGroupsAccumulator for ({} --> {})",
208-
args.input_type,
208+
&args.input_types[0],
209209
args.data_type
210210
),
211211
}

datafusion/functions-aggregate/src/count.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ impl AggregateUDFImpl for Count {
127127
Ok(vec![Field::new_list(
128128
format_state_name(args.name, "count distinct"),
129129
// See COMMENTS.md to understand why nullable is set to true
130-
Field::new("item", args.input_type.clone(), true),
130+
Field::new("item", args.input_types[0].clone(), true),
131131
false,
132132
)])
133133
} else {
@@ -148,7 +148,7 @@ impl AggregateUDFImpl for Count {
148148
return not_impl_err!("COUNT DISTINCT with multiple arguments");
149149
}
150150

151-
let data_type = acc_args.input_type;
151+
let data_type = &acc_args.input_types[0];
152152
Ok(match data_type {
153153
// try and use a specialized accumulator if possible, otherwise fall back to generic accumulator
154154
DataType::Int8 => Box::new(

datafusion/functions-aggregate/src/first_last.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -440,14 +440,14 @@ impl AggregateUDFImpl for LastValue {
440440
fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
441441
let StateFieldsArgs {
442442
name,
443-
input_type,
443+
input_types,
444444
return_type: _,
445445
ordering_fields,
446446
is_distinct: _,
447447
} = args;
448448
let mut fields = vec![Field::new(
449449
format_state_name(name, "last_value"),
450-
input_type.clone(),
450+
input_types[0].clone(),
451451
true,
452452
)];
453453
fields.extend(ordering_fields.to_vec());

datafusion/functions-aggregate/src/median.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ impl AggregateUDFImpl for Median {
102102

103103
fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
104104
//Intermediate state is a list of the elements we have collected so far
105-
let field = Field::new("item", args.input_type.clone(), true);
105+
let field = Field::new("item", args.input_types[0].clone(), true);
106106
let state_name = if args.is_distinct {
107107
"distinct_median"
108108
} else {
@@ -133,7 +133,7 @@ impl AggregateUDFImpl for Median {
133133
};
134134
}
135135

136-
let dt = acc_args.input_type;
136+
let dt = &acc_args.input_types[0];
137137
downcast_integer! {
138138
dt => (helper, dt),
139139
DataType::Float16 => helper!(Float16Type, dt),

datafusion/functions-aggregate/src/nth_value.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ impl AggregateUDFImpl for NthValueAgg {
114114

115115
NthValueAccumulator::try_new(
116116
n,
117-
acc_args.input_type,
117+
&acc_args.input_types[0],
118118
&ordering_dtypes,
119119
ordering_req,
120120
)
@@ -125,7 +125,7 @@ impl AggregateUDFImpl for NthValueAgg {
125125
let mut fields = vec![Field::new_list(
126126
format_state_name(self.name(), "nth_value"),
127127
// See COMMENTS.md to understand why nullable is set to true
128-
Field::new("item", args.input_type.clone(), true),
128+
Field::new("item", args.input_types[0].clone(), true),
129129
false,
130130
)];
131131
let orderings = args.ordering_fields.to_vec();

datafusion/functions-aggregate/src/stddev.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,7 @@ mod tests {
335335
name: "a",
336336
is_distinct: false,
337337
is_reversed: false,
338-
input_type: &DataType::Float64,
338+
input_types: &[DataType::Float64],
339339
input_exprs: &[datafusion_expr::col("a")],
340340
};
341341

@@ -348,7 +348,7 @@ mod tests {
348348
name: "a",
349349
is_distinct: false,
350350
is_reversed: false,
351-
input_type: &DataType::Float64,
351+
input_types: &[DataType::Float64],
352352
input_exprs: &[datafusion_expr::col("a")],
353353
};
354354

datafusion/physical-expr-common/src/aggregate/mod.rs

+20-18
Original file line numberDiff line numberDiff line change
@@ -15,31 +15,33 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
pub mod count_distinct;
19-
pub mod groups_accumulator;
20-
pub mod merge_arrays;
21-
pub mod stats;
22-
pub mod tdigest;
23-
pub mod utils;
18+
use std::fmt::Debug;
19+
use std::{any::Any, sync::Arc};
2420

2521
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
22+
23+
use datafusion_common::exec_err;
2624
use datafusion_common::{internal_err, not_impl_err, DFSchema, Result};
2725
use datafusion_expr::function::StateFieldsArgs;
2826
use datafusion_expr::type_coercion::aggregates::check_arg_count;
27+
use datafusion_expr::utils::AggregateOrderSensitivity;
2928
use datafusion_expr::ReversedUDAF;
3029
use datafusion_expr::{
3130
function::AccumulatorArgs, Accumulator, AggregateUDF, Expr, GroupsAccumulator,
3231
};
33-
use std::fmt::Debug;
34-
use std::{any::Any, sync::Arc};
3532

36-
use self::utils::down_cast_any_ref;
3733
use crate::physical_expr::PhysicalExpr;
3834
use crate::sort_expr::{LexOrdering, PhysicalSortExpr};
3935
use crate::utils::reverse_order_bys;
4036

41-
use datafusion_common::exec_err;
42-
use datafusion_expr::utils::AggregateOrderSensitivity;
37+
use self::utils::down_cast_any_ref;
38+
39+
pub mod count_distinct;
40+
pub mod groups_accumulator;
41+
pub mod merge_arrays;
42+
pub mod stats;
43+
pub mod tdigest;
44+
pub mod utils;
4345

4446
/// Creates a physical expression of the UDAF, that includes all necessary type coercion.
4547
/// This function errors when `args`' can't be coerced to a valid argument type of the UDAF.
@@ -225,7 +227,7 @@ impl AggregateExprBuilder {
225227
ignore_nulls,
226228
ordering_fields,
227229
is_distinct,
228-
input_type: input_exprs_types[0].clone(),
230+
input_types: input_exprs_types,
229231
is_reversed,
230232
}))
231233
}
@@ -466,7 +468,7 @@ pub struct AggregateFunctionExpr {
466468
ordering_fields: Vec<Field>,
467469
is_distinct: bool,
468470
is_reversed: bool,
469-
input_type: DataType,
471+
input_types: Vec<DataType>,
470472
}
471473

472474
impl AggregateFunctionExpr {
@@ -504,7 +506,7 @@ impl AggregateExpr for AggregateFunctionExpr {
504506
fn state_fields(&self) -> Result<Vec<Field>> {
505507
let args = StateFieldsArgs {
506508
name: &self.name,
507-
input_type: &self.input_type,
509+
input_types: &self.input_types,
508510
return_type: &self.data_type,
509511
ordering_fields: &self.ordering_fields,
510512
is_distinct: self.is_distinct,
@@ -525,7 +527,7 @@ impl AggregateExpr for AggregateFunctionExpr {
525527
ignore_nulls: self.ignore_nulls,
526528
sort_exprs: &self.sort_exprs,
527529
is_distinct: self.is_distinct,
528-
input_type: &self.input_type,
530+
input_types: &self.input_types,
529531
input_exprs: &self.logical_args,
530532
name: &self.name,
531533
is_reversed: self.is_reversed,
@@ -542,7 +544,7 @@ impl AggregateExpr for AggregateFunctionExpr {
542544
ignore_nulls: self.ignore_nulls,
543545
sort_exprs: &self.sort_exprs,
544546
is_distinct: self.is_distinct,
545-
input_type: &self.input_type,
547+
input_types: &self.input_types,
546548
input_exprs: &self.logical_args,
547549
name: &self.name,
548550
is_reversed: self.is_reversed,
@@ -614,7 +616,7 @@ impl AggregateExpr for AggregateFunctionExpr {
614616
ignore_nulls: self.ignore_nulls,
615617
sort_exprs: &self.sort_exprs,
616618
is_distinct: self.is_distinct,
617-
input_type: &self.input_type,
619+
input_types: &self.input_types,
618620
input_exprs: &self.logical_args,
619621
name: &self.name,
620622
is_reversed: self.is_reversed,
@@ -630,7 +632,7 @@ impl AggregateExpr for AggregateFunctionExpr {
630632
ignore_nulls: self.ignore_nulls,
631633
sort_exprs: &self.sort_exprs,
632634
is_distinct: self.is_distinct,
633-
input_type: &self.input_type,
635+
input_types: &self.input_types,
634636
input_exprs: &self.logical_args,
635637
name: &self.name,
636638
is_reversed: self.is_reversed,

0 commit comments

Comments
 (0)