Skip to content

Commit ab92626

Browse files
committed
re-design the sketch.
1 parent fd237f8 commit ab92626

File tree

3 files changed

+85
-10
lines changed

3 files changed

+85
-10
lines changed

datafusion/expr-common/src/groups_accumulator.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ pub enum EmitTo {
3131
/// For example, if `n=10`, group_index `0, 1, ... 9` are emitted
3232
/// and group indexes '`10, 11, 12, ...` become `0, 1, 2, ...`.
3333
First(usize),
34+
/// Emit all groups managed by blocks
35+
CurrentBlock(bool),
3436
}
3537

3638
impl EmitTo {
@@ -52,6 +54,7 @@ impl EmitTo {
5254
std::mem::swap(v, &mut t);
5355
t
5456
}
57+
EmitTo::CurrentBlock(_) => unimplemented!(),
5558
}
5659
}
5760
}
@@ -143,6 +146,12 @@ pub trait GroupsAccumulator: Send {
143146
/// [`Accumulator::state`]: crate::accumulator::Accumulator::state
144147
fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>>;
145148

149+
/// Returns `true` if blocked emission is supported
150+
/// The blocked emission is possible to avoid result splitting in aggregation.
151+
fn supports_blocked_emission(&self) -> bool {
152+
false
153+
}
154+
146155
/// Merges intermediate state (the output from [`Self::state`])
147156
/// into this accumulator's current state.
148157
///

datafusion/physical-plan/src/aggregates/group_values/mod.rs

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ use bytes_view::GroupValuesBytesView;
2222
use datafusion_common::Result;
2323

2424
pub(crate) mod primitive;
25-
use datafusion_expr::EmitTo;
25+
use datafusion_expr::{groups_accumulator::GroupIndicesType, EmitTo};
2626
use primitive::GroupValuesPrimitive;
2727

2828
mod row;
@@ -36,7 +36,12 @@ use datafusion_physical_expr::binary_map::OutputType;
3636
/// An interning store for group keys
3737
pub trait GroupValues: Send {
3838
/// Calculates the `groups` for each input row of `cols`
39-
fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec<usize>) -> Result<()>;
39+
fn intern(
40+
&mut self,
41+
cols: &[ArrayRef],
42+
groups: &mut Vec<u64>,
43+
group_type: GroupIndicesType,
44+
) -> Result<()>;
4045

4146
/// Returns the number of bytes used by this [`GroupValues`]
4247
fn size(&self) -> usize;
@@ -52,6 +57,12 @@ pub trait GroupValues: Send {
5257

5358
/// Clear the contents and shrink the capacity to the size of the batch (free up memory usage)
5459
fn clear_shrink(&mut self, batch: &RecordBatch);
60+
61+
/// Returns `true` if blocked emission is supported
62+
/// The blocked emission is possible to avoid result splitting in aggregation.
63+
fn supports_blocked_emission(&self) -> bool {
64+
false
65+
}
5566
}
5667

5768
pub fn new_group_values(schema: SchemaRef) -> Result<Box<dyn GroupValues>> {

datafusion/physical-plan/src/aggregates/row_hash.rs

Lines changed: 63 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ pub(crate) enum ExecutionState {
6262
/// When producing output, the remaining rows to output are stored
6363
/// here and are sliced off as needed in batch_size chunks
6464
ProducingOutput(RecordBatch),
65+
ProducingBlocks(Option<usize>),
6566
/// Produce intermediate aggregate state for each input row without
6667
/// aggregation.
6768
///
@@ -387,6 +388,10 @@ pub(crate) struct GroupedHashAggregateStream {
387388
/// Optional probe for skipping data aggregation, if supported by
388389
/// current stream.
389390
skip_aggregation_probe: Option<SkipAggregationProbe>,
391+
392+
enable_blocked_group_states: bool,
393+
394+
block_size: usize,
390395
}
391396

392397
impl GroupedHashAggregateStream {
@@ -676,6 +681,43 @@ impl Stream for GroupedHashAggregateStream {
676681
)));
677682
}
678683

684+
ExecutionState::ProducingBlocks(blocks) => {
685+
if let Some(blk) = blocks {
686+
if blk > 0 {
687+
self.exec_state = ExecutionState::ProducingBlocks(Some(*blk - 1));
688+
} else {
689+
self.exec_state = if self.input_done {
690+
ExecutionState::Done
691+
} else if self.should_skip_aggregation() {
692+
ExecutionState::SkippingAggregation
693+
} else {
694+
ExecutionState::ReadingInput
695+
};
696+
continue;
697+
}
698+
}
699+
700+
let emit_result = self.emit(EmitTo::CurrentBlock(true), false);
701+
if emit_result.is_err() {
702+
return Poll::Ready(Some(emit_result));
703+
}
704+
705+
let emit_batch = emit_result.unwrap();
706+
if emit_batch.num_rows() == 0 {
707+
self.exec_state = if self.input_done {
708+
ExecutionState::Done
709+
} else if self.should_skip_aggregation() {
710+
ExecutionState::SkippingAggregation
711+
} else {
712+
ExecutionState::ReadingInput
713+
};
714+
}
715+
716+
return Poll::Ready(Some(Ok(
717+
emit_batch.record_output(&self.baseline_metrics)
718+
)));
719+
}
720+
679721
ExecutionState::Done => {
680722
// release the memory reservation since sending back output batch itself needs
681723
// some memory reservation, so make some room for it.
@@ -900,10 +942,15 @@ impl GroupedHashAggregateStream {
900942
&& matches!(self.group_ordering, GroupOrdering::None)
901943
&& matches!(self.mode, AggregateMode::Partial)
902944
&& self.update_memory_reservation().is_err()
903-
{
904-
let n = self.group_values.len() / self.batch_size * self.batch_size;
905-
let batch = self.emit(EmitTo::First(n), false)?;
906-
self.exec_state = ExecutionState::ProducingOutput(batch);
945+
{
946+
if self.enable_blocked_group_states {
947+
let n = self.group_values.len() / self.batch_size * self.batch_size;
948+
let batch = self.emit(EmitTo::First(n), false)?;
949+
self.exec_state = ExecutionState::ProducingOutput(batch);
950+
} else {
951+
let blocks = self.group_values.len() / self.block_size;
952+
self.exec_state = ExecutionState::ProducingBlocks(Some(blocks));
953+
}
907954
}
908955
Ok(())
909956
}
@@ -961,8 +1008,12 @@ impl GroupedHashAggregateStream {
9611008
let elapsed_compute = self.baseline_metrics.elapsed_compute().clone();
9621009
let timer = elapsed_compute.timer();
9631010
self.exec_state = if self.spill_state.spills.is_empty() {
964-
let batch = self.emit(EmitTo::All, false)?;
965-
ExecutionState::ProducingOutput(batch)
1011+
if !self.enable_blocked_group_states {
1012+
let batch = self.emit(EmitTo::All, false)?;
1013+
ExecutionState::ProducingOutput(batch)
1014+
} else {
1015+
ExecutionState::ProducingBlocks(None)
1016+
}
9661017
} else {
9671018
// If spill files exist, stream-merge them.
9681019
self.update_merged_stream()?;
@@ -994,8 +1045,12 @@ impl GroupedHashAggregateStream {
9941045
fn switch_to_skip_aggregation(&mut self) -> Result<()> {
9951046
if let Some(probe) = self.skip_aggregation_probe.as_mut() {
9961047
if probe.should_skip() {
997-
let batch = self.emit(EmitTo::All, false)?;
998-
self.exec_state = ExecutionState::ProducingOutput(batch);
1048+
if !self.enable_blocked_group_states {
1049+
let batch = self.emit(EmitTo::All, false)?;
1050+
self.exec_state = ExecutionState::ProducingOutput(batch);
1051+
} else {
1052+
self.exec_state = ExecutionState::ProducingBlocks(None);
1053+
}
9991054
}
10001055
}
10011056

0 commit comments

Comments
 (0)