@@ -19,6 +19,7 @@ use arrow::datatypes::{
19
19
DataType , TimeUnit , DECIMAL128_MAX_PRECISION , DECIMAL128_MAX_SCALE ,
20
20
DECIMAL256_MAX_PRECISION , DECIMAL256_MAX_SCALE ,
21
21
} ;
22
+
22
23
use datafusion_common:: { internal_err, plan_err, DataFusionError , Result } ;
23
24
use std:: ops:: Deref ;
24
25
@@ -89,6 +90,7 @@ pub fn coerce_types(
89
90
input_types : & [ DataType ] ,
90
91
signature : & Signature ,
91
92
) -> Result < Vec < DataType > > {
93
+ use DataType :: * ;
92
94
// Validate input_types matches (at least one of) the func signature.
93
95
check_arg_count ( agg_fun, input_types, & signature. type_signature ) ?;
94
96
@@ -105,26 +107,44 @@ pub fn coerce_types(
105
107
AggregateFunction :: Sum => {
106
108
// Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc
107
109
// smallint, int, bigint, real, double precision, decimal, or interval.
108
- if !is_sum_support_arg_type ( & input_types[ 0 ] ) {
109
- return plan_err ! (
110
- "The function {:?} does not support inputs of type {:?}." ,
111
- agg_fun,
112
- input_types[ 0 ]
113
- ) ;
114
- }
115
- Ok ( input_types. to_vec ( ) )
110
+ let v = match & input_types[ 0 ] {
111
+ Decimal128 ( p, s) => Decimal128 ( * p, * s) ,
112
+ Decimal256 ( p, s) => Decimal256 ( * p, * s) ,
113
+ d if d. is_signed_integer ( ) => Int64 ,
114
+ d if d. is_unsigned_integer ( ) => UInt64 ,
115
+ d if d. is_floating ( ) => Float64 ,
116
+ Dictionary ( _, v) => {
117
+ return coerce_types ( agg_fun, & [ v. as_ref ( ) . clone ( ) ] , signature)
118
+ }
119
+ _ => {
120
+ return plan_err ! (
121
+ "The function {:?} does not support inputs of type {:?}." ,
122
+ agg_fun,
123
+ input_types[ 0 ]
124
+ )
125
+ }
126
+ } ;
127
+ Ok ( vec ! [ v] )
116
128
}
117
129
AggregateFunction :: Avg => {
118
130
// Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc
119
131
// smallint, int, bigint, real, double precision, decimal, or interval
120
- if !is_avg_support_arg_type ( & input_types[ 0 ] ) {
121
- return plan_err ! (
122
- "The function {:?} does not support inputs of type {:?}." ,
123
- agg_fun,
124
- input_types[ 0 ]
125
- ) ;
126
- }
127
- Ok ( input_types. to_vec ( ) )
132
+ let v = match & input_types[ 0 ] {
133
+ Decimal128 ( p, s) => Decimal128 ( * p, * s) ,
134
+ Decimal256 ( p, s) => Decimal256 ( * p, * s) ,
135
+ d if d. is_numeric ( ) => Float64 ,
136
+ Dictionary ( _, v) => {
137
+ return coerce_types ( agg_fun, & [ v. as_ref ( ) . clone ( ) ] , signature)
138
+ }
139
+ _ => {
140
+ return plan_err ! (
141
+ "The function {:?} does not support inputs of type {:?}." ,
142
+ agg_fun,
143
+ input_types[ 0 ]
144
+ )
145
+ }
146
+ } ;
147
+ Ok ( vec ! [ v] )
128
148
}
129
149
AggregateFunction :: BitAnd
130
150
| AggregateFunction :: BitOr
@@ -160,7 +180,7 @@ pub fn coerce_types(
160
180
input_types[ 0 ]
161
181
) ;
162
182
}
163
- Ok ( input_types . to_vec ( ) )
183
+ Ok ( vec ! [ Float64 , Float64 ] )
164
184
}
165
185
AggregateFunction :: Covariance | AggregateFunction :: CovariancePop => {
166
186
if !is_covariance_support_arg_type ( & input_types[ 0 ] ) {
@@ -170,7 +190,7 @@ pub fn coerce_types(
170
190
input_types[ 0 ]
171
191
) ;
172
192
}
173
- Ok ( input_types . to_vec ( ) )
193
+ Ok ( vec ! [ Float64 , Float64 ] )
174
194
}
175
195
AggregateFunction :: Stddev | AggregateFunction :: StddevPop => {
176
196
if !is_stddev_support_arg_type ( & input_types[ 0 ] ) {
@@ -180,7 +200,7 @@ pub fn coerce_types(
180
200
input_types[ 0 ]
181
201
) ;
182
202
}
183
- Ok ( input_types . to_vec ( ) )
203
+ Ok ( vec ! [ Float64 ] )
184
204
}
185
205
AggregateFunction :: Correlation => {
186
206
if !is_correlation_support_arg_type ( & input_types[ 0 ] ) {
@@ -190,7 +210,7 @@ pub fn coerce_types(
190
210
input_types[ 0 ]
191
211
) ;
192
212
}
193
- Ok ( input_types . to_vec ( ) )
213
+ Ok ( vec ! [ Float64 , Float64 ] )
194
214
}
195
215
AggregateFunction :: RegrSlope
196
216
| AggregateFunction :: RegrIntercept
@@ -211,7 +231,7 @@ pub fn coerce_types(
211
231
input_types[ 0 ]
212
232
) ;
213
233
}
214
- Ok ( input_types . to_vec ( ) )
234
+ Ok ( vec ! [ Float64 , Float64 ] )
215
235
}
216
236
AggregateFunction :: ApproxPercentileCont => {
217
237
if !is_approx_percentile_cont_supported_arg_type ( & input_types[ 0 ] ) {
@@ -357,11 +377,9 @@ fn get_min_max_result_type(input_types: &[DataType]) -> Result<Vec<DataType>> {
357
377
/// function return type of a sum
358
378
pub fn sum_return_type ( arg_type : & DataType ) -> Result < DataType > {
359
379
match arg_type {
360
- arg_type if SIGNED_INTEGERS . contains ( arg_type) => Ok ( DataType :: Int64 ) ,
361
- arg_type if UNSIGNED_INTEGERS . contains ( arg_type) => Ok ( DataType :: UInt64 ) ,
362
- // In the https://www.postgresql.org/docs/current/functions-aggregate.html doc,
363
- // the result type of floating-point is FLOAT64 with the double precision.
364
- DataType :: Float64 | DataType :: Float32 => Ok ( DataType :: Float64 ) ,
380
+ DataType :: Int64 => Ok ( DataType :: Int64 ) ,
381
+ DataType :: UInt64 => Ok ( DataType :: UInt64 ) ,
382
+ DataType :: Float64 => Ok ( DataType :: Float64 ) ,
365
383
DataType :: Decimal128 ( precision, scale) => {
366
384
// in the spark, the result type is DECIMAL(min(38,precision+10), s)
367
385
// ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
@@ -374,9 +392,6 @@ pub fn sum_return_type(arg_type: &DataType) -> Result<DataType> {
374
392
let new_precision = DECIMAL256_MAX_PRECISION . min ( * precision + 10 ) ;
375
393
Ok ( DataType :: Decimal256 ( new_precision, * scale) )
376
394
}
377
- DataType :: Dictionary ( _, dict_value_type) => {
378
- sum_return_type ( dict_value_type. as_ref ( ) )
379
- }
380
395
other => plan_err ! ( "SUM does not support type \" {other:?}\" " ) ,
381
396
}
382
397
}
@@ -601,21 +616,29 @@ mod tests {
601
616
assert_eq ! ( * input_type, result. unwrap( ) ) ;
602
617
}
603
618
}
604
- // test sum, avg
605
- let funs = vec ! [ AggregateFunction :: Sum , AggregateFunction :: Avg ] ;
606
- let input_types = vec ! [
607
- vec![ DataType :: Int32 ] ,
608
- vec![ DataType :: Float32 ] ,
609
- vec![ DataType :: Decimal128 ( 20 , 3 ) ] ,
610
- vec![ DataType :: Decimal256 ( 20 , 3 ) ] ,
611
- ] ;
612
- for fun in funs {
613
- for input_type in & input_types {
614
- let signature = fun. signature ( ) ;
615
- let result = coerce_types ( & fun, input_type, & signature) ;
616
- assert_eq ! ( * input_type, result. unwrap( ) ) ;
617
- }
618
- }
619
+ // test sum
620
+ let fun = AggregateFunction :: Sum ;
621
+ let signature = fun. signature ( ) ;
622
+ let r = coerce_types ( & fun, & [ DataType :: Int32 ] , & signature) . unwrap ( ) ;
623
+ assert_eq ! ( r[ 0 ] , DataType :: Int64 ) ;
624
+ let r = coerce_types ( & fun, & [ DataType :: Float32 ] , & signature) . unwrap ( ) ;
625
+ assert_eq ! ( r[ 0 ] , DataType :: Float64 ) ;
626
+ let r = coerce_types ( & fun, & [ DataType :: Decimal128 ( 20 , 3 ) ] , & signature) . unwrap ( ) ;
627
+ assert_eq ! ( r[ 0 ] , DataType :: Decimal128 ( 20 , 3 ) ) ;
628
+ let r = coerce_types ( & fun, & [ DataType :: Decimal256 ( 20 , 3 ) ] , & signature) . unwrap ( ) ;
629
+ assert_eq ! ( r[ 0 ] , DataType :: Decimal256 ( 20 , 3 ) ) ;
630
+
631
+ // test avg
632
+ let fun = AggregateFunction :: Avg ;
633
+ let signature = fun. signature ( ) ;
634
+ let r = coerce_types ( & fun, & [ DataType :: Int32 ] , & signature) . unwrap ( ) ;
635
+ assert_eq ! ( r[ 0 ] , DataType :: Float64 ) ;
636
+ let r = coerce_types ( & fun, & [ DataType :: Float32 ] , & signature) . unwrap ( ) ;
637
+ assert_eq ! ( r[ 0 ] , DataType :: Float64 ) ;
638
+ let r = coerce_types ( & fun, & [ DataType :: Decimal128 ( 20 , 3 ) ] , & signature) . unwrap ( ) ;
639
+ assert_eq ! ( r[ 0 ] , DataType :: Decimal128 ( 20 , 3 ) ) ;
640
+ let r = coerce_types ( & fun, & [ DataType :: Decimal256 ( 20 , 3 ) ] , & signature) . unwrap ( ) ;
641
+ assert_eq ! ( r[ 0 ] , DataType :: Decimal256 ( 20 , 3 ) ) ;
619
642
620
643
// ApproxPercentileCont input types
621
644
let input_types = vec ! [
0 commit comments