@@ -23,7 +23,7 @@ pub mod tdigest;
23
23
pub mod utils;
24
24
25
25
use arrow:: datatypes:: { DataType , Field , Schema } ;
26
- use datafusion_common:: { not_impl_err, DFSchema , Result } ;
26
+ use datafusion_common:: { internal_err , not_impl_err, DFSchema , Result } ;
27
27
use datafusion_expr:: function:: StateFieldsArgs ;
28
28
use datafusion_expr:: type_coercion:: aggregates:: check_arg_count;
29
29
use datafusion_expr:: ReversedUDAF ;
@@ -33,7 +33,7 @@ use datafusion_expr::{
33
33
use std:: fmt:: Debug ;
34
34
use std:: { any:: Any , sync:: Arc } ;
35
35
36
- use self :: utils:: { down_cast_any_ref, ordering_fields } ;
36
+ use self :: utils:: down_cast_any_ref;
37
37
use crate :: physical_expr:: PhysicalExpr ;
38
38
use crate :: sort_expr:: { LexOrdering , PhysicalSortExpr } ;
39
39
use crate :: utils:: reverse_order_bys;
@@ -55,6 +55,8 @@ use datafusion_expr::utils::AggregateOrderSensitivity;
55
55
/// `is_reversed` is used to indicate whether the aggregation is running in reverse order,
56
56
/// it could be used to hint Accumulator to accumulate in the reversed order,
57
57
/// you can just set to false if you are not reversing expression
58
+ ///
59
+ /// You can also create expression by [`AggregateExprBuilder`]
58
60
#[ allow( clippy:: too_many_arguments) ]
59
61
pub fn create_aggregate_expr (
60
62
fun : & AggregateUDF ,
@@ -66,45 +68,24 @@ pub fn create_aggregate_expr(
66
68
name : impl Into < String > ,
67
69
ignore_nulls : bool ,
68
70
is_distinct : bool ,
69
- is_reversed : bool ,
71
+ _is_reversed : bool ,
70
72
) -> Result < Arc < dyn AggregateExpr > > {
71
- debug_assert_eq ! ( sort_exprs. len( ) , ordering_req. len( ) ) ;
72
-
73
- let input_exprs_types = input_phy_exprs
74
- . iter ( )
75
- . map ( |arg| arg. data_type ( schema) )
76
- . collect :: < Result < Vec < _ > > > ( ) ?;
77
-
78
- check_arg_count (
79
- fun. name ( ) ,
80
- & input_exprs_types,
81
- & fun. signature ( ) . type_signature ,
82
- ) ?;
83
-
84
- let ordering_types = ordering_req
85
- . iter ( )
86
- . map ( |e| e. expr . data_type ( schema) )
87
- . collect :: < Result < Vec < _ > > > ( ) ?;
88
-
89
- let ordering_fields = ordering_fields ( ordering_req, & ordering_types) ;
90
- let name = name. into ( ) ;
91
-
92
- Ok ( Arc :: new ( AggregateFunctionExpr {
93
- fun : fun. clone ( ) ,
94
- args : input_phy_exprs. to_vec ( ) ,
95
- logical_args : input_exprs. to_vec ( ) ,
96
- data_type : fun. return_type ( & input_exprs_types) ?,
97
- name,
98
- schema : schema. clone ( ) ,
99
- dfschema : DFSchema :: empty ( ) ,
100
- sort_exprs : sort_exprs. to_vec ( ) ,
101
- ordering_req : ordering_req. to_vec ( ) ,
102
- ignore_nulls,
103
- ordering_fields,
104
- is_distinct,
105
- input_type : input_exprs_types[ 0 ] . clone ( ) ,
106
- is_reversed,
107
- } ) )
73
+ let mut builder =
74
+ AggregateExprBuilder :: new ( Arc :: new ( fun. clone ( ) ) , input_phy_exprs. to_vec ( ) ) ;
75
+ builder = builder. sort_exprs ( sort_exprs. to_vec ( ) ) ;
76
+ builder = builder. order_by ( ordering_req. to_vec ( ) ) ;
77
+ builder = builder. logical_exprs ( input_exprs. to_vec ( ) ) ;
78
+ builder = builder. schema ( schema. clone ( ) ) ;
79
+ builder = builder. name ( name) ;
80
+
81
+ if ignore_nulls {
82
+ builder = builder. ignore_nulls ( ) ;
83
+ }
84
+ if is_distinct {
85
+ builder = builder. distinct ( ) ;
86
+ }
87
+
88
+ builder. build ( )
108
89
}
109
90
110
91
#[ allow( clippy:: too_many_arguments) ]
@@ -121,44 +102,177 @@ pub fn create_aggregate_expr_with_dfschema(
121
102
is_distinct : bool ,
122
103
is_reversed : bool ,
123
104
) -> Result < Arc < dyn AggregateExpr > > {
124
- debug_assert_eq ! ( sort_exprs. len( ) , ordering_req. len( ) ) ;
125
-
105
+ let mut builder =
106
+ AggregateExprBuilder :: new ( Arc :: new ( fun. clone ( ) ) , input_phy_exprs. to_vec ( ) ) ;
107
+ builder = builder. sort_exprs ( sort_exprs. to_vec ( ) ) ;
108
+ builder = builder. order_by ( ordering_req. to_vec ( ) ) ;
109
+ builder = builder. logical_exprs ( input_exprs. to_vec ( ) ) ;
110
+ builder = builder. dfschema ( dfschema. clone ( ) ) ;
126
111
let schema: Schema = dfschema. into ( ) ;
112
+ builder = builder. schema ( schema) ;
113
+ builder = builder. name ( name) ;
114
+
115
+ if ignore_nulls {
116
+ builder = builder. ignore_nulls ( ) ;
117
+ }
118
+ if is_distinct {
119
+ builder = builder. distinct ( ) ;
120
+ }
121
+ if is_reversed {
122
+ builder = builder. reversed ( ) ;
123
+ }
124
+
125
+ builder. build ( )
126
+ }
127
127
128
- let input_exprs_types = input_phy_exprs
129
- . iter ( )
130
- . map ( |arg| arg. data_type ( & schema) )
131
- . collect :: < Result < Vec < _ > > > ( ) ?;
132
-
133
- check_arg_count (
134
- fun. name ( ) ,
135
- & input_exprs_types,
136
- & fun. signature ( ) . type_signature ,
137
- ) ?;
138
-
139
- let ordering_types = ordering_req
140
- . iter ( )
141
- . map ( |e| e. expr . data_type ( & schema) )
142
- . collect :: < Result < Vec < _ > > > ( ) ?;
143
-
144
- let ordering_fields = ordering_fields ( ordering_req, & ordering_types) ;
145
-
146
- Ok ( Arc :: new ( AggregateFunctionExpr {
147
- fun : fun. clone ( ) ,
148
- args : input_phy_exprs. to_vec ( ) ,
149
- logical_args : input_exprs. to_vec ( ) ,
150
- data_type : fun. return_type ( & input_exprs_types) ?,
151
- name : name. into ( ) ,
152
- schema : schema. clone ( ) ,
153
- dfschema : dfschema. clone ( ) ,
154
- sort_exprs : sort_exprs. to_vec ( ) ,
155
- ordering_req : ordering_req. to_vec ( ) ,
156
- ignore_nulls,
157
- ordering_fields,
158
- is_distinct,
159
- input_type : input_exprs_types[ 0 ] . clone ( ) ,
160
- is_reversed,
161
- } ) )
128
+ #[ derive( Debug , Clone ) ]
129
+ pub struct AggregateExprBuilder {
130
+ fun : Arc < AggregateUDF > ,
131
+ /// Physical expressions of the aggregate function
132
+ args : Vec < Arc < dyn PhysicalExpr > > ,
133
+ /// Logical expressions of the aggregate function, it will be deprecated in <https://github.com/apache/datafusion/issues/11359>
134
+ logical_args : Vec < Expr > ,
135
+ name : String ,
136
+ /// Arrow Schema for the aggregate function
137
+ schema : Schema ,
138
+ /// Datafusion Schema for the aggregate function
139
+ dfschema : DFSchema ,
140
+ /// The logical order by expressions, it will be deprecated in <https://github.com/apache/datafusion/issues/11359>
141
+ sort_exprs : Vec < Expr > ,
142
+ /// The physical order by expressions
143
+ ordering_req : LexOrdering ,
144
+ /// Whether to ignore null values
145
+ ignore_nulls : bool ,
146
+ /// Whether is distinct aggregate function
147
+ is_distinct : bool ,
148
+ /// Whether the expression is reversed
149
+ is_reversed : bool ,
150
+ }
151
+
152
+ impl AggregateExprBuilder {
153
+ pub fn new ( fun : Arc < AggregateUDF > , args : Vec < Arc < dyn PhysicalExpr > > ) -> Self {
154
+ Self {
155
+ fun,
156
+ args,
157
+ logical_args : vec ! [ ] ,
158
+ name : String :: new ( ) ,
159
+ schema : Schema :: empty ( ) ,
160
+ dfschema : DFSchema :: empty ( ) ,
161
+ sort_exprs : vec ! [ ] ,
162
+ ordering_req : vec ! [ ] ,
163
+ ignore_nulls : false ,
164
+ is_distinct : false ,
165
+ is_reversed : false ,
166
+ }
167
+ }
168
+
169
+ pub fn build ( self ) -> Result < Arc < dyn AggregateExpr > > {
170
+ let Self {
171
+ fun,
172
+ args,
173
+ logical_args,
174
+ name,
175
+ schema,
176
+ dfschema,
177
+ sort_exprs,
178
+ ordering_req,
179
+ ignore_nulls,
180
+ is_distinct,
181
+ is_reversed,
182
+ } = self ;
183
+ if args. is_empty ( ) {
184
+ return internal_err ! ( "args should not be empty" ) ;
185
+ }
186
+
187
+ let mut ordering_fields = vec ! [ ] ;
188
+
189
+ debug_assert_eq ! ( sort_exprs. len( ) , ordering_req. len( ) ) ;
190
+ if !ordering_req. is_empty ( ) {
191
+ let ordering_types = ordering_req
192
+ . iter ( )
193
+ . map ( |e| e. expr . data_type ( & schema) )
194
+ . collect :: < Result < Vec < _ > > > ( ) ?;
195
+
196
+ ordering_fields = utils:: ordering_fields ( & ordering_req, & ordering_types) ;
197
+ }
198
+
199
+ let input_exprs_types = args
200
+ . iter ( )
201
+ . map ( |arg| arg. data_type ( & schema) )
202
+ . collect :: < Result < Vec < _ > > > ( ) ?;
203
+
204
+ check_arg_count (
205
+ fun. name ( ) ,
206
+ & input_exprs_types,
207
+ & fun. signature ( ) . type_signature ,
208
+ ) ?;
209
+
210
+ let data_type = fun. return_type ( & input_exprs_types) ?;
211
+
212
+ Ok ( Arc :: new ( AggregateFunctionExpr {
213
+ fun : Arc :: unwrap_or_clone ( fun) ,
214
+ args,
215
+ logical_args,
216
+ data_type,
217
+ name,
218
+ schema,
219
+ dfschema,
220
+ sort_exprs,
221
+ ordering_req,
222
+ ignore_nulls,
223
+ ordering_fields,
224
+ is_distinct,
225
+ input_type : input_exprs_types[ 0 ] . clone ( ) ,
226
+ is_reversed,
227
+ } ) )
228
+ }
229
+
230
+ pub fn name ( mut self , name : impl Into < String > ) -> Self {
231
+ self . name = name. into ( ) ;
232
+ self
233
+ }
234
+
235
+ pub fn schema ( mut self , schema : Schema ) -> Self {
236
+ self . schema = schema;
237
+ self
238
+ }
239
+
240
+ pub fn dfschema ( mut self , dfschema : DFSchema ) -> Self {
241
+ self . dfschema = dfschema;
242
+ self
243
+ }
244
+
245
+ pub fn order_by ( mut self , order_by : LexOrdering ) -> Self {
246
+ self . ordering_req = order_by;
247
+ self
248
+ }
249
+
250
+ pub fn reversed ( mut self ) -> Self {
251
+ self . is_reversed = true ;
252
+ self
253
+ }
254
+
255
+ pub fn distinct ( mut self ) -> Self {
256
+ self . is_distinct = true ;
257
+ self
258
+ }
259
+
260
+ pub fn ignore_nulls ( mut self ) -> Self {
261
+ self . ignore_nulls = true ;
262
+ self
263
+ }
264
+
265
+ /// This method will be deprecated in <https://github.com/apache/datafusion/issues/11359>
266
+ pub fn sort_exprs ( mut self , sort_exprs : Vec < Expr > ) -> Self {
267
+ self . sort_exprs = sort_exprs;
268
+ self
269
+ }
270
+
271
+ /// This method will be deprecated in <https://github.com/apache/datafusion/issues/11359>
272
+ pub fn logical_exprs ( mut self , logical_args : Vec < Expr > ) -> Self {
273
+ self . logical_args = logical_args;
274
+ self
275
+ }
162
276
}
163
277
164
278
/// An aggregate expression that:
0 commit comments