|
19 | 19 |
|
20 | 20 | use std::sync::Arc;
|
21 | 21 | use std::task::{Context, Poll};
|
22 |
| -use std::vec; |
| 22 | +use std::{mem, vec}; |
23 | 23 |
|
24 | 24 | use crate::aggregates::group_values::{new_group_values, GroupValuesLike};
|
25 | 25 | use crate::aggregates::order::GroupOrderingFull;
|
@@ -79,28 +79,74 @@ use super::order::GroupOrdering;
|
79 | 79 | use super::AggregateExec;
|
80 | 80 |
|
81 | 81 | struct PartitionedOutput {
|
82 |
| - batches: Vec<Option<RecordBatch>>, |
83 |
| - current_idx: usize, |
84 |
| - exhausted: bool |
| 82 | + partitions: Vec<Option<RecordBatch>>, |
| 83 | + start_idx: usize, |
| 84 | + batch_size: usize, |
| 85 | + num_partitions: usize, |
85 | 86 | }
|
86 | 87 |
|
87 | 88 | impl PartitionedOutput {
|
88 |
| - pub fn new(batches: Vec<RecordBatch>) -> Self { |
89 |
| - let batches = batches.into_iter().map(|batch| Some(batch)).collect(); |
| 89 | + pub fn new( |
| 90 | + src_batches: Vec<RecordBatch>, |
| 91 | + batch_size: usize, |
| 92 | + num_partitions: usize, |
| 93 | + ) -> Self { |
| 94 | + let partitions = src_batches.into_iter().map(|batch| Some(batch)).collect(); |
90 | 95 |
|
91 | 96 | Self {
|
92 |
| - batches, |
93 |
| - current_idx: 0, |
94 |
| - exhausted: false, |
| 97 | + partitions, |
| 98 | + start_idx: 0, |
| 99 | + batch_size, |
| 100 | + num_partitions, |
95 | 101 | }
|
96 | 102 | }
|
97 | 103 |
|
98 | 104 | pub fn next_batch(&mut self) -> Option<RecordBatch> {
|
| 105 | + let mut current_idx = self.start_idx; |
| 106 | + loop { |
| 107 | + // If found a partition having data, |
| 108 | + let batch_opt = if self.partitions[current_idx].is_some() { |
| 109 | + Some(self.extract_batch_from_partition(current_idx)) |
| 110 | + } else { |
| 111 | + None |
| 112 | + }; |
| 113 | + |
| 114 | + // Advance the `current_idx` |
| 115 | + current_idx = (current_idx + 1) % self.num_partitions; |
| 116 | + |
| 117 | + if batch_opt.is_some() { |
| 118 | + // If found batch, we update the `start_idx` and return it |
| 119 | + self.start_idx = current_idx; |
| 120 | + return batch_opt; |
| 121 | + } else if self.start_idx == current_idx { |
| 122 | + // If not found, and has loop to end, we return None |
| 123 | + return batch_opt; |
| 124 | + } |
| 125 | + // Otherwise, we loop to check next partition |
| 126 | + } |
| 127 | + } |
99 | 128 |
|
| 129 | + pub fn extract_batch_from_partition(&mut self, part_idx: usize) -> RecordBatch { |
| 130 | + let partition_batch = mem::take(&mut self.partitions[part_idx]).unwrap(); |
| 131 | + if partition_batch.num_rows() > self.batch_size { |
| 132 | + // If still the exist rows num > `batch_size`, |
| 133 | + // cut off `batch_size` rows as `output``, |
| 134 | + // and set back `remaining`. |
| 135 | + let size = self.batch_size; |
| 136 | + let num_remaining = batch.num_rows() - size; |
| 137 | + let remaining = partition_batch.slice(size, num_remaining); |
| 138 | + let output = partition_batch.slice(0, size); |
| 139 | + self.partitions[part_idx] = Some(remaining); |
| 140 | + |
| 141 | + output |
| 142 | + } else { |
| 143 | + // If they are the last rows in `partition_batch`, just return, |
| 144 | + // because `partition_batch` has been set to `None`. |
| 145 | + partition_batch |
| 146 | + } |
100 | 147 | }
|
101 | 148 | }
|
102 | 149 |
|
103 |
| - |
104 | 150 | /// This encapsulates the spilling state
|
105 | 151 | struct SpillState {
|
106 | 152 | // ========================================================================
|
|
0 commit comments