diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index b7839c4873af..54f948534856 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -54,8 +54,7 @@ use datafusion_common::file_options::file_type::FileType; use datafusion_common::{ exec_err, get_target_functional_dependencies, internal_err, not_impl_err, plan_datafusion_err, plan_err, Column, DFSchema, DFSchemaRef, DataFusionError, - FunctionalDependencies, Result, ScalarValue, TableReference, ToDFSchema, - UnnestOptions, + Result, ScalarValue, TableReference, ToDFSchema, UnnestOptions, }; use datafusion_expr_common::type_coercion::binary::type_union_resolution; use indexmap::IndexSet; @@ -1508,27 +1507,10 @@ pub fn validate_unique_names<'a>( /// [`TypeCoercionRewriter::coerce_union`]: https://docs.rs/datafusion-optimizer/latest/datafusion_optimizer/analyzer/type_coercion/struct.TypeCoercionRewriter.html#method.coerce_union /// [`coerce_union_schema`]: https://docs.rs/datafusion-optimizer/latest/datafusion_optimizer/analyzer/type_coercion/fn.coerce_union_schema.html pub fn union(left_plan: LogicalPlan, right_plan: LogicalPlan) -> Result { - if left_plan.schema().fields().len() != right_plan.schema().fields().len() { - return plan_err!( - "UNION queries have different number of columns: \ - left has {} columns whereas right has {} columns", - left_plan.schema().fields().len(), - right_plan.schema().fields().len() - ); - } - - // Temporarily use the schema from the left input and later rely on the analyzer to - // coerce the two schemas into a common one. - - // Functional Dependencies doesn't preserve after UNION operation - let schema = (**left_plan.schema()).clone(); - let schema = - Arc::new(schema.with_functional_dependencies(FunctionalDependencies::empty())?); - - Ok(LogicalPlan::Union(Union { - inputs: vec![Arc::new(left_plan), Arc::new(right_plan)], - schema, - })) + Ok(LogicalPlan::Union(Union::try_new_with_loose_types(vec![ + Arc::new(left_plan), + Arc::new(right_plan), + ])?)) } /// Create Projection diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 191a42e38e3a..32660d5d6159 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -685,15 +685,13 @@ impl LogicalPlan { })) } LogicalPlan::Union(Union { inputs, schema }) => { - let input_schema = inputs[0].schema(); - // If inputs are not pruned do not change schema - // TODO this seems wrong (shouldn't we always use the schema of the input?) - let schema = if schema.fields().len() == input_schema.fields().len() { - Arc::clone(&schema) + let first_input_schema = inputs[0].schema(); + if schema.fields().len() == first_input_schema.fields().len() { + // If inputs are not pruned do not change schema + Ok(LogicalPlan::Union(Union { inputs, schema })) } else { - Arc::clone(input_schema) - }; - Ok(LogicalPlan::Union(Union { inputs, schema })) + Ok(LogicalPlan::Union(Union::try_new(inputs)?)) + } } LogicalPlan::Distinct(distinct) => { let distinct = match distinct { @@ -2598,6 +2596,106 @@ pub struct Union { pub schema: DFSchemaRef, } +impl Union { + /// Constructs new Union instance deriving schema from inputs. + fn try_new(inputs: Vec>) -> Result { + let schema = Self::derive_schema_from_inputs(&inputs, false)?; + Ok(Union { inputs, schema }) + } + + /// Constructs new Union instance deriving schema from inputs. + /// Inputs do not have to have matching types and produced schema will + /// take type from the first input. + pub fn try_new_with_loose_types(inputs: Vec>) -> Result { + let schema = Self::derive_schema_from_inputs(&inputs, true)?; + Ok(Union { inputs, schema }) + } + + /// Constructs new Union instance deriving schema from inputs. + /// + /// `loose_types` if true, inputs do not have to have matching types and produced schema will + /// take type from the first input. TODO this is not necessarily reasonable behavior. + fn derive_schema_from_inputs( + inputs: &[Arc], + loose_types: bool, + ) -> Result { + if inputs.len() < 2 { + return plan_err!("UNION requires at least two inputs"); + } + let first_schema = inputs[0].schema(); + let fields_count = first_schema.fields().len(); + for input in inputs.iter().skip(1) { + if fields_count != input.schema().fields().len() { + return plan_err!( + "UNION queries have different number of columns: \ + left has {} columns whereas right has {} columns", + fields_count, + input.schema().fields().len() + ); + } + } + + let union_fields = (0..fields_count) + .map(|i| { + let fields = inputs + .iter() + .map(|input| input.schema().field(i)) + .collect::>(); + let first_field = fields[0]; + let name = first_field.name(); + let data_type = if loose_types { + // TODO apply type coercion here, or document why it's better to defer + // temporarily use the data type from the left input and later rely on the analyzer to + // coerce the two schemas into a common one. + first_field.data_type() + } else { + fields.iter().skip(1).try_fold( + first_field.data_type(), + |acc, field| { + if acc != field.data_type() { + return plan_err!( + "UNION field {i} have different type in inputs: \ + left has {} whereas right has {}", + first_field.data_type(), + field.data_type() + ); + } + Ok(acc) + }, + )? + }; + let nullable = fields.iter().any(|field| field.is_nullable()); + let mut field = Field::new(name, data_type.clone(), nullable); + let field_metadata = + intersect_maps(fields.iter().map(|field| field.metadata())); + field.set_metadata(field_metadata); + // TODO reusing table reference from the first schema is probably wrong + let table_reference = first_schema.qualified_field(i).0.cloned(); + Ok((table_reference, Arc::new(field))) + }) + .collect::>()?; + let union_schema_metadata = + intersect_maps(inputs.iter().map(|input| input.schema().metadata())); + + // Functional Dependencies doesn't preserve after UNION operation + let schema = DFSchema::new_with_metadata(union_fields, union_schema_metadata)?; + let schema = Arc::new(schema); + + Ok(schema) + } +} + +fn intersect_maps<'a>( + inputs: impl IntoIterator>, +) -> HashMap { + let mut inputs = inputs.into_iter(); + let mut merged: HashMap = inputs.next().cloned().unwrap_or_default(); + for input in inputs { + merged.retain(|k, v| input.get(k) == Some(v)); + } + merged +} + // Manual implementation needed because of `schema` field. Comparison excludes this field. impl PartialOrd for Union { fn partial_cmp(&self, other: &Self) -> Option { diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index e0df6a3a68ce..72ca2276b4fa 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -1384,29 +1384,26 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { when_then_expr, else_expr, }) if !when_then_expr.is_empty() - && when_then_expr.len() < 3 // The rewrite is O(n!) so limit to small number + && when_then_expr.len() < 3 // The rewrite is O(n²) so limit to small number && info.is_boolean_type(&when_then_expr[0].1)? => { - // The disjunction of all the when predicates encountered so far + // String disjunction of all the when predicates encountered so far. Not nullable. let mut filter_expr = lit(false); // The disjunction of all the cases let mut out_expr = lit(false); for (when, then) in when_then_expr { - let case_expr = when - .as_ref() - .clone() - .and(filter_expr.clone().not()) - .and(*then); + let when = is_exactly_true(*when, info)?; + let case_expr = + when.clone().and(filter_expr.clone().not()).and(*then); out_expr = out_expr.or(case_expr); - filter_expr = filter_expr.or(*when); + filter_expr = filter_expr.or(when); } - if let Some(else_expr) = else_expr { - let case_expr = filter_expr.not().and(*else_expr); - out_expr = out_expr.or(case_expr); - } + let else_expr = else_expr.map(|b| *b).unwrap_or_else(lit_bool_null); + let case_expr = filter_expr.not().and(else_expr); + out_expr = out_expr.or(case_expr); // Do a first pass at simplification out_expr.rewrite(self)? @@ -1826,6 +1823,19 @@ fn inlist_except(mut l1: InList, l2: &InList) -> Result { Ok(Expr::InList(l1)) } +/// Returns expression testing a boolean `expr` for being exactly `true` (not `false` or NULL). +fn is_exactly_true(expr: Expr, info: &impl SimplifyInfo) -> Result { + if !info.nullable(&expr)? { + Ok(expr) + } else { + Ok(Expr::BinaryExpr(BinaryExpr { + left: Box::new(expr), + op: Operator::IsNotDistinctFrom, + right: Box::new(lit(true)), + })) + } +} + #[cfg(test)] mod tests { use crate::simplify_expressions::SimplifyContext; @@ -3243,12 +3253,12 @@ mod tests { simplify(Expr::Case(Case::new( None, vec![( - Box::new(col("c2").not_eq(lit(false))), + Box::new(col("c2_non_null").not_eq(lit(false))), Box::new(lit("ok").eq(lit("not_ok"))), )], - Some(Box::new(col("c2").eq(lit(true)))), + Some(Box::new(col("c2_non_null").eq(lit(true)))), ))), - col("c2").not().and(col("c2")) // #1716 + lit(false) // #1716 ); // CASE WHEN c2 != false THEN "ok" == "ok" ELSE c2 @@ -3263,12 +3273,12 @@ mod tests { simplify(simplify(Expr::Case(Case::new( None, vec![( - Box::new(col("c2").not_eq(lit(false))), + Box::new(col("c2_non_null").not_eq(lit(false))), Box::new(lit("ok").eq(lit("ok"))), )], - Some(Box::new(col("c2").eq(lit(true)))), + Some(Box::new(col("c2_non_null").eq(lit(true)))), )))), - col("c2") + col("c2_non_null") ); // CASE WHEN ISNULL(c2) THEN true ELSE c2 @@ -3299,12 +3309,12 @@ mod tests { simplify(simplify(Expr::Case(Case::new( None, vec![ - (Box::new(col("c1")), Box::new(lit(true)),), - (Box::new(col("c2")), Box::new(lit(false)),), + (Box::new(col("c1_non_null")), Box::new(lit(true)),), + (Box::new(col("c2_non_null")), Box::new(lit(false)),), ], Some(Box::new(lit(true))), )))), - col("c1").or(col("c1").not().and(col("c2").not())) + col("c1_non_null").or(col("c1_non_null").not().and(col("c2_non_null").not())) ); // CASE WHEN c1 then true WHEN c2 then true ELSE false @@ -3318,13 +3328,53 @@ mod tests { simplify(simplify(Expr::Case(Case::new( None, vec![ - (Box::new(col("c1")), Box::new(lit(true)),), - (Box::new(col("c2")), Box::new(lit(false)),), + (Box::new(col("c1_non_null")), Box::new(lit(true)),), + (Box::new(col("c2_non_null")), Box::new(lit(false)),), ], Some(Box::new(lit(true))), )))), - col("c1").or(col("c1").not().and(col("c2").not())) + col("c1_non_null").or(col("c1_non_null").not().and(col("c2_non_null").not())) + ); + + // CASE WHEN c > 0 THEN true END AS c1 + assert_eq!( + simplify(simplify(Expr::Case(Case::new( + None, + vec![(Box::new(col("c3").gt(lit(0_i64))), Box::new(lit(true)))], + None, + )))), + not_distinct_from(col("c3").gt(lit(0_i64)), lit(true)).or(distinct_from( + col("c3").gt(lit(0_i64)), + lit(true) + ) + .and(lit_bool_null())) ); + + // CASE WHEN c > 0 THEN true ELSE false END AS c1 + assert_eq!( + simplify(simplify(Expr::Case(Case::new( + None, + vec![(Box::new(col("c3").gt(lit(0_i64))), Box::new(lit(true)))], + Some(Box::new(lit(false))), + )))), + not_distinct_from(col("c3").gt(lit(0_i64)), lit(true)) + ); + } + + fn distinct_from(left: impl Into, right: impl Into) -> Expr { + Expr::BinaryExpr(BinaryExpr { + left: Box::new(left.into()), + op: Operator::IsDistinctFrom, + right: Box::new(right.into()), + }) + } + + fn not_distinct_from(left: impl Into, right: impl Into) -> Expr { + Expr::BinaryExpr(BinaryExpr { + left: Box::new(left.into()), + op: Operator::IsNotDistinctFrom, + right: Box::new(right.into()), + }) } #[test] diff --git a/datafusion/sqllogictest/test_files/case.slt b/datafusion/sqllogictest/test_files/case.slt index 3c967eed219a..41f60cd379af 100644 --- a/datafusion/sqllogictest/test_files/case.slt +++ b/datafusion/sqllogictest/test_files/case.slt @@ -202,3 +202,14 @@ query I SELECT CASE arrow_cast([1,2,3], 'FixedSizeList(3, Int64)') WHEN arrow_cast([1,2,3], 'FixedSizeList(3, Int64)') THEN 1 ELSE 0 END; ---- 1 + +query IBB +SELECT c, + CASE WHEN c > 0 THEN true END AS c1, + CASE WHEN c > 0 THEN true ELSE false END AS c2 +FROM (VALUES (1), (0), (-1), (NULL)) AS t(c) +---- +1 true true +0 NULL false +-1 NULL false +NULL NULL false diff --git a/datafusion/sqllogictest/test_files/union.slt b/datafusion/sqllogictest/test_files/union.slt index fb7afdda2ea8..beceade43232 100644 --- a/datafusion/sqllogictest/test_files/union.slt +++ b/datafusion/sqllogictest/test_files/union.slt @@ -761,3 +761,18 @@ SELECT NULL WHERE FALSE; ---- 0.5 1 + +# test for https://github.com/apache/datafusion/issues/14352 +query TB rowsort +SELECT + a, + a IS NOT NULL +FROM ( + -- second column, even though it's not selected, was necessary to reproduce the bug linked above + SELECT 'foo' AS a, 3 AS b + UNION ALL + SELECT NULL AS a, 4 AS b +) +---- +NULL false +foo true