@@ -22,21 +22,32 @@ use arrow::compute::{concat_batches, SortOptions};
22
22
use arrow:: datatypes:: DataType ;
23
23
use arrow:: record_batch:: RecordBatch ;
24
24
use arrow:: util:: pretty:: pretty_format_batches;
25
- use datafusion:: physical_plan:: aggregates:: {
26
- AggregateExec , AggregateMode , PhysicalGroupBy ,
27
- } ;
25
+ use arrow_array:: cast:: AsArray ;
26
+ use arrow_array:: types:: Int64Type ;
27
+ use arrow_array:: Array ;
28
+ use hashbrown:: HashMap ;
28
29
use rand:: rngs:: StdRng ;
29
30
use rand:: { Rng , SeedableRng } ;
31
+ use tokio:: task:: JoinSet ;
30
32
33
+ use datafusion:: common:: Result ;
34
+ use datafusion:: datasource:: MemTable ;
35
+ use datafusion:: physical_plan:: aggregates:: {
36
+ AggregateExec , AggregateMode , PhysicalGroupBy ,
37
+ } ;
31
38
use datafusion:: physical_plan:: memory:: MemoryExec ;
32
39
use datafusion:: physical_plan:: { collect, displayable, ExecutionPlan } ;
33
- use datafusion:: prelude:: { SessionConfig , SessionContext } ;
40
+ use datafusion:: prelude:: { DataFrame , SessionConfig , SessionContext } ;
41
+ use datafusion_common:: tree_node:: { TreeNode , TreeNodeVisitor , VisitRecursion } ;
34
42
use datafusion_physical_expr:: expressions:: { col, Sum } ;
35
43
use datafusion_physical_expr:: { AggregateExpr , PhysicalSortExpr } ;
36
- use test_utils:: add_empty_batches;
44
+ use datafusion_physical_plan:: InputOrderMode ;
45
+ use test_utils:: { add_empty_batches, StringBatchGenerator } ;
37
46
38
- #[ tokio:: test( flavor = "multi_thread" , worker_threads = 8 ) ]
39
- async fn aggregate_test ( ) {
47
+ /// Tests that streaming aggregate and batch (non streaming) aggregate produce
48
+ /// same results
49
+ #[ tokio:: test( flavor = "multi_thread" ) ]
50
+ async fn streaming_aggregate_test ( ) {
40
51
let test_cases = vec ! [
41
52
vec![ "a" ] ,
42
53
vec![ "b" , "a" ] ,
@@ -50,18 +61,18 @@ async fn aggregate_test() {
50
61
let n = 300 ;
51
62
let distincts = vec ! [ 10 , 20 ] ;
52
63
for distinct in distincts {
53
- let mut handles = Vec :: new ( ) ;
64
+ let mut join_set = JoinSet :: new ( ) ;
54
65
for i in 0 ..n {
55
66
let test_idx = i % test_cases. len ( ) ;
56
67
let group_by_columns = test_cases[ test_idx] . clone ( ) ;
57
- let job = tokio :: spawn ( run_aggregate_test (
68
+ join_set . spawn ( run_aggregate_test (
58
69
make_staggered_batches :: < true > ( 1000 , distinct, i as u64 ) ,
59
70
group_by_columns,
60
71
) ) ;
61
- handles. push ( job) ;
62
72
}
63
- for job in handles {
64
- job. await . unwrap ( ) ;
73
+ while let Some ( join_handle) = join_set. join_next ( ) . await {
74
+ // propagate errors
75
+ join_handle. unwrap ( ) ;
65
76
}
66
77
}
67
78
}
@@ -234,3 +245,158 @@ pub(crate) fn make_staggered_batches<const STREAM: bool>(
234
245
}
235
246
add_empty_batches ( batches, & mut rng)
236
247
}
248
+
249
+ /// Test group by with string/large string columns
250
+ #[ tokio:: test( flavor = "multi_thread" ) ]
251
+ async fn group_by_strings ( ) {
252
+ let mut join_set = JoinSet :: new ( ) ;
253
+ for large in [ true , false ] {
254
+ for sorted in [ true , false ] {
255
+ for generator in StringBatchGenerator :: interesting_cases ( ) {
256
+ join_set. spawn ( group_by_string_test ( generator, sorted, large) ) ;
257
+ }
258
+ }
259
+ }
260
+ while let Some ( join_handle) = join_set. join_next ( ) . await {
261
+ // propagate errors
262
+ join_handle. unwrap ( ) ;
263
+ }
264
+ }
265
+
266
+ /// Run GROUP BY <x> using SQL and ensure the results are correct
267
+ ///
268
+ /// If sorted is true, the input batches will be sorted by the group by column
269
+ /// to test the streaming group by case
270
+ ///
271
+ /// if large is true, the input batches will be LargeStringArray
272
+ async fn group_by_string_test (
273
+ mut generator : StringBatchGenerator ,
274
+ sorted : bool ,
275
+ large : bool ,
276
+ ) {
277
+ let column_name = "a" ;
278
+ let input = if sorted {
279
+ generator. make_sorted_input_batches ( large)
280
+ } else {
281
+ generator. make_input_batches ( )
282
+ } ;
283
+
284
+ let expected = compute_counts ( & input, column_name) ;
285
+
286
+ let schema = input[ 0 ] . schema ( ) ;
287
+ let session_config = SessionConfig :: new ( ) . with_batch_size ( 50 ) ;
288
+ let ctx = SessionContext :: new_with_config ( session_config) ;
289
+
290
+ let provider = MemTable :: try_new ( schema. clone ( ) , vec ! [ input] ) . unwrap ( ) ;
291
+ let provider = if sorted {
292
+ let sort_expr = datafusion:: prelude:: col ( "a" ) . sort ( true , true ) ;
293
+ provider. with_sort_order ( vec ! [ vec![ sort_expr] ] )
294
+ } else {
295
+ provider
296
+ } ;
297
+
298
+ ctx. register_table ( "t" , Arc :: new ( provider) ) . unwrap ( ) ;
299
+
300
+ let df = ctx
301
+ . sql ( "SELECT a, COUNT(*) FROM t GROUP BY a" )
302
+ . await
303
+ . unwrap ( ) ;
304
+ verify_ordered_aggregate ( & df, sorted) . await ;
305
+ let results = df. collect ( ) . await . unwrap ( ) ;
306
+
307
+ // verify that the results are correct
308
+ let actual = extract_result_counts ( results) ;
309
+ assert_eq ! ( expected, actual) ;
310
+ }
311
+ async fn verify_ordered_aggregate ( frame : & DataFrame , expected_sort : bool ) {
312
+ struct Visitor {
313
+ expected_sort : bool ,
314
+ }
315
+ let mut visitor = Visitor { expected_sort } ;
316
+
317
+ impl TreeNodeVisitor for Visitor {
318
+ type N = Arc < dyn ExecutionPlan > ;
319
+ fn pre_visit ( & mut self , node : & Self :: N ) -> Result < VisitRecursion > {
320
+ if let Some ( exec) = node. as_any ( ) . downcast_ref :: < AggregateExec > ( ) {
321
+ if self . expected_sort {
322
+ assert ! ( matches!(
323
+ exec. input_order_mode( ) ,
324
+ InputOrderMode :: PartiallySorted ( _) | InputOrderMode :: Sorted
325
+ ) ) ;
326
+ } else {
327
+ assert ! ( matches!( exec. input_order_mode( ) , InputOrderMode :: Linear ) ) ;
328
+ }
329
+ }
330
+ Ok ( VisitRecursion :: Continue )
331
+ }
332
+ }
333
+
334
+ let plan = frame. clone ( ) . create_physical_plan ( ) . await . unwrap ( ) ;
335
+ plan. visit ( & mut visitor) . unwrap ( ) ;
336
+ }
337
+
338
+ /// Compute the count of each distinct value in the specified column
339
+ ///
340
+ /// ```text
341
+ /// +---------------+---------------+
342
+ /// | a | b |
343
+ /// +---------------+---------------+
344
+ /// | 𭏷𑩁 | 𘱦𫎛 |
345
+ /// | | 𬿪 |
346
+ /// ```
347
+ fn compute_counts ( batches : & [ RecordBatch ] , col : & str ) -> HashMap < Option < String > , i64 > {
348
+ let mut output = HashMap :: new ( ) ;
349
+ for arr in batches
350
+ . iter ( )
351
+ . map ( |batch| batch. column_by_name ( col) . unwrap ( ) )
352
+ {
353
+ for value in to_str_vec ( arr) {
354
+ output. entry ( value) . and_modify ( |e| * e += 1 ) . or_insert ( 1 ) ;
355
+ }
356
+ }
357
+ output
358
+ }
359
+
360
+ fn to_str_vec ( array : & ArrayRef ) -> Vec < Option < String > > {
361
+ match array. data_type ( ) {
362
+ DataType :: Utf8 => array
363
+ . as_string :: < i32 > ( )
364
+ . iter ( )
365
+ . map ( |x| x. map ( |x| x. to_string ( ) ) )
366
+ . collect ( ) ,
367
+ DataType :: LargeUtf8 => array
368
+ . as_string :: < i64 > ( )
369
+ . iter ( )
370
+ . map ( |x| x. map ( |x| x. to_string ( ) ) )
371
+ . collect ( ) ,
372
+ _ => panic ! ( "unexpected type" ) ,
373
+ }
374
+ }
375
+
376
+ /// extracts the value of the first column and the count of the second column
377
+ /// ```text
378
+ /// +----------------+----------+
379
+ /// | a | COUNT(*) |
380
+ /// +----------------+----------+
381
+ /// | | 8 |
382
+ /// | | 11 |
383
+ /// ```
384
+ fn extract_result_counts ( results : Vec < RecordBatch > ) -> HashMap < Option < String > , i64 > {
385
+ let group_arrays = results. iter ( ) . map ( |batch| batch. column ( 0 ) ) ;
386
+
387
+ let count_arrays = results
388
+ . iter ( )
389
+ . map ( |batch| batch. column ( 1 ) . as_primitive :: < Int64Type > ( ) ) ;
390
+
391
+ let mut output = HashMap :: new ( ) ;
392
+ for ( group_arr, count_arr) in group_arrays. zip ( count_arrays) {
393
+ assert_eq ! ( group_arr. len( ) , count_arr. len( ) ) ;
394
+ let group_values = to_str_vec ( group_arr) ;
395
+ for ( group, count) in group_values. into_iter ( ) . zip ( count_arr. iter ( ) ) {
396
+ assert ! ( output. get( & group) . is_none( ) ) ;
397
+ let count = count. unwrap ( ) ; // counts can never be null
398
+ output. insert ( group, count) ;
399
+ }
400
+ }
401
+ output
402
+ }
0 commit comments