diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index 183bb1f7fb49..ebbe4ee54ee9 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -17,7 +17,7 @@ use datafusion_common::{internal_err, not_impl_err, plan_err, DataFusionError, Result}; use datafusion_expr::{ - expr::Alias, Distinct, Expr, JoinConstraint, JoinType, LogicalPlan, + expr::Alias, Distinct, Expr, JoinConstraint, JoinType, LogicalPlan, Projection, }; use sqlparser::ast::{self, SetExpr}; @@ -131,6 +131,87 @@ impl Unparser<'_> { Ok(ast::SetExpr::Select(Box::new(select_builder.build()?))) } + /// Reconstructs a SELECT SQL statement from a logical plan by unprojecting column expressions + /// found in a [Projection] node. This requires scanning the plan tree for relevant Aggregate + /// and Window nodes and matching column expressions to the appropriate agg or window expressions. + fn reconstruct_select_statement( + &self, + plan: &LogicalPlan, + p: &Projection, + select: &mut SelectBuilder, + ) -> Result<()> { + match find_agg_node_within_select(plan, None, true) { + Some(AggVariant::Aggregate(agg)) => { + let items = p + .expr + .iter() + .map(|proj_expr| { + let unproj = unproject_agg_exprs(proj_expr, agg)?; + self.select_item_to_sql(&unproj) + }) + .collect::>>()?; + + select.projection(items); + select.group_by(ast::GroupByExpr::Expressions( + agg.group_expr + .iter() + .map(|expr| self.expr_to_sql(expr)) + .collect::>>()?, + )); + } + Some(AggVariant::Window(window)) => { + let items = p + .expr + .iter() + .map(|proj_expr| { + let unproj = unproject_window_exprs(proj_expr, &window)?; + self.select_item_to_sql(&unproj) + }) + .collect::>>()?; + + select.projection(items); + } + None => { + let items = p + .expr + .iter() + .map(|e| self.select_item_to_sql(e)) + .collect::>>()?; + select.projection(items); + } + } + Ok(()) + } + + fn projection_to_sql( + &self, + plan: &LogicalPlan, + p: &Projection, + query: &mut Option, + select: &mut SelectBuilder, + relation: &mut RelationBuilder, + ) -> Result<()> { + // A second projection implies a derived tablefactor + if !select.already_projected() { + self.reconstruct_select_statement(plan, p, select)?; + self.select_to_sql_recursively(p.input.as_ref(), query, select, relation) + } else { + let mut derived_builder = DerivedRelationBuilder::default(); + derived_builder.lateral(false).alias(None).subquery({ + let inner_statment = self.plan_to_sql(plan)?; + if let ast::Statement::Query(inner_query) = inner_statment { + inner_query + } else { + return internal_err!( + "Subquery must be a Query, but found {inner_statment:?}" + ); + } + }); + relation.derived(derived_builder); + Ok(()) + } + } + fn select_to_sql_recursively( &self, plan: &LogicalPlan, @@ -159,74 +240,7 @@ impl Unparser<'_> { Ok(()) } LogicalPlan::Projection(p) => { - // A second projection implies a derived tablefactor - if !select.already_projected() { - // Special handling when projecting an agregation plan - if let Some(aggvariant) = - find_agg_node_within_select(plan, None, true) - { - match aggvariant { - AggVariant::Aggregate(agg) => { - let items = p - .expr - .iter() - .map(|proj_expr| { - let unproj = unproject_agg_exprs(proj_expr, agg)?; - self.select_item_to_sql(&unproj) - }) - .collect::>>()?; - - select.projection(items); - select.group_by(ast::GroupByExpr::Expressions( - agg.group_expr - .iter() - .map(|expr| self.expr_to_sql(expr)) - .collect::>>()?, - )); - } - AggVariant::Window(window) => { - let items = p - .expr - .iter() - .map(|proj_expr| { - let unproj = - unproject_window_exprs(proj_expr, &window)?; - self.select_item_to_sql(&unproj) - }) - .collect::>>()?; - - select.projection(items); - } - } - } else { - let items = p - .expr - .iter() - .map(|e| self.select_item_to_sql(e)) - .collect::>>()?; - select.projection(items); - } - self.select_to_sql_recursively( - p.input.as_ref(), - query, - select, - relation, - ) - } else { - let mut derived_builder = DerivedRelationBuilder::default(); - derived_builder.lateral(false).alias(None).subquery({ - let inner_statment = self.plan_to_sql(plan)?; - if let ast::Statement::Query(inner_query) = inner_statment { - inner_query - } else { - return internal_err!( - "Subquery must be a Query, but found {inner_statment:?}" - ); - } - }); - relation.derived(derived_builder); - Ok(()) - } + self.projection_to_sql(plan, p, query, select, relation) } LogicalPlan::Filter(filter) => { if let Some(AggVariant::Aggregate(agg)) = diff --git a/datafusion/sql/src/unparser/utils.rs b/datafusion/sql/src/unparser/utils.rs index 326cd15ba140..331da9773f16 100644 --- a/datafusion/sql/src/unparser/utils.rs +++ b/datafusion/sql/src/unparser/utils.rs @@ -46,29 +46,30 @@ pub(crate) fn find_agg_node_within_select<'a>( } else { input.first()? }; + // Agg nodes explicitly return immediately with a single node // Window nodes accumulate in a vec until encountering a TableScan or 2nd projection - if let LogicalPlan::Aggregate(agg) = input { - Some(AggVariant::Aggregate(agg)) - } else if let LogicalPlan::Window(window) = input { - prev_windows = match &mut prev_windows { - Some(AggVariant::Window(windows)) => { - windows.push(window); + match input { + LogicalPlan::Aggregate(agg) => Some(AggVariant::Aggregate(agg)), + LogicalPlan::Window(window) => { + prev_windows = match &mut prev_windows { + Some(AggVariant::Window(windows)) => { + windows.push(window); + prev_windows + } + _ => Some(AggVariant::Window(vec![window])), + }; + find_agg_node_within_select(input, prev_windows, already_projected) + } + LogicalPlan::Projection(_) => { + if already_projected { prev_windows + } else { + find_agg_node_within_select(input, prev_windows, true) } - _ => Some(AggVariant::Window(vec![window])), - }; - find_agg_node_within_select(input, prev_windows, already_projected) - } else if let LogicalPlan::TableScan(_) = input { - prev_windows - } else if let LogicalPlan::Projection(_) = input { - if already_projected { - prev_windows - } else { - find_agg_node_within_select(input, prev_windows, true) } - } else { - find_agg_node_within_select(input, prev_windows, already_projected) + LogicalPlan::TableScan(_) => prev_windows, + _ => find_agg_node_within_select(input, prev_windows, already_projected), } }