Skip to content

Backport: Fix incorrect searched CASE optimization, Fix UNION field nullability tracking #78

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 5 additions & 23 deletions datafusion/expr/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@ use datafusion_common::file_options::file_type::FileType;
use datafusion_common::{
exec_err, get_target_functional_dependencies, internal_err, not_impl_err,
plan_datafusion_err, plan_err, Column, DFSchema, DFSchemaRef, DataFusionError,
FunctionalDependencies, Result, ScalarValue, TableReference, ToDFSchema,
UnnestOptions,
Result, ScalarValue, TableReference, ToDFSchema, UnnestOptions,
};
use datafusion_expr_common::type_coercion::binary::type_union_resolution;
use indexmap::IndexSet;
Expand Down Expand Up @@ -1508,27 +1507,10 @@ pub fn validate_unique_names<'a>(
/// [`TypeCoercionRewriter::coerce_union`]: https://docs.rs/datafusion-optimizer/latest/datafusion_optimizer/analyzer/type_coercion/struct.TypeCoercionRewriter.html#method.coerce_union
/// [`coerce_union_schema`]: https://docs.rs/datafusion-optimizer/latest/datafusion_optimizer/analyzer/type_coercion/fn.coerce_union_schema.html
pub fn union(left_plan: LogicalPlan, right_plan: LogicalPlan) -> Result<LogicalPlan> {
if left_plan.schema().fields().len() != right_plan.schema().fields().len() {
return plan_err!(
"UNION queries have different number of columns: \
left has {} columns whereas right has {} columns",
left_plan.schema().fields().len(),
right_plan.schema().fields().len()
);
}

// Temporarily use the schema from the left input and later rely on the analyzer to
// coerce the two schemas into a common one.

// Functional Dependencies doesn't preserve after UNION operation
let schema = (**left_plan.schema()).clone();
let schema =
Arc::new(schema.with_functional_dependencies(FunctionalDependencies::empty())?);

Ok(LogicalPlan::Union(Union {
inputs: vec![Arc::new(left_plan), Arc::new(right_plan)],
schema,
}))
Ok(LogicalPlan::Union(Union::try_new_with_loose_types(vec![
Arc::new(left_plan),
Arc::new(right_plan),
])?))
}

/// Create Projection
Expand Down
114 changes: 106 additions & 8 deletions datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -685,15 +685,13 @@ impl LogicalPlan {
}))
}
LogicalPlan::Union(Union { inputs, schema }) => {
let input_schema = inputs[0].schema();
// If inputs are not pruned do not change schema
// TODO this seems wrong (shouldn't we always use the schema of the input?)
let schema = if schema.fields().len() == input_schema.fields().len() {
Arc::clone(&schema)
let first_input_schema = inputs[0].schema();
if schema.fields().len() == first_input_schema.fields().len() {
// If inputs are not pruned do not change schema
Ok(LogicalPlan::Union(Union { inputs, schema }))
} else {
Arc::clone(input_schema)
};
Ok(LogicalPlan::Union(Union { inputs, schema }))
Ok(LogicalPlan::Union(Union::try_new(inputs)?))
}
}
LogicalPlan::Distinct(distinct) => {
let distinct = match distinct {
Expand Down Expand Up @@ -2598,6 +2596,106 @@ pub struct Union {
pub schema: DFSchemaRef,
}

impl Union {
/// Constructs new Union instance deriving schema from inputs.
fn try_new(inputs: Vec<Arc<LogicalPlan>>) -> Result<Self> {
let schema = Self::derive_schema_from_inputs(&inputs, false)?;
Ok(Union { inputs, schema })
}

/// Constructs new Union instance deriving schema from inputs.
/// Inputs do not have to have matching types and produced schema will
/// take type from the first input.
pub fn try_new_with_loose_types(inputs: Vec<Arc<LogicalPlan>>) -> Result<Self> {
let schema = Self::derive_schema_from_inputs(&inputs, true)?;
Ok(Union { inputs, schema })
}

/// Constructs new Union instance deriving schema from inputs.
///
/// `loose_types` if true, inputs do not have to have matching types and produced schema will
/// take type from the first input. TODO this is not necessarily reasonable behavior.
fn derive_schema_from_inputs(
inputs: &[Arc<LogicalPlan>],
loose_types: bool,
) -> Result<DFSchemaRef> {
if inputs.len() < 2 {
return plan_err!("UNION requires at least two inputs");
}
let first_schema = inputs[0].schema();
let fields_count = first_schema.fields().len();
for input in inputs.iter().skip(1) {
if fields_count != input.schema().fields().len() {
return plan_err!(
"UNION queries have different number of columns: \
left has {} columns whereas right has {} columns",
fields_count,
input.schema().fields().len()
);
}
}

let union_fields = (0..fields_count)
.map(|i| {
let fields = inputs
.iter()
.map(|input| input.schema().field(i))
.collect::<Vec<_>>();
let first_field = fields[0];
let name = first_field.name();
let data_type = if loose_types {
// TODO apply type coercion here, or document why it's better to defer
// temporarily use the data type from the left input and later rely on the analyzer to
// coerce the two schemas into a common one.
first_field.data_type()
} else {
fields.iter().skip(1).try_fold(
first_field.data_type(),
|acc, field| {
if acc != field.data_type() {
return plan_err!(
"UNION field {i} have different type in inputs: \
left has {} whereas right has {}",
first_field.data_type(),
field.data_type()
);
}
Ok(acc)
},
)?
};
let nullable = fields.iter().any(|field| field.is_nullable());
let mut field = Field::new(name, data_type.clone(), nullable);
let field_metadata =
intersect_maps(fields.iter().map(|field| field.metadata()));
field.set_metadata(field_metadata);
// TODO reusing table reference from the first schema is probably wrong
let table_reference = first_schema.qualified_field(i).0.cloned();
Ok((table_reference, Arc::new(field)))
})
.collect::<Result<_>>()?;
let union_schema_metadata =
intersect_maps(inputs.iter().map(|input| input.schema().metadata()));

// Functional Dependencies doesn't preserve after UNION operation
let schema = DFSchema::new_with_metadata(union_fields, union_schema_metadata)?;
let schema = Arc::new(schema);

Ok(schema)
}
}

fn intersect_maps<'a>(
inputs: impl IntoIterator<Item = &'a HashMap<String, String>>,
) -> HashMap<String, String> {
let mut inputs = inputs.into_iter();
let mut merged: HashMap<String, String> = inputs.next().cloned().unwrap_or_default();
for input in inputs {
merged.retain(|k, v| input.get(k) == Some(v));
}
merged
}

// Manual implementation needed because of `schema` field. Comparison excludes this field.
impl PartialOrd for Union {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Expand Down
98 changes: 74 additions & 24 deletions datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1384,29 +1384,26 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> {
when_then_expr,
else_expr,
}) if !when_then_expr.is_empty()
&& when_then_expr.len() < 3 // The rewrite is O(n!) so limit to small number
&& when_then_expr.len() < 3 // The rewrite is O(n²) so limit to small number
&& info.is_boolean_type(&when_then_expr[0].1)? =>
{
// The disjunction of all the when predicates encountered so far
// String disjunction of all the when predicates encountered so far. Not nullable.
let mut filter_expr = lit(false);
// The disjunction of all the cases
let mut out_expr = lit(false);

for (when, then) in when_then_expr {
let case_expr = when
.as_ref()
.clone()
.and(filter_expr.clone().not())
.and(*then);
let when = is_exactly_true(*when, info)?;
let case_expr =
when.clone().and(filter_expr.clone().not()).and(*then);

out_expr = out_expr.or(case_expr);
filter_expr = filter_expr.or(*when);
filter_expr = filter_expr.or(when);
}

if let Some(else_expr) = else_expr {
let case_expr = filter_expr.not().and(*else_expr);
out_expr = out_expr.or(case_expr);
}
let else_expr = else_expr.map(|b| *b).unwrap_or_else(lit_bool_null);
let case_expr = filter_expr.not().and(else_expr);
out_expr = out_expr.or(case_expr);

// Do a first pass at simplification
out_expr.rewrite(self)?
Expand Down Expand Up @@ -1826,6 +1823,19 @@ fn inlist_except(mut l1: InList, l2: &InList) -> Result<Expr> {
Ok(Expr::InList(l1))
}

/// Returns expression testing a boolean `expr` for being exactly `true` (not `false` or NULL).
fn is_exactly_true(expr: Expr, info: &impl SimplifyInfo) -> Result<Expr> {
if !info.nullable(&expr)? {
Ok(expr)
} else {
Ok(Expr::BinaryExpr(BinaryExpr {
left: Box::new(expr),
op: Operator::IsNotDistinctFrom,
right: Box::new(lit(true)),
}))
}
}

#[cfg(test)]
mod tests {
use crate::simplify_expressions::SimplifyContext;
Expand Down Expand Up @@ -3243,12 +3253,12 @@ mod tests {
simplify(Expr::Case(Case::new(
None,
vec![(
Box::new(col("c2").not_eq(lit(false))),
Box::new(col("c2_non_null").not_eq(lit(false))),
Box::new(lit("ok").eq(lit("not_ok"))),
)],
Some(Box::new(col("c2").eq(lit(true)))),
Some(Box::new(col("c2_non_null").eq(lit(true)))),
))),
col("c2").not().and(col("c2")) // #1716
lit(false) // #1716
);

// CASE WHEN c2 != false THEN "ok" == "ok" ELSE c2
Expand All @@ -3263,12 +3273,12 @@ mod tests {
simplify(simplify(Expr::Case(Case::new(
None,
vec![(
Box::new(col("c2").not_eq(lit(false))),
Box::new(col("c2_non_null").not_eq(lit(false))),
Box::new(lit("ok").eq(lit("ok"))),
)],
Some(Box::new(col("c2").eq(lit(true)))),
Some(Box::new(col("c2_non_null").eq(lit(true)))),
)))),
col("c2")
col("c2_non_null")
);

// CASE WHEN ISNULL(c2) THEN true ELSE c2
Expand Down Expand Up @@ -3299,12 +3309,12 @@ mod tests {
simplify(simplify(Expr::Case(Case::new(
None,
vec![
(Box::new(col("c1")), Box::new(lit(true)),),
(Box::new(col("c2")), Box::new(lit(false)),),
(Box::new(col("c1_non_null")), Box::new(lit(true)),),
(Box::new(col("c2_non_null")), Box::new(lit(false)),),
],
Some(Box::new(lit(true))),
)))),
col("c1").or(col("c1").not().and(col("c2").not()))
col("c1_non_null").or(col("c1_non_null").not().and(col("c2_non_null").not()))
);

// CASE WHEN c1 then true WHEN c2 then true ELSE false
Expand All @@ -3318,13 +3328,53 @@ mod tests {
simplify(simplify(Expr::Case(Case::new(
None,
vec![
(Box::new(col("c1")), Box::new(lit(true)),),
(Box::new(col("c2")), Box::new(lit(false)),),
(Box::new(col("c1_non_null")), Box::new(lit(true)),),
(Box::new(col("c2_non_null")), Box::new(lit(false)),),
],
Some(Box::new(lit(true))),
)))),
col("c1").or(col("c1").not().and(col("c2").not()))
col("c1_non_null").or(col("c1_non_null").not().and(col("c2_non_null").not()))
);

// CASE WHEN c > 0 THEN true END AS c1
assert_eq!(
simplify(simplify(Expr::Case(Case::new(
None,
vec![(Box::new(col("c3").gt(lit(0_i64))), Box::new(lit(true)))],
None,
)))),
not_distinct_from(col("c3").gt(lit(0_i64)), lit(true)).or(distinct_from(
col("c3").gt(lit(0_i64)),
lit(true)
)
.and(lit_bool_null()))
);

// CASE WHEN c > 0 THEN true ELSE false END AS c1
assert_eq!(
simplify(simplify(Expr::Case(Case::new(
None,
vec![(Box::new(col("c3").gt(lit(0_i64))), Box::new(lit(true)))],
Some(Box::new(lit(false))),
)))),
not_distinct_from(col("c3").gt(lit(0_i64)), lit(true))
);
}

fn distinct_from(left: impl Into<Expr>, right: impl Into<Expr>) -> Expr {
Expr::BinaryExpr(BinaryExpr {
left: Box::new(left.into()),
op: Operator::IsDistinctFrom,
right: Box::new(right.into()),
})
}

fn not_distinct_from(left: impl Into<Expr>, right: impl Into<Expr>) -> Expr {
Expr::BinaryExpr(BinaryExpr {
left: Box::new(left.into()),
op: Operator::IsNotDistinctFrom,
right: Box::new(right.into()),
})
}

#[test]
Expand Down
11 changes: 11 additions & 0 deletions datafusion/sqllogictest/test_files/case.slt
Original file line number Diff line number Diff line change
Expand Up @@ -202,3 +202,14 @@ query I
SELECT CASE arrow_cast([1,2,3], 'FixedSizeList(3, Int64)') WHEN arrow_cast([1,2,3], 'FixedSizeList(3, Int64)') THEN 1 ELSE 0 END;
----
1

query IBB
SELECT c,
CASE WHEN c > 0 THEN true END AS c1,
CASE WHEN c > 0 THEN true ELSE false END AS c2
FROM (VALUES (1), (0), (-1), (NULL)) AS t(c)
----
1 true true
0 NULL false
-1 NULL false
NULL NULL false
15 changes: 15 additions & 0 deletions datafusion/sqllogictest/test_files/union.slt
Original file line number Diff line number Diff line change
Expand Up @@ -761,3 +761,18 @@ SELECT NULL WHERE FALSE;
----
0.5
1

# test for https://github.com/apache/datafusion/issues/14352
query TB rowsort
SELECT
a,
a IS NOT NULL
FROM (
-- second column, even though it's not selected, was necessary to reproduce the bug linked above
SELECT 'foo' AS a, 3 AS b
UNION ALL
SELECT NULL AS a, 4 AS b
)
----
NULL false
foo true