@@ -20,7 +20,11 @@ use std::fmt::{Debug, Formatter};
20
20
use std:: mem:: { size_of, size_of_val} ;
21
21
use std:: sync:: Arc ;
22
22
23
- use arrow:: array:: { downcast_integer, ArrowNumericType } ;
23
+ use arrow:: array:: {
24
+ downcast_integer, ArrowNumericType , BooleanArray , ListArray , PrimitiveArray ,
25
+ PrimitiveBuilder ,
26
+ } ;
27
+ use arrow:: buffer:: { OffsetBuffer , ScalarBuffer } ;
24
28
use arrow:: {
25
29
array:: { ArrayRef , AsArray } ,
26
30
datatypes:: {
@@ -33,12 +37,17 @@ use arrow::array::Array;
33
37
use arrow:: array:: ArrowNativeTypeOp ;
34
38
use arrow:: datatypes:: { ArrowNativeType , ArrowPrimitiveType } ;
35
39
36
- use datafusion_common:: { DataFusionError , HashSet , Result , ScalarValue } ;
40
+ use datafusion_common:: {
41
+ internal_datafusion_err, internal_err, DataFusionError , HashSet , Result , ScalarValue ,
42
+ } ;
37
43
use datafusion_expr:: function:: StateFieldsArgs ;
38
44
use datafusion_expr:: {
39
45
function:: AccumulatorArgs , utils:: format_state_name, Accumulator , AggregateUDFImpl ,
40
46
Documentation , Signature , Volatility ,
41
47
} ;
48
+ use datafusion_expr:: { EmitTo , GroupsAccumulator } ;
49
+ use datafusion_functions_aggregate_common:: aggregate:: groups_accumulator:: accumulate:: accumulate;
50
+ use datafusion_functions_aggregate_common:: aggregate:: groups_accumulator:: nulls:: filtered_null_mask;
42
51
use datafusion_functions_aggregate_common:: utils:: Hashable ;
43
52
use datafusion_macros:: user_doc;
44
53
@@ -165,6 +174,45 @@ impl AggregateUDFImpl for Median {
165
174
}
166
175
}
167
176
177
+ fn groups_accumulator_supported ( & self , args : AccumulatorArgs ) -> bool {
178
+ !args. is_distinct
179
+ }
180
+
181
+ fn create_groups_accumulator (
182
+ & self ,
183
+ args : AccumulatorArgs ,
184
+ ) -> Result < Box < dyn GroupsAccumulator > > {
185
+ let num_args = args. exprs . len ( ) ;
186
+ if num_args != 1 {
187
+ return internal_err ! (
188
+ "median should only have 1 arg, but found num args:{}" ,
189
+ args. exprs. len( )
190
+ ) ;
191
+ }
192
+
193
+ let dt = args. exprs [ 0 ] . data_type ( args. schema ) ?;
194
+
195
+ macro_rules! helper {
196
+ ( $t: ty, $dt: expr) => {
197
+ Ok ( Box :: new( MedianGroupsAccumulator :: <$t>:: new( $dt) ) )
198
+ } ;
199
+ }
200
+
201
+ downcast_integer ! {
202
+ dt => ( helper, dt) ,
203
+ DataType :: Float16 => helper!( Float16Type , dt) ,
204
+ DataType :: Float32 => helper!( Float32Type , dt) ,
205
+ DataType :: Float64 => helper!( Float64Type , dt) ,
206
+ DataType :: Decimal128 ( _, _) => helper!( Decimal128Type , dt) ,
207
+ DataType :: Decimal256 ( _, _) => helper!( Decimal256Type , dt) ,
208
+ _ => Err ( DataFusionError :: NotImplemented ( format!(
209
+ "MedianGroupsAccumulator not supported for {} with {}" ,
210
+ args. name,
211
+ dt,
212
+ ) ) ) ,
213
+ }
214
+ }
215
+
168
216
fn aliases ( & self ) -> & [ String ] {
169
217
& [ ]
170
218
}
@@ -230,6 +278,216 @@ impl<T: ArrowNumericType> Accumulator for MedianAccumulator<T> {
230
278
}
231
279
}
232
280
281
+ /// The median groups accumulator accumulates the raw input values
282
+ ///
283
+ /// For calculating the accurate medians of groups, we need to store all values
284
+ /// of groups before final evaluation.
285
+ /// So values in each group will be stored in a `Vec<T>`, and the total group values
286
+ /// will be actually organized as a `Vec<Vec<T>>`.
287
+ ///
288
+ #[ derive( Debug ) ]
289
+ struct MedianGroupsAccumulator < T : ArrowNumericType + Send > {
290
+ data_type : DataType ,
291
+ group_values : Vec < Vec < T :: Native > > ,
292
+ }
293
+
294
+ impl < T : ArrowNumericType + Send > MedianGroupsAccumulator < T > {
295
+ pub fn new ( data_type : DataType ) -> Self {
296
+ Self {
297
+ data_type,
298
+ group_values : Vec :: new ( ) ,
299
+ }
300
+ }
301
+ }
302
+
303
+ impl < T : ArrowNumericType + Send > GroupsAccumulator for MedianGroupsAccumulator < T > {
304
+ fn update_batch (
305
+ & mut self ,
306
+ values : & [ ArrayRef ] ,
307
+ group_indices : & [ usize ] ,
308
+ opt_filter : Option < & BooleanArray > ,
309
+ total_num_groups : usize ,
310
+ ) -> Result < ( ) > {
311
+ assert_eq ! ( values. len( ) , 1 , "single argument to update_batch" ) ;
312
+ let values = values[ 0 ] . as_primitive :: < T > ( ) ;
313
+
314
+ // Push the `not nulls + not filtered` row into its group
315
+ self . group_values . resize ( total_num_groups, Vec :: new ( ) ) ;
316
+ accumulate (
317
+ group_indices,
318
+ values,
319
+ opt_filter,
320
+ |group_index, new_value| {
321
+ self . group_values [ group_index] . push ( new_value) ;
322
+ } ,
323
+ ) ;
324
+
325
+ Ok ( ( ) )
326
+ }
327
+
328
+ fn merge_batch (
329
+ & mut self ,
330
+ values : & [ ArrayRef ] ,
331
+ group_indices : & [ usize ] ,
332
+ // Since aggregate filter should be applied in partial stage, in final stage there should be no filter
333
+ _opt_filter : Option < & BooleanArray > ,
334
+ total_num_groups : usize ,
335
+ ) -> Result < ( ) > {
336
+ assert_eq ! ( values. len( ) , 1 , "one argument to merge_batch" ) ;
337
+
338
+ // The merged values should be organized like as a `ListArray` which is nullable
339
+ // (input with nulls usually generated from `convert_to_state`), but `inner array` of
340
+ // `ListArray` is `non-nullable`.
341
+ //
342
+ // Following is the possible and impossible input `values`:
343
+ //
344
+ // # Possible values
345
+ // ```text
346
+ // group 0: [1, 2, 3]
347
+ // group 1: null (list array is nullable)
348
+ // group 2: [6, 7, 8]
349
+ // ...
350
+ // group n: [...]
351
+ // ```
352
+ //
353
+ // # Impossible values
354
+ // ```text
355
+ // group x: [1, 2, null] (values in list array is non-nullable)
356
+ // ```
357
+ //
358
+ let input_group_values = values[ 0 ] . as_list :: < i32 > ( ) ;
359
+
360
+ // Ensure group values big enough
361
+ self . group_values . resize ( total_num_groups, Vec :: new ( ) ) ;
362
+
363
+ // Extend values to related groups
364
+ // TODO: avoid using iterator of the `ListArray`, this will lead to
365
+ // many calls of `slice` of its ``inner array`, and `slice` is not
366
+ // so efficient(due to the calculation of `null_count` for each `slice`).
367
+ group_indices
368
+ . iter ( )
369
+ . zip ( input_group_values. iter ( ) )
370
+ . for_each ( |( & group_index, values_opt) | {
371
+ if let Some ( values) = values_opt {
372
+ let values = values. as_primitive :: < T > ( ) ;
373
+ self . group_values [ group_index] . extend ( values. values ( ) . iter ( ) ) ;
374
+ }
375
+ } ) ;
376
+
377
+ Ok ( ( ) )
378
+ }
379
+
380
+ fn state ( & mut self , emit_to : EmitTo ) -> Result < Vec < ArrayRef > > {
381
+ // Emit values
382
+ let emit_group_values = emit_to. take_needed ( & mut self . group_values ) ;
383
+
384
+ // Build offsets
385
+ let mut offsets = Vec :: with_capacity ( self . group_values . len ( ) + 1 ) ;
386
+ offsets. push ( 0 ) ;
387
+ let mut cur_len = 0_i32 ;
388
+ for group_value in & emit_group_values {
389
+ cur_len += group_value. len ( ) as i32 ;
390
+ offsets. push ( cur_len) ;
391
+ }
392
+ // TODO: maybe we can use `OffsetBuffer::new_unchecked` like what in `convert_to_state`,
393
+ // but safety should be considered more carefully here(and I am not sure if it can get
394
+ // performance improvement when we introduce checks to keep the safety...).
395
+ //
396
+ // Can see more details in:
397
+ // https://github.com/apache/datafusion/pull/13681#discussion_r1931209791
398
+ //
399
+ let offsets = OffsetBuffer :: new ( ScalarBuffer :: from ( offsets) ) ;
400
+
401
+ // Build inner array
402
+ let flatten_group_values =
403
+ emit_group_values. into_iter ( ) . flatten ( ) . collect :: < Vec < _ > > ( ) ;
404
+ let group_values_array =
405
+ PrimitiveArray :: < T > :: new ( ScalarBuffer :: from ( flatten_group_values) , None )
406
+ . with_data_type ( self . data_type . clone ( ) ) ;
407
+
408
+ // Build the result list array
409
+ let result_list_array = ListArray :: new (
410
+ Arc :: new ( Field :: new_list_field ( self . data_type . clone ( ) , true ) ) ,
411
+ offsets,
412
+ Arc :: new ( group_values_array) ,
413
+ None ,
414
+ ) ;
415
+
416
+ Ok ( vec ! [ Arc :: new( result_list_array) ] )
417
+ }
418
+
419
+ fn evaluate ( & mut self , emit_to : EmitTo ) -> Result < ArrayRef > {
420
+ // Emit values
421
+ let emit_group_values = emit_to. take_needed ( & mut self . group_values ) ;
422
+
423
+ // Calculate median for each group
424
+ let mut evaluate_result_builder =
425
+ PrimitiveBuilder :: < T > :: new ( ) . with_data_type ( self . data_type . clone ( ) ) ;
426
+ for values in emit_group_values {
427
+ let median = calculate_median :: < T > ( values) ;
428
+ evaluate_result_builder. append_option ( median) ;
429
+ }
430
+
431
+ Ok ( Arc :: new ( evaluate_result_builder. finish ( ) ) )
432
+ }
433
+
434
+ fn convert_to_state (
435
+ & self ,
436
+ values : & [ ArrayRef ] ,
437
+ opt_filter : Option < & BooleanArray > ,
438
+ ) -> Result < Vec < ArrayRef > > {
439
+ assert_eq ! ( values. len( ) , 1 , "one argument to merge_batch" ) ;
440
+
441
+ let input_array = values[ 0 ] . as_primitive :: < T > ( ) ;
442
+
443
+ // Directly convert the input array to states, each row will be
444
+ // seen as a respective group.
445
+ // For detail, the `input_array` will be converted to a `ListArray`.
446
+ // And if row is `not null + not filtered`, it will be converted to a list
447
+ // with only one element; otherwise, this row in `ListArray` will be set
448
+ // to null.
449
+
450
+ // Reuse values buffer in `input_array` to build `values` in `ListArray`
451
+ let values = PrimitiveArray :: < T > :: new ( input_array. values ( ) . clone ( ) , None )
452
+ . with_data_type ( self . data_type . clone ( ) ) ;
453
+
454
+ // `offsets` in `ListArray`, each row as a list element
455
+ let offset_end = i32:: try_from ( input_array. len ( ) ) . map_err ( |e| {
456
+ internal_datafusion_err ! (
457
+ "cast array_len to i32 failed in convert_to_state of group median, err:{e:?}"
458
+ )
459
+ } ) ?;
460
+ let offsets = ( 0 ..=offset_end) . collect :: < Vec < _ > > ( ) ;
461
+ // Safety: all checks in `OffsetBuffer::new` are ensured to pass
462
+ let offsets = unsafe { OffsetBuffer :: new_unchecked ( ScalarBuffer :: from ( offsets) ) } ;
463
+
464
+ // `nulls` for converted `ListArray`
465
+ let nulls = filtered_null_mask ( opt_filter, input_array) ;
466
+
467
+ let converted_list_array = ListArray :: new (
468
+ Arc :: new ( Field :: new_list_field ( self . data_type . clone ( ) , true ) ) ,
469
+ offsets,
470
+ Arc :: new ( values) ,
471
+ nulls,
472
+ ) ;
473
+
474
+ Ok ( vec ! [ Arc :: new( converted_list_array) ] )
475
+ }
476
+
477
+ fn supports_convert_to_state ( & self ) -> bool {
478
+ true
479
+ }
480
+
481
+ fn size ( & self ) -> usize {
482
+ self . group_values
483
+ . iter ( )
484
+ . map ( |values| values. capacity ( ) * size_of :: < T > ( ) )
485
+ . sum :: < usize > ( )
486
+ // account for size of self.grou_values too
487
+ + self . group_values . capacity ( ) * size_of :: < Vec < T > > ( )
488
+ }
489
+ }
490
+
233
491
/// The distinct median accumulator accumulates the raw input values
234
492
/// as `ScalarValue`s
235
493
///
0 commit comments