@@ -85,7 +85,11 @@ pub(crate) struct GroupedHashAggregateStreamV2 {
85
85
86
86
baseline_metrics : BaselineMetrics ,
87
87
random_state : RandomState ,
88
- finished : bool ,
88
+ /// size to be used for resulting RecordBatches
89
+ batch_size : usize ,
90
+ /// if the result is chunked into batches,
91
+ /// last offset is preserved for continuation.
92
+ row_group_skip_position : usize ,
89
93
}
90
94
91
95
fn aggr_state_schema ( aggr_expr : & [ Arc < dyn AggregateExpr > ] ) -> Result < SchemaRef > {
@@ -105,6 +109,7 @@ impl GroupedHashAggregateStreamV2 {
105
109
aggr_expr : Vec < Arc < dyn AggregateExpr > > ,
106
110
input : SendableRecordBatchStream ,
107
111
baseline_metrics : BaselineMetrics ,
112
+ batch_size : usize ,
108
113
) -> Result < Self > {
109
114
let timer = baseline_metrics. elapsed_compute ( ) . timer ( ) ;
110
115
@@ -135,7 +140,8 @@ impl GroupedHashAggregateStreamV2 {
135
140
aggregate_expressions,
136
141
aggr_state : Default :: default ( ) ,
137
142
random_state : Default :: default ( ) ,
138
- finished : false ,
143
+ batch_size,
144
+ row_group_skip_position : 0 ,
139
145
} )
140
146
}
141
147
}
@@ -148,56 +154,62 @@ impl Stream for GroupedHashAggregateStreamV2 {
148
154
cx : & mut Context < ' _ > ,
149
155
) -> Poll < Option < Self :: Item > > {
150
156
let this = & mut * self ;
151
- if this. finished {
152
- return Poll :: Ready ( None ) ;
153
- }
154
157
155
158
let elapsed_compute = this. baseline_metrics . elapsed_compute ( ) ;
156
159
157
160
loop {
158
- let result = match ready ! ( this. input. poll_next_unpin( cx) ) {
159
- Some ( Ok ( batch) ) => {
160
- let timer = elapsed_compute. timer ( ) ;
161
- let result = group_aggregate_batch (
162
- & this. mode ,
163
- & this. random_state ,
164
- & this. group_by ,
165
- & mut this. accumulators ,
166
- & this. group_schema ,
167
- this. aggr_layout . clone ( ) ,
168
- batch,
169
- & mut this. aggr_state ,
170
- & this. aggregate_expressions ,
171
- ) ;
172
-
173
- timer. done ( ) ;
174
-
175
- match result {
176
- Ok ( _) => continue ,
177
- Err ( e) => Err ( ArrowError :: ExternalError ( Box :: new ( e) ) ) ,
161
+ let result: ArrowResult < Option < RecordBatch > > =
162
+ match ready ! ( this. input. poll_next_unpin( cx) ) {
163
+ Some ( Ok ( batch) ) => {
164
+ let timer = elapsed_compute. timer ( ) ;
165
+ let result = group_aggregate_batch (
166
+ & this. mode ,
167
+ & this. random_state ,
168
+ & this. group_by ,
169
+ & mut this. accumulators ,
170
+ & this. group_schema ,
171
+ this. aggr_layout . clone ( ) ,
172
+ batch,
173
+ & mut this. aggr_state ,
174
+ & this. aggregate_expressions ,
175
+ ) ;
176
+
177
+ timer. done ( ) ;
178
+
179
+ match result {
180
+ Ok ( _) => continue ,
181
+ Err ( e) => Err ( ArrowError :: ExternalError ( Box :: new ( e) ) ) ,
182
+ }
178
183
}
184
+ Some ( Err ( e) ) => Err ( e) ,
185
+ None => {
186
+ let timer = this. baseline_metrics . elapsed_compute ( ) . timer ( ) ;
187
+ let result = create_batch_from_map (
188
+ & this. mode ,
189
+ & this. group_schema ,
190
+ & this. aggr_schema ,
191
+ this. batch_size ,
192
+ this. row_group_skip_position ,
193
+ & mut this. aggr_state ,
194
+ & mut this. accumulators ,
195
+ & this. schema ,
196
+ ) ;
197
+
198
+ timer. done ( ) ;
199
+ result
200
+ }
201
+ } ;
202
+
203
+ this. row_group_skip_position += this. batch_size ;
204
+ match result {
205
+ Ok ( Some ( result) ) => {
206
+ return Poll :: Ready ( Some ( Ok (
207
+ result. record_output ( & this. baseline_metrics )
208
+ ) ) )
179
209
}
180
- Some ( Err ( e) ) => Err ( e) ,
181
- None => {
182
- this. finished = true ;
183
- let timer = this. baseline_metrics . elapsed_compute ( ) . timer ( ) ;
184
- let result = create_batch_from_map (
185
- & this. mode ,
186
- & this. group_schema ,
187
- & this. aggr_schema ,
188
- & mut this. aggr_state ,
189
- & mut this. accumulators ,
190
- & this. schema ,
191
- )
192
- . record_output ( & this. baseline_metrics ) ;
193
-
194
- timer. done ( ) ;
195
- result
196
- }
197
- } ;
198
-
199
- this. finished = true ;
200
- return Poll :: Ready ( Some ( result) ) ;
210
+ Ok ( None ) => return Poll :: Ready ( None ) ,
211
+ Err ( error) => return Poll :: Ready ( Some ( Err ( error) ) ) ,
212
+ }
201
213
}
202
214
}
203
215
}
@@ -419,23 +431,34 @@ fn create_group_rows(arrays: Vec<ArrayRef>, schema: &Schema) -> Vec<Vec<u8>> {
419
431
}
420
432
421
433
/// Create a RecordBatch with all group keys and accumulator' states or values.
434
+ #[ allow( clippy:: too_many_arguments) ]
422
435
fn create_batch_from_map (
423
436
mode : & AggregateMode ,
424
437
group_schema : & Schema ,
425
438
aggr_schema : & Schema ,
439
+ batch_size : usize ,
440
+ skip_items : usize ,
426
441
aggr_state : & mut AggregationState ,
427
442
accumulators : & mut [ AccumulatorItemV2 ] ,
428
443
output_schema : & Schema ,
429
- ) -> ArrowResult < RecordBatch > {
444
+ ) -> ArrowResult < Option < RecordBatch > > {
445
+ if skip_items > aggr_state. group_states . len ( ) {
446
+ return Ok ( None ) ;
447
+ }
448
+
430
449
if aggr_state. group_states . is_empty ( ) {
431
- return Ok ( RecordBatch :: new_empty ( Arc :: new ( output_schema. to_owned ( ) ) ) ) ;
450
+ return Ok ( Some ( RecordBatch :: new_empty ( Arc :: new (
451
+ output_schema. to_owned ( ) ,
452
+ ) ) ) ) ;
432
453
}
433
454
434
455
let mut state_accessor = RowAccessor :: new ( aggr_schema, RowType :: WordAligned ) ;
435
456
436
457
let ( group_buffers, mut state_buffers) : ( Vec < _ > , Vec < _ > ) = aggr_state
437
458
. group_states
438
459
. iter ( )
460
+ . skip ( skip_items)
461
+ . take ( batch_size)
439
462
. map ( |gs| ( gs. group_by_values . clone ( ) , gs. aggregation_buffer . clone ( ) ) )
440
463
. unzip ( ) ;
441
464
@@ -471,7 +494,7 @@ fn create_batch_from_map(
471
494
. map ( |( col, desired_field) | cast ( col, desired_field. data_type ( ) ) )
472
495
. collect :: < ArrowResult < Vec < _ > > > ( ) ?;
473
496
474
- RecordBatch :: try_new ( Arc :: new ( output_schema. to_owned ( ) ) , columns)
497
+ RecordBatch :: try_new ( Arc :: new ( output_schema. to_owned ( ) ) , columns) . map ( Some )
475
498
}
476
499
477
500
fn read_as_batch ( rows : & [ Vec < u8 > ] , schema : & Schema , row_type : RowType ) -> Vec < ArrayRef > {
0 commit comments