Skip to content

Commit c456926

Browse files
committed
[PR SNAPSHOT] Fix incorrect searched CASE optimization
This cherry pick a snapshot of a PR offered upstream. ------------------------------------------------------ There is an optimization for searched CASE where values are of boolean type. It was converting the expression like CASE WHEN X THEN A WHEN Y THEN B .. [ ELSE D ] END into (X AND A) OR (Y AND NOT X AND B) [ OR (NOT (X OR Y) AND D) ] This had the following problems - does not work for nullable conditions. If X is nullable, we cannot use NOT (X) to compliment it. We need to use `X IS DISTINCT FROM true` - it does not work correctly when some conditions are nullable and other values are false. E.g. X=NULL, A=true, Y=NULL, B=true, D=false, the CASE should return false, but the boolean expression will simplify to `(NULL AND ..) OR (NULL AND ..) OR (false)` which is NULL, not false - thus we use `X` for truthness check of `X`, we need to test `X IS NOT DISTINCT FROM true` - it did not work correctly when default D is missing, but conditions do not evaluate to NULL. CASE's result should be NULL but was false. This commit fixes that optimization.
1 parent 2a0a82c commit c456926

File tree

2 files changed

+85
-24
lines changed

2 files changed

+85
-24
lines changed

datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs

Lines changed: 74 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1384,29 +1384,26 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> {
13841384
when_then_expr,
13851385
else_expr,
13861386
}) if !when_then_expr.is_empty()
1387-
&& when_then_expr.len() < 3 // The rewrite is O(n!) so limit to small number
1387+
&& when_then_expr.len() < 3 // The rewrite is O(n²) so limit to small number
13881388
&& info.is_boolean_type(&when_then_expr[0].1)? =>
13891389
{
1390-
// The disjunction of all the when predicates encountered so far
1390+
// String disjunction of all the when predicates encountered so far. Not nullable.
13911391
let mut filter_expr = lit(false);
13921392
// The disjunction of all the cases
13931393
let mut out_expr = lit(false);
13941394

13951395
for (when, then) in when_then_expr {
1396-
let case_expr = when
1397-
.as_ref()
1398-
.clone()
1399-
.and(filter_expr.clone().not())
1400-
.and(*then);
1396+
let when = is_exactly_true(*when, info)?;
1397+
let case_expr =
1398+
when.clone().and(filter_expr.clone().not()).and(*then);
14011399

14021400
out_expr = out_expr.or(case_expr);
1403-
filter_expr = filter_expr.or(*when);
1401+
filter_expr = filter_expr.or(when);
14041402
}
14051403

1406-
if let Some(else_expr) = else_expr {
1407-
let case_expr = filter_expr.not().and(*else_expr);
1408-
out_expr = out_expr.or(case_expr);
1409-
}
1404+
let else_expr = else_expr.map(|b| *b).unwrap_or_else(lit_bool_null);
1405+
let case_expr = filter_expr.not().and(else_expr);
1406+
out_expr = out_expr.or(case_expr);
14101407

14111408
// Do a first pass at simplification
14121409
out_expr.rewrite(self)?
@@ -1826,6 +1823,19 @@ fn inlist_except(mut l1: InList, l2: &InList) -> Result<Expr> {
18261823
Ok(Expr::InList(l1))
18271824
}
18281825

1826+
/// Returns expression testing a boolean `expr` for being exactly `true` (not `false` or NULL).
1827+
fn is_exactly_true(expr: Expr, info: &impl SimplifyInfo) -> Result<Expr> {
1828+
if !info.nullable(&expr)? {
1829+
Ok(expr)
1830+
} else {
1831+
Ok(Expr::BinaryExpr(BinaryExpr {
1832+
left: Box::new(expr),
1833+
op: Operator::IsNotDistinctFrom,
1834+
right: Box::new(lit(true)),
1835+
}))
1836+
}
1837+
}
1838+
18291839
#[cfg(test)]
18301840
mod tests {
18311841
use crate::simplify_expressions::SimplifyContext;
@@ -3243,12 +3253,12 @@ mod tests {
32433253
simplify(Expr::Case(Case::new(
32443254
None,
32453255
vec![(
3246-
Box::new(col("c2").not_eq(lit(false))),
3256+
Box::new(col("c2_non_null").not_eq(lit(false))),
32473257
Box::new(lit("ok").eq(lit("not_ok"))),
32483258
)],
3249-
Some(Box::new(col("c2").eq(lit(true)))),
3259+
Some(Box::new(col("c2_non_null").eq(lit(true)))),
32503260
))),
3251-
col("c2").not().and(col("c2")) // #1716
3261+
lit(false) // #1716
32523262
);
32533263

32543264
// CASE WHEN c2 != false THEN "ok" == "ok" ELSE c2
@@ -3263,12 +3273,12 @@ mod tests {
32633273
simplify(simplify(Expr::Case(Case::new(
32643274
None,
32653275
vec![(
3266-
Box::new(col("c2").not_eq(lit(false))),
3276+
Box::new(col("c2_non_null").not_eq(lit(false))),
32673277
Box::new(lit("ok").eq(lit("ok"))),
32683278
)],
3269-
Some(Box::new(col("c2").eq(lit(true)))),
3279+
Some(Box::new(col("c2_non_null").eq(lit(true)))),
32703280
)))),
3271-
col("c2")
3281+
col("c2_non_null")
32723282
);
32733283

32743284
// CASE WHEN ISNULL(c2) THEN true ELSE c2
@@ -3299,12 +3309,12 @@ mod tests {
32993309
simplify(simplify(Expr::Case(Case::new(
33003310
None,
33013311
vec![
3302-
(Box::new(col("c1")), Box::new(lit(true)),),
3303-
(Box::new(col("c2")), Box::new(lit(false)),),
3312+
(Box::new(col("c1_non_null")), Box::new(lit(true)),),
3313+
(Box::new(col("c2_non_null")), Box::new(lit(false)),),
33043314
],
33053315
Some(Box::new(lit(true))),
33063316
)))),
3307-
col("c1").or(col("c1").not().and(col("c2").not()))
3317+
col("c1_non_null").or(col("c1_non_null").not().and(col("c2_non_null").not()))
33083318
);
33093319

33103320
// CASE WHEN c1 then true WHEN c2 then true ELSE false
@@ -3318,13 +3328,53 @@ mod tests {
33183328
simplify(simplify(Expr::Case(Case::new(
33193329
None,
33203330
vec![
3321-
(Box::new(col("c1")), Box::new(lit(true)),),
3322-
(Box::new(col("c2")), Box::new(lit(false)),),
3331+
(Box::new(col("c1_non_null")), Box::new(lit(true)),),
3332+
(Box::new(col("c2_non_null")), Box::new(lit(false)),),
33233333
],
33243334
Some(Box::new(lit(true))),
33253335
)))),
3326-
col("c1").or(col("c1").not().and(col("c2").not()))
3336+
col("c1_non_null").or(col("c1_non_null").not().and(col("c2_non_null").not()))
3337+
);
3338+
3339+
// CASE WHEN c > 0 THEN true END AS c1
3340+
assert_eq!(
3341+
simplify(simplify(Expr::Case(Case::new(
3342+
None,
3343+
vec![(Box::new(col("c3").gt(lit(0_i64))), Box::new(lit(true)))],
3344+
None,
3345+
)))),
3346+
not_distinct_from(col("c3").gt(lit(0_i64)), lit(true)).or(distinct_from(
3347+
col("c3").gt(lit(0_i64)),
3348+
lit(true)
3349+
)
3350+
.and(lit_bool_null()))
33273351
);
3352+
3353+
// CASE WHEN c > 0 THEN true ELSE false END AS c1
3354+
assert_eq!(
3355+
simplify(simplify(Expr::Case(Case::new(
3356+
None,
3357+
vec![(Box::new(col("c3").gt(lit(0_i64))), Box::new(lit(true)))],
3358+
Some(Box::new(lit(false))),
3359+
)))),
3360+
not_distinct_from(col("c3").gt(lit(0_i64)), lit(true))
3361+
);
3362+
}
3363+
3364+
fn distinct_from(left: impl Into<Expr>, right: impl Into<Expr>) -> Expr {
3365+
Expr::BinaryExpr(BinaryExpr {
3366+
left: Box::new(left.into()),
3367+
op: Operator::IsDistinctFrom,
3368+
right: Box::new(right.into()),
3369+
})
3370+
}
3371+
3372+
fn not_distinct_from(left: impl Into<Expr>, right: impl Into<Expr>) -> Expr {
3373+
Expr::BinaryExpr(BinaryExpr {
3374+
left: Box::new(left.into()),
3375+
op: Operator::IsNotDistinctFrom,
3376+
right: Box::new(right.into()),
3377+
})
33283378
}
33293379

33303380
#[test]

datafusion/sqllogictest/test_files/case.slt

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,3 +202,14 @@ query I
202202
SELECT CASE arrow_cast([1,2,3], 'FixedSizeList(3, Int64)') WHEN arrow_cast([1,2,3], 'FixedSizeList(3, Int64)') THEN 1 ELSE 0 END;
203203
----
204204
1
205+
206+
query IBB
207+
SELECT c,
208+
CASE WHEN c > 0 THEN true END AS c1,
209+
CASE WHEN c > 0 THEN true ELSE false END AS c2
210+
FROM (VALUES (1), (0), (-1), (NULL)) AS t(c)
211+
----
212+
1 true true
213+
0 NULL false
214+
-1 NULL false
215+
NULL NULL false

0 commit comments

Comments
 (0)