@@ -1200,8 +1200,10 @@ mod tests {
1200
1200
1201
1201
use arrow:: array:: { Float64Array , UInt32Array } ;
1202
1202
use arrow:: compute:: { concat_batches, SortOptions } ;
1203
- use arrow:: datatypes:: DataType ;
1204
- use arrow_array:: { Float32Array , Int32Array } ;
1203
+ use arrow:: datatypes:: { DataType , Int32Type } ;
1204
+ use arrow_array:: {
1205
+ DictionaryArray , Float32Array , Int32Array , StructArray , UInt64Array ,
1206
+ } ;
1205
1207
use datafusion_common:: {
1206
1208
assert_batches_eq, assert_batches_sorted_eq, internal_err, DataFusionError ,
1207
1209
ScalarValue ,
@@ -1214,6 +1216,7 @@ mod tests {
1214
1216
use datafusion_functions_aggregate:: count:: count_udaf;
1215
1217
use datafusion_functions_aggregate:: first_last:: { first_value_udaf, last_value_udaf} ;
1216
1218
use datafusion_functions_aggregate:: median:: median_udaf;
1219
+ use datafusion_functions_aggregate:: sum:: sum_udaf;
1217
1220
use datafusion_physical_expr:: expressions:: lit;
1218
1221
use datafusion_physical_expr:: PhysicalSortExpr ;
1219
1222
@@ -2316,6 +2319,127 @@ mod tests {
2316
2319
Ok ( ( ) )
2317
2320
}
2318
2321
2322
+ #[ tokio:: test]
2323
+ async fn test_agg_exec_struct_of_dicts ( ) -> Result < ( ) > {
2324
+ let batch = RecordBatch :: try_new (
2325
+ Arc :: new ( Schema :: new ( vec ! [
2326
+ Field :: new(
2327
+ "labels" . to_string( ) ,
2328
+ DataType :: Struct (
2329
+ vec![
2330
+ Field :: new_dict(
2331
+ "a" . to_string( ) ,
2332
+ DataType :: Dictionary (
2333
+ Box :: new( DataType :: Int32 ) ,
2334
+ Box :: new( DataType :: Utf8 ) ,
2335
+ ) ,
2336
+ true ,
2337
+ 0 ,
2338
+ false ,
2339
+ ) ,
2340
+ Field :: new_dict(
2341
+ "b" . to_string( ) ,
2342
+ DataType :: Dictionary (
2343
+ Box :: new( DataType :: Int32 ) ,
2344
+ Box :: new( DataType :: Utf8 ) ,
2345
+ ) ,
2346
+ true ,
2347
+ 0 ,
2348
+ false ,
2349
+ ) ,
2350
+ ]
2351
+ . into( ) ,
2352
+ ) ,
2353
+ false ,
2354
+ ) ,
2355
+ Field :: new( "value" , DataType :: UInt64 , false ) ,
2356
+ ] ) ) ,
2357
+ vec ! [
2358
+ Arc :: new( StructArray :: from( vec![
2359
+ (
2360
+ Arc :: new( Field :: new_dict(
2361
+ "a" . to_string( ) ,
2362
+ DataType :: Dictionary (
2363
+ Box :: new( DataType :: Int32 ) ,
2364
+ Box :: new( DataType :: Utf8 ) ,
2365
+ ) ,
2366
+ true ,
2367
+ 0 ,
2368
+ false ,
2369
+ ) ) ,
2370
+ Arc :: new(
2371
+ vec![ Some ( "a" ) , None , Some ( "a" ) ]
2372
+ . into_iter( )
2373
+ . collect:: <DictionaryArray <Int32Type >>( ) ,
2374
+ ) as ArrayRef ,
2375
+ ) ,
2376
+ (
2377
+ Arc :: new( Field :: new_dict(
2378
+ "b" . to_string( ) ,
2379
+ DataType :: Dictionary (
2380
+ Box :: new( DataType :: Int32 ) ,
2381
+ Box :: new( DataType :: Utf8 ) ,
2382
+ ) ,
2383
+ true ,
2384
+ 0 ,
2385
+ false ,
2386
+ ) ) ,
2387
+ Arc :: new(
2388
+ vec![ Some ( "b" ) , Some ( "c" ) , Some ( "b" ) ]
2389
+ . into_iter( )
2390
+ . collect:: <DictionaryArray <Int32Type >>( ) ,
2391
+ ) as ArrayRef ,
2392
+ ) ,
2393
+ ] ) ) ,
2394
+ Arc :: new( UInt64Array :: from( vec![ 1 , 1 , 1 ] ) ) ,
2395
+ ] ,
2396
+ )
2397
+ . expect ( "Failed to create RecordBatch" ) ;
2398
+
2399
+ let group_by = PhysicalGroupBy :: new_single ( vec ! [ (
2400
+ col( "labels" , & batch. schema( ) ) ?,
2401
+ "labels" . to_string( ) ,
2402
+ ) ] ) ;
2403
+
2404
+ let aggr_expr = vec ! [ AggregateExprBuilder :: new(
2405
+ sum_udaf( ) ,
2406
+ vec![ col( "value" , & batch. schema( ) ) ?] ,
2407
+ )
2408
+ . schema( Arc :: clone( & batch. schema( ) ) )
2409
+ . alias( String :: from( "SUM(value)" ) )
2410
+ . build( ) ?] ;
2411
+
2412
+ let input = Arc :: new ( MemoryExec :: try_new (
2413
+ & [ vec ! [ batch. clone( ) ] ] ,
2414
+ Arc :: < arrow_schema:: Schema > :: clone ( & batch. schema ( ) ) ,
2415
+ None ,
2416
+ ) ?) ;
2417
+ let aggregate_exec = Arc :: new ( AggregateExec :: try_new (
2418
+ AggregateMode :: FinalPartitioned ,
2419
+ group_by,
2420
+ aggr_expr,
2421
+ vec ! [ None ] ,
2422
+ Arc :: clone ( & input) as Arc < dyn ExecutionPlan > ,
2423
+ batch. schema ( ) ,
2424
+ ) ?) ;
2425
+
2426
+ let session_config = SessionConfig :: default ( ) ;
2427
+ let ctx = TaskContext :: default ( ) . with_session_config ( session_config) ;
2428
+ let output = collect ( aggregate_exec. execute ( 0 , Arc :: new ( ctx) ) ?) . await ?;
2429
+
2430
+ let expected = [
2431
+ "+--------------+------------+" ,
2432
+ "| labels | SUM(value) |" ,
2433
+ "+--------------+------------+" ,
2434
+ "| {a: a, b: b} | 2 |" ,
2435
+ "| {a: , b: c} | 1 |" ,
2436
+ "+--------------+------------+" ,
2437
+ ] ;
2438
+ assert_batches_eq ! ( expected, & output) ;
2439
+
2440
+ Ok ( ( ) )
2441
+ }
2442
+
2319
2443
#[ tokio:: test]
2320
2444
async fn test_skip_aggregation_after_first_batch ( ) -> Result < ( ) > {
2321
2445
let schema = Arc :: new ( Schema :: new ( vec ! [
0 commit comments