Skip to content

Commit

Permalink
consider volatile function in simply_expression
Browse files Browse the repository at this point in the history
  • Loading branch information
Lordworms committed Oct 26, 2024
1 parent 813220d commit b9fa3ce
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 30 deletions.
51 changes: 28 additions & 23 deletions datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -835,40 +835,38 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> {
left,
op: Or,
right,
}) if expr_contains(&left, &right, Or) => Transformed::yes(*left),
}) if expr_contains(&left, &right, Or)? => Transformed::yes(*left),
// A OR (..A..) --> (..A..)
Expr::BinaryExpr(BinaryExpr {
left,
op: Or,
right,
}) if expr_contains(&right, &left, Or) => Transformed::yes(*right),
}) if expr_contains(&right, &left, Or)? => Transformed::yes(*right),
// A OR (A AND B) --> A
Expr::BinaryExpr(BinaryExpr {
left,
op: Or,
right,
}) if is_op_with(And, &right, &left) => Transformed::yes(*left),
}) if is_op_with(And, &right, &left)? => Transformed::yes(*left),
// (A AND B) OR A --> A
Expr::BinaryExpr(BinaryExpr {
left,
op: Or,
right,
}) if is_op_with(And, &left, &right) => Transformed::yes(*right),
}) if is_op_with(And, &left, &right)? => Transformed::yes(*right),
// Eliminate common factors in conjunctions e.g
// (A AND B) OR (A AND C) -> A AND (B OR C)
Expr::BinaryExpr(BinaryExpr {
left,
op: Or,
right,
}) if has_common_conjunction(&left, &right) => {
}) if has_common_conjunction(&left, &right)? => {
let lhs: IndexSet<Expr> = iter_conjunction_owned(*left).collect();
let (common, rhs): (Vec<_>, Vec<_>) =
iter_conjunction_owned(*right).partition(|e| lhs.contains(e));

let new_rhs = rhs.into_iter().reduce(and);
let new_lhs = lhs.into_iter().filter(|e| !common.contains(e)).reduce(and);
let common_conjunction = common.into_iter().reduce(and).unwrap();

let new_expr = match (new_lhs, new_rhs) {
(Some(lhs), Some(rhs)) => and(common_conjunction, or(lhs, rhs)),
(_, _) => common_conjunction,
Expand Down Expand Up @@ -925,25 +923,25 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> {
left,
op: And,
right,
}) if expr_contains(&left, &right, And) => Transformed::yes(*left),
}) if expr_contains(&left, &right, And)? => Transformed::yes(*left),
// A AND (..A..) --> (..A..)
Expr::BinaryExpr(BinaryExpr {
left,
op: And,
right,
}) if expr_contains(&right, &left, And) => Transformed::yes(*right),
}) if expr_contains(&right, &left, And)? => Transformed::yes(*right),
// A AND (A OR B) --> A
Expr::BinaryExpr(BinaryExpr {
left,
op: And,
right,
}) if is_op_with(Or, &right, &left) => Transformed::yes(*left),
}) if is_op_with(Or, &right, &left)? => Transformed::yes(*left),
// (A OR B) AND A --> A
Expr::BinaryExpr(BinaryExpr {
left,
op: And,
right,
}) if is_op_with(Or, &left, &right) => Transformed::yes(*right),
}) if is_op_with(Or, &left, &right)? => Transformed::yes(*right),

//
// Rules for Multiply
Expand Down Expand Up @@ -1109,21 +1107,21 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> {
left,
op: BitwiseAnd,
right,
}) if expr_contains(&left, &right, BitwiseAnd) => Transformed::yes(*left),
}) if expr_contains(&left, &right, BitwiseAnd)? => Transformed::yes(*left),

// A & (..A..) --> (..A..)
Expr::BinaryExpr(BinaryExpr {
left,
op: BitwiseAnd,
right,
}) if expr_contains(&right, &left, BitwiseAnd) => Transformed::yes(*right),
}) if expr_contains(&right, &left, BitwiseAnd)? => Transformed::yes(*right),

// A & (A | B) --> A (if B not null)
Expr::BinaryExpr(BinaryExpr {
left,
op: BitwiseAnd,
right,
}) if !info.nullable(&right)? && is_op_with(BitwiseOr, &right, &left) => {
}) if !info.nullable(&right)? && is_op_with(BitwiseOr, &right, &left)? => {
Transformed::yes(*left)
}

Expand All @@ -1132,7 +1130,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> {
left,
op: BitwiseAnd,
right,
}) if !info.nullable(&left)? && is_op_with(BitwiseOr, &left, &right) => {
}) if !info.nullable(&left)? && is_op_with(BitwiseOr, &left, &right)? => {
Transformed::yes(*right)
}

Expand Down Expand Up @@ -1195,21 +1193,21 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> {
left,
op: BitwiseOr,
right,
}) if expr_contains(&left, &right, BitwiseOr) => Transformed::yes(*left),
}) if expr_contains(&left, &right, BitwiseOr)? => Transformed::yes(*left),

// A | (..A..) --> (..A..)
Expr::BinaryExpr(BinaryExpr {
left,
op: BitwiseOr,
right,
}) if expr_contains(&right, &left, BitwiseOr) => Transformed::yes(*right),
}) if expr_contains(&right, &left, BitwiseOr)? => Transformed::yes(*right),

// A | (A & B) --> A (if B not null)
Expr::BinaryExpr(BinaryExpr {
left,
op: BitwiseOr,
right,
}) if !info.nullable(&right)? && is_op_with(BitwiseAnd, &right, &left) => {
}) if !info.nullable(&right)? && is_op_with(BitwiseAnd, &right, &left)? => {
Transformed::yes(*left)
}

Expand All @@ -1218,7 +1216,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> {
left,
op: BitwiseOr,
right,
}) if !info.nullable(&left)? && is_op_with(BitwiseAnd, &left, &right) => {
}) if !info.nullable(&left)? && is_op_with(BitwiseAnd, &left, &right)? => {
Transformed::yes(*right)
}

Expand Down Expand Up @@ -1281,7 +1279,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> {
left,
op: BitwiseXor,
right,
}) if expr_contains(&left, &right, BitwiseXor) => {
}) if expr_contains(&left, &right, BitwiseXor)? => {
let expr = delete_xor_in_complex_expr(&left, &right, false);
Transformed::yes(if expr == *right {
Expr::Literal(ScalarValue::new_zero(&info.get_data_type(&right)?)?)
Expand All @@ -1295,7 +1293,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> {
left,
op: BitwiseXor,
right,
}) if expr_contains(&right, &left, BitwiseXor) => {
}) if expr_contains(&right, &left, BitwiseXor)? => {
let expr = delete_xor_in_complex_expr(&right, &left, true);
Transformed::yes(if expr == *left {
Expr::Literal(ScalarValue::new_zero(&info.get_data_type(&left)?)?)
Expand Down Expand Up @@ -1681,9 +1679,16 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> {
}
}

fn has_common_conjunction(lhs: &Expr, rhs: &Expr) -> bool {
fn has_common_conjunction(lhs: &Expr, rhs: &Expr) -> Result<bool, DataFusionError> {
let lhs: HashSet<&Expr> = iter_conjunction(lhs).collect();
iter_conjunction(rhs).any(|e| lhs.contains(&e))
iter_conjunction(rhs).try_fold(false, |acc, e| {
if lhs.contains(&e) {
let count = count_volatile_calls(e)?;
Ok::<bool, DataFusionError>(acc || count == 0)
} else {
Ok(acc)
}
})
}

// TODO: We might not need this after defer pattern for Box is stabilized. https://github.com/rust-lang/rust/issues/87121
Expand Down
74 changes: 67 additions & 7 deletions datafusion/optimizer/src/simplify_expressions/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,15 @@

//! Utility functions for expression simplification

use datafusion_common::{internal_err, Result, ScalarValue};
use datafusion_common::{
internal_err,
tree_node::{TreeNode, TreeNodeVisitor},
Result, ScalarValue,
};
use datafusion_expr::{
expr::{Between, BinaryExpr, InList},
expr_fn::{and, bitwise_and, bitwise_or, or},
Expr, Like, Operator,
Expr, Like, Operator, Volatility,
};

pub static POWS_OF_TEN: [i128; 38] = [
Expand Down Expand Up @@ -67,16 +71,24 @@ pub static POWS_OF_TEN: [i128; 38] = [

/// returns true if `needle` is found in a chain of search_op
/// expressions. Such as: (A AND B) AND C
pub fn expr_contains(expr: &Expr, needle: &Expr, search_op: Operator) -> bool {
fn expr_contains_inner(expr: &Expr, needle: &Expr, search_op: Operator) -> bool {
match expr {
Expr::BinaryExpr(BinaryExpr { left, op, right }) if *op == search_op => {
expr_contains(left, needle, search_op)
|| expr_contains(right, needle, search_op)
expr_contains_inner(left, needle, search_op)
|| expr_contains_inner(right, needle, search_op)
}
_ => expr == needle,
}
}

/// check volatile calls and return if expr contains needle
pub fn expr_contains(expr: &Expr, needle: &Expr, search_op: Operator) -> Result<bool> {
Ok(
expr_contains_inner(expr, needle, search_op)
&& count_volatile_calls(needle)? == 0,
)
}

/// Deletes all 'needles' or remains one 'needle' that are found in a chain of xor
/// expressions. Such as: A ^ (A ^ (B ^ A))
pub fn delete_xor_in_complex_expr(expr: &Expr, needle: &Expr, is_left: bool) -> Expr {
Expand Down Expand Up @@ -205,8 +217,10 @@ pub fn is_false(expr: &Expr) -> bool {
}

/// returns true if `haystack` looks like (needle OP X) or (X OP needle)
pub fn is_op_with(target_op: Operator, haystack: &Expr, needle: &Expr) -> bool {
matches!(haystack, Expr::BinaryExpr(BinaryExpr { left, op, right }) if op == &target_op && (needle == left.as_ref() || needle == right.as_ref()))
pub fn is_op_with(target_op: Operator, haystack: &Expr, needle: &Expr) -> Result<bool> {
Ok(
matches!(haystack, Expr::BinaryExpr(BinaryExpr { left, op, right }) if op == &target_op && (needle == left.as_ref() || needle == right.as_ref()) && count_volatile_calls(needle)? == 0),
)
}

/// returns true if `not_expr` is !`expr` (not)
Expand Down Expand Up @@ -341,3 +355,49 @@ pub fn distribute_negation(expr: Expr) -> Expr {
_ => Expr::Negative(Box::new(expr)),
}
}

struct VolatileFunctionCounter {
counter: usize,
}

impl VolatileFunctionCounter {
pub fn get_count(&self) -> usize {
self.counter
}

pub fn new() -> Self {
Self { counter: 0 }
}
}

impl<'n> TreeNodeVisitor<'n> for VolatileFunctionCounter {
type Node = Expr;
fn f_up(
&mut self,
expr: &'n Self::Node,
) -> Result<datafusion_common::tree_node::TreeNodeRecursion> {
match expr {
Expr::ScalarFunction(func)
if matches!(func.func.signature().volatility, Volatility::Volatile) =>
{
self.counter += 1;
}
_ => {}
}
Ok(datafusion_common::tree_node::TreeNodeRecursion::Continue)
}

fn f_down(
&mut self,
_node: &'n Self::Node,
) -> Result<datafusion_common::tree_node::TreeNodeRecursion> {
Ok(datafusion_common::tree_node::TreeNodeRecursion::Continue)
}
}

// get the number of volatile call in a expression
pub fn count_volatile_calls(expr: &Expr) -> Result<usize> {
let mut volatile_visitor = VolatileFunctionCounter::new();
expr.visit(&mut volatile_visitor)?;
Ok(volatile_visitor.get_count())
}
25 changes: 25 additions & 0 deletions datafusion/sqllogictest/test_files/explain.slt
Original file line number Diff line number Diff line change
Expand Up @@ -411,3 +411,28 @@ logical_plan
physical_plan
01)ProjectionExec: expr=[{c0:1,c1:2.3,c2:abc} as struct(Int64(1),Float64(2.3),Utf8("abc"))]
02)--PlaceholderRowExec




query TT
explain select * from VALUES (1), (2) where random() = 0 OR (column1 = 2 AND random() = 0);
----
logical_plan
01)Filter: random() = Float64(0) OR column1 = Int64(2) AND random() = Float64(0)
02)--Values: (Int64(1)), (Int64(2))
physical_plan
01)CoalesceBatchesExec: target_batch_size=8192
02)--FilterExec: random() = 0 OR column1@0 = 2 AND random() = 0
03)----ValuesExec

query TT
explain select * from VALUES (1), (2) where column1 = 2 OR (column1 = 2 AND random() = 0);
----
logical_plan
01)Filter: column1 = Int64(2)
02)--Values: (Int64(1)), (Int64(2))
physical_plan
01)CoalesceBatchesExec: target_batch_size=8192
02)--FilterExec: column1@0 = 2
03)----ValuesExec

0 comments on commit b9fa3ce

Please sign in to comment.