@@ -64,6 +64,8 @@ pub(crate) enum ExecutionState {
64
64
/// When producing output, the remaining rows to output are stored
65
65
/// here and are sliced off as needed in batch_size chunks
66
66
ProducingOutput ( RecordBatch ) ,
67
+
68
+ ProducingPartitionedOutput ( Vec < Option < RecordBatch > > ) ,
67
69
/// Produce intermediate aggregate state for each input row without
68
70
/// aggregation.
69
71
///
@@ -76,6 +78,29 @@ pub(crate) enum ExecutionState {
76
78
use super :: order:: GroupOrdering ;
77
79
use super :: AggregateExec ;
78
80
81
+ struct PartitionedOutput {
82
+ batches : Vec < Option < RecordBatch > > ,
83
+ current_idx : usize ,
84
+ exhausted : bool
85
+ }
86
+
87
+ impl PartitionedOutput {
88
+ pub fn new ( batches : Vec < RecordBatch > ) -> Self {
89
+ let batches = batches. into_iter ( ) . map ( |batch| Some ( batch) ) . collect ( ) ;
90
+
91
+ Self {
92
+ batches,
93
+ current_idx : 0 ,
94
+ exhausted : false ,
95
+ }
96
+ }
97
+
98
+ pub fn next_batch ( & mut self ) -> Option < RecordBatch > {
99
+
100
+ }
101
+ }
102
+
103
+
79
104
/// This encapsulates the spilling state
80
105
struct SpillState {
81
106
// ========================================================================
@@ -677,7 +702,9 @@ impl Stream for GroupedHashAggregateStream {
677
702
}
678
703
679
704
if let Some ( to_emit) = self . group_ordering . emit_to ( ) {
680
- let batch = extract_ok ! ( self . emit( to_emit, false ) ) ;
705
+ let mut batch = extract_ok ! ( self . emit( to_emit, false ) ) ;
706
+ assert_eq ! ( batch. len( ) , 1 ) ;
707
+ let batch = batch. pop ( ) . unwrap ( ) ;
681
708
self . exec_state = ExecutionState :: ProducingOutput ( batch) ;
682
709
timer. done ( ) ;
683
710
// make sure the exec_state just set is not overwritten below
@@ -759,6 +786,8 @@ impl Stream for GroupedHashAggregateStream {
759
786
let _ = self . update_memory_reservation ( ) ;
760
787
return Poll :: Ready ( None ) ;
761
788
}
789
+
790
+ ExecutionState :: ProducingPartitionedOutput ( _) => todo ! ( ) ,
762
791
}
763
792
}
764
793
}
@@ -1101,7 +1130,9 @@ impl GroupedHashAggregateStream {
1101
1130
1102
1131
/// Emit all rows, sort them, and store them on disk.
1103
1132
fn spill ( & mut self ) -> Result < ( ) > {
1104
- let emit = self . emit ( EmitTo :: All , true ) ?;
1133
+ let mut emit = self . emit ( EmitTo :: All , true ) ?;
1134
+ assert_eq ! ( emit. len( ) , 1 ) ;
1135
+ let emit = emit. pop ( ) . unwrap ( ) ;
1105
1136
let sorted = sort_batch ( & emit, & self . spill_state . spill_expr , None ) ?;
1106
1137
let spillfile = self . runtime . disk_manager . create_tmp_file ( "HashAggSpill" ) ?;
1107
1138
// TODO: slice large `sorted` and write to multiple files in parallel
@@ -1138,9 +1169,16 @@ impl GroupedHashAggregateStream {
1138
1169
&& matches ! ( self . mode, AggregateMode :: Partial )
1139
1170
&& self . update_memory_reservation ( ) . is_err ( )
1140
1171
{
1141
- let n = self . group_values . len ( ) / self . batch_size * self . batch_size ;
1142
- let batch = self . emit ( EmitTo :: First ( n) , false ) ?;
1143
- self . exec_state = ExecutionState :: ProducingOutput ( batch) ;
1172
+ if !self . group_values . is_partitioned ( ) {
1173
+ let n = self . group_values . len ( ) / self . batch_size * self . batch_size ;
1174
+ let mut batch = self . emit ( EmitTo :: First ( n) , false ) ?;
1175
+ let batch = batch. pop ( ) . unwrap ( ) ;
1176
+ self . exec_state = ExecutionState :: ProducingOutput ( batch) ;
1177
+ } else {
1178
+ let batches = self . emit ( EmitTo :: All , false ) ?;
1179
+ let batches = batches. into_iter ( ) . map ( |batch| Some ( batch) ) . collect ( ) ;
1180
+ self . exec_state = ExecutionState :: ProducingPartitionedOutput ( batches) ;
1181
+ }
1144
1182
}
1145
1183
Ok ( ( ) )
1146
1184
}
@@ -1150,7 +1188,9 @@ impl GroupedHashAggregateStream {
1150
1188
/// Conduct a streaming merge sort between the batch and spilled data. Since the stream is fully
1151
1189
/// sorted, set `self.group_ordering` to Full, then later we can read with [`EmitTo::First`].
1152
1190
fn update_merged_stream ( & mut self ) -> Result < ( ) > {
1153
- let batch = self . emit ( EmitTo :: All , true ) ?;
1191
+ let mut batch = self . emit ( EmitTo :: All , true ) ?;
1192
+ assert_eq ! ( batch. len( ) , 1 ) ;
1193
+ let batch = batch. pop ( ) . unwrap ( ) ;
1154
1194
// clear up memory for streaming_merge
1155
1195
self . clear_all ( ) ;
1156
1196
self . update_memory_reservation ( ) ?;
@@ -1198,8 +1238,15 @@ impl GroupedHashAggregateStream {
1198
1238
let elapsed_compute = self . baseline_metrics . elapsed_compute ( ) . clone ( ) ;
1199
1239
let timer = elapsed_compute. timer ( ) ;
1200
1240
self . exec_state = if self . spill_state . spills . is_empty ( ) {
1201
- let batch = self . emit ( EmitTo :: All , false ) ?;
1202
- ExecutionState :: ProducingOutput ( batch)
1241
+ if !self . group_values . is_partitioned ( ) {
1242
+ let mut batch = self . emit ( EmitTo :: All , false ) ?;
1243
+ let batch = batch. pop ( ) . unwrap ( ) ;
1244
+ ExecutionState :: ProducingOutput ( batch)
1245
+ } else {
1246
+ let batches = self . emit ( EmitTo :: All , false ) ?;
1247
+ let batches = batches. into_iter ( ) . map ( |batch| Some ( batch) ) . collect ( ) ;
1248
+ ExecutionState :: ProducingPartitionedOutput ( batches)
1249
+ }
1203
1250
} else {
1204
1251
// If spill files exist, stream-merge them.
1205
1252
self . update_merged_stream ( ) ?;
@@ -1231,8 +1278,15 @@ impl GroupedHashAggregateStream {
1231
1278
fn switch_to_skip_aggregation ( & mut self ) -> Result < ( ) > {
1232
1279
if let Some ( probe) = self . skip_aggregation_probe . as_mut ( ) {
1233
1280
if probe. should_skip ( ) {
1234
- let batch = self . emit ( EmitTo :: All , false ) ?;
1235
- self . exec_state = ExecutionState :: ProducingOutput ( batch) ;
1281
+ if !self . group_values . is_partitioned ( ) {
1282
+ let mut batch = self . emit ( EmitTo :: All , false ) ?;
1283
+ let batch = batch. pop ( ) . unwrap ( ) ;
1284
+ self . exec_state = ExecutionState :: ProducingOutput ( batch) ;
1285
+ } else {
1286
+ let batches = self . emit ( EmitTo :: All , false ) ?;
1287
+ let batches = batches. into_iter ( ) . map ( |batch| Some ( batch) ) . collect ( ) ;
1288
+ self . exec_state = ExecutionState :: ProducingPartitionedOutput ( batches) ;
1289
+ }
1236
1290
}
1237
1291
}
1238
1292
0 commit comments