18
18
//! Defines the execution plan for the hash aggregate operation
19
19
20
20
use std:: any:: Any ;
21
- use std:: sync:: Arc ;
21
+ use std:: sync:: { Arc , Mutex } ;
22
22
use std:: task:: { Context , Poll } ;
23
23
24
24
use ahash:: RandomState ;
@@ -28,7 +28,7 @@ use futures::{
28
28
} ;
29
29
30
30
use crate :: error:: { DataFusionError , Result } ;
31
- use crate :: physical_plan:: { Accumulator , AggregateExpr } ;
31
+ use crate :: physical_plan:: { Accumulator , AggregateExpr , MetricType , SQLMetric } ;
32
32
use crate :: physical_plan:: { Distribution , ExecutionPlan , Partitioning , PhysicalExpr } ;
33
33
34
34
use arrow:: {
@@ -94,6 +94,8 @@ pub struct HashAggregateExec {
94
94
/// same as input.schema() but for the final aggregate it will be the same as the input
95
95
/// to the partial aggregate
96
96
input_schema : SchemaRef ,
97
+ /// Metric to track number of output rows
98
+ output_rows : Arc < Mutex < SQLMetric > > ,
97
99
}
98
100
99
101
fn create_schema (
@@ -142,13 +144,19 @@ impl HashAggregateExec {
142
144
143
145
let schema = Arc :: new ( schema) ;
144
146
147
+ let output_rows = Arc :: new ( Mutex :: new ( SQLMetric :: new (
148
+ "outputRows" ,
149
+ MetricType :: Counter ,
150
+ ) ) ) ;
151
+
145
152
Ok ( HashAggregateExec {
146
153
mode,
147
154
group_expr,
148
155
aggr_expr,
149
156
input,
150
157
schema,
151
158
input_schema,
159
+ output_rows,
152
160
} )
153
161
}
154
162
@@ -223,6 +231,7 @@ impl ExecutionPlan for HashAggregateExec {
223
231
group_expr,
224
232
self . aggr_expr . clone ( ) ,
225
233
input,
234
+ self . output_rows . clone ( ) ,
226
235
) ) )
227
236
}
228
237
}
@@ -244,6 +253,15 @@ impl ExecutionPlan for HashAggregateExec {
244
253
) ) ,
245
254
}
246
255
}
256
+
257
+ fn metrics ( & self ) -> HashMap < String , SQLMetric > {
258
+ let mut metrics = HashMap :: new ( ) ;
259
+ metrics. insert (
260
+ "outputRows" . to_owned ( ) ,
261
+ self . output_rows . lock ( ) . unwrap ( ) . clone ( ) ,
262
+ ) ;
263
+ metrics
264
+ }
247
265
}
248
266
249
267
/*
@@ -277,6 +295,7 @@ pin_project! {
277
295
#[ pin]
278
296
output: futures:: channel:: oneshot:: Receiver <ArrowResult <RecordBatch >>,
279
297
finished: bool ,
298
+ output_rows: Arc <Mutex <SQLMetric >>,
280
299
}
281
300
}
282
301
@@ -628,6 +647,7 @@ impl GroupedHashAggregateStream {
628
647
group_expr : Vec < Arc < dyn PhysicalExpr > > ,
629
648
aggr_expr : Vec < Arc < dyn AggregateExpr > > ,
630
649
input : SendableRecordBatchStream ,
650
+ output_rows : Arc < Mutex < SQLMetric > > ,
631
651
) -> Self {
632
652
let ( tx, rx) = futures:: channel:: oneshot:: channel ( ) ;
633
653
@@ -648,6 +668,7 @@ impl GroupedHashAggregateStream {
648
668
schema,
649
669
output : rx,
650
670
finished : false ,
671
+ output_rows,
651
672
}
652
673
}
653
674
}
@@ -667,6 +688,8 @@ impl Stream for GroupedHashAggregateStream {
667
688
return Poll :: Ready ( None ) ;
668
689
}
669
690
691
+ let output_rows = self . output_rows . clone ( ) ;
692
+
670
693
// is the output ready?
671
694
let this = self . project ( ) ;
672
695
let output_poll = this. output . poll ( cx) ;
@@ -680,6 +703,12 @@ impl Stream for GroupedHashAggregateStream {
680
703
Err ( e) => Err ( ArrowError :: ExternalError ( Box :: new ( e) ) ) , // error receiving
681
704
Ok ( result) => result,
682
705
} ;
706
+
707
+ if let Ok ( batch) = & result {
708
+ let mut output_rows = output_rows. lock ( ) . unwrap ( ) ;
709
+ output_rows. add ( batch. num_rows ( ) )
710
+ }
711
+
683
712
Poll :: Ready ( Some ( result) )
684
713
}
685
714
Poll :: Pending => Poll :: Pending ,
@@ -1255,6 +1284,11 @@ mod tests {
1255
1284
] ;
1256
1285
1257
1286
assert_batches_sorted_eq ! ( & expected, & result) ;
1287
+
1288
+ let metrics = merged_aggregate. metrics ( ) ;
1289
+ let output_rows = metrics. get ( "outputRows" ) . unwrap ( ) ;
1290
+ assert_eq ! ( 3 , output_rows. value( ) ) ;
1291
+
1258
1292
Ok ( ( ) )
1259
1293
}
1260
1294
0 commit comments