Skip to content

Commit 05c0347

Browse files
findepifelipecrv
authored andcommitted
[PR SNAPSHOT] Fix incorrect searched CASE optimization
This cherry pick a snapshot of a PR offered upstream. ------------------------------------------------------ There is an optimization for searched CASE where values are of boolean type. It was converting the expression like CASE WHEN X THEN A WHEN Y THEN B .. [ ELSE D ] END into (X AND A) OR (Y AND NOT X AND B) [ OR (NOT (X OR Y) AND D) ] This had the following problems - does not work for nullable conditions. If X is nullable, we cannot use NOT (X) to compliment it. We need to use `X IS DISTINCT FROM true` - it does not work correctly when some conditions are nullable and other values are false. E.g. X=NULL, A=true, Y=NULL, B=true, D=false, the CASE should return false, but the boolean expression will simplify to `(NULL AND ..) OR (NULL AND ..) OR (false)` which is NULL, not false - thus we use `X` for truthness check of `X`, we need to test `X IS NOT DISTINCT FROM true` - it did not work correctly when default D is missing, but conditions do not evaluate to NULL. CASE's result should be NULL but was false. This commit fixes that optimization.
1 parent c07d449 commit 05c0347

File tree

2 files changed

+85
-24
lines changed

2 files changed

+85
-24
lines changed

datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs

Lines changed: 74 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1385,29 +1385,26 @@ impl<S: SimplifyInfo> TreeNodeRewriter for Simplifier<'_, S> {
13851385
when_then_expr,
13861386
else_expr,
13871387
}) if !when_then_expr.is_empty()
1388-
&& when_then_expr.len() < 3 // The rewrite is O(n!) so limit to small number
1388+
&& when_then_expr.len() < 3 // The rewrite is O(n²) so limit to small number
13891389
&& info.is_boolean_type(&when_then_expr[0].1)? =>
13901390
{
1391-
// The disjunction of all the when predicates encountered so far
1391+
// String disjunction of all the when predicates encountered so far. Not nullable.
13921392
let mut filter_expr = lit(false);
13931393
// The disjunction of all the cases
13941394
let mut out_expr = lit(false);
13951395

13961396
for (when, then) in when_then_expr {
1397-
let case_expr = when
1398-
.as_ref()
1399-
.clone()
1400-
.and(filter_expr.clone().not())
1401-
.and(*then);
1397+
let when = is_exactly_true(*when, info)?;
1398+
let case_expr =
1399+
when.clone().and(filter_expr.clone().not()).and(*then);
14021400

14031401
out_expr = out_expr.or(case_expr);
1404-
filter_expr = filter_expr.or(*when);
1402+
filter_expr = filter_expr.or(when);
14051403
}
14061404

1407-
if let Some(else_expr) = else_expr {
1408-
let case_expr = filter_expr.not().and(*else_expr);
1409-
out_expr = out_expr.or(case_expr);
1410-
}
1405+
let else_expr = else_expr.map(|b| *b).unwrap_or_else(lit_bool_null);
1406+
let case_expr = filter_expr.not().and(else_expr);
1407+
out_expr = out_expr.or(case_expr);
14111408

14121409
// Do a first pass at simplification
14131410
out_expr.rewrite(self)?
@@ -1881,6 +1878,19 @@ fn inlist_except(mut l1: InList, l2: &InList) -> Result<Expr> {
18811878
Ok(Expr::InList(l1))
18821879
}
18831880

1881+
/// Returns expression testing a boolean `expr` for being exactly `true` (not `false` or NULL).
1882+
fn is_exactly_true(expr: Expr, info: &impl SimplifyInfo) -> Result<Expr> {
1883+
if !info.nullable(&expr)? {
1884+
Ok(expr)
1885+
} else {
1886+
Ok(Expr::BinaryExpr(BinaryExpr {
1887+
left: Box::new(expr),
1888+
op: Operator::IsNotDistinctFrom,
1889+
right: Box::new(lit(true)),
1890+
}))
1891+
}
1892+
}
1893+
18841894
#[cfg(test)]
18851895
mod tests {
18861896
use crate::simplify_expressions::SimplifyContext;
@@ -3272,12 +3282,12 @@ mod tests {
32723282
simplify(Expr::Case(Case::new(
32733283
None,
32743284
vec![(
3275-
Box::new(col("c2").not_eq(lit(false))),
3285+
Box::new(col("c2_non_null").not_eq(lit(false))),
32763286
Box::new(lit("ok").eq(lit("not_ok"))),
32773287
)],
3278-
Some(Box::new(col("c2").eq(lit(true)))),
3288+
Some(Box::new(col("c2_non_null").eq(lit(true)))),
32793289
))),
3280-
col("c2").not().and(col("c2")) // #1716
3290+
lit(false) // #1716
32813291
);
32823292

32833293
// CASE WHEN c2 != false THEN "ok" == "ok" ELSE c2
@@ -3292,12 +3302,12 @@ mod tests {
32923302
simplify(simplify(Expr::Case(Case::new(
32933303
None,
32943304
vec![(
3295-
Box::new(col("c2").not_eq(lit(false))),
3305+
Box::new(col("c2_non_null").not_eq(lit(false))),
32963306
Box::new(lit("ok").eq(lit("ok"))),
32973307
)],
3298-
Some(Box::new(col("c2").eq(lit(true)))),
3308+
Some(Box::new(col("c2_non_null").eq(lit(true)))),
32993309
)))),
3300-
col("c2")
3310+
col("c2_non_null")
33013311
);
33023312

33033313
// CASE WHEN ISNULL(c2) THEN true ELSE c2
@@ -3328,12 +3338,12 @@ mod tests {
33283338
simplify(simplify(Expr::Case(Case::new(
33293339
None,
33303340
vec![
3331-
(Box::new(col("c1")), Box::new(lit(true)),),
3332-
(Box::new(col("c2")), Box::new(lit(false)),),
3341+
(Box::new(col("c1_non_null")), Box::new(lit(true)),),
3342+
(Box::new(col("c2_non_null")), Box::new(lit(false)),),
33333343
],
33343344
Some(Box::new(lit(true))),
33353345
)))),
3336-
col("c1").or(col("c1").not().and(col("c2").not()))
3346+
col("c1_non_null").or(col("c1_non_null").not().and(col("c2_non_null").not()))
33373347
);
33383348

33393349
// CASE WHEN c1 then true WHEN c2 then true ELSE false
@@ -3347,13 +3357,53 @@ mod tests {
33473357
simplify(simplify(Expr::Case(Case::new(
33483358
None,
33493359
vec![
3350-
(Box::new(col("c1")), Box::new(lit(true)),),
3351-
(Box::new(col("c2")), Box::new(lit(false)),),
3360+
(Box::new(col("c1_non_null")), Box::new(lit(true)),),
3361+
(Box::new(col("c2_non_null")), Box::new(lit(false)),),
33523362
],
33533363
Some(Box::new(lit(true))),
33543364
)))),
3355-
col("c1").or(col("c1").not().and(col("c2").not()))
3365+
col("c1_non_null").or(col("c1_non_null").not().and(col("c2_non_null").not()))
3366+
);
3367+
3368+
// CASE WHEN c > 0 THEN true END AS c1
3369+
assert_eq!(
3370+
simplify(simplify(Expr::Case(Case::new(
3371+
None,
3372+
vec![(Box::new(col("c3").gt(lit(0_i64))), Box::new(lit(true)))],
3373+
None,
3374+
)))),
3375+
not_distinct_from(col("c3").gt(lit(0_i64)), lit(true)).or(distinct_from(
3376+
col("c3").gt(lit(0_i64)),
3377+
lit(true)
3378+
)
3379+
.and(lit_bool_null()))
33563380
);
3381+
3382+
// CASE WHEN c > 0 THEN true ELSE false END AS c1
3383+
assert_eq!(
3384+
simplify(simplify(Expr::Case(Case::new(
3385+
None,
3386+
vec![(Box::new(col("c3").gt(lit(0_i64))), Box::new(lit(true)))],
3387+
Some(Box::new(lit(false))),
3388+
)))),
3389+
not_distinct_from(col("c3").gt(lit(0_i64)), lit(true))
3390+
);
3391+
}
3392+
3393+
fn distinct_from(left: impl Into<Expr>, right: impl Into<Expr>) -> Expr {
3394+
Expr::BinaryExpr(BinaryExpr {
3395+
left: Box::new(left.into()),
3396+
op: Operator::IsDistinctFrom,
3397+
right: Box::new(right.into()),
3398+
})
3399+
}
3400+
3401+
fn not_distinct_from(left: impl Into<Expr>, right: impl Into<Expr>) -> Expr {
3402+
Expr::BinaryExpr(BinaryExpr {
3403+
left: Box::new(left.into()),
3404+
op: Operator::IsNotDistinctFrom,
3405+
right: Box::new(right.into()),
3406+
})
33573407
}
33583408

33593409
#[test]

datafusion/sqllogictest/test_files/case.slt

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,3 +202,14 @@ query I
202202
SELECT CASE arrow_cast([1,2,3], 'FixedSizeList(3, Int64)') WHEN arrow_cast([1,2,3], 'FixedSizeList(3, Int64)') THEN 1 ELSE 0 END;
203203
----
204204
1
205+
206+
query IBB
207+
SELECT c,
208+
CASE WHEN c > 0 THEN true END AS c1,
209+
CASE WHEN c > 0 THEN true ELSE false END AS c2
210+
FROM (VALUES (1), (0), (-1), (NULL)) AS t(c)
211+
----
212+
1 true true
213+
0 NULL false
214+
-1 NULL false
215+
NULL NULL false

0 commit comments

Comments
 (0)