Skip to content

Commit bec7a3a

Browse files
committed
impl emit_partitioned.
1 parent f24f3c3 commit bec7a3a

File tree

1 file changed

+40
-15
lines changed

1 file changed

+40
-15
lines changed

datafusion/physical-plan/src/aggregates/row_hash.rs

Lines changed: 40 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,9 @@ use arrow::array::*;
3939
use arrow::datatypes::SchemaRef;
4040
use arrow_schema::SortOptions;
4141
use datafusion_common::utils::get_arrayref_at_indices;
42-
use datafusion_common::{internal_datafusion_err, DataFusionError, Result};
42+
use datafusion_common::{
43+
arrow_datafusion_err, internal_datafusion_err, DataFusionError, Result,
44+
};
4345
use datafusion_execution::disk_manager::RefCountedTempFile;
4446
use datafusion_execution::memory_pool::proxy::VecAllocExt;
4547
use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
@@ -868,7 +870,7 @@ impl GroupedHashAggregateStream {
868870
get_filter_at_indices(opt_filter, &batch_indices)
869871
})
870872
.collect::<Result<Vec<_>>>()?;
871-
873+
872874
// Update the accumulators of each partition
873875
for (part_idx, part_start_end) in offsets.windows(2).enumerate() {
874876
let (offset, length) =
@@ -886,7 +888,8 @@ impl GroupedHashAggregateStream {
886888
.map(|array| array.slice(offset, length))
887889
.collect::<Vec<_>>();
888890

889-
let part_opt_filter = opt_filter.as_ref().map(|f| f.slice(offset, length));
891+
let part_opt_filter =
892+
opt_filter.as_ref().map(|f| f.slice(offset, length));
890893
let part_opt_filter =
891894
part_opt_filter.as_ref().map(|filter| filter.as_boolean());
892895

@@ -980,21 +983,29 @@ impl GroupedHashAggregateStream {
980983

981984
/// Create an output RecordBatch with the group keys and
982985
/// accumulator states/values specified in emit_to
983-
fn emit(&mut self, emit_to: EmitTo, spilling: bool) -> Result<RecordBatch> {
984-
self.emit_single(emit_to, spilling)
986+
fn emit(&mut self, emit_to: EmitTo, spilling: bool) -> Result<Vec<RecordBatch>> {
987+
if !self.group_values.is_partitioned() {
988+
self.emit_single(emit_to, spilling)
989+
} else {
990+
self.emit_partitioned(emit_to)
991+
}
985992
}
986993

987994
/// Create an output RecordBatch with the group keys and
988995
/// accumulator states/values specified in emit_to
989-
fn emit_single(&mut self, emit_to: EmitTo, spilling: bool) -> Result<RecordBatch> {
996+
fn emit_single(
997+
&mut self,
998+
emit_to: EmitTo,
999+
spilling: bool,
1000+
) -> Result<Vec<RecordBatch>> {
9901001
let schema = if spilling {
9911002
Arc::clone(&self.spill_state.spill_schema)
9921003
} else {
9931004
self.schema()
9941005
};
9951006

9961007
if self.group_values.is_empty() {
997-
return Ok(RecordBatch::new_empty(schema));
1008+
return Ok(Vec::new());
9981009
}
9991010

10001011
let group_values = self.group_values.as_single_mut();
@@ -1023,10 +1034,11 @@ impl GroupedHashAggregateStream {
10231034
// over the target memory size after emission, we can emit again rather than returning Err.
10241035
let _ = self.update_memory_reservation();
10251036
let batch = RecordBatch::try_new(schema, output)?;
1026-
Ok(batch)
1037+
1038+
Ok(vec![batch])
10271039
}
10281040

1029-
fn emit_partitioned(&mut self, emit_to: EmitTo, spilling: bool) -> Result<Vec<RecordBatch>> {
1041+
fn emit_partitioned(&mut self, emit_to: EmitTo) -> Result<Vec<RecordBatch>> {
10301042
assert!(
10311043
self.mode == AggregateMode::Partial
10321044
&& matches!(self.group_ordering, GroupOrdering::None)
@@ -1035,22 +1047,35 @@ impl GroupedHashAggregateStream {
10351047
let schema = self.schema();
10361048

10371049
if self.group_values.is_empty() {
1038-
return Ok(RecordBatch::new_empty(schema));
1050+
return Ok(Vec::new());
10391051
}
10401052

10411053
let group_values = self.group_values.as_partitioned_mut();
1042-
let mut output = group_values.emit(emit_to)?;
1054+
let mut partitioned_outputs = group_values.emit(emit_to)?;
10431055

10441056
// Next output each aggregate value
1045-
for acc in self.accumulators[0].iter_mut() {
1046-
output.extend(acc.state(emit_to)?)
1057+
for (output, accs) in partitioned_outputs
1058+
.iter_mut()
1059+
.zip(self.accumulators.iter_mut())
1060+
{
1061+
for acc in accs.iter_mut() {
1062+
output.extend(acc.state(emit_to)?);
1063+
}
10471064
}
10481065

10491066
// emit reduces the memory usage. Ignore Err from update_memory_reservation. Even if it is
10501067
// over the target memory size after emission, we can emit again rather than returning Err.
10511068
let _ = self.update_memory_reservation();
1052-
let batch = RecordBatch::try_new(schema, output)?;
1053-
Ok(batch)
1069+
1070+
let batch_parts = partitioned_outputs
1071+
.into_iter()
1072+
.map(|part| {
1073+
RecordBatch::try_new(schema.clone(), part)
1074+
.map_err(|e| arrow_datafusion_err!(e))
1075+
})
1076+
.collect::<Result<Vec<_>>>()?;
1077+
1078+
Ok(batch_parts)
10541079
}
10551080

10561081
/// Optimistically, [`Self::group_aggregate_batch`] allows to exceed the memory target slightly

0 commit comments

Comments
 (0)