@@ -65,7 +65,7 @@ pub(crate) enum ExecutionState {
65
65
/// here and are sliced off as needed in batch_size chunks
66
66
ProducingOutput ( RecordBatch ) ,
67
67
68
- ProducingPartitionedOutput ( Vec < Option < RecordBatch > > ) ,
68
+ ProducingPartitionedOutput ( PartitionedOutput ) ,
69
69
/// Produce intermediate aggregate state for each input row without
70
70
/// aggregation.
71
71
///
@@ -78,6 +78,7 @@ pub(crate) enum ExecutionState {
78
78
use super :: order:: GroupOrdering ;
79
79
use super :: AggregateExec ;
80
80
81
+ #[ derive( Debug , Clone , Default ) ]
81
82
struct PartitionedOutput {
82
83
partitions : Vec < Option < RecordBatch > > ,
83
84
start_idx : usize ,
@@ -133,7 +134,7 @@ impl PartitionedOutput {
133
134
// cut off `batch_size` rows as `output``,
134
135
// and set back `remaining`.
135
136
let size = self . batch_size ;
136
- let num_remaining = batch . num_rows ( ) - size;
137
+ let num_remaining = partition_batch . num_rows ( ) - size;
137
138
let remaining = partition_batch. slice ( size, num_remaining) ;
138
139
let output = partition_batch. slice ( 0 , size) ;
139
140
self . partitions [ part_idx] = Some ( remaining) ;
@@ -718,7 +719,7 @@ impl Stream for GroupedHashAggregateStream {
718
719
let elapsed_compute = self . baseline_metrics . elapsed_compute ( ) . clone ( ) ;
719
720
720
721
loop {
721
- match & self . exec_state {
722
+ match & mut self . exec_state {
722
723
ExecutionState :: ReadingInput => ' reading_input: {
723
724
match ready ! ( self . input. poll_next_unpin( cx) ) {
724
725
// new batch to aggregate
@@ -800,6 +801,7 @@ impl Stream for GroupedHashAggregateStream {
800
801
ExecutionState :: ProducingOutput ( batch) => {
801
802
// slice off a part of the batch, if needed
802
803
let output_batch;
804
+ let batch = batch. clone ( ) ;
803
805
let size = self . batch_size ;
804
806
( self . exec_state , output_batch) = if batch. num_rows ( ) <= size {
805
807
(
@@ -810,7 +812,7 @@ impl Stream for GroupedHashAggregateStream {
810
812
} else {
811
813
ExecutionState :: ReadingInput
812
814
} ,
813
- batch. clone ( ) ,
815
+ batch,
814
816
)
815
817
} else {
816
818
// output first batch_size rows
@@ -833,7 +835,30 @@ impl Stream for GroupedHashAggregateStream {
833
835
return Poll :: Ready ( None ) ;
834
836
}
835
837
836
- ExecutionState :: ProducingPartitionedOutput ( _) => todo ! ( ) ,
838
+ ExecutionState :: ProducingPartitionedOutput ( parts) => {
839
+ // slice off a part of the batch, if needed
840
+ let batch_opt = parts. next_batch ( ) ;
841
+ if let Some ( batch) = batch_opt {
842
+ // output first batch_size rows
843
+ let size = self . batch_size ;
844
+ let num_remaining = batch. num_rows ( ) - size;
845
+ let remaining = batch. slice ( size, num_remaining) ;
846
+ let output = batch. slice ( 0 , size) ;
847
+ self . exec_state = ExecutionState :: ProducingOutput ( remaining) ;
848
+
849
+ return Poll :: Ready ( Some ( Ok (
850
+ output. record_output ( & self . baseline_metrics )
851
+ ) ) ) ;
852
+ } else {
853
+ self . exec_state = if self . input_done {
854
+ ExecutionState :: Done
855
+ } else if self . should_skip_aggregation ( ) {
856
+ ExecutionState :: SkippingAggregation
857
+ } else {
858
+ ExecutionState :: ReadingInput
859
+ } ;
860
+ }
861
+ }
837
862
}
838
863
}
839
864
}
@@ -1222,8 +1247,12 @@ impl GroupedHashAggregateStream {
1222
1247
self . exec_state = ExecutionState :: ProducingOutput ( batch) ;
1223
1248
} else {
1224
1249
let batches = self . emit ( EmitTo :: All , false ) ?;
1225
- let batches = batches. into_iter ( ) . map ( |batch| Some ( batch) ) . collect ( ) ;
1226
- self . exec_state = ExecutionState :: ProducingPartitionedOutput ( batches) ;
1250
+ self . exec_state =
1251
+ ExecutionState :: ProducingPartitionedOutput ( PartitionedOutput :: new (
1252
+ batches,
1253
+ self . batch_size ,
1254
+ self . group_values . num_partitions ( ) ,
1255
+ ) ) ;
1227
1256
}
1228
1257
}
1229
1258
Ok ( ( ) )
@@ -1290,8 +1319,11 @@ impl GroupedHashAggregateStream {
1290
1319
ExecutionState :: ProducingOutput ( batch)
1291
1320
} else {
1292
1321
let batches = self . emit ( EmitTo :: All , false ) ?;
1293
- let batches = batches. into_iter ( ) . map ( |batch| Some ( batch) ) . collect ( ) ;
1294
- ExecutionState :: ProducingPartitionedOutput ( batches)
1322
+ ExecutionState :: ProducingPartitionedOutput ( PartitionedOutput :: new (
1323
+ batches,
1324
+ self . batch_size ,
1325
+ self . group_values . num_partitions ( ) ,
1326
+ ) )
1295
1327
}
1296
1328
} else {
1297
1329
// If spill files exist, stream-merge them.
@@ -1330,8 +1362,13 @@ impl GroupedHashAggregateStream {
1330
1362
self . exec_state = ExecutionState :: ProducingOutput ( batch) ;
1331
1363
} else {
1332
1364
let batches = self . emit ( EmitTo :: All , false ) ?;
1333
- let batches = batches. into_iter ( ) . map ( |batch| Some ( batch) ) . collect ( ) ;
1334
- self . exec_state = ExecutionState :: ProducingPartitionedOutput ( batches) ;
1365
+ self . exec_state = ExecutionState :: ProducingPartitionedOutput (
1366
+ PartitionedOutput :: new (
1367
+ batches,
1368
+ self . batch_size ,
1369
+ self . group_values . num_partitions ( ) ,
1370
+ ) ,
1371
+ ) ;
1335
1372
}
1336
1373
}
1337
1374
}
0 commit comments