Skip to content

Commit

Permalink
implement inlist guarantee use
Browse files Browse the repository at this point in the history
  • Loading branch information
wjones127 committed Sep 4, 2023
1 parent 2134f2f commit caa738f
Showing 1 changed file with 66 additions and 38 deletions.
104 changes: 66 additions & 38 deletions datafusion/optimizer/src/simplify_expressions/guarantees.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
//! Logic to inject guarantees with expressions.
//!
use datafusion_common::{tree_node::TreeNodeRewriter, Result, ScalarValue};
use datafusion_expr::{lit, Between, BinaryExpr, Expr, Operator};
use datafusion_expr::{expr::InList, lit, Between, BinaryExpr, Expr, Operator};
use std::collections::HashMap;

/// A bound on the value of an expression.
Expand Down Expand Up @@ -108,6 +108,11 @@ impl Guarantee {
fn less_than_or_eq(&self, value: &ScalarValue) -> bool {
self.max.bound <= *value
}

/// Whether the guarantee could contain the given value.
fn contains(&self, value: &ScalarValue) -> bool {
!self.less_than(value) && !self.greater_than(value)
}
}

impl From<&ScalarValue> for Guarantee {
Expand Down Expand Up @@ -237,6 +242,7 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> {
_ => return Ok(expr),
};

// TODO: can this be simplified?
if let Some(guarantee) = self.guarantees.get(col.as_ref()) {
match op {
Operator::Eq => {
Expand Down Expand Up @@ -339,7 +345,35 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> {
}
}

// In list
Expr::InList(InList {
expr: inner,
list,
negated,
}) => {
if let Some(guarantee) = self.guarantees.get(inner.as_ref()) {
// Can remove items from the list that don't match the guarantee
let new_list: Vec<Expr> = list
.iter()
.filter(|item| {
if let Expr::Literal(item) = item {
guarantee.contains(item)
} else {
true
}
})
.cloned()
.collect();

Ok(Expr::InList(InList {
expr: inner.clone(),
list: new_list,
negated: *negated,
}))
} else {
Ok(expr)
}
}

_ => Ok(expr),
}
}
Expand Down Expand Up @@ -471,59 +505,53 @@ mod tests {
#[test]
fn test_in_list() {
let guarantees = vec![
// x = 2
(col("x"), Guarantee::from(&ScalarValue::Int32(Some(2)))),
// 1 <= y < 10
// 1 <= x < 10
(
col("y"),
col("x"),
Guarantee::new(
Some(GuaranteeBound::new(ScalarValue::Int32(Some(1)), false)),
Some(GuaranteeBound::new(ScalarValue::Int32(Some(10)), true)),
NullStatus::NeverNull,
),
),
// z is null
(col("z"), Guarantee::from(&ScalarValue::Null)),
];
let mut rewriter = GuaranteeRewriter::new(guarantees.iter());

// These cases should be simplified
// These cases should be simplified so the list doesn't contain any
// values the guarantee says are outside the range.
// (column_name, starting_list, negated, expected_list)
let cases = &[
// x IN ()
(col("x").in_list(vec![], false), false),
// x IN (10, 11)
(col("x").in_list(vec![lit(10), lit(11)], false), false),
// x IN (10, 2)
(col("x").in_list(vec![lit(10), lit(2)], false), true),
// x NOT IN (10, 2)
(col("x").in_list(vec![lit(10), lit(2)], true), false),
// y IN (10, 11)
(col("y").in_list(vec![lit(10), lit(11)], false), false),
// y NOT IN (0, 22)
(col("y").in_list(vec![lit(0), lit(22)], true), true),
// z IN (10, 11)
(col("z").in_list(vec![lit(10), lit(11)], false), false),
// x IN (9, 11) => x IN (9)
("x", vec![9, 11], false, vec![9]),
// x IN (10, 2) => x IN (2)
("x", vec![10, 2], false, vec![2]),
// x NOT IN (9, 11) => x NOT IN (9)
("x", vec![9, 11], true, vec![9]),
// x NOT IN (0, 22) => x NOT IN ()
("x", vec![0, 22], true, vec![]),
];

for (expr, expected_value) in cases {
for (column_name, starting_list, negated, expected_list) in cases {
let expr = col(*column_name).in_list(
starting_list
.iter()
.map(|v| lit(ScalarValue::Int32(Some(*v))))
.collect(),
*negated,
);
let output = expr.clone().rewrite(&mut rewriter).unwrap();
let expected_list = expected_list
.iter()
.map(|v| lit(ScalarValue::Int32(Some(*v))))
.collect();
assert_eq!(
output,
Expr::Literal(ScalarValue::Boolean(Some(*expected_value)))
Expr::InList(InList {
expr: Box::new(col(*column_name)),
list: expected_list,
negated: *negated,
})
);
}

// These cases should be left as-is
let cases = &[
// y IN (10, 2)
col("y").in_list(vec![lit(10), lit(2)], false),
// y NOT IN (10, 2)
col("y").in_list(vec![lit(10), lit(2)], true),
];

for expr in cases {
let output = expr.clone().rewrite(&mut rewriter).unwrap();
assert_eq!(&output, expr);
}
}
}

0 comments on commit caa738f

Please sign in to comment.