Skip to content

Commit c176533

Browse files
authored
Support aliases in ConstEvaluator (#14734)
Not sure why they are not supported. It seems that if we're not careful, some transformations can introduce aliases nested inside other expressions.
1 parent 6a036ae commit c176533

File tree

3 files changed

+48
-26
lines changed

3 files changed

+48
-26
lines changed

datafusion/core/tests/expr_api/simplification.rs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,33 @@ fn test_const_evaluator() {
365365
);
366366
}
367367

368+
#[test]
369+
fn test_const_evaluator_alias() {
370+
// true --> true
371+
test_evaluate(lit(true).alias("a"), lit(true));
372+
// true or true --> true
373+
test_evaluate(lit(true).alias("a").or(lit(true).alias("b")), lit(true));
374+
// "foo" == "foo" --> true
375+
test_evaluate(lit("foo").alias("a").eq(lit("foo").alias("b")), lit(true));
376+
// c = 1 + 2 --> c + 3
377+
test_evaluate(
378+
col("c")
379+
.alias("a")
380+
.eq(lit(1).alias("b") + lit(2).alias("c")),
381+
col("c").alias("a").eq(lit(3)),
382+
);
383+
// (foo != foo) OR (c = 1) --> false OR (c = 1)
384+
test_evaluate(
385+
lit("foo")
386+
.alias("a")
387+
.not_eq(lit("foo").alias("b"))
388+
.alias("c")
389+
.or(col("c").alias("d").eq(lit(1).alias("e")))
390+
.alias("f"),
391+
col("c").alias("d").eq(lit(1)).alias("f"),
392+
);
393+
}
394+
368395
#[test]
369396
fn test_const_evaluator_scalar_functions() {
370397
// concat("foo", "bar") --> "foobar"

datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs

Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -44,16 +44,15 @@ use datafusion_expr::{
4444
};
4545
use datafusion_physical_expr::{create_physical_expr, execution_props::ExecutionProps};
4646

47+
use super::inlist_simplifier::ShortenInListSimplifier;
48+
use super::utils::*;
4749
use crate::analyzer::type_coercion::TypeCoercionRewriter;
4850
use crate::simplify_expressions::guarantees::GuaranteeRewriter;
4951
use crate::simplify_expressions::regex::simplify_regex_expr;
5052
use crate::simplify_expressions::SimplifyInfo;
5153
use indexmap::IndexSet;
5254
use regex::Regex;
5355

54-
use super::inlist_simplifier::ShortenInListSimplifier;
55-
use super::utils::*;
56-
5756
/// This structure handles API for expression simplification
5857
///
5958
/// Provides simplification information based on DFSchema and
@@ -515,30 +514,27 @@ impl TreeNodeRewriter for ConstEvaluator<'_> {
515514

516515
// NB: do not short circuit recursion even if we find a non
517516
// evaluatable node (so we can fold other children, args to
518-
// functions, etc)
517+
// functions, etc.)
519518
Ok(Transformed::no(expr))
520519
}
521520

522521
fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
523522
match self.can_evaluate.pop() {
524-
// Certain expressions such as `CASE` and `COALESCE` are short circuiting
525-
// and may not evaluate all their sub expressions. Thus if
526-
// if any error is countered during simplification, return the original
523+
// Certain expressions such as `CASE` and `COALESCE` are short-circuiting
524+
// and may not evaluate all their sub expressions. Thus, if
525+
// any error is countered during simplification, return the original
527526
// so that normal evaluation can occur
528-
Some(true) => {
529-
let result = self.evaluate_to_scalar(expr);
530-
match result {
531-
ConstSimplifyResult::Simplified(s) => {
532-
Ok(Transformed::yes(Expr::Literal(s)))
533-
}
534-
ConstSimplifyResult::NotSimplified(s) => {
535-
Ok(Transformed::no(Expr::Literal(s)))
536-
}
537-
ConstSimplifyResult::SimplifyRuntimeError(_, expr) => {
538-
Ok(Transformed::yes(expr))
539-
}
527+
Some(true) => match self.evaluate_to_scalar(expr) {
528+
ConstSimplifyResult::Simplified(s) => {
529+
Ok(Transformed::yes(Expr::Literal(s)))
540530
}
541-
}
531+
ConstSimplifyResult::NotSimplified(s) => {
532+
Ok(Transformed::no(Expr::Literal(s)))
533+
}
534+
ConstSimplifyResult::SimplifyRuntimeError(_, expr) => {
535+
Ok(Transformed::yes(expr))
536+
}
537+
},
542538
Some(false) => Ok(Transformed::no(expr)),
543539
_ => internal_err!("Failed to pop can_evaluate"),
544540
}
@@ -586,9 +582,7 @@ impl<'a> ConstEvaluator<'a> {
586582
// added they can be checked for their ability to be evaluated
587583
// at plan time
588584
match expr {
589-
// Has no runtime cost, but needed during planning
590-
Expr::Alias(..)
591-
| Expr::AggregateFunction { .. }
585+
Expr::AggregateFunction { .. }
592586
| Expr::ScalarVariable(_, _)
593587
| Expr::Column(_)
594588
| Expr::OuterReferenceColumn(_, _)
@@ -603,6 +597,7 @@ impl<'a> ConstEvaluator<'a> {
603597
Self::volatility_ok(func.signature().volatility)
604598
}
605599
Expr::Literal(_)
600+
| Expr::Alias(..)
606601
| Expr::Unnest(_)
607602
| Expr::BinaryExpr { .. }
608603
| Expr::Not(_)

datafusion/sqllogictest/test_files/subquery.slt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -834,7 +834,7 @@ query TT
834834
explain SELECT t1_id, (SELECT count(*) as _cnt FROM t2 WHERE t2.t2_int = t1.t1_int) as cnt from t1
835835
----
836836
logical_plan
837-
01)Projection: t1.t1_id, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(0) AS _cnt ELSE __scalar_sq_1._cnt END AS cnt
837+
01)Projection: t1.t1_id, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(0) ELSE __scalar_sq_1._cnt END AS cnt
838838
02)--Left Join: t1.t1_int = __scalar_sq_1.t2_int
839839
03)----TableScan: t1 projection=[t1_id, t1_int]
840840
04)----SubqueryAlias: __scalar_sq_1
@@ -855,7 +855,7 @@ query TT
855855
explain SELECT t1_id, (SELECT count(*) + 2 as _cnt FROM t2 WHERE t2.t2_int = t1.t1_int) from t1
856856
----
857857
logical_plan
858-
01)Projection: t1.t1_id, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(2) AS _cnt ELSE __scalar_sq_1._cnt END AS _cnt
858+
01)Projection: t1.t1_id, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(2) ELSE __scalar_sq_1._cnt END AS _cnt
859859
02)--Left Join: t1.t1_int = __scalar_sq_1.t2_int
860860
03)----TableScan: t1 projection=[t1_id, t1_int]
861861
04)----SubqueryAlias: __scalar_sq_1
@@ -922,7 +922,7 @@ query TT
922922
explain SELECT t1_id, (SELECT count(*) + 2 as cnt_plus_2 FROM t2 WHERE t2.t2_int = t1.t1_int having count(*) = 0) from t1
923923
----
924924
logical_plan
925-
01)Projection: t1.t1_id, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(2) AS cnt_plus_2 WHEN __scalar_sq_1.count(*) != Int64(0) THEN NULL ELSE __scalar_sq_1.cnt_plus_2 END AS cnt_plus_2
925+
01)Projection: t1.t1_id, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(2) WHEN __scalar_sq_1.count(*) != Int64(0) THEN NULL ELSE __scalar_sq_1.cnt_plus_2 END AS cnt_plus_2
926926
02)--Left Join: t1.t1_int = __scalar_sq_1.t2_int
927927
03)----TableScan: t1 projection=[t1_id, t1_int]
928928
04)----SubqueryAlias: __scalar_sq_1

0 commit comments

Comments
 (0)