Skip to content

Commit 4108065

Browse files
committed
introduce PartitionedOutput.
1 parent 553c6a3 commit 4108065

File tree

1 file changed

+56
-10
lines changed

1 file changed

+56
-10
lines changed

datafusion/physical-plan/src/aggregates/row_hash.rs

+56-10
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
2020
use std::sync::Arc;
2121
use std::task::{Context, Poll};
22-
use std::vec;
22+
use std::{mem, vec};
2323

2424
use crate::aggregates::group_values::{new_group_values, GroupValuesLike};
2525
use crate::aggregates::order::GroupOrderingFull;
@@ -79,28 +79,74 @@ use super::order::GroupOrdering;
7979
use super::AggregateExec;
8080

8181
struct PartitionedOutput {
82-
batches: Vec<Option<RecordBatch>>,
83-
current_idx: usize,
84-
exhausted: bool
82+
partitions: Vec<Option<RecordBatch>>,
83+
start_idx: usize,
84+
batch_size: usize,
85+
num_partitions: usize,
8586
}
8687

8788
impl PartitionedOutput {
88-
pub fn new(batches: Vec<RecordBatch>) -> Self {
89-
let batches = batches.into_iter().map(|batch| Some(batch)).collect();
89+
pub fn new(
90+
src_batches: Vec<RecordBatch>,
91+
batch_size: usize,
92+
num_partitions: usize,
93+
) -> Self {
94+
let partitions = src_batches.into_iter().map(|batch| Some(batch)).collect();
9095

9196
Self {
92-
batches,
93-
current_idx: 0,
94-
exhausted: false,
97+
partitions,
98+
start_idx: 0,
99+
batch_size,
100+
num_partitions,
95101
}
96102
}
97103

98104
pub fn next_batch(&mut self) -> Option<RecordBatch> {
105+
let mut current_idx = self.start_idx;
106+
loop {
107+
// If found a partition having data,
108+
let batch_opt = if self.partitions[current_idx].is_some() {
109+
Some(self.extract_batch_from_partition(current_idx))
110+
} else {
111+
None
112+
};
113+
114+
// Advance the `current_idx`
115+
current_idx = (current_idx + 1) % self.num_partitions;
116+
117+
if batch_opt.is_some() {
118+
// If found batch, we update the `start_idx` and return it
119+
self.start_idx = current_idx;
120+
return batch_opt;
121+
} else if self.start_idx == current_idx {
122+
// If not found, and has loop to end, we return None
123+
return batch_opt;
124+
}
125+
// Otherwise, we loop to check next partition
126+
}
127+
}
99128

129+
pub fn extract_batch_from_partition(&mut self, part_idx: usize) -> RecordBatch {
130+
let partition_batch = mem::take(&mut self.partitions[part_idx]).unwrap();
131+
if partition_batch.num_rows() > self.batch_size {
132+
// If still the exist rows num > `batch_size`,
133+
// cut off `batch_size` rows as `output``,
134+
// and set back `remaining`.
135+
let size = self.batch_size;
136+
let num_remaining = batch.num_rows() - size;
137+
let remaining = partition_batch.slice(size, num_remaining);
138+
let output = partition_batch.slice(0, size);
139+
self.partitions[part_idx] = Some(remaining);
140+
141+
output
142+
} else {
143+
// If they are the last rows in `partition_batch`, just return,
144+
// because `partition_batch` has been set to `None`.
145+
partition_batch
146+
}
100147
}
101148
}
102149

103-
104150
/// This encapsulates the spilling state
105151
struct SpillState {
106152
// ========================================================================

0 commit comments

Comments
 (0)