@@ -609,14 +609,11 @@ impl Stream for GroupedHashAggregateStream {
609
609
match & self . exec_state {
610
610
ExecutionState :: ReadingInput => ' reading_input: {
611
611
match ready ! ( self . input. poll_next_unpin( cx) ) {
612
- // new batch to aggregate
613
- Some ( Ok ( batch) ) => {
612
+ // New batch to aggregate in partial aggregation operator
613
+ Some ( Ok ( batch) ) if self . mode == AggregateMode :: Partial => {
614
614
let timer = elapsed_compute. timer ( ) ;
615
615
let input_rows = batch. num_rows ( ) ;
616
616
617
- // Make sure we have enough capacity for `batch`, otherwise spill
618
- extract_ok ! ( self . spill_previous_if_necessary( & batch) ) ;
619
-
620
617
// Do the grouping
621
618
extract_ok ! ( self . group_aggregate_batch( batch) ) ;
622
619
@@ -649,10 +646,49 @@ impl Stream for GroupedHashAggregateStream {
649
646
650
647
timer. done ( ) ;
651
648
}
649
+
650
+ // New batch to aggregate in terminal aggregation operator
651
+ // (Final/FinalPartitioned/Single/SinglePartitioned)
652
+ Some ( Ok ( batch) ) => {
653
+ let timer = elapsed_compute. timer ( ) ;
654
+
655
+ // Make sure we have enough capacity for `batch`, otherwise spill
656
+ extract_ok ! ( self . spill_previous_if_necessary( & batch) ) ;
657
+
658
+ // Do the grouping
659
+ extract_ok ! ( self . group_aggregate_batch( batch) ) ;
660
+
661
+ // If we can begin emitting rows, do so,
662
+ // otherwise keep consuming input
663
+ assert ! ( !self . input_done) ;
664
+
665
+ // If the number of group values equals or exceeds the soft limit,
666
+ // emit all groups and switch to producing output
667
+ if self . hit_soft_group_limit ( ) {
668
+ timer. done ( ) ;
669
+ extract_ok ! ( self . set_input_done_and_produce_output( ) ) ;
670
+ // make sure the exec_state just set is not overwritten below
671
+ break ' reading_input;
672
+ }
673
+
674
+ if let Some ( to_emit) = self . group_ordering . emit_to ( ) {
675
+ let batch = extract_ok ! ( self . emit( to_emit, false ) ) ;
676
+ self . exec_state = ExecutionState :: ProducingOutput ( batch) ;
677
+ timer. done ( ) ;
678
+ // make sure the exec_state just set is not overwritten below
679
+ break ' reading_input;
680
+ }
681
+
682
+ timer. done ( ) ;
683
+ }
684
+
685
+ // Found error from input stream
652
686
Some ( Err ( e) ) => {
653
687
// inner had error, return to caller
654
688
return Poll :: Ready ( Some ( Err ( e) ) ) ;
655
689
}
690
+
691
+ // Found end from input stream
656
692
None => {
657
693
// inner is done, emit all rows and switch to producing output
658
694
extract_ok ! ( self . set_input_done_and_produce_output( ) ) ;
@@ -691,7 +727,12 @@ impl Stream for GroupedHashAggregateStream {
691
727
(
692
728
if self . input_done {
693
729
ExecutionState :: Done
694
- } else if self . should_skip_aggregation ( ) {
730
+ }
731
+ // In Partial aggregation, we also need to check
732
+ // if we should trigger partial skipping
733
+ else if self . mode == AggregateMode :: Partial
734
+ && self . should_skip_aggregation ( )
735
+ {
695
736
ExecutionState :: SkippingAggregation
696
737
} else {
697
738
ExecutionState :: ReadingInput
@@ -879,10 +920,10 @@ impl GroupedHashAggregateStream {
879
920
if self . group_values . len ( ) > 0
880
921
&& batch. num_rows ( ) > 0
881
922
&& matches ! ( self . group_ordering, GroupOrdering :: None )
882
- && !matches ! ( self . mode, AggregateMode :: Partial )
883
923
&& !self . spill_state . is_stream_merging
884
924
&& self . update_memory_reservation ( ) . is_err ( )
885
925
{
926
+ assert_ne ! ( self . mode, AggregateMode :: Partial ) ;
886
927
// Use input batch (Partial mode) schema for spilling because
887
928
// the spilled data will be merged and re-evaluated later.
888
929
self . spill_state . spill_schema = batch. schema ( ) ;
@@ -927,9 +968,9 @@ impl GroupedHashAggregateStream {
927
968
fn emit_early_if_necessary ( & mut self ) -> Result < ( ) > {
928
969
if self . group_values . len ( ) >= self . batch_size
929
970
&& matches ! ( self . group_ordering, GroupOrdering :: None )
930
- && matches ! ( self . mode, AggregateMode :: Partial )
931
971
&& self . update_memory_reservation ( ) . is_err ( )
932
972
{
973
+ assert_eq ! ( self . mode, AggregateMode :: Partial ) ;
933
974
let n = self . group_values . len ( ) / self . batch_size * self . batch_size ;
934
975
let batch = self . emit ( EmitTo :: First ( n) , false ) ?;
935
976
self . exec_state = ExecutionState :: ProducingOutput ( batch) ;
@@ -1002,6 +1043,8 @@ impl GroupedHashAggregateStream {
1002
1043
}
1003
1044
1004
1045
/// Updates skip aggregation probe state.
1046
+ ///
1047
+ /// Notice: It should only be called in Partial aggregation
1005
1048
fn update_skip_aggregation_probe ( & mut self , input_rows : usize ) {
1006
1049
if let Some ( probe) = self . skip_aggregation_probe . as_mut ( ) {
1007
1050
// Skip aggregation probe is not supported if stream has any spills,
@@ -1013,6 +1056,8 @@ impl GroupedHashAggregateStream {
1013
1056
1014
1057
/// In case the probe indicates that aggregation may be
1015
1058
/// skipped, forces stream to produce currently accumulated output.
1059
+ ///
1060
+ /// Notice: It should only be called in Partial aggregation
1016
1061
fn switch_to_skip_aggregation ( & mut self ) -> Result < ( ) > {
1017
1062
if let Some ( probe) = self . skip_aggregation_probe . as_mut ( ) {
1018
1063
if probe. should_skip ( ) {
@@ -1026,6 +1071,8 @@ impl GroupedHashAggregateStream {
1026
1071
1027
1072
/// Returns true if the aggregation probe indicates that aggregation
1028
1073
/// should be skipped.
1074
+ ///
1075
+ /// Notice: It should only be called in Partial aggregation
1029
1076
fn should_skip_aggregation ( & self ) -> bool {
1030
1077
self . skip_aggregation_probe
1031
1078
. as_ref ( )
0 commit comments