Skip to content

Commit 0b90a8a

Browse files
authored
Generate hash aggregation output in smaller record batches (#3461)
* change how final aggregation row group is created ... this change would prevent of cloning of whole state, doubling memory needed for aggregation. this PR relates to #1570 * Fix clippy issues * read batch size from `session_config`
1 parent 011bcf4 commit 0b90a8a

File tree

2 files changed

+74
-49
lines changed

2 files changed

+74
-49
lines changed

datafusion/core/src/physical_plan/aggregates/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,7 @@ impl ExecutionPlan for AggregateExec {
298298
partition: usize,
299299
context: Arc<TaskContext>,
300300
) -> Result<SendableRecordBatchStream> {
301+
let batch_size = context.session_config().batch_size();
301302
let input = self.input.execute(partition, context)?;
302303

303304
let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
@@ -318,6 +319,7 @@ impl ExecutionPlan for AggregateExec {
318319
self.aggr_expr.clone(),
319320
input,
320321
baseline_metrics,
322+
batch_size,
321323
)?))
322324
} else {
323325
Ok(Box::pin(GroupedHashAggregateStream::new(

datafusion/core/src/physical_plan/aggregates/row_hash.rs

Lines changed: 72 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,11 @@ pub(crate) struct GroupedHashAggregateStreamV2 {
8585

8686
baseline_metrics: BaselineMetrics,
8787
random_state: RandomState,
88-
finished: bool,
88+
/// size to be used for resulting RecordBatches
89+
batch_size: usize,
90+
/// if the result is chunked into batches,
91+
/// last offset is preserved for continuation.
92+
row_group_skip_position: usize,
8993
}
9094

9195
fn aggr_state_schema(aggr_expr: &[Arc<dyn AggregateExpr>]) -> Result<SchemaRef> {
@@ -105,6 +109,7 @@ impl GroupedHashAggregateStreamV2 {
105109
aggr_expr: Vec<Arc<dyn AggregateExpr>>,
106110
input: SendableRecordBatchStream,
107111
baseline_metrics: BaselineMetrics,
112+
batch_size: usize,
108113
) -> Result<Self> {
109114
let timer = baseline_metrics.elapsed_compute().timer();
110115

@@ -135,7 +140,8 @@ impl GroupedHashAggregateStreamV2 {
135140
aggregate_expressions,
136141
aggr_state: Default::default(),
137142
random_state: Default::default(),
138-
finished: false,
143+
batch_size,
144+
row_group_skip_position: 0,
139145
})
140146
}
141147
}
@@ -148,56 +154,62 @@ impl Stream for GroupedHashAggregateStreamV2 {
148154
cx: &mut Context<'_>,
149155
) -> Poll<Option<Self::Item>> {
150156
let this = &mut *self;
151-
if this.finished {
152-
return Poll::Ready(None);
153-
}
154157

155158
let elapsed_compute = this.baseline_metrics.elapsed_compute();
156159

157160
loop {
158-
let result = match ready!(this.input.poll_next_unpin(cx)) {
159-
Some(Ok(batch)) => {
160-
let timer = elapsed_compute.timer();
161-
let result = group_aggregate_batch(
162-
&this.mode,
163-
&this.random_state,
164-
&this.group_by,
165-
&mut this.accumulators,
166-
&this.group_schema,
167-
this.aggr_layout.clone(),
168-
batch,
169-
&mut this.aggr_state,
170-
&this.aggregate_expressions,
171-
);
172-
173-
timer.done();
174-
175-
match result {
176-
Ok(_) => continue,
177-
Err(e) => Err(ArrowError::ExternalError(Box::new(e))),
161+
let result: ArrowResult<Option<RecordBatch>> =
162+
match ready!(this.input.poll_next_unpin(cx)) {
163+
Some(Ok(batch)) => {
164+
let timer = elapsed_compute.timer();
165+
let result = group_aggregate_batch(
166+
&this.mode,
167+
&this.random_state,
168+
&this.group_by,
169+
&mut this.accumulators,
170+
&this.group_schema,
171+
this.aggr_layout.clone(),
172+
batch,
173+
&mut this.aggr_state,
174+
&this.aggregate_expressions,
175+
);
176+
177+
timer.done();
178+
179+
match result {
180+
Ok(_) => continue,
181+
Err(e) => Err(ArrowError::ExternalError(Box::new(e))),
182+
}
178183
}
184+
Some(Err(e)) => Err(e),
185+
None => {
186+
let timer = this.baseline_metrics.elapsed_compute().timer();
187+
let result = create_batch_from_map(
188+
&this.mode,
189+
&this.group_schema,
190+
&this.aggr_schema,
191+
this.batch_size,
192+
this.row_group_skip_position,
193+
&mut this.aggr_state,
194+
&mut this.accumulators,
195+
&this.schema,
196+
);
197+
198+
timer.done();
199+
result
200+
}
201+
};
202+
203+
this.row_group_skip_position += this.batch_size;
204+
match result {
205+
Ok(Some(result)) => {
206+
return Poll::Ready(Some(Ok(
207+
result.record_output(&this.baseline_metrics)
208+
)))
179209
}
180-
Some(Err(e)) => Err(e),
181-
None => {
182-
this.finished = true;
183-
let timer = this.baseline_metrics.elapsed_compute().timer();
184-
let result = create_batch_from_map(
185-
&this.mode,
186-
&this.group_schema,
187-
&this.aggr_schema,
188-
&mut this.aggr_state,
189-
&mut this.accumulators,
190-
&this.schema,
191-
)
192-
.record_output(&this.baseline_metrics);
193-
194-
timer.done();
195-
result
196-
}
197-
};
198-
199-
this.finished = true;
200-
return Poll::Ready(Some(result));
210+
Ok(None) => return Poll::Ready(None),
211+
Err(error) => return Poll::Ready(Some(Err(error))),
212+
}
201213
}
202214
}
203215
}
@@ -419,23 +431,34 @@ fn create_group_rows(arrays: Vec<ArrayRef>, schema: &Schema) -> Vec<Vec<u8>> {
419431
}
420432

421433
/// Create a RecordBatch with all group keys and accumulator' states or values.
434+
#[allow(clippy::too_many_arguments)]
422435
fn create_batch_from_map(
423436
mode: &AggregateMode,
424437
group_schema: &Schema,
425438
aggr_schema: &Schema,
439+
batch_size: usize,
440+
skip_items: usize,
426441
aggr_state: &mut AggregationState,
427442
accumulators: &mut [AccumulatorItemV2],
428443
output_schema: &Schema,
429-
) -> ArrowResult<RecordBatch> {
444+
) -> ArrowResult<Option<RecordBatch>> {
445+
if skip_items > aggr_state.group_states.len() {
446+
return Ok(None);
447+
}
448+
430449
if aggr_state.group_states.is_empty() {
431-
return Ok(RecordBatch::new_empty(Arc::new(output_schema.to_owned())));
450+
return Ok(Some(RecordBatch::new_empty(Arc::new(
451+
output_schema.to_owned(),
452+
))));
432453
}
433454

434455
let mut state_accessor = RowAccessor::new(aggr_schema, RowType::WordAligned);
435456

436457
let (group_buffers, mut state_buffers): (Vec<_>, Vec<_>) = aggr_state
437458
.group_states
438459
.iter()
460+
.skip(skip_items)
461+
.take(batch_size)
439462
.map(|gs| (gs.group_by_values.clone(), gs.aggregation_buffer.clone()))
440463
.unzip();
441464

@@ -471,7 +494,7 @@ fn create_batch_from_map(
471494
.map(|(col, desired_field)| cast(col, desired_field.data_type()))
472495
.collect::<ArrowResult<Vec<_>>>()?;
473496

474-
RecordBatch::try_new(Arc::new(output_schema.to_owned()), columns)
497+
RecordBatch::try_new(Arc::new(output_schema.to_owned()), columns).map(Some)
475498
}
476499

477500
fn read_as_batch(rows: &[Vec<u8>], schema: &Schema, row_type: RowType) -> Vec<ArrayRef> {

0 commit comments

Comments
 (0)