diff --git a/datafusion/common/src/dfschema.rs b/datafusion/common/src/dfschema.rs index c490852c6ee3..cb07f15b9d26 100644 --- a/datafusion/common/src/dfschema.rs +++ b/datafusion/common/src/dfschema.rs @@ -384,12 +384,8 @@ impl DFSchema { let self_fields = self.fields().iter(); let other_fields = other.fields().iter(); self_fields.zip(other_fields).all(|(f1, f2)| { - // TODO: resolve field when exist alias - // f1.qualifier() == f2.qualifier() - // && f1.name() == f2.name() - // column(t1.a) field is "t1"."a" - // column(x) as t1.a field is ""."t1.a" - f1.qualified_name() == f2.qualified_name() + f1.qualifier() == f2.qualifier() + && f1.name() == f2.name() && Self::datatype_is_semantically_equal(f1.data_type(), f2.data_type()) }) } diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 9e1a0aae4ceb..3271290b4eb4 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -27,7 +27,9 @@ use crate::window_frame; use crate::window_function; use crate::Operator; use arrow::datatypes::DataType; -use datafusion_common::{plan_err, Column, DataFusionError, Result, ScalarValue}; +use datafusion_common::{ + plan_err, Column, DFField, DataFusionError, Result, ScalarValue, +}; use std::collections::HashSet; use std::fmt; use std::fmt::{Display, Formatter, Write}; @@ -185,6 +187,7 @@ pub enum Expr { pub struct Alias { pub expr: Box, pub name: String, + field: Option, } impl Alias { @@ -192,6 +195,31 @@ impl Alias { Self { expr: Box::new(expr), name: name.into(), + field: None, + } + } + + pub fn new_with_field( + expr: Box, + name: impl Into, + field: DFField, + ) -> Self { + Self { + expr: expr, + name: name.into(), + field: Some(field), + } + } + + pub fn field(&self) -> &Option { + &self.field + } + + pub fn with_new_expr(self, expr: Expr) -> Self { + Self { + expr: Box::new(expr), + name: self.name, + field: self.field, } } } @@ -785,6 +813,13 @@ impl Expr { } } + pub fn alias_field(&self) -> Option { + match self { + Expr::Alias(alias) => alias.field().clone(), + _ => None, + } + } + /// Ensure `expr` has the name as `original_name` by adding an /// alias if necessary. pub fn alias_if_changed(self, original_name: String) -> Result { @@ -809,6 +844,15 @@ impl Expr { } } + /// Return `self AS name` alias expression with a field + pub fn alias_with_field(self, field: DFField) -> Expr { + Expr::Alias(Alias::new_with_field( + Box::new(self.unalias()), + field.qualified_name(), + field, + )) + } + /// Remove an alias from an expression if one exists. pub fn unalias(self) -> Expr { match self { @@ -930,6 +974,7 @@ macro_rules! expr_vec_fmt { impl fmt::Display for Expr { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { + // Expr::Alias(Alias { expr, name, field }) => write!(f, "{expr:?} AS {name} field: {field:?}"), Expr::Alias(Alias { expr, name, .. }) => write!(f, "{expr} AS {name}"), Expr::Column(c) => write!(f, "{c}"), Expr::OuterReferenceColumn(_, c) => write!(f, "outer_ref({c})"), diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 30537d0fdd81..8826fdb7aeea 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -279,6 +279,9 @@ impl ExprSchemable for Expr { /// placed in an output field **named** col("c1 + c2") fn to_field(&self, input_schema: &DFSchema) -> Result { match self { + Expr::Alias(alias) if alias.field().is_some() => { + Ok(alias.field().as_ref().unwrap().clone()) + } Expr::Column(c) => Ok(DFField::new( c.relation.clone(), &c.name, diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index e058708701b9..2dbc63fdab75 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -482,7 +482,38 @@ impl LogicalPlan { } pub fn with_new_inputs(&self, inputs: &[LogicalPlan]) -> Result { - from_plan(self, &self.expressions(), inputs) + // with_new_inputs use original expression, + // so we don't need to recompute Schema. + match &self { + LogicalPlan::Projection(projection) => { + Ok(LogicalPlan::Projection(Projection::try_new_with_schema( + projection.expr.to_vec(), + Arc::new(inputs[0].clone()), + projection.schema.clone(), + )?)) + } + LogicalPlan::Window(Window { + window_expr, + schema, + .. + }) => Ok(LogicalPlan::Window(Window { + input: Arc::new(inputs[0].clone()), + window_expr: window_expr.to_vec(), + schema: schema.clone(), + })), + LogicalPlan::Aggregate(Aggregate { + group_expr, + aggr_expr, + schema, + .. + }) => Ok(LogicalPlan::Aggregate(Aggregate::try_new_with_schema( + Arc::new(inputs[0].clone()), + group_expr.to_vec(), + aggr_expr.to_vec(), + schema.clone(), + )?)), + _ => from_plan(self, &self.expressions(), inputs), + } } /// Convert a prepared [`LogicalPlan`] into its inner logical plan diff --git a/datafusion/optimizer/src/merge_projection.rs b/datafusion/optimizer/src/merge_projection.rs index 408055b8e7d4..88a933fa5b82 100644 --- a/datafusion/optimizer/src/merge_projection.rs +++ b/datafusion/optimizer/src/merge_projection.rs @@ -77,9 +77,18 @@ pub(super) fn merge_projection( .map(|expr| replace_cols_by_name(expr.clone(), &replace_map)) .enumerate() .map(|(i, e)| match e { - Ok(e) => { - let parent_expr = parent_projection.schema.fields()[i].qualified_name(); - e.alias_if_changed(parent_expr) + Ok(expr) => { + let parent_expr = parent_projection.expr[i].clone(); + match parent_expr { + Expr::Alias(alias) => { + Ok(Expr::Alias(alias.with_new_expr(expr.unalias()))) + } + _ => { + let parent_expr_name = + parent_projection.schema.fields()[i].qualified_name(); + expr.alias_if_changed(parent_expr_name) + } + } } Err(e) => Err(e), }) diff --git a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs index c65768bb8b11..46a39aedc425 100644 --- a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs +++ b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs @@ -20,7 +20,7 @@ use std::sync::Arc; use super::{ExprSimplifier, SimplifyContext}; -use crate::utils::merge_schema; +use crate::utils::{generate_alias_project, merge_schema}; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::{DFSchema, DFSchemaRef, Result}; use datafusion_expr::{logical_plan::LogicalPlan, utils::from_plan}; @@ -85,15 +85,11 @@ impl SimplifyExpressions { let expr = plan .expressions() .into_iter() - .map(|e| { - // TODO: unify with `rewrite_preserving_name` - let original_name = e.name_for_alias()?; - let new_e = simplifier.simplify(e)?; - new_e.alias_if_changed(original_name) - }) + .map(|e| simplifier.simplify(e)) .collect::>>()?; - from_plan(plan, &expr, &new_inputs) + let new_plan = from_plan(plan, &expr, &new_inputs)?; + return generate_alias_project(plan.schema(), new_plan); } } @@ -164,14 +160,11 @@ mod tests { } fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) -> Result<()> { - let rule = SimplifyExpressions::new(); - let optimized_plan = rule - .try_optimize(plan, &OptimizerContext::new()) - .unwrap() - .expect("failed to optimize plan"); - let formatted_plan = format!("{optimized_plan:?}"); - assert_eq!(formatted_plan, expected); - Ok(()) + crate::test::assert_optimized_plan_eq( + Arc::new(SimplifyExpressions::new()), + plan, + expected, + ) } #[test] @@ -343,10 +336,10 @@ mod tests { )? .build()?; - let expected = "\ - Aggregate: groupBy=[[test.a, test.c]], aggr=[[MAX(test.b) AS MAX(test.b = Boolean(true)), MIN(test.b)]]\ - \n Projection: test.a, test.c, test.b\ - \n TableScan: test"; + let expected = "Projection: test.a, test.c, MAX(test.b) AS MAX(test.b = Boolean(true)), MIN(test.b)\ + \n Aggregate: groupBy=[[test.a, test.c]], aggr=[[MAX(test.b), MIN(test.b)]]\ + \n Projection: test.a, test.c, test.b\ + \n TableScan: test"; assert_optimized_plan_eq(&plan, expected) } @@ -366,8 +359,7 @@ mod tests { let values = vec![vec![expr1, expr2]]; let plan = LogicalPlanBuilder::values(values)?.build()?; - let expected = "\ - Values: (Int32(3) AS Int32(1) + Int32(2), Int32(1) AS Int32(2) - Int32(1))"; + let expected = "Values: (Int32(3), Int32(1))"; assert_optimized_plan_eq(&plan, expected) } @@ -832,9 +824,8 @@ mod tests { .build()?; // before simplify: t.g = power(t.f, 1.0) - // after simplify: (t.g = t.f) as "t.g = power(t.f, 1.0)" - let expected = - "TableScan: test, unsupported_filters=[g = f AS g = power(f,Float64(1))]"; + // after simplify: t.g = t.f" + let expected = "TableScan: test, unsupported_filters=[g = f]"; assert_optimized_plan_eq(&plan, expected) } diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs index daa695f77144..fd78fecb34e2 100644 --- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs +++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs @@ -19,16 +19,15 @@ //! of expr can be added if needed. //! This rule can reduce adding the `Expr::Cast` the expr instead of adding the `Expr::Cast` to literal expr. use crate::optimizer::ApplyOrder; -use crate::utils::merge_schema; +use crate::utils::{generate_alias_project, merge_schema}; use crate::{OptimizerConfig, OptimizerRule}; use arrow::datatypes::{ DataType, TimeUnit, MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION, }; use arrow::temporal_conversions::{MICROSECONDS, MILLISECONDS, NANOSECONDS}; -use datafusion_common::tree_node::{RewriteRecursion, TreeNodeRewriter}; +use datafusion_common::tree_node::{RewriteRecursion, TreeNode, TreeNodeRewriter}; use datafusion_common::{DFSchemaRef, DataFusionError, Result, ScalarValue}; use datafusion_expr::expr::{BinaryExpr, Cast, InList, TryCast}; -use datafusion_expr::expr_rewriter::rewrite_preserving_name; use datafusion_expr::utils::from_plan; use datafusion_expr::{ binary_expr, in_list, lit, Expr, ExprSchemable, LogicalPlan, Operator, @@ -99,15 +98,13 @@ impl OptimizerRule for UnwrapCastInComparison { let new_exprs = plan .expressions() .into_iter() - .map(|expr| rewrite_preserving_name(expr, &mut expr_rewriter)) + .map(|expr| expr.rewrite(&mut expr_rewriter)) .collect::>>()?; let inputs: Vec = plan.inputs().into_iter().cloned().collect(); - Ok(Some(from_plan( - plan, - new_exprs.as_slice(), - inputs.as_slice(), - )?)) + let new_plan = from_plan(plan, new_exprs.as_slice(), inputs.as_slice())?; + + Ok(Some(generate_alias_project(plan.schema(), new_plan)?)) } fn name(&self) -> &str { diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index adb3bf6302fc..3110843e26a0 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -17,6 +17,7 @@ //! Collection of utility functions that are leveraged by the query optimizer rules +use crate::merge_projection::merge_projection; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::{plan_err, Column, DFSchemaRef}; use datafusion_common::{DFSchema, Result}; @@ -25,7 +26,7 @@ use datafusion_expr::expr_rewriter::{replace_col, strip_outer_reference}; use datafusion_expr::{ and, logical_plan::{Filter, LogicalPlan}, - Expr, Operator, + Expr, Operator, Projection, }; use log::{debug, trace}; use std::collections::{BTreeSet, HashMap}; @@ -324,6 +325,38 @@ pub(crate) fn replace_qualified_name( replace_col(expr, &replace_map) } +pub(super) fn generate_alias_project( + original_schema: &DFSchemaRef, + new_plan: LogicalPlan, +) -> Result { + if original_schema == new_plan.schema() { + return Ok(new_plan); + } + + let old_fields = original_schema.fields(); + let new_fields = new_plan.schema().fields(); + let alias_exprs = new_fields + .iter() + .enumerate() + .map(|(i, new_field)| { + let old_field = &old_fields[i]; + let col = Expr::Column(new_field.qualified_column()); + Ok(if old_field == new_field { + col + } else { + col.alias_with_field(old_field.clone()) + }) + }) + .collect::>>()?; + + let alias_project = Projection::try_new(alias_exprs, Arc::new(new_plan))?; + + match alias_project.input.as_ref() { + LogicalPlan::Projection(project) => merge_projection(&alias_project, project), + _ => Ok(LogicalPlan::Projection(alias_project)), + } +} + /// Log the plan in debug/tracing mode after some part of the optimizer runs pub fn log_plan(description: &str, plan: &LogicalPlan) { debug!("{description}:\n{}\n", plan.display_indent()); diff --git a/datafusion/optimizer/tests/optimizer_integration.rs b/datafusion/optimizer/tests/optimizer_integration.rs index e2e10428ad63..a07aee3a7e95 100644 --- a/datafusion/optimizer/tests/optimizer_integration.rs +++ b/datafusion/optimizer/tests/optimizer_integration.rs @@ -82,8 +82,8 @@ fn subquery_filter_with_cast() -> Result<()> { fn case_when_aggregate() -> Result<()> { let sql = "SELECT col_utf8, SUM(CASE WHEN col_int32 > 0 THEN 1 ELSE 0 END) AS n FROM test GROUP BY col_utf8"; let plan = test_sql(sql)?; - let expected = "Projection: test.col_utf8, SUM(CASE WHEN test.col_int32 > Int64(0) THEN Int64(1) ELSE Int64(0) END) AS n\ - \n Aggregate: groupBy=[[test.col_utf8]], aggr=[[SUM(CASE WHEN test.col_int32 > Int32(0) THEN Int64(1) ELSE Int64(0) END) AS SUM(CASE WHEN test.col_int32 > Int64(0) THEN Int64(1) ELSE Int64(0) END)]]\ + let expected = "Projection: test.col_utf8, SUM(CASE WHEN test.col_int32 > Int32(0) THEN Int64(1) ELSE Int64(0) END) AS n\ + \n Aggregate: groupBy=[[test.col_utf8]], aggr=[[SUM(CASE WHEN test.col_int32 > Int32(0) THEN Int64(1) ELSE Int64(0) END)]]\ \n TableScan: test projection=[col_int32, col_utf8]"; assert_eq!(expected, format!("{plan:?}")); Ok(())