diff --git a/datafusion/optimizer/src/simplify_expressions/guarantees.rs b/datafusion/optimizer/src/simplify_expressions/guarantees.rs index 4e142ef2803a..0772eaab50f3 100644 --- a/datafusion/optimizer/src/simplify_expressions/guarantees.rs +++ b/datafusion/optimizer/src/simplify_expressions/guarantees.rs @@ -18,7 +18,7 @@ //! Logic to inject guarantees with expressions. //! use datafusion_common::{tree_node::TreeNodeRewriter, Result, ScalarValue}; -use datafusion_expr::{lit, Between, BinaryExpr, Expr, Operator}; +use datafusion_expr::{expr::InList, lit, Between, BinaryExpr, Expr, Operator}; use std::collections::HashMap; /// A bound on the value of an expression. @@ -108,6 +108,11 @@ impl Guarantee { fn less_than_or_eq(&self, value: &ScalarValue) -> bool { self.max.bound <= *value } + + /// Whether the guarantee could contain the given value. + fn contains(&self, value: &ScalarValue) -> bool { + !self.less_than(value) && !self.greater_than(value) + } } impl From<&ScalarValue> for Guarantee { @@ -237,6 +242,7 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> { _ => return Ok(expr), }; + // TODO: can this be simplified? if let Some(guarantee) = self.guarantees.get(col.as_ref()) { match op { Operator::Eq => { @@ -339,7 +345,35 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> { } } - // In list + Expr::InList(InList { + expr: inner, + list, + negated, + }) => { + if let Some(guarantee) = self.guarantees.get(inner.as_ref()) { + // Can remove items from the list that don't match the guarantee + let new_list: Vec = list + .iter() + .filter(|item| { + if let Expr::Literal(item) = item { + guarantee.contains(item) + } else { + true + } + }) + .cloned() + .collect(); + + Ok(Expr::InList(InList { + expr: inner.clone(), + list: new_list, + negated: *negated, + })) + } else { + Ok(expr) + } + } + _ => Ok(expr), } } @@ -471,59 +505,53 @@ mod tests { #[test] fn test_in_list() { let guarantees = vec![ - // x = 2 - (col("x"), Guarantee::from(&ScalarValue::Int32(Some(2)))), - // 1 <= y < 10 + // 1 <= x < 10 ( - col("y"), + col("x"), Guarantee::new( Some(GuaranteeBound::new(ScalarValue::Int32(Some(1)), false)), Some(GuaranteeBound::new(ScalarValue::Int32(Some(10)), true)), NullStatus::NeverNull, ), ), - // z is null - (col("z"), Guarantee::from(&ScalarValue::Null)), ]; let mut rewriter = GuaranteeRewriter::new(guarantees.iter()); - // These cases should be simplified + // These cases should be simplified so the list doesn't contain any + // values the guarantee says are outside the range. + // (column_name, starting_list, negated, expected_list) let cases = &[ - // x IN () - (col("x").in_list(vec![], false), false), - // x IN (10, 11) - (col("x").in_list(vec![lit(10), lit(11)], false), false), - // x IN (10, 2) - (col("x").in_list(vec![lit(10), lit(2)], false), true), - // x NOT IN (10, 2) - (col("x").in_list(vec![lit(10), lit(2)], true), false), - // y IN (10, 11) - (col("y").in_list(vec![lit(10), lit(11)], false), false), - // y NOT IN (0, 22) - (col("y").in_list(vec![lit(0), lit(22)], true), true), - // z IN (10, 11) - (col("z").in_list(vec![lit(10), lit(11)], false), false), + // x IN (9, 11) => x IN (9) + ("x", vec![9, 11], false, vec![9]), + // x IN (10, 2) => x IN (2) + ("x", vec![10, 2], false, vec![2]), + // x NOT IN (9, 11) => x NOT IN (9) + ("x", vec![9, 11], true, vec![9]), + // x NOT IN (0, 22) => x NOT IN () + ("x", vec![0, 22], true, vec![]), ]; - for (expr, expected_value) in cases { + for (column_name, starting_list, negated, expected_list) in cases { + let expr = col(*column_name).in_list( + starting_list + .iter() + .map(|v| lit(ScalarValue::Int32(Some(*v)))) + .collect(), + *negated, + ); let output = expr.clone().rewrite(&mut rewriter).unwrap(); + let expected_list = expected_list + .iter() + .map(|v| lit(ScalarValue::Int32(Some(*v)))) + .collect(); assert_eq!( output, - Expr::Literal(ScalarValue::Boolean(Some(*expected_value))) + Expr::InList(InList { + expr: Box::new(col(*column_name)), + list: expected_list, + negated: *negated, + }) ); } - - // These cases should be left as-is - let cases = &[ - // y IN (10, 2) - col("y").in_list(vec![lit(10), lit(2)], false), - // y NOT IN (10, 2) - col("y").in_list(vec![lit(10), lit(2)], true), - ]; - - for expr in cases { - let output = expr.clone().rewrite(&mut rewriter).unwrap(); - assert_eq!(&output, expr); - } } }