diff --git a/wren-modeling-rs/core/src/logical_plan/analyze/expand_view.rs b/wren-modeling-rs/core/src/logical_plan/analyze/expand_view.rs new file mode 100644 index 000000000..547bc82ba --- /dev/null +++ b/wren-modeling-rs/core/src/logical_plan/analyze/expand_view.rs @@ -0,0 +1,62 @@ +use crate::logical_plan::utils::belong_to_mdl; +use crate::mdl::utils::quoted; +use crate::mdl::{AnalyzedWrenMDL, SessionStateRef}; +use datafusion::common::tree_node::Transformed; +use datafusion::common::Result; +use datafusion::config::ConfigOptions; +use datafusion::logical_expr::{LogicalPlan, LogicalPlanBuilder}; +use datafusion::optimizer::AnalyzerRule; +use std::sync::Arc; + +pub struct ExpandWrenViewRule { + analyzed_wren_mdl: Arc, + session_state: SessionStateRef, +} + +impl ExpandWrenViewRule { + pub fn new( + analyzed_wren_mdl: Arc, + session_state: SessionStateRef, + ) -> Self { + Self { + analyzed_wren_mdl, + session_state, + } + } +} + +impl AnalyzerRule for ExpandWrenViewRule { + fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result { + let plan = plan + .transform_up_with_subqueries(|plan| match &plan { + LogicalPlan::TableScan(table_scan) => { + if belong_to_mdl( + &self.analyzed_wren_mdl.wren_mdl(), + table_scan.table_name.clone(), + Arc::clone(&self.session_state), + ) && self + .analyzed_wren_mdl + .wren_mdl() + .get_view(table_scan.table_name.table()) + .is_some() + { + if let Some(logical_plan) = table_scan.source.get_logical_plan() { + let subquery = LogicalPlanBuilder::from(logical_plan.clone()) + .alias(quoted(table_scan.table_name.table()))? + .build()?; + return Ok(Transformed::yes(subquery)); + } + } + Ok(Transformed::no(plan)) + } + _ => Ok(Transformed::no(plan)), + })? + .map_data(|plan| plan.recompute_schema())? + .data; + Ok(plan) + } + + fn name(&self) -> &str { + "ExpandWrenViewRule" + } +} diff --git a/wren-modeling-rs/core/src/logical_plan/analyze/mod.rs b/wren-modeling-rs/core/src/logical_plan/analyze/mod.rs index 11bee973a..283fd714d 100644 --- a/wren-modeling-rs/core/src/logical_plan/analyze/mod.rs +++ b/wren-modeling-rs/core/src/logical_plan/analyze/mod.rs @@ -1,3 +1,4 @@ +pub mod expand_view; pub mod model_anlayze; pub mod model_generation; pub mod plan; diff --git a/wren-modeling-rs/core/src/logical_plan/analyze/model_anlayze.rs b/wren-modeling-rs/core/src/logical_plan/analyze/model_anlayze.rs index 574b7a022..b1467dc6b 100644 --- a/wren-modeling-rs/core/src/logical_plan/analyze/model_anlayze.rs +++ b/wren-modeling-rs/core/src/logical_plan/analyze/model_anlayze.rs @@ -1,31 +1,62 @@ -use std::cell::{RefCell, RefMut}; -use std::collections::{HashMap, HashSet}; -use std::sync::Arc; - +use crate::logical_plan::analyze::plan::ModelPlanNode; +use crate::logical_plan::utils::{belong_to_mdl, expr_to_columns}; +use crate::mdl::utils::quoted; +use crate::mdl::{AnalyzedWrenMDL, Dataset, SessionStateRef}; use datafusion::catalog_common::TableReference; use datafusion::common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion::common::{Column, DFSchemaRef, Result}; +use datafusion::common::{internal_err, plan_err, Column, DFSchemaRef, Result}; use datafusion::config::ConfigOptions; +use datafusion::error::DataFusionError; use datafusion::logical_expr::expr::Alias; use datafusion::logical_expr::{ - col, ident, utils, Aggregate, Distinct, DistinctOn, Expr, Extension, Filter, Join, + col, ident, Aggregate, Distinct, DistinctOn, Expr, Extension, Filter, Join, LogicalPlan, LogicalPlanBuilder, Projection, Subquery, SubqueryAlias, TableScan, Window, }; use datafusion::optimizer::AnalyzerRule; - -use crate::logical_plan::analyze::plan::ModelPlanNode; -use crate::mdl::utils::quoted; -use crate::mdl::{AnalyzedWrenMDL, SessionStateRef, WrenMDL}; +use std::cell::{RefCell, RefMut}; +use std::collections::{HashMap, HashSet, VecDeque}; +use std::sync::Arc; /// [ModelAnalyzeRule] responsible for analyzing the model plan node. Turn TableScan from a model to a ModelPlanNode. /// We collect the required fields from the projection, filter, aggregation, and join, /// and pass them to the ModelPlanNode. +/// +/// There are three main steps in this rule: +/// 1. Analyze the scope of the logical plan and collect the required columns for models and visited tables. (button-up and depth-first) +/// 2. Analyze the model and generate the ModelPlanNode according to the scope analysis. (button-up and depth-first) +/// 3. Remove the catalog and schema prefix of Wren for the column and refresh the schema. (top-down) +/// +/// The traverse path of step 1 and step 2 should be same. +/// The corresponding scope will be pushed to or popped from the scope_queue sequentially. pub struct ModelAnalyzeRule { analyzed_wren_mdl: Arc, session_state: SessionStateRef, } +impl AnalyzerRule for ModelAnalyzeRule { + fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result { + let scope_queue = RefCell::new(VecDeque::new()); + let root = RefCell::new(Scope::new()); + self.analyze_scope(plan, &root, &scope_queue)? + .map_data(|plan| self.analyze_model(plan, &root, &scope_queue).data())? + .map_data(|plan| { + plan.transform_up_with_subqueries(&|plan| -> Result< + Transformed, + > { + self.remove_wren_catalog_schema_prefix_and_refresh_schema(plan) + }) + .data() + })? + .map_data(|plan| plan.recompute_schema()) + .data() + } + + fn name(&self) -> &str { + "ModelAnalyzeRule" + } +} + impl ModelAnalyzeRule { pub fn new( analyzed_wren_mdl: Arc, @@ -41,66 +72,307 @@ impl ModelAnalyzeRule { Arc::clone(&self.session_state) } - fn analyze_model_internal( + /// The goal of this function is to analyze the scope of the logical plan and collect the required columns for models and visited tables. + /// If the plan contains subquery, we should create a new child scope and analyze the subquery recursively. + /// After leaving the subquery, we should push(push_back) the child scope to the scope_queue. + fn analyze_scope( &self, plan: LogicalPlan, - analysis: &RefCell, + root: &RefCell, + scope_queue: &RefCell>>, ) -> Result> { - match plan { + plan.transform_up(&|plan| -> Result> { + let plan = self.analyze_scope_internal(plan, root)?.data; + plan.map_subqueries(|plan| { + if let LogicalPlan::Subquery(Subquery { + subquery, + outer_ref_columns, + }) = &plan + { + outer_ref_columns.iter().try_for_each(|expr| { + let mut scope_mut = root.borrow_mut(); + self.collect_required_column(expr.clone(), &mut scope_mut) + })?; + let child_scope = + RefCell::new(Scope::new_child(RefCell::clone(root))); + self.analyze_scope( + Arc::unwrap_or_clone(Arc::clone(subquery)), + &child_scope, + scope_queue, + )?; + let mut scope_queue = scope_queue.borrow_mut(); + scope_queue.push_back(child_scope); + } + Ok(Transformed::no(plan)) + }) + }) + } + + /// Collect the visited dataset and required columns + fn analyze_scope_internal( + &self, + plan: LogicalPlan, + scope: &RefCell, + ) -> Result> { + match &plan { + LogicalPlan::TableScan(table_scan) => { + if belong_to_mdl( + &self.analyzed_wren_mdl.wren_mdl(), + table_scan.table_name.clone(), + Arc::clone(&self.session_state), + ) { + let mut scope_mut = scope.borrow_mut(); + if let Some(model) = self + .analyzed_wren_mdl + .wren_mdl + .get_model(table_scan.table_name.table()) + { + scope_mut.add_visited_dataset( + table_scan.table_name.clone(), + Dataset::Model(model), + ); + } + scope_mut.add_visited_table(table_scan.table_name.clone()); + Ok(Transformed::no(plan)) + } else { + Ok(Transformed::no(plan)) + } + } + LogicalPlan::Join(Join { on, filter, .. }) => { + let mut scope_mut = scope.borrow_mut(); + let mut accum = HashSet::new(); + on.iter().try_for_each(|expr| { + expr_to_columns(&expr.0, &mut accum)?; + expr_to_columns(&expr.1, &mut accum)?; + Ok::<_, DataFusionError>(()) + })?; + if let Some(filter_expr) = &filter { + expr_to_columns(filter_expr, &mut accum)?; + } + accum.iter().try_for_each(|expr| { + self.collect_required_column( + Expr::Column(expr.clone()), + &mut scope_mut, + ) + })?; + Ok(Transformed::no(plan)) + } LogicalPlan::Projection(projection) => { - let mut analysis_mut = analysis.borrow_mut(); - let buffer = analysis_mut.required_columns_mut(); + let mut scope_mut = scope.borrow_mut(); projection.expr.iter().try_for_each(|expr| { let mut acuum = HashSet::new(); - utils::expr_to_columns(expr, &mut acuum)?; - acuum.iter().try_for_each(|expr| { - self.collect_column(Expr::Column(expr.clone()), buffer) + expr_to_columns(expr, &mut acuum)?; + acuum.into_iter().try_for_each(|expr| { + self.collect_required_column(Expr::Column(expr), &mut scope_mut) }) })?; - Ok(Transformed::no(LogicalPlan::Projection(projection))) + Ok(Transformed::no(plan)) } LogicalPlan::Filter(filter) => { + let mut scope_mut = scope.borrow_mut(); let mut acuum = HashSet::new(); - utils::expr_to_columns(&filter.predicate, &mut acuum)?; - let mut analysis_mut = analysis.borrow_mut(); - let buffer = analysis_mut.required_columns_mut(); - acuum.iter().try_for_each(|expr| { - self.collect_column(Expr::Column(expr.clone()), buffer) + expr_to_columns(&filter.predicate, &mut acuum)?; + acuum.into_iter().try_for_each(|expr| { + self.collect_required_column(Expr::Column(expr), &mut scope_mut) })?; - Ok(Transformed::no(LogicalPlan::Filter(filter))) + Ok(Transformed::no(plan)) } LogicalPlan::Aggregate(aggregate) => { - let mut analysis_mut = analysis.borrow_mut(); - let buffer = analysis_mut.required_columns_mut(); + let mut scope_mut = scope.borrow_mut(); let mut accum = HashSet::new(); - let _ = &aggregate.aggr_expr.iter().for_each(|expr| { + aggregate.aggr_expr.iter().for_each(|expr| { Expr::add_column_refs(expr, &mut accum); }); - let _ = &aggregate.group_expr.iter().for_each(|expr| { + aggregate.group_expr.iter().for_each(|expr| { Expr::add_column_refs(expr, &mut accum); }); accum.iter().try_for_each(|expr| { - self.collect_column(Expr::Column(expr.to_owned().clone()), buffer) + self.collect_required_column( + Expr::Column(expr.to_owned().clone()), + &mut scope_mut, + ) })?; - Ok(Transformed::no(LogicalPlan::Aggregate(aggregate))) + Ok(Transformed::no(plan)) } - LogicalPlan::Subquery(Subquery { - subquery, - outer_ref_columns, + LogicalPlan::SubqueryAlias(subquery_alias) => { + let mut scope_mut = scope.borrow_mut(); + if let LogicalPlan::TableScan(table_scan) = + Arc::unwrap_or_clone(Arc::clone(&subquery_alias.input)) + { + if belong_to_mdl( + &self.analyzed_wren_mdl.wren_mdl(), + table_scan.table_name.clone(), + Arc::clone(&self.session_state), + ) { + if let Some(model) = self + .analyzed_wren_mdl + .wren_mdl + .get_model(table_scan.table_name.table()) + { + scope_mut.add_visited_dataset( + subquery_alias.alias.clone(), + Dataset::Model(model), + ); + } + } + } + scope_mut.add_visited_table(subquery_alias.alias.clone()); + Ok(Transformed::no(plan)) + } + _ => Ok(Transformed::no(plan)), + } + } + + /// This function only collects the model required columns + fn collect_required_column( + &self, + expr: Expr, + scope: &mut RefMut, + ) -> Result<()> { + match expr { + Expr::Column(Column { + relation: Some(relation), + name, }) => { - let mut analysis_mut = analysis.borrow_mut(); - let buffer = analysis_mut.required_columns_mut(); - outer_ref_columns - .iter() - .try_for_each(|expr| self.collect_column(expr.clone(), buffer))?; - Ok(Transformed::no(LogicalPlan::Subquery(Subquery { - subquery, - outer_ref_columns, - }))) + // only collect the required column if the relation belongs to the mdl + if belong_to_mdl( + &self.analyzed_wren_mdl.wren_mdl(), + relation.clone(), + Arc::clone(&self.session_state), + ) && self + .analyzed_wren_mdl + .wren_mdl() + .get_view(relation.table()) + .is_none() + { + scope.add_required_column( + relation.clone(), + Expr::Column(Column::new(Some(relation), name)), + )?; + } } + // It is possible that the column is a rebase column from the aggregation or join + // e.g. Column { + // relation: None, + // name: "min(wrenai.public.order_items_model.price)", + // }, + Expr::Column(Column { relation: None, .. }) => { + // do nothing + } + Expr::OuterReferenceColumn(_, column) => { + self.collect_required_column(Expr::Column(column), scope)?; + } + _ => return plan_err!("Invalid column expression: {}", expr), + } + Ok(()) + } + + /// Analyze the table scan and rewrite the table scan to the ModelPlanNode according to the scope analysis. + /// If the plan contains subquery, we should analyze the subquery recursively. + /// Before enter the subquery, the corresponding child scope should be popped (pop_front) from the scope_queue. + fn analyze_model( + &self, + plan: LogicalPlan, + root: &RefCell, + scope_queue: &RefCell>>, + ) -> Result> { + plan.transform_up(&|plan| -> Result> { + let plan = self.analyze_model_internal(plan, root, scope_queue)?.data; + // If the plan contains subquery, we should analyze the subquery recursively + plan.map_subqueries(|plan| { + if let LogicalPlan::Subquery(subquery) = &plan { + let mut scope_queue_mut = scope_queue.borrow_mut(); + let Some(child_scope) = scope_queue_mut.pop_front() else { + return internal_err!("No child scope found for subquery"); + }; + let transformed = self + .analyze_model( + Arc::unwrap_or_clone(Arc::clone(&subquery.subquery)), + &child_scope, + scope_queue, + )? + .data; + return Ok(Transformed::yes(LogicalPlan::Subquery( + subquery.with_plan(Arc::new(transformed)), + ))); + } + Ok(Transformed::no(plan)) + }) + }) + } + + /// Analyze the model and generate the ModelPlanNode + fn analyze_model_internal( + &self, + plan: LogicalPlan, + scope: &RefCell, + scope_queue: &RefCell>>, + ) -> Result> { + match plan { LogicalPlan::SubqueryAlias(SubqueryAlias { input, alias, .. }) => { - let mut analysis_mut = analysis.borrow_mut(); - match Arc::unwrap_or_clone(input) { + // Because the bottom-up transformation is used, the table_scan is already transformed + // to the ModelPlanNode before the SubqueryAlias. We should check the patten of Wren-generated model plan like: + // SubqueryAlias -> SubqueryAlias -> Extension -> ModelPlanNode + // to get the correct required columns + match Arc::unwrap_or_clone(Arc::clone(&input)) { + LogicalPlan::SubqueryAlias(SubqueryAlias { input, .. }) => { + if let LogicalPlan::Extension(Extension { node }) = + Arc::unwrap_or_clone(Arc::clone(&input)) + { + if let Some(model_node) = + node.as_any().downcast_ref::() + { + if let Some(model) = self + .analyzed_wren_mdl + .wren_mdl() + .get_model(model_node.plan_name()) + { + let scope = scope.borrow(); + let field: Vec = if let Some(used_columns) = + scope.try_get_required_columns(&alias) + { + used_columns.iter().cloned().collect() + } else { + // If the required columns are not found in the current scope but the table is visited, + // it could be a count(*) query + if scope.try_get_visited_dataset(&alias).is_none() + { + return internal_err!( + "Table {} not found in the visited dataset and required columns map", + alias); + }; + vec![] + }; + let model_plan = LogicalPlan::Extension(Extension { + node: Arc::new(ModelPlanNode::new( + Arc::clone(&model), + field, + None, + Arc::clone(&self.analyzed_wren_mdl), + Arc::clone(&self.session_state), + )?), + }); + let subquery = LogicalPlanBuilder::from(model_plan) + .alias(alias)? + .build()?; + Ok(Transformed::yes(subquery)) + } else { + internal_err!( + "Model {} not found in the WrenMDL", + model_node.plan_name() + ) + } + } else { + internal_err!( + "ModelPlanNode not found in the Extension node" + ) + } + } else { + Ok(Transformed::no(LogicalPlan::SubqueryAlias( + SubqueryAlias::try_new(input, alias)?, + ))) + } + } LogicalPlan::TableScan(table_scan) => { let model_plan = self .analyze_table_scan( @@ -108,15 +380,15 @@ impl ModelAnalyzeRule { Arc::clone(&self.session_state), table_scan, Some(alias.clone()), - &mut analysis_mut, + scope, )? .data; - Ok(Transformed::yes(LogicalPlan::SubqueryAlias( - SubqueryAlias::try_new(Arc::new(model_plan), alias)?, - ))) + let subquery = + LogicalPlanBuilder::from(model_plan).alias(alias)?.build()?; + Ok(Transformed::yes(subquery)) } - ignore => Ok(Transformed::no(LogicalPlan::SubqueryAlias( - SubqueryAlias::try_new(Arc::new(ignore), alias)?, + _ => Ok(Transformed::no(LogicalPlan::SubqueryAlias( + SubqueryAlias::try_new(input, alias)?, ))), } } @@ -125,23 +397,9 @@ impl ModelAnalyzeRule { Arc::clone(&self.session_state), table_scan, None, - &mut analysis.borrow_mut(), + scope, ), LogicalPlan::Join(join) => { - let mut analysis_mut = analysis.borrow_mut(); - let buffer = analysis_mut.required_columns_mut(); - let mut accum = HashSet::new(); - join.on.iter().for_each(|expr| { - let _ = utils::expr_to_columns(&expr.0, &mut accum); - let _ = utils::expr_to_columns(&expr.1, &mut accum); - }); - if let Some(filter_expr) = &join.filter { - let _ = utils::expr_to_columns(filter_expr, &mut accum); - } - accum.iter().try_for_each(|expr| { - self.collect_column(Expr::Column(expr.clone()), buffer) - })?; - let left = match Arc::unwrap_or_clone(join.left) { LogicalPlan::TableScan(table_scan) => { self.analyze_table_scan( @@ -149,7 +407,7 @@ impl ModelAnalyzeRule { Arc::clone(&self.session_state), table_scan, None, - &mut analysis_mut, + scope, )? .data } @@ -163,13 +421,13 @@ impl ModelAnalyzeRule { Arc::clone(&self.session_state), table_scan, None, - &mut analysis_mut, + scope, )? .data } ignore => ignore, }; - Ok(Transformed::no(LogicalPlan::Join(Join { + Ok(Transformed::yes(LogicalPlan::Join(Join { left: Arc::new(left), right: Arc::new(right), on: join.on, @@ -180,40 +438,19 @@ impl ModelAnalyzeRule { null_equals_null: join.null_equals_null, }))) } - _ => Ok(Transformed::no(plan)), - } - } - - fn collect_column( - &self, - expr: Expr, - buffer: &mut HashMap>, - ) -> Result<()> { - match expr { - Expr::Column(Column { - relation: Some(relation), - name, - }) => { - if belong_to_mdl( - &self.analyzed_wren_mdl.wren_mdl(), - relation.clone(), - self.session_state(), - ) { - buffer - .entry(relation.clone()) - .or_default() - .insert(Expr::Column(Column { - relation: Some(relation), - name, - })); - } - } - Expr::OuterReferenceColumn(_, column) => { - self.collect_column(Expr::Column(column), buffer)?; + LogicalPlan::Subquery(Subquery { subquery, .. }) => { + let mut scope_queue_mut = scope_queue.borrow_mut(); + let Some(child_scope) = scope_queue_mut.pop_front() else { + return internal_err!("No child scope found for subquery"); + }; + self.analyze_model( + Arc::unwrap_or_clone(subquery), + &child_scope, + scope_queue, + ) } - _ => {} + _ => Ok(Transformed::no(plan)), } - Ok(()) } fn analyze_table_scan( @@ -222,7 +459,7 @@ impl ModelAnalyzeRule { session_state_ref: SessionStateRef, table_scan: TableScan, alias: Option, - analysis: &mut RefMut, + scope: &RefCell, ) -> Result> { if belong_to_mdl( &analyzed_wren_mdl.wren_mdl(), @@ -230,21 +467,24 @@ impl ModelAnalyzeRule { Arc::clone(&session_state_ref), ) { let table_name = table_scan.table_name.table(); - // transform ViewTable to a subquery plan - if let Some(logical_plan) = table_scan.source.get_logical_plan() { - let subquery = LogicalPlanBuilder::from(logical_plan.clone()) - .alias(quoted(table_name))? - .build()?; - return Ok(Transformed::yes(subquery)); - } - if let Some(model) = analyzed_wren_mdl.wren_mdl.get_model(table_name) { let table_ref = alias.unwrap_or(table_scan.table_name.clone()); - let used_columns = analysis.required_columns_mut(); - let buffer = used_columns.get(&table_ref); - let field: Vec = buffer - .map(|s| s.iter().cloned().collect()) - .unwrap_or_default(); + let scope = scope.borrow(); + let field: Vec = if let Some(used_columns) = + scope.try_get_required_columns(&table_ref) + { + used_columns.iter().cloned().collect() + } else { + // If the required columns are not found in the current scope but the table is visited, + // it could be a count(*) query + if scope.try_get_visited_dataset(&table_ref).is_none() { + return internal_err!( + "Table {} not found in the visited dataset and required columns map", + table_ref + ); + }; + vec![] + }; let model_plan = LogicalPlan::Extension(Extension { node: Arc::new(ModelPlanNode::new( Arc::clone(&model), @@ -257,9 +497,6 @@ impl ModelAnalyzeRule { let subquery = LogicalPlanBuilder::from(model_plan) .alias(quoted(model.name()))? .build()?; - if let Some(buffer) = used_columns.get_mut(&table_ref) { - buffer.clear(); - } Ok(Transformed::yes(subquery)) } else { Ok(Transformed::no(LogicalPlan::TableScan(table_scan))) @@ -269,14 +506,20 @@ impl ModelAnalyzeRule { } } - fn replace_model_prefix_and_refresh_schema( + /// Remove the catalog and schema prefix of Wren for the column and refresh the schema. + /// The plan created by DataFusion is always with the Wren prefix for the column name. + /// Something like "wrenai.public.order_items_model.price". However, the model plan will be rewritten to a subquery alias + /// The catalog and schema are invalid for the subquery alias. We should remove the prefix and refresh the schema. + fn remove_wren_catalog_schema_prefix_and_refresh_schema( &self, plan: LogicalPlan, ) -> Result> { match plan { LogicalPlan::SubqueryAlias(SubqueryAlias { input, alias, .. }) => { let subquery = self - .replace_model_prefix_and_refresh_schema(Arc::unwrap_or_clone(input))? + .remove_wren_catalog_schema_prefix_and_refresh_schema( + Arc::unwrap_or_clone(input), + )? .data; Ok(Transformed::yes(LogicalPlan::SubqueryAlias( SubqueryAlias::try_new(Arc::new(subquery), alias)?, @@ -287,9 +530,9 @@ impl ModelAnalyzeRule { outer_ref_columns, }) => { let subquery = self - .replace_model_prefix_and_refresh_schema(Arc::unwrap_or_clone( - subquery, - ))? + .remove_wren_catalog_schema_prefix_and_refresh_schema( + Arc::unwrap_or_clone(subquery), + )? .data; Ok(Transformed::yes(LogicalPlan::Subquery(Subquery { subquery: Arc::new(subquery), @@ -403,12 +646,10 @@ impl ModelAnalyzeRule { if let Some(relation) = relation { Ok(self.rewrite_column_qualifier(relation, name, alias_model)) } else { - let catalog_schema = format!( - "{}.{}.", - self.analyzed_wren_mdl.wren_mdl().catalog(), - self.analyzed_wren_mdl.wren_mdl().schema() + let name = name.replace( + self.analyzed_wren_mdl.wren_mdl().catalog_schema_prefix(), + "", ); - let name = name.replace(&catalog_schema, ""); let ident = ident(&name); Ok(Transformed::yes(ident)) } @@ -452,12 +693,10 @@ impl ModelAnalyzeRule { Transformed::yes(col(format!("{}.{}", alias_model, quoted(&name)))) } else { // handle Wren View - let catalog_schema = format!( - "{}.{}.", - self.analyzed_wren_mdl.wren_mdl().catalog(), - self.analyzed_wren_mdl.wren_mdl().schema() + let name = name.replace( + self.analyzed_wren_mdl.wren_mdl().catalog_schema_prefix(), + "", ); - let name = name.replace(&catalog_schema, ""); Transformed::yes(Expr::Column(Column::new( Some(TableReference::bare(relation.table())), &name, @@ -504,58 +743,120 @@ impl ModelAnalyzeRule { } } -fn belong_to_mdl( - mdl: &WrenMDL, - table_reference: TableReference, - session: SessionStateRef, -) -> bool { - let session = session.read(); - let catalog = table_reference - .catalog() - .unwrap_or(&session.config_options().catalog.default_catalog); - let catalog_match = catalog == mdl.catalog(); +/// [Scope] is used to collect the required columns for models and visited tables in a query scope. +/// A query scope means is a full query body contain projection, relation. e.g. +/// SELECT a, b, c FROM table +/// +/// To avoid the table name be ambiguous, the relation name should be unique in the scope. +/// The relation of parent scope can be accessed by the child scope. +/// The child scope can also add the required columns to the parent scope. +#[derive(Clone, Debug, Default)] +pub struct Scope { + /// The columns required by the dataset + required_columns: HashMap>, + /// The Wren dataset visited in the scope (only the Wren dataset) + visited_dataset: HashMap, + /// The table name visited in the scope (not only the Wren dataset) + visited_tables: HashSet, + /// The parent scope + parent: Option>>, +} - let schema = table_reference - .schema() - .unwrap_or(&session.config_options().catalog.default_schema); - let schema_match = schema == mdl.schema(); +impl Scope { + pub fn new() -> Self { + Self { + required_columns: HashMap::new(), + visited_dataset: HashMap::new(), + visited_tables: HashSet::new(), + parent: None, + } + } - catalog_match && schema_match -} + pub fn new_child(parent: RefCell) -> Self { + Self { + required_columns: HashMap::new(), + visited_dataset: HashMap::new(), + visited_tables: HashSet::new(), + parent: Some(Box::new(parent)), + } + } -impl AnalyzerRule for ModelAnalyzeRule { - fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result { - let analysis = RefCell::new(Analysis::default()); - plan.transform_down_with_subqueries( - &|plan| -> Result> { - self.analyze_model_internal(plan, &analysis) - }, - )? - .map_data(|plan| { - plan.transform_up_with_subqueries( - &|plan| -> Result> { - self.replace_model_prefix_and_refresh_schema(plan) - }, - ) - })? - .map_data(|plan| plan.data.recompute_schema()) - .data() + pub fn add_required_column( + &mut self, + table_ref: TableReference, + expr: Expr, + ) -> Result<()> { + if self.visited_dataset.contains_key(&table_ref) { + self.required_columns + .entry(table_ref) + .or_default() + .insert(expr); + Ok(()) + } else if let Some(ref parent) = &self.parent { + parent + .clone() + .borrow_mut() + .add_required_column(table_ref, expr)?; + Ok(()) + } else if self.visited_tables.contains(&table_ref) { + // If the table is visited but the dataset is not found, it could be a subquery alias + Ok(()) + } else { + plan_err!("Table {} not found in the visited dataset", table_ref) + } } - fn name(&self) -> &str { - "ModelAnalyzeRule" + pub fn add_visited_dataset(&mut self, table_ref: TableReference, dataset: Dataset) { + self.visited_dataset.insert(table_ref, dataset); } -} -/// The context of the analysis -#[derive(Debug, Default)] -struct Analysis { - /// The columns required by the dataset - required_columns: HashMap>, -} + pub fn add_visited_table(&mut self, table_ref: TableReference) { + self.visited_tables.insert(table_ref); + } + + pub fn try_get_required_columns( + &self, + table_ref: &TableReference, + ) -> Option> { + let try_local = self.required_columns.get(table_ref).cloned(); + + if try_local.is_some() { + return try_local; + } + + if let Some(ref parent) = &self.parent { + let scope = parent.borrow(); + scope.try_get_required_columns(table_ref) + } else { + None + } + } + + pub fn try_get_visited_dataset(&self, table_ref: &TableReference) -> Option { + let try_local = self.visited_dataset.get(table_ref).cloned(); -impl Analysis { - fn required_columns_mut(&mut self) -> &mut HashMap> { - &mut self.required_columns + if try_local.is_some() { + return try_local; + } + + if let Some(ref parent) = &self.parent { + let scope = parent.borrow(); + scope.try_get_visited_dataset(table_ref) + } else { + None + } + } + + pub fn try_get_visited_table(&self, table_ref: &TableReference) -> bool { + if self.visited_tables.contains(table_ref) { + return true; + } + + if let Some(ref parent) = &self.parent { + let scope = parent.borrow(); + scope.try_get_visited_table(table_ref) + } else { + false + } } } diff --git a/wren-modeling-rs/core/src/logical_plan/utils.rs b/wren-modeling-rs/core/src/logical_plan/utils.rs index 087f570e8..34072f247 100644 --- a/wren-modeling-rs/core/src/logical_plan/utils.rs +++ b/wren-modeling-rs/core/src/logical_plan/utils.rs @@ -1,22 +1,24 @@ -use std::{collections::HashMap, sync::Arc}; - use datafusion::arrow::datatypes::{ DataType, Field, IntervalUnit, Schema, SchemaBuilder, SchemaRef, TimeUnit, }; +use datafusion::catalog_common::TableReference; +use datafusion::common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion::datasource::DefaultTableSource; use datafusion::error::Result; -use datafusion::logical_expr::{builder::LogicalTableSource, TableSource}; +use datafusion::logical_expr::{builder::LogicalTableSource, Expr, TableSource}; use log::debug; use petgraph::dot::{Config, Dot}; use petgraph::Graph; +use std::collections::HashSet; +use std::{collections::HashMap, sync::Arc}; use crate::mdl::lineage::DatasetLink; use crate::mdl::utils::quoted; -use crate::mdl::Dataset; use crate::mdl::{ manifest::{Column, Model}, WrenMDL, }; +use crate::mdl::{Dataset, SessionStateRef}; fn create_mock_list_type() -> DataType { let string_filed = Arc::new(Field::new("string", DataType::Utf8, false)); @@ -125,6 +127,7 @@ pub fn create_remote_table_source(model: &Model, mdl: &WrenMDL) -> Arc) { println!("graph: {:?}", dot); } +/// Check if the table reference belongs to the mdl +pub fn belong_to_mdl( + mdl: &WrenMDL, + table_reference: TableReference, + session: SessionStateRef, +) -> bool { + let session = session.read(); + let catalog = table_reference + .catalog() + .unwrap_or(&session.config_options().catalog.default_catalog); + let catalog_match = catalog == mdl.catalog(); + + let schema = table_reference + .schema() + .unwrap_or(&session.config_options().catalog.default_schema); + let schema_match = schema == mdl.schema(); + + catalog_match && schema_match +} + +/// Collect all the Columns and OuterReferenceColumns in the expression +pub fn expr_to_columns( + expr: &Expr, + accum: &mut HashSet, +) -> Result<()> { + expr.apply(|expr| { + match expr { + Expr::Column(qc) => { + accum.insert(qc.clone()); + } + Expr::OuterReferenceColumn(_, column) => { + accum.insert(column.clone()); + } + // Use explicit pattern match instead of a default + // implementation, so that in the future if someone adds + // new Expr types, they will check here as well + Expr::Unnest(_) + | Expr::ScalarVariable(_, _) + | Expr::Alias(_) + | Expr::Literal(_) + | Expr::BinaryExpr { .. } + | Expr::Like { .. } + | Expr::SimilarTo { .. } + | Expr::Not(_) + | Expr::IsNotNull(_) + | Expr::IsNull(_) + | Expr::IsTrue(_) + | Expr::IsFalse(_) + | Expr::IsUnknown(_) + | Expr::IsNotTrue(_) + | Expr::IsNotFalse(_) + | Expr::IsNotUnknown(_) + | Expr::Negative(_) + | Expr::Between { .. } + | Expr::Case { .. } + | Expr::Cast { .. } + | Expr::TryCast { .. } + | Expr::Sort { .. } + | Expr::ScalarFunction(..) + | Expr::WindowFunction { .. } + | Expr::AggregateFunction { .. } + | Expr::GroupingSet(_) + | Expr::InList { .. } + | Expr::Exists { .. } + | Expr::InSubquery(_) + | Expr::ScalarSubquery(_) + | Expr::Wildcard { .. } + | Expr::Placeholder(_) => {} + } + Ok(TreeNodeRecursion::Continue) + }) + .map(|_| ()) +} + #[cfg(test)] mod test { use datafusion::arrow::datatypes::{DataType, IntervalUnit, TimeUnit}; diff --git a/wren-modeling-rs/core/src/mdl/context.rs b/wren-modeling-rs/core/src/mdl/context.rs index 347686c2f..4c9b48047 100644 --- a/wren-modeling-rs/core/src/mdl/context.rs +++ b/wren-modeling-rs/core/src/mdl/context.rs @@ -2,6 +2,7 @@ use std::any::Any; use std::ops::Deref; use std::sync::Arc; +use crate::logical_plan::analyze::expand_view::ExpandWrenViewRule; use crate::logical_plan::analyze::model_anlayze::ModelAnalyzeRule; use crate::logical_plan::analyze::model_generation::ModelGenerationRule; use crate::logical_plan::utils::create_schema; @@ -45,6 +46,11 @@ pub async fn create_ctx_with_mdl( reset_default_catalog_schema.clone().read().deref().clone(), ) .with_analyzer_rules(vec![ + // expand the view should be the first rule + Arc::new(ExpandWrenViewRule::new( + Arc::clone(&analyzed_mdl), + Arc::clone(&reset_default_catalog_schema), + )), Arc::new(ModelAnalyzeRule::new( Arc::clone(&analyzed_mdl), reset_default_catalog_schema, diff --git a/wren-modeling-rs/core/src/mdl/mod.rs b/wren-modeling-rs/core/src/mdl/mod.rs index dcba52be3..ecf3b7ed0 100644 --- a/wren-modeling-rs/core/src/mdl/mod.rs +++ b/wren-modeling-rs/core/src/mdl/mod.rs @@ -9,11 +9,12 @@ use log::{debug, info}; use parking_lot::RwLock; use std::{collections::HashMap, sync::Arc}; +use crate::logical_plan::analyze::expand_view::ExpandWrenViewRule; use crate::logical_plan::analyze::model_anlayze::ModelAnalyzeRule; use crate::logical_plan::analyze::model_generation::ModelGenerationRule; use crate::logical_plan::utils::from_qualified_name_str; use crate::mdl::context::{create_ctx_with_mdl, register_table_with_mdl}; -use crate::mdl::manifest::{Column, Manifest, Model}; +use crate::mdl::manifest::{Column, Manifest, Model, View}; pub use dataset::Dataset; use manifest::Relationship; use regex::Regex; @@ -69,6 +70,7 @@ pub struct WrenMDL { pub manifest: Manifest, pub qualified_references: HashMap, pub register_tables: RegisterTables, + pub catalog_schema_prefix: String, } impl WrenMDL { @@ -122,6 +124,7 @@ impl WrenMDL { }); WrenMDL { + catalog_schema_prefix: format!("{}.{}.", &manifest.catalog, &manifest.schema), manifest, qualified_references: qualifed_references, register_tables: HashMap::new(), @@ -160,6 +163,14 @@ impl WrenMDL { .cloned() } + pub fn get_view(&self, name: &str) -> Option> { + self.manifest + .views + .iter() + .find(|view| view.name == name) + .cloned() + } + pub fn get_relationship(&self, name: &str) -> Option> { self.manifest .relationships @@ -174,6 +185,10 @@ impl WrenMDL { ) -> Option { self.qualified_references.get(column).cloned() } + + pub fn catalog_schema_prefix(&self) -> &str { + &self.catalog_schema_prefix + } } /// Transform the SQL based on the MDL @@ -194,13 +209,8 @@ pub async fn transform_sql_with_ctx( analyzed_mdl: Arc, sql: &str, ) -> Result { - let catalog_schema = format!( - "{}.{}.", - analyzed_mdl.wren_mdl().catalog(), - analyzed_mdl.wren_mdl().schema() - ); info!("wren-core received SQL: {}", sql); - let ctx = create_ctx_with_mdl(ctx, analyzed_mdl).await?; + let ctx = create_ctx_with_mdl(ctx, Arc::clone(&analyzed_mdl)).await?; let plan = ctx.state().create_logical_plan(sql).await?; debug!("wren-core original plan:\n {plan:?}"); let analyzed = ctx.state().optimize(&plan)?; @@ -211,7 +221,9 @@ pub async fn transform_sql_with_ctx( match unparser.plan_to_sql(&analyzed) { Ok(sql) => { // TODO: workaround to remove unnecessary catalog and schema of mdl - let replaced = sql.to_string().replace(&catalog_schema, ""); + let replaced = sql + .to_string() + .replace(analyzed_mdl.wren_mdl().catalog_schema_prefix(), ""); info!("wren-core planned SQL: {}", replaced); Ok(replaced) } @@ -251,6 +263,11 @@ pub async fn apply_wren_rules( ctx: &SessionContext, analyzed_wren_mdl: Arc, ) -> Result<()> { + // expand the view should be the first rule + ctx.add_analyzer_rule(Arc::new(ExpandWrenViewRule::new( + Arc::clone(&analyzed_wren_mdl), + ctx.state_ref(), + ))); ctx.add_analyzer_rule(Arc::new(ModelAnalyzeRule::new( Arc::clone(&analyzed_wren_mdl), ctx.state_ref(), diff --git a/wren-modeling-rs/sqllogictest/src/test_context.rs b/wren-modeling-rs/sqllogictest/src/test_context.rs index a31e802e0..cdcae8c5b 100644 --- a/wren-modeling-rs/sqllogictest/src/test_context.rs +++ b/wren-modeling-rs/sqllogictest/src/test_context.rs @@ -307,7 +307,7 @@ async fn register_ecommerce_mdl( pub async fn register_tpch_table(ctx: &SessionContext) -> Result { let path = PathBuf::from(TEST_RESOURCES).join("tpch"); - let data = read_dir_recursive(&path).unwrap(); + let data = read_dir_recursive(&path)?; // register parquet file with the execution context for file in data.iter() { @@ -317,8 +317,7 @@ pub async fn register_tpch_table(ctx: &SessionContext) -> Result { file.to_str().unwrap(), ParquetReadOptions::default(), ) - .await - .unwrap(); + .await?; } let (ctx, mdl) = register_tpch_mdl(ctx).await?; @@ -446,7 +445,92 @@ async fn register_tpch_mdl( .build(), ) .build(); - let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?); + let mut register_tables = HashMap::new(); + register_tables.insert( + "datafusion.public.customer".to_string(), + ctx.catalog("datafusion") + .unwrap() + .schema("public") + .unwrap() + .table("customer") + .await? + .unwrap(), + ); + register_tables.insert( + "datafusion.public.orders".to_string(), + ctx.catalog("datafusion") + .unwrap() + .schema("public") + .unwrap() + .table("orders") + .await? + .unwrap(), + ); + register_tables.insert( + "datafusion.public.lineitem".to_string(), + ctx.catalog("datafusion") + .unwrap() + .schema("public") + .unwrap() + .table("lineitem") + .await? + .unwrap(), + ); + register_tables.insert( + "datafusion.public.part".to_string(), + ctx.catalog("datafusion") + .unwrap() + .schema("public") + .unwrap() + .table("part") + .await? + .unwrap(), + ); + register_tables.insert( + "datafusion.public.supplier".to_string(), + ctx.catalog("datafusion") + .unwrap() + .schema("public") + .unwrap() + .table("supplier") + .await? + .unwrap(), + ); + register_tables.insert( + "datafusion.public.partsupp".to_string(), + ctx.catalog("datafusion") + .unwrap() + .schema("public") + .unwrap() + .table("partsupp") + .await? + .unwrap(), + ); + register_tables.insert( + "datafusion.public.nation".to_string(), + ctx.catalog("datafusion") + .unwrap() + .schema("public") + .unwrap() + .table("nation") + .await? + .unwrap(), + ); + register_tables.insert( + "datafusion.public.region".to_string(), + ctx.catalog("datafusion") + .unwrap() + .schema("public") + .unwrap() + .table("region") + .await? + .unwrap(), + ); + + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze_with_tables( + manifest, + register_tables, + )?); // TODO: there're some conflicts for datafusion optimization rules. // let ctx = create_ctx_with_mdl(ctx, Arc::clone(&analyzed_mdl)).await?; Ok((ctx.to_owned(), analyzed_mdl)) diff --git a/wren-modeling-rs/sqllogictest/test_files/model.slt b/wren-modeling-rs/sqllogictest/test_files/model.slt index b1f0e431b..c5318d480 100644 --- a/wren-modeling-rs/sqllogictest/test_files/model.slt +++ b/wren-modeling-rs/sqllogictest/test_files/model.slt @@ -46,3 +46,10 @@ query IR select "Id", "Price" from "Order_items" where "Order_id" in (SELECT "Order_id" FROM "Orders" WHERE "Customer_id" = 'f6c39f83de772dd502809cee2fee4c41') ---- 105 287.4 + +# TODO: DataFusion has some case sensitivity issue with the outer reference column name +# Test the query with outer reference column +# query I +# select "Customer_id" from wrenai.public."Orders" where not exists (select 1 from wrenai.public."Order_items" where "Orders"."Order_id" = "Order_items"."Order_id") +# ---- +# 1 diff --git a/wren-modeling-rs/sqllogictest/test_files/tpch/q22.slt.part b/wren-modeling-rs/sqllogictest/test_files/tpch/q22.slt.part index 531358cc0..171b22b50 100644 --- a/wren-modeling-rs/sqllogictest/test_files/tpch/q22.slt.part +++ b/wren-modeling-rs/sqllogictest/test_files/tpch/q22.slt.part @@ -21,8 +21,7 @@ #caused by #Schema error: No field named customer.c_phone. -# query TIR -query error +query TIR select cntrycode, count(*) as numcust, @@ -60,11 +59,11 @@ group by cntrycode order by cntrycode; -#---- -#13 10 75359.29 -#17 8 62288.98 -#18 14 111072.45 -#23 5 40458.86 -#29 11 88722.85 -#30 17 122189.33 -#31 8 66313.16 +---- +13 10 75359.29 +17 8 62288.98 +18 14 111072.45 +23 5 40458.86 +29 11 88722.85 +30 17 122189.33 +31 8 66313.16