Skip to content

Commit 870857a

Browse files
authored
Make SUM and AVG Aggregate Type Coercion Explicit (#7369)
* Make Aggregate Type Coercion Explicit * Clippy
1 parent ffccbe6 commit 870857a

File tree

19 files changed

+385
-416
lines changed

19 files changed

+385
-416
lines changed

datafusion/core/src/physical_plan/aggregates/mod.rs

Lines changed: 2 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ use datafusion_execution::TaskContext;
3535
use datafusion_expr::Accumulator;
3636
use datafusion_physical_expr::{
3737
equivalence::project_equivalence_properties,
38-
expressions::{Avg, CastExpr, Column, Sum},
38+
expressions::Column,
3939
normalize_out_expr_with_columns_map, reverse_order_bys,
4040
utils::{convert_to_expr, get_indices_of_matching_exprs},
4141
AggregateExpr, LexOrdering, LexOrderingReq, OrderingEquivalenceProperties,
@@ -1010,40 +1010,7 @@ fn aggregate_expressions(
10101010
| AggregateMode::SinglePartitioned => Ok(aggr_expr
10111011
.iter()
10121012
.map(|agg| {
1013-
let pre_cast_type = if let Some(Sum {
1014-
data_type,
1015-
pre_cast_to_sum_type,
1016-
..
1017-
}) = agg.as_any().downcast_ref::<Sum>()
1018-
{
1019-
if *pre_cast_to_sum_type {
1020-
Some(data_type.clone())
1021-
} else {
1022-
None
1023-
}
1024-
} else if let Some(Avg {
1025-
sum_data_type,
1026-
pre_cast_to_sum_type,
1027-
..
1028-
}) = agg.as_any().downcast_ref::<Avg>()
1029-
{
1030-
if *pre_cast_to_sum_type {
1031-
Some(sum_data_type.clone())
1032-
} else {
1033-
None
1034-
}
1035-
} else {
1036-
None
1037-
};
1038-
let mut result = agg
1039-
.expressions()
1040-
.into_iter()
1041-
.map(|expr| {
1042-
pre_cast_type.clone().map_or(expr.clone(), |cast_type| {
1043-
Arc::new(CastExpr::new(expr, cast_type, None))
1044-
})
1045-
})
1046-
.collect::<Vec<_>>();
1013+
let mut result = agg.expressions().clone();
10471014
// In partial mode, append ordering requirements to expressions' results.
10481015
// Ordering requirements are used by subsequent executors to satisfy the required
10491016
// ordering for `AggregateMode::FinalPartitioned`/`AggregateMode::Final` modes.

datafusion/core/tests/fuzz_cases/window_fuzz.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ use datafusion_expr::{
3939

4040
use datafusion::prelude::{SessionConfig, SessionContext};
4141
use datafusion_common::{Result, ScalarValue};
42-
use datafusion_physical_expr::expressions::{col, lit};
42+
use datafusion_expr::type_coercion::aggregates::coerce_types;
43+
use datafusion_physical_expr::expressions::{cast, col, lit};
4344
use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr};
4445
use test_utils::add_empty_batches;
4546

@@ -261,6 +262,14 @@ fn get_random_function(
261262
let rand_fn_idx = rng.gen_range(0..window_fn_map.len());
262263
let fn_name = window_fn_map.keys().collect::<Vec<_>>()[rand_fn_idx];
263264
let (window_fn, new_args) = window_fn_map.values().collect::<Vec<_>>()[rand_fn_idx];
265+
if let WindowFunction::AggregateFunction(f) = window_fn {
266+
let a = args[0].clone();
267+
let dt = a.data_type(schema.as_ref()).unwrap();
268+
let sig = f.signature();
269+
let coerced = coerce_types(f, &[dt], &sig).unwrap();
270+
args[0] = cast(a, schema, coerced[0].clone()).unwrap();
271+
}
272+
264273
for new_arg in new_args {
265274
args.push(new_arg.clone());
266275
}

datafusion/expr/src/aggregate_function.rs

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -228,20 +228,16 @@ impl AggregateFunction {
228228
// Note that this function *must* return the same type that the respective physical expression returns
229229
// or the execution panics.
230230

231-
let coerced_data_types = crate::type_coercion::aggregates::coerce_types(
232-
self,
233-
input_expr_types,
234-
&self.signature(),
235-
)
236-
// original errors are all related to wrong function signature
237-
// aggregate them for better error message
238-
.map_err(|_| {
239-
DataFusionError::Plan(utils::generate_signature_error_msg(
240-
&format!("{self}"),
241-
self.signature(),
242-
input_expr_types,
243-
))
244-
})?;
231+
let coerced_data_types = coerce_types(self, input_expr_types, &self.signature())
232+
// original errors are all related to wrong function signature
233+
// aggregate them for better error message
234+
.map_err(|_| {
235+
DataFusionError::Plan(utils::generate_signature_error_msg(
236+
&format!("{self}"),
237+
self.signature(),
238+
input_expr_types,
239+
))
240+
})?;
245241

246242
match self {
247243
AggregateFunction::Count | AggregateFunction::ApproxDistinct => {

datafusion/expr/src/type_coercion/aggregates.rs

Lines changed: 67 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ use arrow::datatypes::{
1919
DataType, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE,
2020
DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE,
2121
};
22+
2223
use datafusion_common::{internal_err, plan_err, DataFusionError, Result};
2324
use std::ops::Deref;
2425

@@ -89,6 +90,7 @@ pub fn coerce_types(
8990
input_types: &[DataType],
9091
signature: &Signature,
9192
) -> Result<Vec<DataType>> {
93+
use DataType::*;
9294
// Validate input_types matches (at least one of) the func signature.
9395
check_arg_count(agg_fun, input_types, &signature.type_signature)?;
9496

@@ -105,26 +107,44 @@ pub fn coerce_types(
105107
AggregateFunction::Sum => {
106108
// Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc
107109
// smallint, int, bigint, real, double precision, decimal, or interval.
108-
if !is_sum_support_arg_type(&input_types[0]) {
109-
return plan_err!(
110-
"The function {:?} does not support inputs of type {:?}.",
111-
agg_fun,
112-
input_types[0]
113-
);
114-
}
115-
Ok(input_types.to_vec())
110+
let v = match &input_types[0] {
111+
Decimal128(p, s) => Decimal128(*p, *s),
112+
Decimal256(p, s) => Decimal256(*p, *s),
113+
d if d.is_signed_integer() => Int64,
114+
d if d.is_unsigned_integer() => UInt64,
115+
d if d.is_floating() => Float64,
116+
Dictionary(_, v) => {
117+
return coerce_types(agg_fun, &[v.as_ref().clone()], signature)
118+
}
119+
_ => {
120+
return plan_err!(
121+
"The function {:?} does not support inputs of type {:?}.",
122+
agg_fun,
123+
input_types[0]
124+
)
125+
}
126+
};
127+
Ok(vec![v])
116128
}
117129
AggregateFunction::Avg => {
118130
// Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc
119131
// smallint, int, bigint, real, double precision, decimal, or interval
120-
if !is_avg_support_arg_type(&input_types[0]) {
121-
return plan_err!(
122-
"The function {:?} does not support inputs of type {:?}.",
123-
agg_fun,
124-
input_types[0]
125-
);
126-
}
127-
Ok(input_types.to_vec())
132+
let v = match &input_types[0] {
133+
Decimal128(p, s) => Decimal128(*p, *s),
134+
Decimal256(p, s) => Decimal256(*p, *s),
135+
d if d.is_numeric() => Float64,
136+
Dictionary(_, v) => {
137+
return coerce_types(agg_fun, &[v.as_ref().clone()], signature)
138+
}
139+
_ => {
140+
return plan_err!(
141+
"The function {:?} does not support inputs of type {:?}.",
142+
agg_fun,
143+
input_types[0]
144+
)
145+
}
146+
};
147+
Ok(vec![v])
128148
}
129149
AggregateFunction::BitAnd
130150
| AggregateFunction::BitOr
@@ -160,7 +180,7 @@ pub fn coerce_types(
160180
input_types[0]
161181
);
162182
}
163-
Ok(input_types.to_vec())
183+
Ok(vec![Float64, Float64])
164184
}
165185
AggregateFunction::Covariance | AggregateFunction::CovariancePop => {
166186
if !is_covariance_support_arg_type(&input_types[0]) {
@@ -170,7 +190,7 @@ pub fn coerce_types(
170190
input_types[0]
171191
);
172192
}
173-
Ok(input_types.to_vec())
193+
Ok(vec![Float64, Float64])
174194
}
175195
AggregateFunction::Stddev | AggregateFunction::StddevPop => {
176196
if !is_stddev_support_arg_type(&input_types[0]) {
@@ -180,7 +200,7 @@ pub fn coerce_types(
180200
input_types[0]
181201
);
182202
}
183-
Ok(input_types.to_vec())
203+
Ok(vec![Float64])
184204
}
185205
AggregateFunction::Correlation => {
186206
if !is_correlation_support_arg_type(&input_types[0]) {
@@ -190,7 +210,7 @@ pub fn coerce_types(
190210
input_types[0]
191211
);
192212
}
193-
Ok(input_types.to_vec())
213+
Ok(vec![Float64, Float64])
194214
}
195215
AggregateFunction::RegrSlope
196216
| AggregateFunction::RegrIntercept
@@ -211,7 +231,7 @@ pub fn coerce_types(
211231
input_types[0]
212232
);
213233
}
214-
Ok(input_types.to_vec())
234+
Ok(vec![Float64, Float64])
215235
}
216236
AggregateFunction::ApproxPercentileCont => {
217237
if !is_approx_percentile_cont_supported_arg_type(&input_types[0]) {
@@ -357,11 +377,9 @@ fn get_min_max_result_type(input_types: &[DataType]) -> Result<Vec<DataType>> {
357377
/// function return type of a sum
358378
pub fn sum_return_type(arg_type: &DataType) -> Result<DataType> {
359379
match arg_type {
360-
arg_type if SIGNED_INTEGERS.contains(arg_type) => Ok(DataType::Int64),
361-
arg_type if UNSIGNED_INTEGERS.contains(arg_type) => Ok(DataType::UInt64),
362-
// In the https://www.postgresql.org/docs/current/functions-aggregate.html doc,
363-
// the result type of floating-point is FLOAT64 with the double precision.
364-
DataType::Float64 | DataType::Float32 => Ok(DataType::Float64),
380+
DataType::Int64 => Ok(DataType::Int64),
381+
DataType::UInt64 => Ok(DataType::UInt64),
382+
DataType::Float64 => Ok(DataType::Float64),
365383
DataType::Decimal128(precision, scale) => {
366384
// in the spark, the result type is DECIMAL(min(38,precision+10), s)
367385
// ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
@@ -374,9 +392,6 @@ pub fn sum_return_type(arg_type: &DataType) -> Result<DataType> {
374392
let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10);
375393
Ok(DataType::Decimal256(new_precision, *scale))
376394
}
377-
DataType::Dictionary(_, dict_value_type) => {
378-
sum_return_type(dict_value_type.as_ref())
379-
}
380395
other => plan_err!("SUM does not support type \"{other:?}\""),
381396
}
382397
}
@@ -601,21 +616,29 @@ mod tests {
601616
assert_eq!(*input_type, result.unwrap());
602617
}
603618
}
604-
// test sum, avg
605-
let funs = vec![AggregateFunction::Sum, AggregateFunction::Avg];
606-
let input_types = vec![
607-
vec![DataType::Int32],
608-
vec![DataType::Float32],
609-
vec![DataType::Decimal128(20, 3)],
610-
vec![DataType::Decimal256(20, 3)],
611-
];
612-
for fun in funs {
613-
for input_type in &input_types {
614-
let signature = fun.signature();
615-
let result = coerce_types(&fun, input_type, &signature);
616-
assert_eq!(*input_type, result.unwrap());
617-
}
618-
}
619+
// test sum
620+
let fun = AggregateFunction::Sum;
621+
let signature = fun.signature();
622+
let r = coerce_types(&fun, &[DataType::Int32], &signature).unwrap();
623+
assert_eq!(r[0], DataType::Int64);
624+
let r = coerce_types(&fun, &[DataType::Float32], &signature).unwrap();
625+
assert_eq!(r[0], DataType::Float64);
626+
let r = coerce_types(&fun, &[DataType::Decimal128(20, 3)], &signature).unwrap();
627+
assert_eq!(r[0], DataType::Decimal128(20, 3));
628+
let r = coerce_types(&fun, &[DataType::Decimal256(20, 3)], &signature).unwrap();
629+
assert_eq!(r[0], DataType::Decimal256(20, 3));
630+
631+
// test avg
632+
let fun = AggregateFunction::Avg;
633+
let signature = fun.signature();
634+
let r = coerce_types(&fun, &[DataType::Int32], &signature).unwrap();
635+
assert_eq!(r[0], DataType::Float64);
636+
let r = coerce_types(&fun, &[DataType::Float32], &signature).unwrap();
637+
assert_eq!(r[0], DataType::Float64);
638+
let r = coerce_types(&fun, &[DataType::Decimal128(20, 3)], &signature).unwrap();
639+
assert_eq!(r[0], DataType::Decimal128(20, 3));
640+
let r = coerce_types(&fun, &[DataType::Decimal256(20, 3)], &signature).unwrap();
641+
assert_eq!(r[0], DataType::Decimal256(20, 3));
619642

620643
// ApproxPercentileCont input types
621644
let input_types = vec![

datafusion/optimizer/src/analyzer/type_coercion.rs

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ use datafusion_expr::type_coercion::{is_datetime, is_numeric, is_utf8_or_large_u
4444
use datafusion_expr::utils::from_plan;
4545
use datafusion_expr::{
4646
is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown,
47-
type_coercion, AggregateFunction, BuiltinScalarFunction, Expr, LogicalPlan, Operator,
48-
Projection, WindowFrame, WindowFrameBound, WindowFrameUnits,
47+
type_coercion, window_function, AggregateFunction, BuiltinScalarFunction, Expr,
48+
LogicalPlan, Operator, Projection, WindowFrame, WindowFrameBound, WindowFrameUnits,
4949
};
5050
use datafusion_expr::{ExprSchemable, Signature};
5151

@@ -381,6 +381,19 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
381381
}) => {
382382
let window_frame =
383383
coerce_window_frame(window_frame, &self.schema, &order_by)?;
384+
385+
let args = match &fun {
386+
window_function::WindowFunction::AggregateFunction(fun) => {
387+
coerce_agg_exprs_for_signature(
388+
fun,
389+
&args,
390+
&self.schema,
391+
&fun.signature(),
392+
)?
393+
}
394+
_ => args,
395+
};
396+
384397
let expr = Expr::WindowFunction(WindowFunction::new(
385398
fun,
386399
args,
@@ -961,7 +974,7 @@ mod test {
961974
None,
962975
));
963976
let plan = LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty)?);
964-
let expected = "Projection: AVG(Int64(12))\n EmptyRelation";
977+
let expected = "Projection: AVG(CAST(Int64(12) AS Float64))\n EmptyRelation";
965978
assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected)?;
966979

967980
let empty = empty_with_type(DataType::Int32);
@@ -974,7 +987,7 @@ mod test {
974987
None,
975988
));
976989
let plan = LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty)?);
977-
let expected = "Projection: AVG(a)\n EmptyRelation";
990+
let expected = "Projection: AVG(CAST(a AS Float64))\n EmptyRelation";
978991
assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected)?;
979992
Ok(())
980993
}

datafusion/optimizer/tests/optimizer_integration.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ fn subquery_filter_with_cast() -> Result<()> {
7070
\n Inner Join: Filter: CAST(test.col_int32 AS Float64) > __scalar_sq_1.AVG(test.col_int32)\
7171
\n TableScan: test projection=[col_int32]\
7272
\n SubqueryAlias: __scalar_sq_1\
73-
\n Aggregate: groupBy=[[]], aggr=[[AVG(test.col_int32)]]\
73+
\n Aggregate: groupBy=[[]], aggr=[[AVG(CAST(test.col_int32 AS Float64))]]\
7474
\n Projection: test.col_int32\
7575
\n Filter: test.col_utf8 >= Utf8(\"2002-05-08\") AND test.col_utf8 <= Utf8(\"2002-05-13\")\
7676
\n TableScan: test projection=[col_int32, col_utf8]";

0 commit comments

Comments
 (0)