@@ -21,13 +21,17 @@ use crate::fuzz_cases::aggregation_fuzzer::{
21
21
AggregationFuzzerBuilder , ColumnDescr , DatasetGeneratorConfig , QueryBuilder ,
22
22
} ;
23
23
24
- use arrow:: array:: { types:: Int64Type , Array , ArrayRef , AsArray , Int64Array , RecordBatch } ;
24
+ use arrow:: array:: {
25
+ types:: Int64Type , Array , ArrayRef , AsArray , Int32Array , Int64Array , RecordBatch ,
26
+ StringArray ,
27
+ } ;
25
28
use arrow:: compute:: { concat_batches, SortOptions } ;
26
29
use arrow:: datatypes:: {
27
30
DataType , IntervalUnit , TimeUnit , DECIMAL128_MAX_PRECISION , DECIMAL128_MAX_SCALE ,
28
31
DECIMAL256_MAX_PRECISION , DECIMAL256_MAX_SCALE ,
29
32
} ;
30
33
use arrow:: util:: pretty:: pretty_format_batches;
34
+ use arrow_schema:: { Field , Schema , SchemaRef } ;
31
35
use datafusion:: common:: Result ;
32
36
use datafusion:: datasource:: memory:: MemorySourceConfig ;
33
37
use datafusion:: datasource:: source:: DataSourceExec ;
@@ -42,14 +46,18 @@ use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor}
42
46
use datafusion_common:: HashMap ;
43
47
use datafusion_common_runtime:: JoinSet ;
44
48
use datafusion_functions_aggregate:: sum:: sum_udaf;
45
- use datafusion_physical_expr:: expressions:: col;
49
+ use datafusion_physical_expr:: expressions:: { col, lit , Column } ;
46
50
use datafusion_physical_expr:: PhysicalSortExpr ;
47
51
use datafusion_physical_expr_common:: sort_expr:: LexOrdering ;
48
52
use datafusion_physical_plan:: InputOrderMode ;
49
53
use test_utils:: { add_empty_batches, StringBatchGenerator } ;
50
54
55
+ use datafusion_execution:: memory_pool:: FairSpillPool ;
56
+ use datafusion_execution:: runtime_env:: RuntimeEnvBuilder ;
57
+ use datafusion_execution:: TaskContext ;
58
+ use datafusion_physical_plan:: metrics:: MetricValue ;
51
59
use rand:: rngs:: StdRng ;
52
- use rand:: { thread_rng, Rng , SeedableRng } ;
60
+ use rand:: { random , thread_rng, Rng , SeedableRng } ;
53
61
54
62
// ========================================================================
55
63
// The new aggregation fuzz tests based on [`AggregationFuzzer`]
@@ -663,3 +671,134 @@ fn extract_result_counts(results: Vec<RecordBatch>) -> HashMap<Option<String>, i
663
671
}
664
672
output
665
673
}
674
+
675
+ fn assert_spill_count_metric ( expect_spill : bool , single_aggregate : Arc < AggregateExec > ) {
676
+ if let Some ( metrics_set) = single_aggregate. metrics ( ) {
677
+ let mut spill_count = 0 ;
678
+
679
+ // Inspect metrics for SpillCount
680
+ for metric in metrics_set. iter ( ) {
681
+ if let MetricValue :: SpillCount ( count) = metric. value ( ) {
682
+ spill_count = count. value ( ) ;
683
+ break ;
684
+ }
685
+ }
686
+
687
+ if expect_spill && spill_count == 0 {
688
+ panic ! ( "Expected spill but SpillCount metric not found or SpillCount was 0." ) ;
689
+ } else if !expect_spill && spill_count > 0 {
690
+ panic ! ( "Expected no spill but found SpillCount metric with value greater than 0." ) ;
691
+ }
692
+ } else {
693
+ panic ! ( "No metrics returned from the operator; cannot verify spilling." ) ;
694
+ }
695
+ }
696
+
697
+ // Fix for https://github.com/apache/datafusion/issues/15530
698
+ #[ tokio:: test]
699
+ async fn test_single_mode_aggregate_with_spill ( ) -> Result < ( ) > {
700
+ let scan_schema = Arc :: new ( Schema :: new ( vec ! [
701
+ Field :: new( "col_0" , DataType :: Int64 , true ) ,
702
+ Field :: new( "col_1" , DataType :: Utf8 , true ) ,
703
+ Field :: new( "col_2" , DataType :: Utf8 , true ) ,
704
+ Field :: new( "col_3" , DataType :: Utf8 , true ) ,
705
+ Field :: new( "col_4" , DataType :: Utf8 , true ) ,
706
+ Field :: new( "col_5" , DataType :: Int32 , true ) ,
707
+ Field :: new( "col_6" , DataType :: Utf8 , true ) ,
708
+ Field :: new( "col_7" , DataType :: Utf8 , true ) ,
709
+ Field :: new( "col_8" , DataType :: Utf8 , true ) ,
710
+ ] ) ) ;
711
+
712
+ let group_by = PhysicalGroupBy :: new_single ( vec ! [
713
+ ( Arc :: new( Column :: new( "col_1" , 1 ) ) , "col_1" . to_string( ) ) ,
714
+ ( Arc :: new( Column :: new( "col_7" , 7 ) ) , "col_7" . to_string( ) ) ,
715
+ ( Arc :: new( Column :: new( "col_0" , 0 ) ) , "col_0" . to_string( ) ) ,
716
+ ( Arc :: new( Column :: new( "col_8" , 8 ) ) , "col_8" . to_string( ) ) ,
717
+ ] ) ;
718
+
719
+ fn generate_int64_array ( ) -> ArrayRef {
720
+ Arc :: new ( Int64Array :: from_iter_values (
721
+ ( 0 ..1024 ) . map ( |_| random :: < i64 > ( ) ) ,
722
+ ) )
723
+ }
724
+ fn generate_int32_array ( ) -> ArrayRef {
725
+ Arc :: new ( Int32Array :: from_iter_values (
726
+ ( 0 ..1024 ) . map ( |_| random :: < i32 > ( ) ) ,
727
+ ) )
728
+ }
729
+
730
+ fn generate_string_array ( ) -> ArrayRef {
731
+ Arc :: new ( StringArray :: from (
732
+ ( 0 ..1024 )
733
+ . map ( |_| -> String {
734
+ thread_rng ( )
735
+ . sample_iter :: < char , _ > ( rand:: distributions:: Standard )
736
+ . take ( 5 )
737
+ . collect ( )
738
+ } )
739
+ . collect :: < Vec < _ > > ( ) ,
740
+ ) )
741
+ }
742
+
743
+ fn generate_record_batch ( schema : & SchemaRef ) -> Result < RecordBatch > {
744
+ RecordBatch :: try_new (
745
+ Arc :: clone ( schema) ,
746
+ vec ! [
747
+ generate_int64_array( ) ,
748
+ generate_string_array( ) ,
749
+ generate_string_array( ) ,
750
+ generate_string_array( ) ,
751
+ generate_string_array( ) ,
752
+ generate_int32_array( ) ,
753
+ generate_string_array( ) ,
754
+ generate_string_array( ) ,
755
+ generate_string_array( ) ,
756
+ ] ,
757
+ )
758
+ . map_err ( |err| err. into ( ) )
759
+ }
760
+
761
+ let aggregate_expressions = vec ! [ Arc :: new(
762
+ AggregateExprBuilder :: new( sum_udaf( ) , vec![ lit( 1i64 ) ] )
763
+ . schema( Arc :: clone( & scan_schema) )
764
+ . alias( "SUM(1i64)" )
765
+ . build( ) ?,
766
+ ) ] ;
767
+
768
+ let batches = ( 0 ..5 )
769
+ . map ( |_| generate_record_batch ( & scan_schema) )
770
+ . collect :: < Result < Vec < _ > > > ( ) ?;
771
+
772
+ let plan: Arc < dyn ExecutionPlan > =
773
+ MemorySourceConfig :: try_new_exec ( & [ batches] , Arc :: clone ( & scan_schema) , None )
774
+ . unwrap ( ) ;
775
+
776
+ let single_aggregate = Arc :: new ( AggregateExec :: try_new (
777
+ AggregateMode :: Single ,
778
+ group_by,
779
+ aggregate_expressions. clone ( ) ,
780
+ vec ! [ None ; aggregate_expressions. len( ) ] ,
781
+ plan,
782
+ Arc :: clone ( & scan_schema) ,
783
+ ) ?) ;
784
+
785
+ let memory_pool = Arc :: new ( FairSpillPool :: new ( 250000 ) ) ;
786
+ let task_ctx = Arc :: new (
787
+ TaskContext :: default ( )
788
+ . with_session_config ( SessionConfig :: new ( ) . with_batch_size ( 248 ) )
789
+ . with_runtime ( Arc :: new (
790
+ RuntimeEnvBuilder :: new ( )
791
+ . with_memory_pool ( memory_pool)
792
+ . build ( ) ?,
793
+ ) ) ,
794
+ ) ;
795
+
796
+ datafusion_physical_plan:: common:: collect (
797
+ single_aggregate. execute ( 0 , Arc :: clone ( & task_ctx) ) ?,
798
+ )
799
+ . await ?;
800
+
801
+ assert_spill_count_metric ( true , single_aggregate) ;
802
+
803
+ Ok ( ( ) )
804
+ }
0 commit comments