Skip to content

Commit 986ba1a

Browse files
committed
use PartitionedOutput.
1 parent 4108065 commit 986ba1a

File tree

1 file changed

+48
-11
lines changed

1 file changed

+48
-11
lines changed

datafusion/physical-plan/src/aggregates/row_hash.rs

Lines changed: 48 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ pub(crate) enum ExecutionState {
6565
/// here and are sliced off as needed in batch_size chunks
6666
ProducingOutput(RecordBatch),
6767

68-
ProducingPartitionedOutput(Vec<Option<RecordBatch>>),
68+
ProducingPartitionedOutput(PartitionedOutput),
6969
/// Produce intermediate aggregate state for each input row without
7070
/// aggregation.
7171
///
@@ -78,6 +78,7 @@ pub(crate) enum ExecutionState {
7878
use super::order::GroupOrdering;
7979
use super::AggregateExec;
8080

81+
#[derive(Debug, Clone, Default)]
8182
struct PartitionedOutput {
8283
partitions: Vec<Option<RecordBatch>>,
8384
start_idx: usize,
@@ -133,7 +134,7 @@ impl PartitionedOutput {
133134
// cut off `batch_size` rows as `output``,
134135
// and set back `remaining`.
135136
let size = self.batch_size;
136-
let num_remaining = batch.num_rows() - size;
137+
let num_remaining = partition_batch.num_rows() - size;
137138
let remaining = partition_batch.slice(size, num_remaining);
138139
let output = partition_batch.slice(0, size);
139140
self.partitions[part_idx] = Some(remaining);
@@ -718,7 +719,7 @@ impl Stream for GroupedHashAggregateStream {
718719
let elapsed_compute = self.baseline_metrics.elapsed_compute().clone();
719720

720721
loop {
721-
match &self.exec_state {
722+
match &mut self.exec_state {
722723
ExecutionState::ReadingInput => 'reading_input: {
723724
match ready!(self.input.poll_next_unpin(cx)) {
724725
// new batch to aggregate
@@ -800,6 +801,7 @@ impl Stream for GroupedHashAggregateStream {
800801
ExecutionState::ProducingOutput(batch) => {
801802
// slice off a part of the batch, if needed
802803
let output_batch;
804+
let batch = batch.clone();
803805
let size = self.batch_size;
804806
(self.exec_state, output_batch) = if batch.num_rows() <= size {
805807
(
@@ -810,7 +812,7 @@ impl Stream for GroupedHashAggregateStream {
810812
} else {
811813
ExecutionState::ReadingInput
812814
},
813-
batch.clone(),
815+
batch,
814816
)
815817
} else {
816818
// output first batch_size rows
@@ -833,7 +835,30 @@ impl Stream for GroupedHashAggregateStream {
833835
return Poll::Ready(None);
834836
}
835837

836-
ExecutionState::ProducingPartitionedOutput(_) => todo!(),
838+
ExecutionState::ProducingPartitionedOutput(parts) => {
839+
// slice off a part of the batch, if needed
840+
let batch_opt = parts.next_batch();
841+
if let Some(batch) = batch_opt {
842+
// output first batch_size rows
843+
let size = self.batch_size;
844+
let num_remaining = batch.num_rows() - size;
845+
let remaining = batch.slice(size, num_remaining);
846+
let output = batch.slice(0, size);
847+
self.exec_state = ExecutionState::ProducingOutput(remaining);
848+
849+
return Poll::Ready(Some(Ok(
850+
output.record_output(&self.baseline_metrics)
851+
)));
852+
} else {
853+
self.exec_state = if self.input_done {
854+
ExecutionState::Done
855+
} else if self.should_skip_aggregation() {
856+
ExecutionState::SkippingAggregation
857+
} else {
858+
ExecutionState::ReadingInput
859+
};
860+
}
861+
}
837862
}
838863
}
839864
}
@@ -1222,8 +1247,12 @@ impl GroupedHashAggregateStream {
12221247
self.exec_state = ExecutionState::ProducingOutput(batch);
12231248
} else {
12241249
let batches = self.emit(EmitTo::All, false)?;
1225-
let batches = batches.into_iter().map(|batch| Some(batch)).collect();
1226-
self.exec_state = ExecutionState::ProducingPartitionedOutput(batches);
1250+
self.exec_state =
1251+
ExecutionState::ProducingPartitionedOutput(PartitionedOutput::new(
1252+
batches,
1253+
self.batch_size,
1254+
self.group_values.num_partitions(),
1255+
));
12271256
}
12281257
}
12291258
Ok(())
@@ -1290,8 +1319,11 @@ impl GroupedHashAggregateStream {
12901319
ExecutionState::ProducingOutput(batch)
12911320
} else {
12921321
let batches = self.emit(EmitTo::All, false)?;
1293-
let batches = batches.into_iter().map(|batch| Some(batch)).collect();
1294-
ExecutionState::ProducingPartitionedOutput(batches)
1322+
ExecutionState::ProducingPartitionedOutput(PartitionedOutput::new(
1323+
batches,
1324+
self.batch_size,
1325+
self.group_values.num_partitions(),
1326+
))
12951327
}
12961328
} else {
12971329
// If spill files exist, stream-merge them.
@@ -1330,8 +1362,13 @@ impl GroupedHashAggregateStream {
13301362
self.exec_state = ExecutionState::ProducingOutput(batch);
13311363
} else {
13321364
let batches = self.emit(EmitTo::All, false)?;
1333-
let batches = batches.into_iter().map(|batch| Some(batch)).collect();
1334-
self.exec_state = ExecutionState::ProducingPartitionedOutput(batches);
1365+
self.exec_state = ExecutionState::ProducingPartitionedOutput(
1366+
PartitionedOutput::new(
1367+
batches,
1368+
self.batch_size,
1369+
self.group_values.num_partitions(),
1370+
),
1371+
);
13351372
}
13361373
}
13371374
}

0 commit comments

Comments
 (0)