Skip to content

Commit

Permalink
null and comparison support
Browse files Browse the repository at this point in the history
  • Loading branch information
wjones127 committed Sep 4, 2023
1 parent 44d1b48 commit 4c1c3a9
Showing 1 changed file with 203 additions and 15 deletions.
218 changes: 203 additions & 15 deletions datafusion/optimizer/src/simplify_expressions/guarantees.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@
//! Logic to inject guarantees with expressions.
//!
use datafusion_common::{tree_node::TreeNodeRewriter, Result, ScalarValue};
use datafusion_expr::Expr;
use datafusion_expr::{lit, Between, BinaryExpr, Expr, Operator};
use std::collections::HashMap;

/// A bound on the value of an expression.
#[derive(Debug, Clone, PartialEq)]
pub struct GuaranteeBound {
/// The value of the bound.
pub bound: ScalarValue,
Expand Down Expand Up @@ -52,6 +53,7 @@ impl Default for GuaranteeBound {
/// This might be populated by null count statistics, for example. A null count
/// of zero would mean `NeverNull`, while a null count equal to row count would
/// mean `AlwaysNull`.
#[derive(Debug, Clone, PartialEq)]
pub enum NullStatus {
/// The expression is guaranteed to be non-null.
NeverNull,
Expand All @@ -66,6 +68,7 @@ pub enum NullStatus {
/// This is similar to [datafusion_physical_expr::intervals::Interval], except
/// that this is designed for working with logical expressions and also handles
/// nulls.
#[derive(Debug, Clone, PartialEq)]
pub struct Guarantee {
/// The min values that the expression can take on. If `min.bound` is
pub min: GuaranteeBound,
Expand All @@ -88,6 +91,23 @@ impl Guarantee {
null_status,
}
}

/// Whether values are guaranteed to be greater than the given value.
fn greater_than(&self, value: &ScalarValue) -> bool {
self.min.bound > *value || (self.min.bound == *value && self.min.open)
}

fn greater_than_or_eq(&self, value: &ScalarValue) -> bool {
self.min.bound >= *value
}

fn less_than(&self, value: &ScalarValue) -> bool {
self.max.bound < *value || (self.max.bound == *value && self.max.open)
}

fn less_than_or_eq(&self, value: &ScalarValue) -> bool {
self.max.bound <= *value
}
}

impl From<&ScalarValue> for Guarantee {
Expand Down Expand Up @@ -128,15 +148,180 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> {
type N = Expr;

fn mutate(&mut self, expr: Expr) -> Result<Expr> {
// IS NUll / NOT NUll

// Inequality expressions

// Columns (if bounds are equal and closed and column is not nullable)

// In list

Ok(expr)
match &expr {
// IS NUll / NOT NUll
Expr::IsNull(inner) => {
if let Some(guarantee) = self.guarantees.get(inner.as_ref()) {
match guarantee.null_status {
NullStatus::AlwaysNull => Ok(lit(true)),
NullStatus::NeverNull => Ok(lit(false)),
NullStatus::MaybeNull => Ok(expr),
}
} else {
Ok(expr)
}
}
Expr::IsNotNull(inner) => {
if let Some(guarantee) = self.guarantees.get(inner.as_ref()) {
match guarantee.null_status {
NullStatus::AlwaysNull => Ok(lit(false)),
NullStatus::NeverNull => Ok(lit(true)),
NullStatus::MaybeNull => Ok(expr),
}
} else {
Ok(expr)
}
}
// Inequality expressions
Expr::Between(Between {
expr: inner,
negated,
low,
high,
}) => {
if let Some(guarantee) = self.guarantees.get(inner.as_ref()) {
match (low.as_ref(), high.as_ref()) {
(Expr::Literal(low), Expr::Literal(high)) => {
if guarantee.greater_than_or_eq(low)
&& guarantee.less_than_or_eq(high)
{
// All values are between the bounds
Ok(lit(!negated))
} else if guarantee.greater_than(high)
|| guarantee.less_than(low)
{
// All values are outside the bounds
Ok(lit(*negated))
} else {
Ok(expr)
}
}
(Expr::Literal(low), _)
if !guarantee.less_than(low) && !negated =>
{
// All values are below the lower bound
Ok(lit(false))
}
(_, Expr::Literal(high))
if !guarantee.greater_than(high) && !negated =>
{
// All values are above the upper bound
Ok(lit(false))
}
_ => Ok(expr),
}
} else {
Ok(expr)
}
}

Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
// Check if this is a comparison
match op {
Operator::Eq
| Operator::NotEq
| Operator::Lt
| Operator::LtEq
| Operator::Gt
| Operator::GtEq => {}
_ => return Ok(expr),
};

// Check if this is a comparison between a column and literal
let (col, op, value) = match (left.as_ref(), right.as_ref()) {
(Expr::Column(_), Expr::Literal(value)) => (left, *op, value),
(Expr::Literal(value), Expr::Column(_)) => {
(right, op.swap().unwrap(), value)
}
_ => return Ok(expr),
};

if let Some(guarantee) = self.guarantees.get(col.as_ref()) {
match op {
Operator::Eq => {
if guarantee.greater_than(value) || guarantee.less_than(value)
{
// All values are outside the bounds
Ok(lit(false))
} else if guarantee.greater_than_or_eq(value)
&& guarantee.less_than_or_eq(value)
{
// All values are equal to the bound
Ok(lit(true))
} else {
Ok(expr)
}
}
Operator::NotEq => {
if guarantee.greater_than(value) || guarantee.less_than(value)
{
// All values are outside the bounds
Ok(lit(true))
} else if guarantee.greater_than_or_eq(value)
&& guarantee.less_than_or_eq(value)
{
// All values are equal to the bound
Ok(lit(false))
} else {
Ok(expr)
}
}
Operator::Gt => {
if guarantee.less_than_or_eq(value) {
// All values are less than or equal to the bound
Ok(lit(false))
} else if guarantee.greater_than(value) {
// All values are greater than the bound
Ok(lit(true))
} else {
Ok(expr)
}
}
Operator::GtEq => {
if guarantee.less_than(value) {
// All values are less than the bound
Ok(lit(false))
} else if guarantee.greater_than_or_eq(value) {
// All values are greater than or equal to the bound
Ok(lit(true))
} else {
Ok(expr)
}
}
Operator::Lt => {
if guarantee.greater_than_or_eq(value) {
// All values are greater than or equal to the bound
Ok(lit(false))
} else if guarantee.less_than(value) {
// All values are less than the bound
Ok(lit(true))
} else {
Ok(expr)
}
}
Operator::LtEq => {
if guarantee.greater_than(value) {
// All values are greater than the bound
Ok(lit(false))
} else if guarantee.less_than_or_eq(value) {
// All values are less than or equal to the bound
Ok(lit(true))
} else {
Ok(expr)
}
}
_ => Ok(expr),
}
} else {
Ok(expr)
}
}

// Columns (if bounds are equal and closed and column is not nullable)

// In list
_ => Ok(expr),
}
}
}

Expand Down Expand Up @@ -214,10 +399,11 @@ mod tests {
// These cases should be simplified
let cases = &[
(col("x").lt_eq(lit(1)), false),
(col("x").lt_eq(lit(3)), true),
(col("x").gt(lit(3)), false),
(col("y").gt_eq(lit(18628)), true),
(col("y").gt(lit(19000)), true),
(col("y").lt_eq(lit(17000)), false),
(col("y").gt_eq(lit(ScalarValue::Date32(Some(18628)))), true),
(col("y").gt_eq(lit(ScalarValue::Date32(Some(17000)))), true),
(col("y").lt_eq(lit(ScalarValue::Date32(Some(17000)))), false),
];

for (expr, expected_value) in cases {
Expand All @@ -231,8 +417,10 @@ mod tests {
// These cases should be left as-is
let cases = &[
col("x").gt(lit(2)),
col("x").lt_eq(lit(3)),
col("y").gt_eq(lit(17000)),
col("x").lt_eq(lit(2)),
col("x").between(lit(2), lit(5)),
col("x").not_between(lit(3), lit(10)),
col("y").gt(lit(ScalarValue::Date32(Some(19000)))),
];

for expr in cases {
Expand Down

0 comments on commit 4c1c3a9

Please sign in to comment.