Skip to content

Commit

Permalink
Add example for writing SQL analysis using DataFusion structures (#10938
Browse files Browse the repository at this point in the history
)

* sql analysis example

* update examples readme

* update comments

* Update datafusion-examples/examples/sql_analysis.rs

Co-authored-by: Andrew Lamb <[email protected]>

* apply feedback

* Run tapelo to fix Cargo.toml formatting

* Tweak comments

---------

Co-authored-by: Andrew Lamb <[email protected]>
  • Loading branch information
LorrensP-2158466 and alamb authored Jun 18, 2024
1 parent 80f4322 commit 5cb1917
Show file tree
Hide file tree
Showing 3 changed files with 311 additions and 0 deletions.
1 change: 1 addition & 0 deletions datafusion-examples/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ prost-derive = { version = "0.12", default-features = false }
serde = { version = "1.0.136", features = ["derive"] }
serde_json = { workspace = true }
tempfile = { workspace = true }
test-utils = { path = "../test-utils" }
tokio = { workspace = true, features = ["rt-multi-thread", "parking_lot"] }
tonic = "0.11"
url = { workspace = true }
Expand Down
1 change: 1 addition & 0 deletions datafusion-examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ cargo run --example csv_sql
- [`simple_udaf.rs`](examples/simple_udaf.rs): Define and invoke a User Defined Aggregate Function (UDAF)
- [`simple_udf.rs`](examples/simple_udf.rs): Define and invoke a User Defined Scalar Function (UDF)
- [`simple_udfw.rs`](examples/simple_udwf.rs): Define and invoke a User Defined Window Function (UDWF)
- [`sql_analysis.rs`](examples/sql_analysis.rs): Analyse SQL queries with DataFusion structures
- [`sql_dialect.rs`](examples/sql_dialect.rs): Example of implementing a custom SQL dialect on top of `DFParser`
- [`to_char.rs`](examples/to_char.rs): Examples of using the to_char function
- [`to_timestamp.rs`](examples/to_timestamp.rs): Examples of using to_timestamp functions
Expand Down
309 changes: 309 additions & 0 deletions datafusion-examples/examples/sql_analysis.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,309 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! This example shows how to use the structures that DataFusion provides to perform
//! Analysis on SQL queries and their plans.
//!
//! As a motivating example, we show how to count the number of JOINs in a query
//! as well as how many join tree's there are with their respective join count

use std::sync::Arc;

use datafusion::common::Result;
use datafusion::{
datasource::MemTable,
execution::context::{SessionConfig, SessionContext},
};
use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
use datafusion_expr::LogicalPlan;
use test_utils::tpcds::tpcds_schemas;

/// Counts the total number of joins in a plan
fn total_join_count(plan: &LogicalPlan) -> usize {
let mut total = 0;

// We can use the TreeNode API to walk over a LogicalPlan.
plan.apply(|node| {
// if we encounter a join we update the running count
if matches!(node, LogicalPlan::Join(_) | LogicalPlan::CrossJoin(_)) {
total += 1;
}
Ok(TreeNodeRecursion::Continue)
})
.unwrap();

total
}

/// Counts the total number of joins in a plan and collects every join tree in
/// the plan with their respective join count.
///
/// Join Tree Definition: the largest subtree consisting entirely of joins
///
/// For example, this plan:
///
/// ```text
/// JOIN
/// / \
/// A JOIN
/// / \
/// B C
/// ```
///
/// has a single join tree `(A-B-C)` which will result in `(2, [2])`
///
/// This plan:
///
/// ```text
/// JOIN
/// / \
/// A GROUP
/// |
/// JOIN
/// / \
/// B C
/// ```
///
/// Has two join trees `(A-, B-C)` which will result in `(2, [1, 1])`
fn count_trees(plan: &LogicalPlan) -> (usize, Vec<usize>) {
// this works the same way as `total_count`, but now when we encounter a Join
// we try to collect it's entire tree
let mut to_visit = vec![plan];
let mut total = 0;
let mut groups = vec![];

while let Some(node) = to_visit.pop() {
// if we encouter a join, we know were at the root of the tree
// count this tree and recurse on it's inputs
if matches!(node, LogicalPlan::Join(_) | LogicalPlan::CrossJoin(_)) {
let (group_count, inputs) = count_tree(node);
total += group_count;
groups.push(group_count);
to_visit.extend(inputs);
} else {
to_visit.extend(node.inputs());
}
}

(total, groups)
}

/// Count the entire join tree and return its inputs using TreeNode API
///
/// For example, if this function receives following plan:
///
/// ```text
/// JOIN
/// / \
/// A GROUP
/// |
/// JOIN
/// / \
/// B C
/// ```
///
/// It will return `(1, [A, GROUP])`
fn count_tree(join: &LogicalPlan) -> (usize, Vec<&LogicalPlan>) {
let mut inputs = Vec::new();
let mut total = 0;

join.apply(|node| {
// Some extra knowledge:
//
// optimized plans have their projections pushed down as far as
// possible, which sometimes results in a projection going in between 2
// subsequent joins giving the illusion these joins are not "related",
// when in fact they are.
//
// This plan:
// JOIN
// / \
// A PROJECTION
// |
// JOIN
// / \
// B C
//
// is the same as:
//
// JOIN
// / \
// A JOIN
// / \
// B C
// we can continue the recursion in this case
if let LogicalPlan::Projection(_) = node {
return Ok(TreeNodeRecursion::Continue);
}

// any join we count
if matches!(node, LogicalPlan::Join(_) | LogicalPlan::CrossJoin(_)) {
total += 1;
Ok(TreeNodeRecursion::Continue)
} else {
inputs.push(node);
// skip children of input node
Ok(TreeNodeRecursion::Jump)
}
})
.unwrap();

(total, inputs)
}

#[tokio::main]
async fn main() -> Result<()> {
// To show how we can count the joins in a sql query we'll be using query 88
// from the TPC-DS benchmark.
//
// q8 has many joins, cross-joins and multiple join-trees, perfect for our
// example:

let tpcds_query_88 = "
select *
from
(select count(*) h8_30_to_9
from store_sales, household_demographics , time_dim, store
where ss_sold_time_sk = time_dim.t_time_sk
and ss_hdemo_sk = household_demographics.hd_demo_sk
and ss_store_sk = s_store_sk
and time_dim.t_hour = 8
and time_dim.t_minute >= 30
and ((household_demographics.hd_dep_count = 3 and household_demographics.hd_vehicle_count<=3+2) or
(household_demographics.hd_dep_count = 0 and household_demographics.hd_vehicle_count<=0+2) or
(household_demographics.hd_dep_count = 1 and household_demographics.hd_vehicle_count<=1+2))
and store.s_store_name = 'ese') s1,
(select count(*) h9_to_9_30
from store_sales, household_demographics , time_dim, store
where ss_sold_time_sk = time_dim.t_time_sk
and ss_hdemo_sk = household_demographics.hd_demo_sk
and ss_store_sk = s_store_sk
and time_dim.t_hour = 9
and time_dim.t_minute < 30
and ((household_demographics.hd_dep_count = 3 and household_demographics.hd_vehicle_count<=3+2) or
(household_demographics.hd_dep_count = 0 and household_demographics.hd_vehicle_count<=0+2) or
(household_demographics.hd_dep_count = 1 and household_demographics.hd_vehicle_count<=1+2))
and store.s_store_name = 'ese') s2,
(select count(*) h9_30_to_10
from store_sales, household_demographics , time_dim, store
where ss_sold_time_sk = time_dim.t_time_sk
and ss_hdemo_sk = household_demographics.hd_demo_sk
and ss_store_sk = s_store_sk
and time_dim.t_hour = 9
and time_dim.t_minute >= 30
and ((household_demographics.hd_dep_count = 3 and household_demographics.hd_vehicle_count<=3+2) or
(household_demographics.hd_dep_count = 0 and household_demographics.hd_vehicle_count<=0+2) or
(household_demographics.hd_dep_count = 1 and household_demographics.hd_vehicle_count<=1+2))
and store.s_store_name = 'ese') s3,
(select count(*) h10_to_10_30
from store_sales, household_demographics , time_dim, store
where ss_sold_time_sk = time_dim.t_time_sk
and ss_hdemo_sk = household_demographics.hd_demo_sk
and ss_store_sk = s_store_sk
and time_dim.t_hour = 10
and time_dim.t_minute < 30
and ((household_demographics.hd_dep_count = 3 and household_demographics.hd_vehicle_count<=3+2) or
(household_demographics.hd_dep_count = 0 and household_demographics.hd_vehicle_count<=0+2) or
(household_demographics.hd_dep_count = 1 and household_demographics.hd_vehicle_count<=1+2))
and store.s_store_name = 'ese') s4,
(select count(*) h10_30_to_11
from store_sales, household_demographics , time_dim, store
where ss_sold_time_sk = time_dim.t_time_sk
and ss_hdemo_sk = household_demographics.hd_demo_sk
and ss_store_sk = s_store_sk
and time_dim.t_hour = 10
and time_dim.t_minute >= 30
and ((household_demographics.hd_dep_count = 3 and household_demographics.hd_vehicle_count<=3+2) or
(household_demographics.hd_dep_count = 0 and household_demographics.hd_vehicle_count<=0+2) or
(household_demographics.hd_dep_count = 1 and household_demographics.hd_vehicle_count<=1+2))
and store.s_store_name = 'ese') s5,
(select count(*) h11_to_11_30
from store_sales, household_demographics , time_dim, store
where ss_sold_time_sk = time_dim.t_time_sk
and ss_hdemo_sk = household_demographics.hd_demo_sk
and ss_store_sk = s_store_sk
and time_dim.t_hour = 11
and time_dim.t_minute < 30
and ((household_demographics.hd_dep_count = 3 and household_demographics.hd_vehicle_count<=3+2) or
(household_demographics.hd_dep_count = 0 and household_demographics.hd_vehicle_count<=0+2) or
(household_demographics.hd_dep_count = 1 and household_demographics.hd_vehicle_count<=1+2))
and store.s_store_name = 'ese') s6,
(select count(*) h11_30_to_12
from store_sales, household_demographics , time_dim, store
where ss_sold_time_sk = time_dim.t_time_sk
and ss_hdemo_sk = household_demographics.hd_demo_sk
and ss_store_sk = s_store_sk
and time_dim.t_hour = 11
and time_dim.t_minute >= 30
and ((household_demographics.hd_dep_count = 3 and household_demographics.hd_vehicle_count<=3+2) or
(household_demographics.hd_dep_count = 0 and household_demographics.hd_vehicle_count<=0+2) or
(household_demographics.hd_dep_count = 1 and household_demographics.hd_vehicle_count<=1+2))
and store.s_store_name = 'ese') s7,
(select count(*) h12_to_12_30
from store_sales, household_demographics , time_dim, store
where ss_sold_time_sk = time_dim.t_time_sk
and ss_hdemo_sk = household_demographics.hd_demo_sk
and ss_store_sk = s_store_sk
and time_dim.t_hour = 12
and time_dim.t_minute < 30
and ((household_demographics.hd_dep_count = 3 and household_demographics.hd_vehicle_count<=3+2) or
(household_demographics.hd_dep_count = 0 and household_demographics.hd_vehicle_count<=0+2) or
(household_demographics.hd_dep_count = 1 and household_demographics.hd_vehicle_count<=1+2))
and store.s_store_name = 'ese') s8;";

// first set up the config
let config = SessionConfig::default();
let ctx = SessionContext::new_with_config(config);

// register the tables of the TPC-DS query
let tables = tpcds_schemas();
for table in tables {
ctx.register_table(
table.name,
Arc::new(MemTable::try_new(Arc::new(table.schema.clone()), vec![])?),
)?;
}
// We can create a LogicalPlan from a SQL query like this
let logical_plan = ctx.sql(tpcds_query_88).await?.into_optimized_plan()?;

println!(
"Optimized Logical Plan:\n\n{}\n",
logical_plan.display_indent()
);
// we can get the total count (query 88 has 31 joins: 7 CROSS joins and 24 INNER joins => 40 input relations)
let total_join_count = total_join_count(&logical_plan);
assert_eq!(31, total_join_count);

println!("The plan has {total_join_count} joins.");

// Furthermore the 24 inner joins are 8 groups of 3 joins with the 7
// cross-joins combining them we can get these groups using the
// `count_trees` method
let (total_join_count, trees) = count_trees(&logical_plan);
assert_eq!(
(total_join_count, &trees),
// query 88 is very straightforward, we know the cross-join group is at
// the top of the plan followed by the INNER joins
(31, &vec![7, 3, 3, 3, 3, 3, 3, 3, 3])
);

println!(
"And following join-trees (number represents join amount in tree): {trees:?}"
);

Ok(())
}

0 comments on commit 5cb1917

Please sign in to comment.