Skip to content

Commit eb4ee62

Browse files
jayzhan211alamb
andauthored
Move UnwrapCastInComparison into Simplifier (#15012)
* add unwrap in simplify expr * rm unwrap cast * return err * rename * fix * fmt * add unwrap_cast module to simplify expressions * tweak comment * Move tests * Rewrite to use simplifier schema * Update tests for simplify logic --------- Co-authored-by: Andrew Lamb <[email protected]>
1 parent 6d5f5cd commit eb4ee62

File tree

7 files changed

+291
-278
lines changed

7 files changed

+291
-278
lines changed

datafusion/core/tests/sql/explain_analyze.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,8 @@ async fn csv_explain_verbose() {
355355
async fn csv_explain_inlist_verbose() {
356356
let ctx = SessionContext::new();
357357
register_aggregate_csv_by_sql(&ctx).await;
358-
let sql = "EXPLAIN VERBOSE SELECT c1 FROM aggregate_test_100 where c2 in (1,2,4)";
358+
// Inlist len <=3 case will be transformed to OR List so we test with len=4
359+
let sql = "EXPLAIN VERBOSE SELECT c1 FROM aggregate_test_100 where c2 in (1,2,4,5)";
359360
let actual = execute(&ctx, sql).await;
360361

361362
// Optimized by PreCastLitInComparisonExpressions rule
@@ -368,12 +369,12 @@ async fn csv_explain_inlist_verbose() {
368369
// before optimization (Int64 literals)
369370
assert_contains!(
370371
&actual,
371-
"aggregate_test_100.c2 IN ([Int64(1), Int64(2), Int64(4)])"
372+
"aggregate_test_100.c2 IN ([Int64(1), Int64(2), Int64(4), Int64(5)])"
372373
);
373374
// after optimization (casted to Int8)
374375
assert_contains!(
375376
&actual,
376-
"aggregate_test_100.c2 IN ([Int8(1), Int8(2), Int8(4)])"
377+
"aggregate_test_100.c2 IN ([Int8(1), Int8(2), Int8(4), Int8(5)])"
377378
);
378379
}
379380

datafusion/optimizer/src/lib.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ pub mod replace_distinct_aggregate;
6060
pub mod scalar_subquery_to_join;
6161
pub mod simplify_expressions;
6262
pub mod single_distinct_to_groupby;
63-
pub mod unwrap_cast_in_comparison;
6463
pub mod utils;
6564

6665
#[cfg(test)]

datafusion/optimizer/src/optimizer.rs

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ use crate::replace_distinct_aggregate::ReplaceDistinctWithAggregate;
5454
use crate::scalar_subquery_to_join::ScalarSubqueryToJoin;
5555
use crate::simplify_expressions::SimplifyExpressions;
5656
use crate::single_distinct_to_groupby::SingleDistinctToGroupBy;
57-
use crate::unwrap_cast_in_comparison::UnwrapCastInComparison;
5857
use crate::utils::log_plan;
5958

6059
/// `OptimizerRule`s transforms one [`LogicalPlan`] into another which
@@ -243,7 +242,6 @@ impl Optimizer {
243242
let rules: Vec<Arc<dyn OptimizerRule + Sync + Send>> = vec![
244243
Arc::new(EliminateNestedUnion::new()),
245244
Arc::new(SimplifyExpressions::new()),
246-
Arc::new(UnwrapCastInComparison::new()),
247245
Arc::new(ReplaceDistinctWithAggregate::new()),
248246
Arc::new(EliminateJoin::new()),
249247
Arc::new(DecorrelatePredicateSubquery::new()),
@@ -266,7 +264,6 @@ impl Optimizer {
266264
// The previous optimizations added expressions and projections,
267265
// that might benefit from the following rules
268266
Arc::new(SimplifyExpressions::new()),
269-
Arc::new(UnwrapCastInComparison::new()),
270267
Arc::new(CommonSubexprEliminate::new()),
271268
Arc::new(EliminateGroupByConstant::new()),
272269
Arc::new(OptimizeProjections::new()),

datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs

Lines changed: 90 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ use datafusion_common::{
3232
tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter},
3333
};
3434
use datafusion_common::{internal_err, DFSchema, DataFusionError, Result, ScalarValue};
35-
use datafusion_expr::simplify::ExprSimplifyResult;
3635
use datafusion_expr::{
3736
and, lit, or, BinaryExpr, Case, ColumnarValue, Expr, Like, Operator, Volatility,
3837
WindowFunctionDefinition,
@@ -42,14 +41,23 @@ use datafusion_expr::{
4241
expr::{InList, InSubquery, WindowFunction},
4342
utils::{iter_conjunction, iter_conjunction_owned},
4443
};
44+
use datafusion_expr::{simplify::ExprSimplifyResult, Cast, TryCast};
4545
use datafusion_physical_expr::{create_physical_expr, execution_props::ExecutionProps};
4646

4747
use super::inlist_simplifier::ShortenInListSimplifier;
4848
use super::utils::*;
49-
use crate::analyzer::type_coercion::TypeCoercionRewriter;
5049
use crate::simplify_expressions::guarantees::GuaranteeRewriter;
5150
use crate::simplify_expressions::regex::simplify_regex_expr;
51+
use crate::simplify_expressions::unwrap_cast::{
52+
is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary,
53+
is_cast_expr_and_support_unwrap_cast_in_comparison_for_inlist,
54+
unwrap_cast_in_comparison_for_binary,
55+
};
5256
use crate::simplify_expressions::SimplifyInfo;
57+
use crate::{
58+
analyzer::type_coercion::TypeCoercionRewriter,
59+
simplify_expressions::unwrap_cast::try_cast_literal_to_type,
60+
};
5361
use indexmap::IndexSet;
5462
use regex::Regex;
5563

@@ -1742,6 +1750,86 @@ impl<S: SimplifyInfo> TreeNodeRewriter for Simplifier<'_, S> {
17421750
}
17431751
}
17441752

1753+
// =======================================
1754+
// unwrap_cast_in_comparison
1755+
// =======================================
1756+
//
1757+
// For case:
1758+
// try_cast/cast(expr as data_type) op literal
1759+
Expr::BinaryExpr(BinaryExpr { left, op, right })
1760+
if is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary(
1761+
info, &left, &right,
1762+
) && op.supports_propagation() =>
1763+
{
1764+
unwrap_cast_in_comparison_for_binary(info, left, right, op)?
1765+
}
1766+
// literal op try_cast/cast(expr as data_type)
1767+
// -->
1768+
// try_cast/cast(expr as data_type) op_swap literal
1769+
Expr::BinaryExpr(BinaryExpr { left, op, right })
1770+
if is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary(
1771+
info, &right, &left,
1772+
) && op.supports_propagation()
1773+
&& op.swap().is_some() =>
1774+
{
1775+
unwrap_cast_in_comparison_for_binary(
1776+
info,
1777+
right,
1778+
left,
1779+
op.swap().unwrap(),
1780+
)?
1781+
}
1782+
// For case:
1783+
// try_cast/cast(expr as left_type) in (expr1,expr2,expr3)
1784+
Expr::InList(InList {
1785+
expr: mut left,
1786+
list,
1787+
negated,
1788+
}) if is_cast_expr_and_support_unwrap_cast_in_comparison_for_inlist(
1789+
info, &left, &list,
1790+
) =>
1791+
{
1792+
let (Expr::TryCast(TryCast {
1793+
expr: left_expr, ..
1794+
})
1795+
| Expr::Cast(Cast {
1796+
expr: left_expr, ..
1797+
})) = left.as_mut()
1798+
else {
1799+
return internal_err!("Expect cast expr, but got {:?}", left)?;
1800+
};
1801+
1802+
let expr_type = info.get_data_type(left_expr)?;
1803+
let right_exprs = list
1804+
.into_iter()
1805+
.map(|right| {
1806+
match right {
1807+
Expr::Literal(right_lit_value) => {
1808+
// if the right_lit_value can be casted to the type of internal_left_expr
1809+
// we need to unwrap the cast for cast/try_cast expr, and add cast to the literal
1810+
let Some(value) = try_cast_literal_to_type(&right_lit_value, &expr_type) else {
1811+
internal_err!(
1812+
"Can't cast the list expr {:?} to type {:?}",
1813+
right_lit_value, &expr_type
1814+
)?
1815+
};
1816+
Ok(lit(value))
1817+
}
1818+
other_expr => internal_err!(
1819+
"Only support literal expr to optimize, but the expr is {:?}",
1820+
&other_expr
1821+
),
1822+
}
1823+
})
1824+
.collect::<Result<Vec<_>>>()?;
1825+
1826+
Transformed::yes(Expr::InList(InList {
1827+
expr: std::mem::take(left_expr),
1828+
list: right_exprs,
1829+
negated,
1830+
}))
1831+
}
1832+
17451833
// no additional rewrites possible
17461834
expr => Transformed::no(expr),
17471835
})

datafusion/optimizer/src/simplify_expressions/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ mod guarantees;
2323
mod inlist_simplifier;
2424
mod regex;
2525
pub mod simplify_exprs;
26+
mod unwrap_cast;
2627
mod utils;
2728

2829
// backwards compatibility

0 commit comments

Comments
 (0)