diff --git a/datafusion/optimizer/src/simplify_expressions/guarantees.rs b/datafusion/optimizer/src/simplify_expressions/guarantees.rs index 3bca3345ae4c..4e142ef2803a 100644 --- a/datafusion/optimizer/src/simplify_expressions/guarantees.rs +++ b/datafusion/optimizer/src/simplify_expressions/guarantees.rs @@ -121,9 +121,10 @@ impl From<&ScalarValue> for Guarantee { bound: value.clone(), open: false, }, - null_status: match value { - ScalarValue::Null => NullStatus::AlwaysNull, - _ => NullStatus::NeverNull, + null_status: if value.is_null() { + NullStatus::AlwaysNull + } else { + NullStatus::NeverNull }, } } @@ -318,6 +319,25 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> { } // Columns (if bounds are equal and closed and column is not nullable) + Expr::Column(_) => { + if let Some(guarantee) = self.guarantees.get(&expr) { + if guarantee.min == guarantee.max + // Case where column has a single valid value + && ((!guarantee.min.open + && !guarantee.min.bound.is_null() + && guarantee.null_status == NullStatus::NeverNull) + // Case where column is always null + || (guarantee.min.bound.is_null() + && guarantee.null_status == NullStatus::AlwaysNull)) + { + Ok(lit(guarantee.min.bound.clone())) + } else { + Ok(expr) + } + } else { + Ok(expr) + } + } // In list _ => Ok(expr), @@ -336,25 +356,21 @@ mod tests { fn test_null_handling() { // IsNull / IsNotNull can be rewritten to true / false let guarantees = vec![ - (col("x"), Guarantee::new(None, None, NullStatus::AlwaysNull)), - (col("y"), Guarantee::new(None, None, NullStatus::NeverNull)), + // Note: AlwaysNull case handled by test_column_single_value test, + // since it's a special case of a column with a single value. + (col("x"), Guarantee::new(None, None, NullStatus::NeverNull)), ]; let mut rewriter = GuaranteeRewriter::new(guarantees.iter()); - let cases = &[ - (col("x").is_null(), true), - (col("x").is_not_null(), false), - (col("y").is_null(), false), - (col("y").is_not_null(), true), - ]; + // x IS NULL => guaranteed false + let expr = col("x").is_null(); + let output = expr.clone().rewrite(&mut rewriter).unwrap(); + assert_eq!(output, lit(false)); - for (expr, expected_value) in cases { - let output = expr.clone().rewrite(&mut rewriter).unwrap(); - assert_eq!( - output, - Expr::Literal(ScalarValue::Boolean(Some(*expected_value))) - ); - } + // x IS NOT NULL => guaranteed true + let expr = col("x").is_not_null(); + let output = expr.clone().rewrite(&mut rewriter).unwrap(); + assert_eq!(output, lit(true)); } #[test] @@ -431,35 +447,24 @@ mod tests { #[test] fn test_column_single_value() { - let guarantees = vec![ - // x = 2 - (col("x"), Guarantee::from(&ScalarValue::Int32(Some(2)))), - // y is Null - (col("y"), Guarantee::from(&ScalarValue::Null)), + let scalars = [ + ScalarValue::Null, + ScalarValue::Int32(Some(1)), + ScalarValue::Boolean(Some(true)), + ScalarValue::Boolean(None), + ScalarValue::Utf8(Some("abc".to_string())), + ScalarValue::LargeUtf8(Some("def".to_string())), + ScalarValue::Date32(Some(18628)), + ScalarValue::Date32(None), + ScalarValue::Decimal128(Some(1000), 19, 2), ]; - let mut rewriter = GuaranteeRewriter::new(guarantees.iter()); - // These cases should be simplified - let cases = &[ - (col("x").lt_eq(lit(1)), false), - (col("x").gt(lit(3)), false), - (col("x").eq(lit(1)), false), - (col("x").eq(lit(2)), true), - (col("x").gt(lit(1)), true), - (col("x").lt_eq(lit(2)), true), - (col("x").is_not_null(), true), - (col("x").is_null(), false), - (col("y").is_null(), true), - (col("y").is_not_null(), false), - (col("y").lt_eq(lit(17000)), false), - ]; + for scalar in &scalars { + let guarantees = vec![(col("x"), Guarantee::from(scalar))]; + let mut rewriter = GuaranteeRewriter::new(guarantees.iter()); - for (expr, expected_value) in cases { - let output = expr.clone().rewrite(&mut rewriter).unwrap(); - assert_eq!( - output, - Expr::Literal(ScalarValue::Boolean(Some(*expected_value))) - ); + let output = col("x").rewrite(&mut rewriter).unwrap(); + assert_eq!(output, Expr::Literal(scalar.clone())); } }