diff --git a/optd-datafusion-repr/src/lib.rs b/optd-datafusion-repr/src/lib.rs index 2c821bf6..3e86bfde 100644 --- a/optd-datafusion-repr/src/lib.rs +++ b/optd-datafusion-repr/src/lib.rs @@ -22,7 +22,7 @@ use properties::{ schema::{Catalog, SchemaPropertyBuilder}, }; use rules::{ - EliminateDuplicatedAggExprRule, EliminateDuplicatedSortExprRule, EliminateFilterRule, EliminateJoinRule, EliminateLimitRule, FilterAggTransposeRule, FilterCrossJoinTransposeRule, FilterInnerJoinTransposeRule, FilterMergeRule, FilterProjectTransposeRule, FilterSortTransposeRule, HashJoinRule, JoinAssocRule, JoinCommuteRule, PhysicalConversionRule, ProjectFilterTransposeRule, ProjectMergeRule, ProjectionPullUpJoin, ProjectionPushDownJoin, SimplifyFilterRule, SimplifyJoinCondRule + EliminateDuplicatedAggExprRule, EliminateDuplicatedSortExprRule, EliminateFilterRule, EliminateJoinRule, EliminateLimitRule, FilterAggTransposeRule, FilterCrossJoinTransposeRule, FilterInnerJoinTransposeRule, FilterMergeRule, FilterProjectTransposeRule, FilterSortTransposeRule, HashJoinRule, JoinAssocRule, JoinCommuteRule, PhysicalConversionRule, ProjectFilterTransposeRule, ProjectMergeRule, ProjectRemoveRule, ProjectionPullUpJoin, ProjectionPushDownJoin, SimplifyFilterRule, SimplifyJoinCondRule }; pub use optd_core::rel_node::Value; @@ -85,6 +85,7 @@ impl DatafusionOptimizer { Arc::new(EliminateDuplicatedAggExprRule::new()), Arc::new(ProjectMergeRule::new()), Arc::new(FilterMergeRule::new()), + Arc::new(ProjectRemoveRule::new()), ] } diff --git a/optd-datafusion-repr/src/rules.rs b/optd-datafusion-repr/src/rules.rs index 9e5c0f09..4e3aee48 100644 --- a/optd-datafusion-repr/src/rules.rs +++ b/optd-datafusion-repr/src/rules.rs @@ -24,4 +24,5 @@ pub use project_transpose::{ project_filter_transpose::{FilterProjectTransposeRule, ProjectFilterTransposeRule}, project_join_transpose::{ProjectionPullUpJoin, ProjectionPushDownJoin}, project_merge::ProjectMergeRule, + project_remove::ProjectRemoveRule, }; diff --git a/optd-datafusion-repr/src/rules/project_transpose.rs b/optd-datafusion-repr/src/rules/project_transpose.rs index 5c4f45bb..27c251c1 100644 --- a/optd-datafusion-repr/src/rules/project_transpose.rs +++ b/optd-datafusion-repr/src/rules/project_transpose.rs @@ -1,4 +1,5 @@ pub mod project_filter_transpose; pub mod project_join_transpose; pub mod project_merge; +pub mod project_remove; pub mod project_transpose_common; diff --git a/optd-datafusion-repr/src/rules/project_transpose/project_remove.rs b/optd-datafusion-repr/src/rules/project_transpose/project_remove.rs index 5490307d..16231ff9 100644 --- a/optd-datafusion-repr/src/rules/project_transpose/project_remove.rs +++ b/optd-datafusion-repr/src/rules/project_transpose/project_remove.rs @@ -1,2 +1,83 @@ // intended to remove a projection that outputs the same num of cols -// that are in scan node \ No newline at end of file +// that are in scan node +use std::collections::HashMap; + +use optd_core::rules::{Rule, RuleMatcher}; +use optd_core::{optimizer::Optimizer, rel_node::RelNode}; + +use crate::plan_nodes::{ColumnRefExpr, ExprList, OptRelNode, OptRelNodeTyp, PlanNode}; +use crate::properties::schema::SchemaPropertyBuilder; +use crate::rules::macros::define_rule; + +// Proj (Scan A) -> Scan A +// removes projections +// TODO: need to somehow match on just scan node instead +// only works in hueristic optimizer (which may be ok) +// ideally include a pass after for physical proj -> physical scan +define_rule!( + ProjectRemoveRule, + apply_projection_remove, + (Projection, child, [exprs]) +); + +fn apply_projection_remove( + optimizer: &impl Optimizer, + ProjectRemoveRulePicks { + child, + exprs + }: ProjectRemoveRulePicks, +) -> Vec> { + let child_schema = optimizer.get_property::(child.clone().into(), 0); + let child = PlanNode::from_group(child.into()); + if child.typ() != OptRelNodeTyp::Scan { + return vec![]; + } + let exprs = ExprList::from_rel_node(exprs.into()).unwrap().to_vec(); + if exprs.len() != child_schema.len() { + return vec![]; + } + let mut exp_col_idx: usize = 0; + for expr in exprs { + let col_ref = ColumnRefExpr::from_rel_node(expr.into_rel_node()).unwrap(); + let col_idx = col_ref.index(); + if exp_col_idx != col_idx { + return vec![]; + } + exp_col_idx += 1; + } + vec![child.into_rel_node().as_ref().clone()] +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use optd_core::optimizer::Optimizer; + + use crate::{ + plan_nodes::{ + ColumnRefExpr, ExprList, LogicalProjection, LogicalScan, OptRelNode, OptRelNodeTyp, + }, + rules::ProjectRemoveRule, + testing::new_test_optimizer, + }; + + #[test] + fn proj_scan_basic() { + // convert proj -> scan to scan + let mut test_optimizer = new_test_optimizer(Arc::new(ProjectRemoveRule::new())); + + let scan = LogicalScan::new("region".into()); + + let proj_exprs = ExprList::new(vec![ + ColumnRefExpr::new(0).into_expr(), + ColumnRefExpr::new(1).into_expr(), + ColumnRefExpr::new(2).into_expr(), + ]); + + let proj_node: LogicalProjection = LogicalProjection::new(scan.into_plan_node(), proj_exprs); + let plan = test_optimizer.optimize(proj_node.into_rel_node()).unwrap(); + + assert_eq!(plan.typ, OptRelNodeTyp::Scan); + } +} diff --git a/optd-sqlplannertest/tests/basic_nodes.planner.sql b/optd-sqlplannertest/tests/basic_nodes.planner.sql index 301f300e..d9d880c5 100644 --- a/optd-sqlplannertest/tests/basic_nodes.planner.sql +++ b/optd-sqlplannertest/tests/basic_nodes.planner.sql @@ -19,8 +19,7 @@ LogicalLimit { skip: 0(u64), fetch: 1(u64) } └── LogicalProjection { exprs: [ #0, #1 ] } └── LogicalScan { table: t1 } PhysicalLimit { skip: 0(u64), fetch: 1(u64) } -└── PhysicalProjection { exprs: [ #0, #1 ] } - └── PhysicalScan { table: t1 } +└── PhysicalScan { table: t1 } 0 0 0 0 1 1 diff --git a/optd-sqlplannertest/tests/eliminate_duplicated_expr.planner.sql b/optd-sqlplannertest/tests/eliminate_duplicated_expr.planner.sql index b31e774f..8c595a05 100644 --- a/optd-sqlplannertest/tests/eliminate_duplicated_expr.planner.sql +++ b/optd-sqlplannertest/tests/eliminate_duplicated_expr.planner.sql @@ -12,8 +12,7 @@ select * from t1; /* LogicalProjection { exprs: [ #0, #1 ] } └── LogicalScan { table: t1 } -PhysicalProjection { exprs: [ #0, #1 ] } -└── PhysicalScan { table: t1 } +PhysicalScan { table: t1 } 0 0 1 1 5 2 @@ -45,8 +44,7 @@ PhysicalSort │ │ └── #0 │ └── SortOrder { order: Asc } │ └── #1 -└── PhysicalProjection { exprs: [ #0, #1 ] } - └── PhysicalScan { table: t1 } +└── PhysicalScan { table: t1 } 0 0 0 2 1 1 diff --git a/optd-sqlplannertest/tests/filter.planner.sql b/optd-sqlplannertest/tests/filter.planner.sql index 8ba252fa..f2ad8484 100644 --- a/optd-sqlplannertest/tests/filter.planner.sql +++ b/optd-sqlplannertest/tests/filter.planner.sql @@ -27,8 +27,7 @@ select * from t1 where true; LogicalProjection { exprs: [ #0, #1 ] } └── LogicalFilter { cond: true } └── LogicalScan { table: t1 } -PhysicalProjection { exprs: [ #0, #1 ] } -└── PhysicalScan { table: t1 } +PhysicalScan { table: t1 } 0 0 1 1 2 2 diff --git a/optd-sqlplannertest/tests/tpch.planner.sql b/optd-sqlplannertest/tests/tpch.planner.sql index 8bf88051..710e4a9a 100644 --- a/optd-sqlplannertest/tests/tpch.planner.sql +++ b/optd-sqlplannertest/tests/tpch.planner.sql @@ -384,8 +384,7 @@ PhysicalLimit { skip: 0(u64), fetch: 100(u64) } │ │ │ │ │ └── PhysicalScan { table: part } │ │ │ │ └── PhysicalProjection { exprs: [ #0, #1, #3 ] } │ │ │ │ └── PhysicalScan { table: partsupp } - │ │ │ └── PhysicalProjection { exprs: [ #0, #1, #2, #3, #4, #5, #6 ] } - │ │ │ └── PhysicalScan { table: supplier } + │ │ │ └── PhysicalScan { table: supplier } │ │ └── PhysicalProjection { exprs: [ #0, #1, #2 ] } │ │ └── PhysicalScan { table: nation } │ └── PhysicalProjection { exprs: [ #0 ] } diff --git a/optd-sqlplannertest/tests/verbose.planner.sql b/optd-sqlplannertest/tests/verbose.planner.sql index 910a1663..391306be 100644 --- a/optd-sqlplannertest/tests/verbose.planner.sql +++ b/optd-sqlplannertest/tests/verbose.planner.sql @@ -10,16 +10,14 @@ insert into t1 values (0), (1), (2), (3); select * from t1; /* -PhysicalProjection { exprs: [ #0 ] } -└── PhysicalScan { table: t1 } +PhysicalScan { table: t1 } */ -- Test verbose explain select * from t1; /* -PhysicalProjection { exprs: [ #0 ], cost: weighted=1.06,row_cnt=1.00,compute=0.06,io=1.00 } -└── PhysicalScan { table: t1, cost: weighted=1.00,row_cnt=1.00,compute=0.00,io=1.00 } +PhysicalScan { table: t1, cost: weighted=1.00,row_cnt=1.00,compute=0.00,io=1.00 } */ -- Test verbose explain with aggregation