Skip to content

Commit 553c6a3

Browse files
committed
introduce ExecutionState::ProducingPartitionedOutput to process the partitioned outputs.
1 parent bec7a3a commit 553c6a3

File tree

2 files changed

+73
-12
lines changed

2 files changed

+73
-12
lines changed

datafusion/physical-plan/src/aggregates/group_values/mod.rs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,9 +162,16 @@ pub trait GroupValues: Send {
162162
fn clear_shrink(&mut self, batch: &RecordBatch);
163163
}
164164

165-
pub fn new_group_values(schema: SchemaRef, partitioning_group_values: bool, num_partitions: usize) -> Result<GroupValuesLike> {
165+
pub fn new_group_values(
166+
schema: SchemaRef,
167+
partitioning_group_values: bool,
168+
num_partitions: usize,
169+
) -> Result<GroupValuesLike> {
166170
let group_values = if partitioning_group_values && schema.fields.len() > 1 {
167-
GroupValuesLike::Partitioned(Box::new(PartitionedGroupValuesRows::try_new(schema, num_partitions)?))
171+
GroupValuesLike::Partitioned(Box::new(PartitionedGroupValuesRows::try_new(
172+
schema,
173+
num_partitions,
174+
)?))
168175
} else {
169176
GroupValuesLike::Single(new_single_group_values(schema)?)
170177
};

datafusion/physical-plan/src/aggregates/row_hash.rs

Lines changed: 64 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ pub(crate) enum ExecutionState {
6464
/// When producing output, the remaining rows to output are stored
6565
/// here and are sliced off as needed in batch_size chunks
6666
ProducingOutput(RecordBatch),
67+
68+
ProducingPartitionedOutput(Vec<Option<RecordBatch>>),
6769
/// Produce intermediate aggregate state for each input row without
6870
/// aggregation.
6971
///
@@ -76,6 +78,29 @@ pub(crate) enum ExecutionState {
7678
use super::order::GroupOrdering;
7779
use super::AggregateExec;
7880

81+
struct PartitionedOutput {
82+
batches: Vec<Option<RecordBatch>>,
83+
current_idx: usize,
84+
exhausted: bool
85+
}
86+
87+
impl PartitionedOutput {
88+
pub fn new(batches: Vec<RecordBatch>) -> Self {
89+
let batches = batches.into_iter().map(|batch| Some(batch)).collect();
90+
91+
Self {
92+
batches,
93+
current_idx: 0,
94+
exhausted: false,
95+
}
96+
}
97+
98+
pub fn next_batch(&mut self) -> Option<RecordBatch> {
99+
100+
}
101+
}
102+
103+
79104
/// This encapsulates the spilling state
80105
struct SpillState {
81106
// ========================================================================
@@ -677,7 +702,9 @@ impl Stream for GroupedHashAggregateStream {
677702
}
678703

679704
if let Some(to_emit) = self.group_ordering.emit_to() {
680-
let batch = extract_ok!(self.emit(to_emit, false));
705+
let mut batch = extract_ok!(self.emit(to_emit, false));
706+
assert_eq!(batch.len(), 1);
707+
let batch = batch.pop().unwrap();
681708
self.exec_state = ExecutionState::ProducingOutput(batch);
682709
timer.done();
683710
// make sure the exec_state just set is not overwritten below
@@ -759,6 +786,8 @@ impl Stream for GroupedHashAggregateStream {
759786
let _ = self.update_memory_reservation();
760787
return Poll::Ready(None);
761788
}
789+
790+
ExecutionState::ProducingPartitionedOutput(_) => todo!(),
762791
}
763792
}
764793
}
@@ -1101,7 +1130,9 @@ impl GroupedHashAggregateStream {
11011130

11021131
/// Emit all rows, sort them, and store them on disk.
11031132
fn spill(&mut self) -> Result<()> {
1104-
let emit = self.emit(EmitTo::All, true)?;
1133+
let mut emit = self.emit(EmitTo::All, true)?;
1134+
assert_eq!(emit.len(), 1);
1135+
let emit = emit.pop().unwrap();
11051136
let sorted = sort_batch(&emit, &self.spill_state.spill_expr, None)?;
11061137
let spillfile = self.runtime.disk_manager.create_tmp_file("HashAggSpill")?;
11071138
// TODO: slice large `sorted` and write to multiple files in parallel
@@ -1138,9 +1169,16 @@ impl GroupedHashAggregateStream {
11381169
&& matches!(self.mode, AggregateMode::Partial)
11391170
&& self.update_memory_reservation().is_err()
11401171
{
1141-
let n = self.group_values.len() / self.batch_size * self.batch_size;
1142-
let batch = self.emit(EmitTo::First(n), false)?;
1143-
self.exec_state = ExecutionState::ProducingOutput(batch);
1172+
if !self.group_values.is_partitioned() {
1173+
let n = self.group_values.len() / self.batch_size * self.batch_size;
1174+
let mut batch = self.emit(EmitTo::First(n), false)?;
1175+
let batch = batch.pop().unwrap();
1176+
self.exec_state = ExecutionState::ProducingOutput(batch);
1177+
} else {
1178+
let batches = self.emit(EmitTo::All, false)?;
1179+
let batches = batches.into_iter().map(|batch| Some(batch)).collect();
1180+
self.exec_state = ExecutionState::ProducingPartitionedOutput(batches);
1181+
}
11441182
}
11451183
Ok(())
11461184
}
@@ -1150,7 +1188,9 @@ impl GroupedHashAggregateStream {
11501188
/// Conduct a streaming merge sort between the batch and spilled data. Since the stream is fully
11511189
/// sorted, set `self.group_ordering` to Full, then later we can read with [`EmitTo::First`].
11521190
fn update_merged_stream(&mut self) -> Result<()> {
1153-
let batch = self.emit(EmitTo::All, true)?;
1191+
let mut batch = self.emit(EmitTo::All, true)?;
1192+
assert_eq!(batch.len(), 1);
1193+
let batch = batch.pop().unwrap();
11541194
// clear up memory for streaming_merge
11551195
self.clear_all();
11561196
self.update_memory_reservation()?;
@@ -1198,8 +1238,15 @@ impl GroupedHashAggregateStream {
11981238
let elapsed_compute = self.baseline_metrics.elapsed_compute().clone();
11991239
let timer = elapsed_compute.timer();
12001240
self.exec_state = if self.spill_state.spills.is_empty() {
1201-
let batch = self.emit(EmitTo::All, false)?;
1202-
ExecutionState::ProducingOutput(batch)
1241+
if !self.group_values.is_partitioned() {
1242+
let mut batch = self.emit(EmitTo::All, false)?;
1243+
let batch = batch.pop().unwrap();
1244+
ExecutionState::ProducingOutput(batch)
1245+
} else {
1246+
let batches = self.emit(EmitTo::All, false)?;
1247+
let batches = batches.into_iter().map(|batch| Some(batch)).collect();
1248+
ExecutionState::ProducingPartitionedOutput(batches)
1249+
}
12031250
} else {
12041251
// If spill files exist, stream-merge them.
12051252
self.update_merged_stream()?;
@@ -1231,8 +1278,15 @@ impl GroupedHashAggregateStream {
12311278
fn switch_to_skip_aggregation(&mut self) -> Result<()> {
12321279
if let Some(probe) = self.skip_aggregation_probe.as_mut() {
12331280
if probe.should_skip() {
1234-
let batch = self.emit(EmitTo::All, false)?;
1235-
self.exec_state = ExecutionState::ProducingOutput(batch);
1281+
if !self.group_values.is_partitioned() {
1282+
let mut batch = self.emit(EmitTo::All, false)?;
1283+
let batch = batch.pop().unwrap();
1284+
self.exec_state = ExecutionState::ProducingOutput(batch);
1285+
} else {
1286+
let batches = self.emit(EmitTo::All, false)?;
1287+
let batches = batches.into_iter().map(|batch| Some(batch)).collect();
1288+
self.exec_state = ExecutionState::ProducingPartitionedOutput(batches);
1289+
}
12361290
}
12371291
}
12381292

0 commit comments

Comments
 (0)