Skip to content

Commit f1aa27f

Browse files
authored
Minor: add partial assertion for skip aggregation probe (#12640)
* add partial assertion for skip aggr probe and improve comments. * fix fmt. * use pattern match for aggr mode to improve readability. * only check `should_skip_aggregation` in partial aggr. * make some condition check assert check. * clearer way to distinguish partial and terminals branches.
1 parent da70fab commit f1aa27f

File tree

1 file changed

+55
-8
lines changed

1 file changed

+55
-8
lines changed

datafusion/physical-plan/src/aggregates/row_hash.rs

+55-8
Original file line numberDiff line numberDiff line change
@@ -609,14 +609,11 @@ impl Stream for GroupedHashAggregateStream {
609609
match &self.exec_state {
610610
ExecutionState::ReadingInput => 'reading_input: {
611611
match ready!(self.input.poll_next_unpin(cx)) {
612-
// new batch to aggregate
613-
Some(Ok(batch)) => {
612+
// New batch to aggregate in partial aggregation operator
613+
Some(Ok(batch)) if self.mode == AggregateMode::Partial => {
614614
let timer = elapsed_compute.timer();
615615
let input_rows = batch.num_rows();
616616

617-
// Make sure we have enough capacity for `batch`, otherwise spill
618-
extract_ok!(self.spill_previous_if_necessary(&batch));
619-
620617
// Do the grouping
621618
extract_ok!(self.group_aggregate_batch(batch));
622619

@@ -649,10 +646,49 @@ impl Stream for GroupedHashAggregateStream {
649646

650647
timer.done();
651648
}
649+
650+
// New batch to aggregate in terminal aggregation operator
651+
// (Final/FinalPartitioned/Single/SinglePartitioned)
652+
Some(Ok(batch)) => {
653+
let timer = elapsed_compute.timer();
654+
655+
// Make sure we have enough capacity for `batch`, otherwise spill
656+
extract_ok!(self.spill_previous_if_necessary(&batch));
657+
658+
// Do the grouping
659+
extract_ok!(self.group_aggregate_batch(batch));
660+
661+
// If we can begin emitting rows, do so,
662+
// otherwise keep consuming input
663+
assert!(!self.input_done);
664+
665+
// If the number of group values equals or exceeds the soft limit,
666+
// emit all groups and switch to producing output
667+
if self.hit_soft_group_limit() {
668+
timer.done();
669+
extract_ok!(self.set_input_done_and_produce_output());
670+
// make sure the exec_state just set is not overwritten below
671+
break 'reading_input;
672+
}
673+
674+
if let Some(to_emit) = self.group_ordering.emit_to() {
675+
let batch = extract_ok!(self.emit(to_emit, false));
676+
self.exec_state = ExecutionState::ProducingOutput(batch);
677+
timer.done();
678+
// make sure the exec_state just set is not overwritten below
679+
break 'reading_input;
680+
}
681+
682+
timer.done();
683+
}
684+
685+
// Found error from input stream
652686
Some(Err(e)) => {
653687
// inner had error, return to caller
654688
return Poll::Ready(Some(Err(e)));
655689
}
690+
691+
// Found end from input stream
656692
None => {
657693
// inner is done, emit all rows and switch to producing output
658694
extract_ok!(self.set_input_done_and_produce_output());
@@ -691,7 +727,12 @@ impl Stream for GroupedHashAggregateStream {
691727
(
692728
if self.input_done {
693729
ExecutionState::Done
694-
} else if self.should_skip_aggregation() {
730+
}
731+
// In Partial aggregation, we also need to check
732+
// if we should trigger partial skipping
733+
else if self.mode == AggregateMode::Partial
734+
&& self.should_skip_aggregation()
735+
{
695736
ExecutionState::SkippingAggregation
696737
} else {
697738
ExecutionState::ReadingInput
@@ -879,10 +920,10 @@ impl GroupedHashAggregateStream {
879920
if self.group_values.len() > 0
880921
&& batch.num_rows() > 0
881922
&& matches!(self.group_ordering, GroupOrdering::None)
882-
&& !matches!(self.mode, AggregateMode::Partial)
883923
&& !self.spill_state.is_stream_merging
884924
&& self.update_memory_reservation().is_err()
885925
{
926+
assert_ne!(self.mode, AggregateMode::Partial);
886927
// Use input batch (Partial mode) schema for spilling because
887928
// the spilled data will be merged and re-evaluated later.
888929
self.spill_state.spill_schema = batch.schema();
@@ -927,9 +968,9 @@ impl GroupedHashAggregateStream {
927968
fn emit_early_if_necessary(&mut self) -> Result<()> {
928969
if self.group_values.len() >= self.batch_size
929970
&& matches!(self.group_ordering, GroupOrdering::None)
930-
&& matches!(self.mode, AggregateMode::Partial)
931971
&& self.update_memory_reservation().is_err()
932972
{
973+
assert_eq!(self.mode, AggregateMode::Partial);
933974
let n = self.group_values.len() / self.batch_size * self.batch_size;
934975
let batch = self.emit(EmitTo::First(n), false)?;
935976
self.exec_state = ExecutionState::ProducingOutput(batch);
@@ -1002,6 +1043,8 @@ impl GroupedHashAggregateStream {
10021043
}
10031044

10041045
/// Updates skip aggregation probe state.
1046+
///
1047+
/// Notice: It should only be called in Partial aggregation
10051048
fn update_skip_aggregation_probe(&mut self, input_rows: usize) {
10061049
if let Some(probe) = self.skip_aggregation_probe.as_mut() {
10071050
// Skip aggregation probe is not supported if stream has any spills,
@@ -1013,6 +1056,8 @@ impl GroupedHashAggregateStream {
10131056

10141057
/// In case the probe indicates that aggregation may be
10151058
/// skipped, forces stream to produce currently accumulated output.
1059+
///
1060+
/// Notice: It should only be called in Partial aggregation
10161061
fn switch_to_skip_aggregation(&mut self) -> Result<()> {
10171062
if let Some(probe) = self.skip_aggregation_probe.as_mut() {
10181063
if probe.should_skip() {
@@ -1026,6 +1071,8 @@ impl GroupedHashAggregateStream {
10261071

10271072
/// Returns true if the aggregation probe indicates that aggregation
10281073
/// should be skipped.
1074+
///
1075+
/// Notice: It should only be called in Partial aggregation
10291076
fn should_skip_aggregation(&self) -> bool {
10301077
self.skip_aggregation_probe
10311078
.as_ref()

0 commit comments

Comments
 (0)