@@ -64,6 +64,8 @@ pub(crate) enum ExecutionState {
64
64
/// When producing output, the remaining rows to output are stored
65
65
/// here and are sliced off as needed in batch_size chunks
66
66
ProducingOutput ( RecordBatch ) ,
67
+
68
+ ProducingPartitionedOutput ( Vec < RecordBatch > ) ,
67
69
/// Produce intermediate aggregate state for each input row without
68
70
/// aggregation.
69
71
///
@@ -677,7 +679,9 @@ impl Stream for GroupedHashAggregateStream {
677
679
}
678
680
679
681
if let Some ( to_emit) = self . group_ordering . emit_to ( ) {
680
- let batch = extract_ok ! ( self . emit( to_emit, false ) ) ;
682
+ let mut batch = extract_ok ! ( self . emit( to_emit, false ) ) ;
683
+ assert_eq ! ( batch. len( ) , 1 ) ;
684
+ let batch = batch. pop ( ) . unwrap ( ) ;
681
685
self . exec_state = ExecutionState :: ProducingOutput ( batch) ;
682
686
timer. done ( ) ;
683
687
// make sure the exec_state just set is not overwritten below
@@ -759,6 +763,7 @@ impl Stream for GroupedHashAggregateStream {
759
763
let _ = self . update_memory_reservation ( ) ;
760
764
return Poll :: Ready ( None ) ;
761
765
}
766
+ ExecutionState :: ProducingPartitionedOutput ( _) => todo ! ( ) ,
762
767
}
763
768
}
764
769
}
@@ -1101,7 +1106,9 @@ impl GroupedHashAggregateStream {
1101
1106
1102
1107
/// Emit all rows, sort them, and store them on disk.
1103
1108
fn spill ( & mut self ) -> Result < ( ) > {
1104
- let emit = self . emit ( EmitTo :: All , true ) ?;
1109
+ let mut emit = self . emit ( EmitTo :: All , true ) ?;
1110
+ assert_eq ! ( emit. len( ) , 1 ) ;
1111
+ let emit = emit. pop ( ) . unwrap ( ) ;
1105
1112
let sorted = sort_batch ( & emit, & self . spill_state . spill_expr , None ) ?;
1106
1113
let spillfile = self . runtime . disk_manager . create_tmp_file ( "HashAggSpill" ) ?;
1107
1114
// TODO: slice large `sorted` and write to multiple files in parallel
@@ -1138,9 +1145,15 @@ impl GroupedHashAggregateStream {
1138
1145
&& matches ! ( self . mode, AggregateMode :: Partial )
1139
1146
&& self . update_memory_reservation ( ) . is_err ( )
1140
1147
{
1141
- let n = self . group_values . len ( ) / self . batch_size * self . batch_size ;
1142
- let batch = self . emit ( EmitTo :: First ( n) , false ) ?;
1143
- self . exec_state = ExecutionState :: ProducingOutput ( batch) ;
1148
+ if !self . group_values . is_partitioned ( ) {
1149
+ let n = self . group_values . len ( ) / self . batch_size * self . batch_size ;
1150
+ let mut batch = self . emit ( EmitTo :: First ( n) , false ) ?;
1151
+ let batch = batch. pop ( ) . unwrap ( ) ;
1152
+ self . exec_state = ExecutionState :: ProducingOutput ( batch) ;
1153
+ } else {
1154
+ let batches = self . emit ( EmitTo :: All , false ) ?;
1155
+ self . exec_state = ExecutionState :: ProducingPartitionedOutput ( batches) ;
1156
+ }
1144
1157
}
1145
1158
Ok ( ( ) )
1146
1159
}
@@ -1150,7 +1163,9 @@ impl GroupedHashAggregateStream {
1150
1163
/// Conduct a streaming merge sort between the batch and spilled data. Since the stream is fully
1151
1164
/// sorted, set `self.group_ordering` to Full, then later we can read with [`EmitTo::First`].
1152
1165
fn update_merged_stream ( & mut self ) -> Result < ( ) > {
1153
- let batch = self . emit ( EmitTo :: All , true ) ?;
1166
+ let mut batch = self . emit ( EmitTo :: All , true ) ?;
1167
+ assert_eq ! ( batch. len( ) , 1 ) ;
1168
+ let batch = batch. pop ( ) . unwrap ( ) ;
1154
1169
// clear up memory for streaming_merge
1155
1170
self . clear_all ( ) ;
1156
1171
self . update_memory_reservation ( ) ?;
@@ -1198,8 +1213,14 @@ impl GroupedHashAggregateStream {
1198
1213
let elapsed_compute = self . baseline_metrics . elapsed_compute ( ) . clone ( ) ;
1199
1214
let timer = elapsed_compute. timer ( ) ;
1200
1215
self . exec_state = if self . spill_state . spills . is_empty ( ) {
1201
- let batch = self . emit ( EmitTo :: All , false ) ?;
1202
- ExecutionState :: ProducingOutput ( batch)
1216
+ if !self . group_values . is_partitioned ( ) {
1217
+ let mut batch = self . emit ( EmitTo :: All , false ) ?;
1218
+ let batch = batch. pop ( ) . unwrap ( ) ;
1219
+ ExecutionState :: ProducingOutput ( batch)
1220
+ } else {
1221
+ let batches = self . emit ( EmitTo :: All , false ) ?;
1222
+ ExecutionState :: ProducingPartitionedOutput ( batches)
1223
+ }
1203
1224
} else {
1204
1225
// If spill files exist, stream-merge them.
1205
1226
self . update_merged_stream ( ) ?;
@@ -1231,8 +1252,14 @@ impl GroupedHashAggregateStream {
1231
1252
fn switch_to_skip_aggregation ( & mut self ) -> Result < ( ) > {
1232
1253
if let Some ( probe) = self . skip_aggregation_probe . as_mut ( ) {
1233
1254
if probe. should_skip ( ) {
1234
- let batch = self . emit ( EmitTo :: All , false ) ?;
1235
- self . exec_state = ExecutionState :: ProducingOutput ( batch) ;
1255
+ if !self . group_values . is_partitioned ( ) {
1256
+ let mut batch = self . emit ( EmitTo :: All , false ) ?;
1257
+ let batch = batch. pop ( ) . unwrap ( ) ;
1258
+ self . exec_state = ExecutionState :: ProducingOutput ( batch) ;
1259
+ } else {
1260
+ let batches = self . emit ( EmitTo :: All , false ) ?;
1261
+ self . exec_state = ExecutionState :: ProducingPartitionedOutput ( batches) ;
1262
+ }
1236
1263
}
1237
1264
}
1238
1265
0 commit comments