Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor: add partial assertion for skip aggregation probe #12640

Merged
merged 6 commits into from
Sep 30, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 55 additions & 8 deletions datafusion/physical-plan/src/aggregates/row_hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -609,14 +609,11 @@ impl Stream for GroupedHashAggregateStream {
match &self.exec_state {
ExecutionState::ReadingInput => 'reading_input: {
match ready!(self.input.poll_next_unpin(cx)) {
// new batch to aggregate
Some(Ok(batch)) => {
// New batch to aggregate in partial aggregation operator
Some(Ok(batch)) if self.mode == AggregateMode::Partial => {
let timer = elapsed_compute.timer();
let input_rows = batch.num_rows();

// Make sure we have enough capacity for `batch`, otherwise spill
extract_ok!(self.spill_previous_if_necessary(&batch));

// Do the grouping
extract_ok!(self.group_aggregate_batch(batch));

Expand Down Expand Up @@ -649,10 +646,49 @@ impl Stream for GroupedHashAggregateStream {

timer.done();
}

// New batch to aggregate in terminal aggregation operator
// (Final/FinalPartitioned/Single/SinglePartitioned)
Some(Ok(batch)) => {
let timer = elapsed_compute.timer();

// Make sure we have enough capacity for `batch`, otherwise spill
extract_ok!(self.spill_previous_if_necessary(&batch));

// Do the grouping
extract_ok!(self.group_aggregate_batch(batch));

// If we can begin emitting rows, do so,
// otherwise keep consuming input
assert!(!self.input_done);

// If the number of group values equals or exceeds the soft limit,
// emit all groups and switch to producing output
if self.hit_soft_group_limit() {
timer.done();
extract_ok!(self.set_input_done_and_produce_output());
// make sure the exec_state just set is not overwritten below
break 'reading_input;
}

if let Some(to_emit) = self.group_ordering.emit_to() {
let batch = extract_ok!(self.emit(to_emit, false));
self.exec_state = ExecutionState::ProducingOutput(batch);
timer.done();
// make sure the exec_state just set is not overwritten below
break 'reading_input;
}

timer.done();
}

// Found error from input stream
Some(Err(e)) => {
// inner had error, return to caller
return Poll::Ready(Some(Err(e)));
}

// Found end from input stream
None => {
// inner is done, emit all rows and switch to producing output
extract_ok!(self.set_input_done_and_produce_output());
Expand Down Expand Up @@ -691,7 +727,12 @@ impl Stream for GroupedHashAggregateStream {
(
if self.input_done {
ExecutionState::Done
} else if self.should_skip_aggregation() {
}
// In Partial aggregation, we also need to check
// if we should trigger partial skipping
else if self.mode == AggregateMode::Partial
&& self.should_skip_aggregation()
{
ExecutionState::SkippingAggregation
} else {
ExecutionState::ReadingInput
Expand Down Expand Up @@ -879,10 +920,10 @@ impl GroupedHashAggregateStream {
if self.group_values.len() > 0
&& batch.num_rows() > 0
&& matches!(self.group_ordering, GroupOrdering::None)
&& !matches!(self.mode, AggregateMode::Partial)
&& !self.spill_state.is_stream_merging
&& self.update_memory_reservation().is_err()
{
assert_ne!(self.mode, AggregateMode::Partial);
// Use input batch (Partial mode) schema for spilling because
// the spilled data will be merged and re-evaluated later.
self.spill_state.spill_schema = batch.schema();
Expand Down Expand Up @@ -927,9 +968,9 @@ impl GroupedHashAggregateStream {
fn emit_early_if_necessary(&mut self) -> Result<()> {
if self.group_values.len() >= self.batch_size
&& matches!(self.group_ordering, GroupOrdering::None)
&& matches!(self.mode, AggregateMode::Partial)
&& self.update_memory_reservation().is_err()
{
assert_eq!(self.mode, AggregateMode::Partial);
let n = self.group_values.len() / self.batch_size * self.batch_size;
let batch = self.emit(EmitTo::First(n), false)?;
self.exec_state = ExecutionState::ProducingOutput(batch);
Expand Down Expand Up @@ -1002,6 +1043,8 @@ impl GroupedHashAggregateStream {
}

/// Updates skip aggregation probe state.
///
/// Notice: It should only be called in Partial aggregation
fn update_skip_aggregation_probe(&mut self, input_rows: usize) {
if let Some(probe) = self.skip_aggregation_probe.as_mut() {
// Skip aggregation probe is not supported if stream has any spills,
Expand All @@ -1013,6 +1056,8 @@ impl GroupedHashAggregateStream {

/// In case the probe indicates that aggregation may be
/// skipped, forces stream to produce currently accumulated output.
///
/// Notice: It should only be called in Partial aggregation
fn switch_to_skip_aggregation(&mut self) -> Result<()> {
if let Some(probe) = self.skip_aggregation_probe.as_mut() {
if probe.should_skip() {
Expand All @@ -1026,6 +1071,8 @@ impl GroupedHashAggregateStream {

/// Returns true if the aggregation probe indicates that aggregation
/// should be skipped.
///
/// Notice: It should only be called in Partial aggregation
fn should_skip_aggregation(&self) -> bool {
self.skip_aggregation_probe
.as_ref()
Expand Down