|
19 | 19 |
|
20 | 20 | use std::sync::Arc;
|
21 | 21 | use std::task::{Context, Poll};
|
22 |
| -use std::vec; |
| 22 | +use std::{mem, vec}; |
23 | 23 |
|
24 | 24 | use crate::aggregates::group_values::{new_group_values, GroupValuesLike};
|
25 | 25 | use crate::aggregates::order::GroupOrderingFull;
|
@@ -79,28 +79,76 @@ use super::order::GroupOrdering;
|
79 | 79 | use super::AggregateExec;
|
80 | 80 |
|
81 | 81 | struct PartitionedOutput {
|
82 |
| - batches: Vec<Option<RecordBatch>>, |
83 |
| - current_idx: usize, |
84 |
| - exhausted: bool |
| 82 | + partitions: Vec<Option<RecordBatch>>, |
| 83 | + start_idx: usize, |
| 84 | + batch_size: usize, |
| 85 | + num_partitions: usize, |
| 86 | + exhausted: bool, |
85 | 87 | }
|
86 | 88 |
|
87 | 89 | impl PartitionedOutput {
|
88 |
| - pub fn new(batches: Vec<RecordBatch>) -> Self { |
89 |
| - let batches = batches.into_iter().map(|batch| Some(batch)).collect(); |
| 90 | + pub fn new( |
| 91 | + src_batches: Vec<RecordBatch>, |
| 92 | + batch_size: usize, |
| 93 | + num_partitions: usize, |
| 94 | + ) -> Self { |
| 95 | + let partitions = src_batches.into_iter().map(|batch| Some(batch)).collect(); |
90 | 96 |
|
91 | 97 | Self {
|
92 |
| - batches, |
93 |
| - current_idx: 0, |
| 98 | + partitions, |
| 99 | + start_idx: 0, |
| 100 | + batch_size, |
| 101 | + num_partitions, |
94 | 102 | exhausted: false,
|
95 | 103 | }
|
96 | 104 | }
|
97 | 105 |
|
98 | 106 | pub fn next_batch(&mut self) -> Option<RecordBatch> {
|
| 107 | + let mut current_idx = self.start_idx; |
| 108 | + loop { |
| 109 | + // If found a partition having data, |
| 110 | + let batch_opt = if self.partitions[current_idx].is_some() { |
| 111 | + Some(self.extract_batch_from_partition(current_idx)) |
| 112 | + } else { |
| 113 | + None |
| 114 | + }; |
| 115 | + |
| 116 | + // Advance the `current_idx` |
| 117 | + current_idx = (current_idx + 1) % self.num_partitions; |
| 118 | + |
| 119 | + if batch_opt.is_some() { |
| 120 | + // If found batch, we update the `start_idx` and return it |
| 121 | + self.start_idx = current_idx; |
| 122 | + return batch_opt; |
| 123 | + } else if self.start_idx == current_idx { |
| 124 | + // If not found, and has loop to end, we return None |
| 125 | + return batch_opt; |
| 126 | + } |
| 127 | + // Otherwise, we loop to check next partition |
| 128 | + } |
| 129 | + } |
99 | 130 |
|
| 131 | + pub fn extract_batch_from_partition(&mut self, part_idx: usize) -> RecordBatch { |
| 132 | + let partition_batch = mem::take(&mut self.partitions[part_idx]).unwrap(); |
| 133 | + if partition_batch.num_rows() > self.batch_size { |
| 134 | + // If still the exist rows num > `batch_size`, |
| 135 | + // cut off `batch_size` rows as `output``, |
| 136 | + // and set back `remaining`. |
| 137 | + let size = self.batch_size; |
| 138 | + let num_remaining = batch.num_rows() - size; |
| 139 | + let remaining = partition_batch.slice(size, num_remaining); |
| 140 | + let output = partition_batch.slice(0, size); |
| 141 | + self.partitions[part_idx] = Some(remaining); |
| 142 | + |
| 143 | + output |
| 144 | + } else { |
| 145 | + // If they are the last rows in `partition_batch`, just return, |
| 146 | + // because `partition_batch` has been set to `None`. |
| 147 | + partition_batch |
| 148 | + } |
100 | 149 | }
|
101 | 150 | }
|
102 | 151 |
|
103 |
| - |
104 | 152 | /// This encapsulates the spilling state
|
105 | 153 | struct SpillState {
|
106 | 154 | // ========================================================================
|
|
0 commit comments