17
17
18
18
//! Defines physical expressions that can evaluated at runtime during query execution
19
19
20
- use arrow:: array:: { Array , ArrayRef } ;
20
+ use arrow:: array:: { Array , ArrayRef , AsArray } ;
21
21
use arrow:: datatypes:: DataType ;
22
22
use arrow_schema:: Field ;
23
23
@@ -29,6 +29,7 @@ use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
29
29
use datafusion_expr:: utils:: format_state_name;
30
30
use datafusion_expr:: AggregateUDFImpl ;
31
31
use datafusion_expr:: { Accumulator , Signature , Volatility } ;
32
+ use std:: collections:: HashSet ;
32
33
use std:: sync:: Arc ;
33
34
34
35
make_udaf_expr_and_func ! (
@@ -82,6 +83,14 @@ impl AggregateUDFImpl for ArrayAgg {
82
83
}
83
84
84
85
fn state_fields ( & self , args : StateFieldsArgs ) -> Result < Vec < Field > > {
86
+ if args. is_distinct {
87
+ return Ok ( vec ! [ Field :: new_list(
88
+ format_state_name( args. name, "distinct_array_agg" ) ,
89
+ Field :: new( "item" , args. input_type. clone( ) , true ) ,
90
+ true ,
91
+ ) ] ) ;
92
+ }
93
+
85
94
Ok ( vec ! [ Field :: new_list(
86
95
format_state_name( args. name, "array_agg" ) ,
87
96
Field :: new( "item" , args. input_type. clone( ) , true ) ,
@@ -90,6 +99,12 @@ impl AggregateUDFImpl for ArrayAgg {
90
99
}
91
100
92
101
fn accumulator ( & self , acc_args : AccumulatorArgs ) -> Result < Box < dyn Accumulator > > {
102
+ if acc_args. is_distinct {
103
+ return Ok ( Box :: new ( DistinctArrayAggAccumulator :: try_new (
104
+ acc_args. input_type ,
105
+ ) ?) ) ;
106
+ }
107
+
93
108
Ok ( Box :: new ( ArrayAggAccumulator :: try_new ( acc_args. input_type ) ?) )
94
109
}
95
110
}
@@ -170,3 +185,65 @@ impl Accumulator for ArrayAggAccumulator {
170
185
- std:: mem:: size_of_val ( & self . datatype )
171
186
}
172
187
}
188
+
189
+ #[ derive( Debug ) ]
190
+ struct DistinctArrayAggAccumulator {
191
+ values : HashSet < ScalarValue > ,
192
+ datatype : DataType ,
193
+ }
194
+
195
+ impl DistinctArrayAggAccumulator {
196
+ pub fn try_new ( datatype : & DataType ) -> Result < Self > {
197
+ Ok ( Self {
198
+ values : HashSet :: new ( ) ,
199
+ datatype : datatype. clone ( ) ,
200
+ } )
201
+ }
202
+ }
203
+
204
+ impl Accumulator for DistinctArrayAggAccumulator {
205
+ fn state ( & mut self ) -> Result < Vec < ScalarValue > > {
206
+ Ok ( vec ! [ self . evaluate( ) ?] )
207
+ }
208
+
209
+ fn update_batch ( & mut self , values : & [ ArrayRef ] ) -> Result < ( ) > {
210
+ assert_eq ! ( values. len( ) , 1 , "batch input should only include 1 column!" ) ;
211
+
212
+ let array = & values[ 0 ] ;
213
+
214
+ for i in 0 ..array. len ( ) {
215
+ let scalar = ScalarValue :: try_from_array ( & array, i) ?;
216
+ self . values . insert ( scalar) ;
217
+ }
218
+
219
+ Ok ( ( ) )
220
+ }
221
+
222
+ fn merge_batch ( & mut self , states : & [ ArrayRef ] ) -> Result < ( ) > {
223
+ if states. is_empty ( ) {
224
+ return Ok ( ( ) ) ;
225
+ }
226
+
227
+ states[ 0 ]
228
+ . as_list :: < i32 > ( )
229
+ . iter ( )
230
+ . flatten ( )
231
+ . try_for_each ( |val| self . update_batch ( & [ val] ) )
232
+ }
233
+
234
+ fn evaluate ( & mut self ) -> Result < ScalarValue > {
235
+ let values: Vec < ScalarValue > = self . values . iter ( ) . cloned ( ) . collect ( ) ;
236
+ if values. is_empty ( ) {
237
+ return Ok ( ScalarValue :: new_null_list ( self . datatype . clone ( ) , true , 1 ) ) ;
238
+ }
239
+ let arr = ScalarValue :: new_list ( & values, & self . datatype , true ) ;
240
+ Ok ( ScalarValue :: List ( arr) )
241
+ }
242
+
243
+ fn size ( & self ) -> usize {
244
+ std:: mem:: size_of_val ( self ) + ScalarValue :: size_of_hashset ( & self . values )
245
+ - std:: mem:: size_of_val ( & self . values )
246
+ + self . datatype . size ( )
247
+ - std:: mem:: size_of_val ( & self . datatype )
248
+ }
249
+ }
0 commit comments