Skip to content

Move UnwrapCastInComparison into Simplifier #15012

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Mar 6, 2025
7 changes: 4 additions & 3 deletions datafusion/core/tests/sql/explain_analyze.rs
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,8 @@ async fn csv_explain_verbose() {
async fn csv_explain_inlist_verbose() {
let ctx = SessionContext::new();
register_aggregate_csv_by_sql(&ctx).await;
let sql = "EXPLAIN VERBOSE SELECT c1 FROM aggregate_test_100 where c2 in (1,2,4)";
// Inlist len <=3 case will be transformed to OR List so we test with len=4
let sql = "EXPLAIN VERBOSE SELECT c1 FROM aggregate_test_100 where c2 in (1,2,4,5)";
let actual = execute(&ctx, sql).await;

// Optimized by PreCastLitInComparisonExpressions rule
Expand All @@ -368,12 +369,12 @@ async fn csv_explain_inlist_verbose() {
// before optimization (Int64 literals)
assert_contains!(
&actual,
"aggregate_test_100.c2 IN ([Int64(1), Int64(2), Int64(4)])"
"aggregate_test_100.c2 IN ([Int64(1), Int64(2), Int64(4), Int64(5)])"
);
// after optimization (casted to Int8)
assert_contains!(
&actual,
"aggregate_test_100.c2 IN ([Int8(1), Int8(2), Int8(4)])"
"aggregate_test_100.c2 IN ([Int8(1), Int8(2), Int8(4), Int8(5)])"
);
}

Expand Down
1 change: 0 additions & 1 deletion datafusion/optimizer/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ pub mod replace_distinct_aggregate;
pub mod scalar_subquery_to_join;
pub mod simplify_expressions;
pub mod single_distinct_to_groupby;
pub mod unwrap_cast_in_comparison;
pub mod utils;

#[cfg(test)]
Expand Down
3 changes: 0 additions & 3 deletions datafusion/optimizer/src/optimizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ use crate::replace_distinct_aggregate::ReplaceDistinctWithAggregate;
use crate::scalar_subquery_to_join::ScalarSubqueryToJoin;
use crate::simplify_expressions::SimplifyExpressions;
use crate::single_distinct_to_groupby::SingleDistinctToGroupBy;
use crate::unwrap_cast_in_comparison::UnwrapCastInComparison;
use crate::utils::log_plan;

/// `OptimizerRule`s transforms one [`LogicalPlan`] into another which
Expand Down Expand Up @@ -243,7 +242,6 @@ impl Optimizer {
let rules: Vec<Arc<dyn OptimizerRule + Sync + Send>> = vec![
Arc::new(EliminateNestedUnion::new()),
Arc::new(SimplifyExpressions::new()),
Arc::new(UnwrapCastInComparison::new()),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is great -- it will also reduce the number of times the entire plan tree gets walked/massaged

Arc::new(ReplaceDistinctWithAggregate::new()),
Arc::new(EliminateJoin::new()),
Arc::new(DecorrelatePredicateSubquery::new()),
Expand All @@ -266,7 +264,6 @@ impl Optimizer {
// The previous optimizations added expressions and projections,
// that might benefit from the following rules
Arc::new(SimplifyExpressions::new()),
Arc::new(UnwrapCastInComparison::new()),
Arc::new(CommonSubexprEliminate::new()),
Arc::new(EliminateGroupByConstant::new()),
Arc::new(OptimizeProjections::new()),
Expand Down
92 changes: 90 additions & 2 deletions datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ use datafusion_common::{
tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter},
};
use datafusion_common::{internal_err, DFSchema, DataFusionError, Result, ScalarValue};
use datafusion_expr::simplify::ExprSimplifyResult;
use datafusion_expr::{
and, lit, or, BinaryExpr, Case, ColumnarValue, Expr, Like, Operator, Volatility,
WindowFunctionDefinition,
Expand All @@ -42,14 +41,23 @@ use datafusion_expr::{
expr::{InList, InSubquery, WindowFunction},
utils::{iter_conjunction, iter_conjunction_owned},
};
use datafusion_expr::{simplify::ExprSimplifyResult, Cast, TryCast};
use datafusion_physical_expr::{create_physical_expr, execution_props::ExecutionProps};

use super::inlist_simplifier::ShortenInListSimplifier;
use super::utils::*;
use crate::analyzer::type_coercion::TypeCoercionRewriter;
use crate::simplify_expressions::guarantees::GuaranteeRewriter;
use crate::simplify_expressions::regex::simplify_regex_expr;
use crate::simplify_expressions::unwrap_cast::{
is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary,
is_cast_expr_and_support_unwrap_cast_in_comparison_for_inlist,
unwrap_cast_in_comparison_for_binary,
};
use crate::simplify_expressions::SimplifyInfo;
use crate::{
analyzer::type_coercion::TypeCoercionRewriter,
simplify_expressions::unwrap_cast::try_cast_literal_to_type,
};
use indexmap::IndexSet;
use regex::Regex;

Expand Down Expand Up @@ -1742,6 +1750,86 @@ impl<S: SimplifyInfo> TreeNodeRewriter for Simplifier<'_, S> {
}
}

// =======================================
// unwrap_cast_in_comparison
// =======================================
//
// For case:
// try_cast/cast(expr as data_type) op literal
Expr::BinaryExpr(BinaryExpr { left, op, right })
if is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary(
info, &left, &right,
) && op.supports_propagation() =>
{
unwrap_cast_in_comparison_for_binary(info, left, right, op)?
}
// literal op try_cast/cast(expr as data_type)
// -->
// try_cast/cast(expr as data_type) op_swap literal
Expr::BinaryExpr(BinaryExpr { left, op, right })
if is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary(
info, &right, &left,
) && op.supports_propagation()
&& op.swap().is_some() =>
{
unwrap_cast_in_comparison_for_binary(
info,
right,
left,
op.swap().unwrap(),
)?
}
// For case:
// try_cast/cast(expr as left_type) in (expr1,expr2,expr3)
Expr::InList(InList {
expr: mut left,
list,
negated,
}) if is_cast_expr_and_support_unwrap_cast_in_comparison_for_inlist(
info, &left, &list,
) =>
{
let (Expr::TryCast(TryCast {
expr: left_expr, ..
})
| Expr::Cast(Cast {
expr: left_expr, ..
})) = left.as_mut()
else {
return internal_err!("Expect cast expr, but got {:?}", left)?;
};

let expr_type = info.get_data_type(left_expr)?;
let right_exprs = list
.into_iter()
.map(|right| {
match right {
Expr::Literal(right_lit_value) => {
// if the right_lit_value can be casted to the type of internal_left_expr
// we need to unwrap the cast for cast/try_cast expr, and add cast to the literal
let Some(value) = try_cast_literal_to_type(&right_lit_value, &expr_type) else {
internal_err!(
"Can't cast the list expr {:?} to type {:?}",
right_lit_value, &expr_type
)?
};
Ok(lit(value))
}
other_expr => internal_err!(
"Only support literal expr to optimize, but the expr is {:?}",
&other_expr
),
}
})
.collect::<Result<Vec<_>>>()?;

Transformed::yes(Expr::InList(InList {
expr: std::mem::take(left_expr),
list: right_exprs,
negated,
}))
}

// no additional rewrites possible
expr => Transformed::no(expr),
})
Expand Down
1 change: 1 addition & 0 deletions datafusion/optimizer/src/simplify_expressions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ mod guarantees;
mod inlist_simplifier;
mod regex;
pub mod simplify_exprs;
mod unwrap_cast;
mod utils;

// backwards compatibility
Expand Down
Loading