Skip to content

Commit 7214b3e

Browse files
committed
introduce ExecutionState::ProducingPartitionedOutput to process the partitioned outputs.
1 parent bec7a3a commit 7214b3e

File tree

2 files changed

+46
-12
lines changed

2 files changed

+46
-12
lines changed

datafusion/physical-plan/src/aggregates/group_values/mod.rs

+9-2
Original file line numberDiff line numberDiff line change
@@ -162,9 +162,16 @@ pub trait GroupValues: Send {
162162
fn clear_shrink(&mut self, batch: &RecordBatch);
163163
}
164164

165-
pub fn new_group_values(schema: SchemaRef, partitioning_group_values: bool, num_partitions: usize) -> Result<GroupValuesLike> {
165+
pub fn new_group_values(
166+
schema: SchemaRef,
167+
partitioning_group_values: bool,
168+
num_partitions: usize,
169+
) -> Result<GroupValuesLike> {
166170
let group_values = if partitioning_group_values && schema.fields.len() > 1 {
167-
GroupValuesLike::Partitioned(Box::new(PartitionedGroupValuesRows::try_new(schema, num_partitions)?))
171+
GroupValuesLike::Partitioned(Box::new(PartitionedGroupValuesRows::try_new(
172+
schema,
173+
num_partitions,
174+
)?))
168175
} else {
169176
GroupValuesLike::Single(new_single_group_values(schema)?)
170177
};

datafusion/physical-plan/src/aggregates/row_hash.rs

+37-10
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ pub(crate) enum ExecutionState {
6464
/// When producing output, the remaining rows to output are stored
6565
/// here and are sliced off as needed in batch_size chunks
6666
ProducingOutput(RecordBatch),
67+
68+
ProducingPartitionedOutput(Vec<RecordBatch>),
6769
/// Produce intermediate aggregate state for each input row without
6870
/// aggregation.
6971
///
@@ -677,7 +679,9 @@ impl Stream for GroupedHashAggregateStream {
677679
}
678680

679681
if let Some(to_emit) = self.group_ordering.emit_to() {
680-
let batch = extract_ok!(self.emit(to_emit, false));
682+
let mut batch = extract_ok!(self.emit(to_emit, false));
683+
assert_eq!(batch.len(), 1);
684+
let batch = batch.pop().unwrap();
681685
self.exec_state = ExecutionState::ProducingOutput(batch);
682686
timer.done();
683687
// make sure the exec_state just set is not overwritten below
@@ -759,6 +763,7 @@ impl Stream for GroupedHashAggregateStream {
759763
let _ = self.update_memory_reservation();
760764
return Poll::Ready(None);
761765
}
766+
ExecutionState::ProducingPartitionedOutput(_) => todo!(),
762767
}
763768
}
764769
}
@@ -1101,7 +1106,9 @@ impl GroupedHashAggregateStream {
11011106

11021107
/// Emit all rows, sort them, and store them on disk.
11031108
fn spill(&mut self) -> Result<()> {
1104-
let emit = self.emit(EmitTo::All, true)?;
1109+
let mut emit = self.emit(EmitTo::All, true)?;
1110+
assert_eq!(emit.len(), 1);
1111+
let emit = emit.pop().unwrap();
11051112
let sorted = sort_batch(&emit, &self.spill_state.spill_expr, None)?;
11061113
let spillfile = self.runtime.disk_manager.create_tmp_file("HashAggSpill")?;
11071114
// TODO: slice large `sorted` and write to multiple files in parallel
@@ -1138,9 +1145,15 @@ impl GroupedHashAggregateStream {
11381145
&& matches!(self.mode, AggregateMode::Partial)
11391146
&& self.update_memory_reservation().is_err()
11401147
{
1141-
let n = self.group_values.len() / self.batch_size * self.batch_size;
1142-
let batch = self.emit(EmitTo::First(n), false)?;
1143-
self.exec_state = ExecutionState::ProducingOutput(batch);
1148+
if !self.group_values.is_partitioned() {
1149+
let n = self.group_values.len() / self.batch_size * self.batch_size;
1150+
let mut batch = self.emit(EmitTo::First(n), false)?;
1151+
let batch = batch.pop().unwrap();
1152+
self.exec_state = ExecutionState::ProducingOutput(batch);
1153+
} else {
1154+
let batches = self.emit(EmitTo::All, false)?;
1155+
self.exec_state = ExecutionState::ProducingPartitionedOutput(batches);
1156+
}
11441157
}
11451158
Ok(())
11461159
}
@@ -1150,7 +1163,9 @@ impl GroupedHashAggregateStream {
11501163
/// Conduct a streaming merge sort between the batch and spilled data. Since the stream is fully
11511164
/// sorted, set `self.group_ordering` to Full, then later we can read with [`EmitTo::First`].
11521165
fn update_merged_stream(&mut self) -> Result<()> {
1153-
let batch = self.emit(EmitTo::All, true)?;
1166+
let mut batch = self.emit(EmitTo::All, true)?;
1167+
assert_eq!(batch.len(), 1);
1168+
let batch = batch.pop().unwrap();
11541169
// clear up memory for streaming_merge
11551170
self.clear_all();
11561171
self.update_memory_reservation()?;
@@ -1198,8 +1213,14 @@ impl GroupedHashAggregateStream {
11981213
let elapsed_compute = self.baseline_metrics.elapsed_compute().clone();
11991214
let timer = elapsed_compute.timer();
12001215
self.exec_state = if self.spill_state.spills.is_empty() {
1201-
let batch = self.emit(EmitTo::All, false)?;
1202-
ExecutionState::ProducingOutput(batch)
1216+
if !self.group_values.is_partitioned() {
1217+
let mut batch = self.emit(EmitTo::All, false)?;
1218+
let batch = batch.pop().unwrap();
1219+
ExecutionState::ProducingOutput(batch)
1220+
} else {
1221+
let batches = self.emit(EmitTo::All, false)?;
1222+
ExecutionState::ProducingPartitionedOutput(batches)
1223+
}
12031224
} else {
12041225
// If spill files exist, stream-merge them.
12051226
self.update_merged_stream()?;
@@ -1231,8 +1252,14 @@ impl GroupedHashAggregateStream {
12311252
fn switch_to_skip_aggregation(&mut self) -> Result<()> {
12321253
if let Some(probe) = self.skip_aggregation_probe.as_mut() {
12331254
if probe.should_skip() {
1234-
let batch = self.emit(EmitTo::All, false)?;
1235-
self.exec_state = ExecutionState::ProducingOutput(batch);
1255+
if !self.group_values.is_partitioned() {
1256+
let mut batch = self.emit(EmitTo::All, false)?;
1257+
let batch = batch.pop().unwrap();
1258+
self.exec_state = ExecutionState::ProducingOutput(batch);
1259+
} else {
1260+
let batches = self.emit(EmitTo::All, false)?;
1261+
self.exec_state = ExecutionState::ProducingPartitionedOutput(batches);
1262+
}
12361263
}
12371264
}
12381265

0 commit comments

Comments
 (0)