Skip to content

Commit 87f9950

Browse files
andygroveGeorgeAp
authored andcommitted
ARROW-12402: [Rust] [DataFusion] Implement SQL metrics example
This introduces a new method on `ExecutionPlan` to be able to access generic metrics from any physical operator. One metric is implemented to demonstrate usage. Closes apache#10049 from andygrove/ARROW-12402 Authored-by: Andy Grove <[email protected]> Signed-off-by: Krisztián Szűcs <[email protected]>
1 parent 8e9199e commit 87f9950

File tree

2 files changed

+82
-2
lines changed

2 files changed

+82
-2
lines changed

rust/datafusion/src/physical_plan/hash_aggregate.rs

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
//! Defines the execution plan for the hash aggregate operation
1919
2020
use std::any::Any;
21-
use std::sync::Arc;
21+
use std::sync::{Arc, Mutex};
2222
use std::task::{Context, Poll};
2323

2424
use ahash::RandomState;
@@ -28,7 +28,7 @@ use futures::{
2828
};
2929

3030
use crate::error::{DataFusionError, Result};
31-
use crate::physical_plan::{Accumulator, AggregateExpr};
31+
use crate::physical_plan::{Accumulator, AggregateExpr, MetricType, SQLMetric};
3232
use crate::physical_plan::{Distribution, ExecutionPlan, Partitioning, PhysicalExpr};
3333

3434
use arrow::{
@@ -94,6 +94,8 @@ pub struct HashAggregateExec {
9494
/// same as input.schema() but for the final aggregate it will be the same as the input
9595
/// to the partial aggregate
9696
input_schema: SchemaRef,
97+
/// Metric to track number of output rows
98+
output_rows: Arc<Mutex<SQLMetric>>,
9799
}
98100

99101
fn create_schema(
@@ -142,13 +144,19 @@ impl HashAggregateExec {
142144

143145
let schema = Arc::new(schema);
144146

147+
let output_rows = Arc::new(Mutex::new(SQLMetric::new(
148+
"outputRows",
149+
MetricType::Counter,
150+
)));
151+
145152
Ok(HashAggregateExec {
146153
mode,
147154
group_expr,
148155
aggr_expr,
149156
input,
150157
schema,
151158
input_schema,
159+
output_rows,
152160
})
153161
}
154162

@@ -223,6 +231,7 @@ impl ExecutionPlan for HashAggregateExec {
223231
group_expr,
224232
self.aggr_expr.clone(),
225233
input,
234+
self.output_rows.clone(),
226235
)))
227236
}
228237
}
@@ -244,6 +253,15 @@ impl ExecutionPlan for HashAggregateExec {
244253
)),
245254
}
246255
}
256+
257+
fn metrics(&self) -> HashMap<String, SQLMetric> {
258+
let mut metrics = HashMap::new();
259+
metrics.insert(
260+
"outputRows".to_owned(),
261+
self.output_rows.lock().unwrap().clone(),
262+
);
263+
metrics
264+
}
247265
}
248266

249267
/*
@@ -277,6 +295,7 @@ pin_project! {
277295
#[pin]
278296
output: futures::channel::oneshot::Receiver<ArrowResult<RecordBatch>>,
279297
finished: bool,
298+
output_rows: Arc<Mutex<SQLMetric>>,
280299
}
281300
}
282301

@@ -628,6 +647,7 @@ impl GroupedHashAggregateStream {
628647
group_expr: Vec<Arc<dyn PhysicalExpr>>,
629648
aggr_expr: Vec<Arc<dyn AggregateExpr>>,
630649
input: SendableRecordBatchStream,
650+
output_rows: Arc<Mutex<SQLMetric>>,
631651
) -> Self {
632652
let (tx, rx) = futures::channel::oneshot::channel();
633653

@@ -648,6 +668,7 @@ impl GroupedHashAggregateStream {
648668
schema,
649669
output: rx,
650670
finished: false,
671+
output_rows,
651672
}
652673
}
653674
}
@@ -667,6 +688,8 @@ impl Stream for GroupedHashAggregateStream {
667688
return Poll::Ready(None);
668689
}
669690

691+
let output_rows = self.output_rows.clone();
692+
670693
// is the output ready?
671694
let this = self.project();
672695
let output_poll = this.output.poll(cx);
@@ -680,6 +703,12 @@ impl Stream for GroupedHashAggregateStream {
680703
Err(e) => Err(ArrowError::ExternalError(Box::new(e))), // error receiving
681704
Ok(result) => result,
682705
};
706+
707+
if let Ok(batch) = &result {
708+
let mut output_rows = output_rows.lock().unwrap();
709+
output_rows.add(batch.num_rows())
710+
}
711+
683712
Poll::Ready(Some(result))
684713
}
685714
Poll::Pending => Poll::Pending,
@@ -1255,6 +1284,11 @@ mod tests {
12551284
];
12561285

12571286
assert_batches_sorted_eq!(&expected, &result);
1287+
1288+
let metrics = merged_aggregate.metrics();
1289+
let output_rows = metrics.get("outputRows").unwrap();
1290+
assert_eq!(3, output_rows.value());
1291+
12581292
Ok(())
12591293
}
12601294

rust/datafusion/src/physical_plan/mod.rs

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ use async_trait::async_trait;
3333
use futures::stream::Stream;
3434

3535
use self::merge::MergeExec;
36+
use hashbrown::HashMap;
3637

3738
/// Trait for types that stream [arrow::record_batch::RecordBatch]
3839
pub trait RecordBatchStream: Stream<Item = ArrowResult<RecordBatch>> {
@@ -46,6 +47,46 @@ pub trait RecordBatchStream: Stream<Item = ArrowResult<RecordBatch>> {
4647
/// Trait for a stream of record batches.
4748
pub type SendableRecordBatchStream = Pin<Box<dyn RecordBatchStream + Send + Sync>>;
4849

50+
/// SQL metric type
51+
#[derive(Debug, Clone)]
52+
pub enum MetricType {
53+
/// Simple counter
54+
Counter,
55+
}
56+
57+
/// SQL metric such as counter (number of input or output rows) or timing information about
58+
/// a physical operator.
59+
#[derive(Debug, Clone)]
60+
pub struct SQLMetric {
61+
/// Metric name
62+
name: String,
63+
/// Metric value
64+
value: usize,
65+
/// Metric type
66+
metric_type: MetricType,
67+
}
68+
69+
impl SQLMetric {
70+
/// Create a new SQLMetric
71+
pub fn new(name: &str, metric_type: MetricType) -> Self {
72+
Self {
73+
name: name.to_owned(),
74+
value: 0,
75+
metric_type,
76+
}
77+
}
78+
79+
/// Add to the value
80+
pub fn add(&mut self, n: usize) {
81+
self.value += n;
82+
}
83+
84+
/// Get the current value
85+
pub fn value(&self) -> usize {
86+
self.value
87+
}
88+
}
89+
4990
/// Physical query planner that converts a `LogicalPlan` to an
5091
/// `ExecutionPlan` suitable for execution.
5192
pub trait PhysicalPlanner {
@@ -84,6 +125,11 @@ pub trait ExecutionPlan: Debug + Send + Sync {
84125

85126
/// creates an iterator
86127
async fn execute(&self, partition: usize) -> Result<SendableRecordBatchStream>;
128+
129+
/// Return a snapshot of the metrics collected during execution
130+
fn metrics(&self) -> HashMap<String, SQLMetric> {
131+
HashMap::new()
132+
}
87133
}
88134

89135
/// Execute the [ExecutionPlan] and collect the results in memory

0 commit comments

Comments
 (0)