diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index f9dfadc70826..64b04c38ef6c 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -835,40 +835,38 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { left, op: Or, right, - }) if expr_contains(&left, &right, Or) => Transformed::yes(*left), + }) if expr_contains(&left, &right, Or)? => Transformed::yes(*left), // A OR (..A..) --> (..A..) Expr::BinaryExpr(BinaryExpr { left, op: Or, right, - }) if expr_contains(&right, &left, Or) => Transformed::yes(*right), + }) if expr_contains(&right, &left, Or)? => Transformed::yes(*right), // A OR (A AND B) --> A Expr::BinaryExpr(BinaryExpr { left, op: Or, right, - }) if is_op_with(And, &right, &left) => Transformed::yes(*left), + }) if is_op_with(And, &right, &left)? => Transformed::yes(*left), // (A AND B) OR A --> A Expr::BinaryExpr(BinaryExpr { left, op: Or, right, - }) if is_op_with(And, &left, &right) => Transformed::yes(*right), + }) if is_op_with(And, &left, &right)? => Transformed::yes(*right), // Eliminate common factors in conjunctions e.g // (A AND B) OR (A AND C) -> A AND (B OR C) Expr::BinaryExpr(BinaryExpr { left, op: Or, right, - }) if has_common_conjunction(&left, &right) => { + }) if has_common_conjunction(&left, &right)? => { let lhs: IndexSet = iter_conjunction_owned(*left).collect(); let (common, rhs): (Vec<_>, Vec<_>) = iter_conjunction_owned(*right).partition(|e| lhs.contains(e)); - let new_rhs = rhs.into_iter().reduce(and); let new_lhs = lhs.into_iter().filter(|e| !common.contains(e)).reduce(and); let common_conjunction = common.into_iter().reduce(and).unwrap(); - let new_expr = match (new_lhs, new_rhs) { (Some(lhs), Some(rhs)) => and(common_conjunction, or(lhs, rhs)), (_, _) => common_conjunction, @@ -925,25 +923,25 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { left, op: And, right, - }) if expr_contains(&left, &right, And) => Transformed::yes(*left), + }) if expr_contains(&left, &right, And)? => Transformed::yes(*left), // A AND (..A..) --> (..A..) Expr::BinaryExpr(BinaryExpr { left, op: And, right, - }) if expr_contains(&right, &left, And) => Transformed::yes(*right), + }) if expr_contains(&right, &left, And)? => Transformed::yes(*right), // A AND (A OR B) --> A Expr::BinaryExpr(BinaryExpr { left, op: And, right, - }) if is_op_with(Or, &right, &left) => Transformed::yes(*left), + }) if is_op_with(Or, &right, &left)? => Transformed::yes(*left), // (A OR B) AND A --> A Expr::BinaryExpr(BinaryExpr { left, op: And, right, - }) if is_op_with(Or, &left, &right) => Transformed::yes(*right), + }) if is_op_with(Or, &left, &right)? => Transformed::yes(*right), // // Rules for Multiply @@ -1109,21 +1107,21 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { left, op: BitwiseAnd, right, - }) if expr_contains(&left, &right, BitwiseAnd) => Transformed::yes(*left), + }) if expr_contains(&left, &right, BitwiseAnd)? => Transformed::yes(*left), // A & (..A..) --> (..A..) Expr::BinaryExpr(BinaryExpr { left, op: BitwiseAnd, right, - }) if expr_contains(&right, &left, BitwiseAnd) => Transformed::yes(*right), + }) if expr_contains(&right, &left, BitwiseAnd)? => Transformed::yes(*right), // A & (A | B) --> A (if B not null) Expr::BinaryExpr(BinaryExpr { left, op: BitwiseAnd, right, - }) if !info.nullable(&right)? && is_op_with(BitwiseOr, &right, &left) => { + }) if !info.nullable(&right)? && is_op_with(BitwiseOr, &right, &left)? => { Transformed::yes(*left) } @@ -1132,7 +1130,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { left, op: BitwiseAnd, right, - }) if !info.nullable(&left)? && is_op_with(BitwiseOr, &left, &right) => { + }) if !info.nullable(&left)? && is_op_with(BitwiseOr, &left, &right)? => { Transformed::yes(*right) } @@ -1195,21 +1193,21 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { left, op: BitwiseOr, right, - }) if expr_contains(&left, &right, BitwiseOr) => Transformed::yes(*left), + }) if expr_contains(&left, &right, BitwiseOr)? => Transformed::yes(*left), // A | (..A..) --> (..A..) Expr::BinaryExpr(BinaryExpr { left, op: BitwiseOr, right, - }) if expr_contains(&right, &left, BitwiseOr) => Transformed::yes(*right), + }) if expr_contains(&right, &left, BitwiseOr)? => Transformed::yes(*right), // A | (A & B) --> A (if B not null) Expr::BinaryExpr(BinaryExpr { left, op: BitwiseOr, right, - }) if !info.nullable(&right)? && is_op_with(BitwiseAnd, &right, &left) => { + }) if !info.nullable(&right)? && is_op_with(BitwiseAnd, &right, &left)? => { Transformed::yes(*left) } @@ -1218,7 +1216,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { left, op: BitwiseOr, right, - }) if !info.nullable(&left)? && is_op_with(BitwiseAnd, &left, &right) => { + }) if !info.nullable(&left)? && is_op_with(BitwiseAnd, &left, &right)? => { Transformed::yes(*right) } @@ -1281,7 +1279,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { left, op: BitwiseXor, right, - }) if expr_contains(&left, &right, BitwiseXor) => { + }) if expr_contains(&left, &right, BitwiseXor)? => { let expr = delete_xor_in_complex_expr(&left, &right, false); Transformed::yes(if expr == *right { Expr::Literal(ScalarValue::new_zero(&info.get_data_type(&right)?)?) @@ -1295,7 +1293,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { left, op: BitwiseXor, right, - }) if expr_contains(&right, &left, BitwiseXor) => { + }) if expr_contains(&right, &left, BitwiseXor)? => { let expr = delete_xor_in_complex_expr(&right, &left, true); Transformed::yes(if expr == *left { Expr::Literal(ScalarValue::new_zero(&info.get_data_type(&left)?)?) @@ -1681,9 +1679,16 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { } } -fn has_common_conjunction(lhs: &Expr, rhs: &Expr) -> bool { +fn has_common_conjunction(lhs: &Expr, rhs: &Expr) -> Result { let lhs: HashSet<&Expr> = iter_conjunction(lhs).collect(); - iter_conjunction(rhs).any(|e| lhs.contains(&e)) + iter_conjunction(rhs).try_fold(false, |acc, e| { + if lhs.contains(&e) { + let count = count_volatile_calls(e)?; + Ok::(acc || count == 0) + } else { + Ok(acc) + } + }) } // TODO: We might not need this after defer pattern for Box is stabilized. https://github.com/rust-lang/rust/issues/87121 diff --git a/datafusion/optimizer/src/simplify_expressions/utils.rs b/datafusion/optimizer/src/simplify_expressions/utils.rs index 38bfc1a93403..3d9a7b659f17 100644 --- a/datafusion/optimizer/src/simplify_expressions/utils.rs +++ b/datafusion/optimizer/src/simplify_expressions/utils.rs @@ -17,11 +17,15 @@ //! Utility functions for expression simplification -use datafusion_common::{internal_err, Result, ScalarValue}; +use datafusion_common::{ + internal_err, + tree_node::{TreeNode, TreeNodeVisitor}, + Result, ScalarValue, +}; use datafusion_expr::{ expr::{Between, BinaryExpr, InList}, expr_fn::{and, bitwise_and, bitwise_or, or}, - Expr, Like, Operator, + Expr, Like, Operator, Volatility, }; pub static POWS_OF_TEN: [i128; 38] = [ @@ -67,16 +71,24 @@ pub static POWS_OF_TEN: [i128; 38] = [ /// returns true if `needle` is found in a chain of search_op /// expressions. Such as: (A AND B) AND C -pub fn expr_contains(expr: &Expr, needle: &Expr, search_op: Operator) -> bool { +fn expr_contains_inner(expr: &Expr, needle: &Expr, search_op: Operator) -> bool { match expr { Expr::BinaryExpr(BinaryExpr { left, op, right }) if *op == search_op => { - expr_contains(left, needle, search_op) - || expr_contains(right, needle, search_op) + expr_contains_inner(left, needle, search_op) + || expr_contains_inner(right, needle, search_op) } _ => expr == needle, } } +/// check volatile calls and return if expr contains needle +pub fn expr_contains(expr: &Expr, needle: &Expr, search_op: Operator) -> Result { + Ok( + expr_contains_inner(expr, needle, search_op) + && count_volatile_calls(needle)? == 0, + ) +} + /// Deletes all 'needles' or remains one 'needle' that are found in a chain of xor /// expressions. Such as: A ^ (A ^ (B ^ A)) pub fn delete_xor_in_complex_expr(expr: &Expr, needle: &Expr, is_left: bool) -> Expr { @@ -205,8 +217,10 @@ pub fn is_false(expr: &Expr) -> bool { } /// returns true if `haystack` looks like (needle OP X) or (X OP needle) -pub fn is_op_with(target_op: Operator, haystack: &Expr, needle: &Expr) -> bool { - matches!(haystack, Expr::BinaryExpr(BinaryExpr { left, op, right }) if op == &target_op && (needle == left.as_ref() || needle == right.as_ref())) +pub fn is_op_with(target_op: Operator, haystack: &Expr, needle: &Expr) -> Result { + Ok( + matches!(haystack, Expr::BinaryExpr(BinaryExpr { left, op, right }) if op == &target_op && (needle == left.as_ref() || needle == right.as_ref()) && count_volatile_calls(needle)? == 0), + ) } /// returns true if `not_expr` is !`expr` (not) @@ -341,3 +355,49 @@ pub fn distribute_negation(expr: Expr) -> Expr { _ => Expr::Negative(Box::new(expr)), } } + +struct VolatileFunctionCounter { + counter: usize, +} + +impl VolatileFunctionCounter { + pub fn get_count(&self) -> usize { + self.counter + } + + pub fn new() -> Self { + Self { counter: 0 } + } +} + +impl<'n> TreeNodeVisitor<'n> for VolatileFunctionCounter { + type Node = Expr; + fn f_up( + &mut self, + expr: &'n Self::Node, + ) -> Result { + match expr { + Expr::ScalarFunction(func) + if matches!(func.func.signature().volatility, Volatility::Volatile) => + { + self.counter += 1; + } + _ => {} + } + Ok(datafusion_common::tree_node::TreeNodeRecursion::Continue) + } + + fn f_down( + &mut self, + _node: &'n Self::Node, + ) -> Result { + Ok(datafusion_common::tree_node::TreeNodeRecursion::Continue) + } +} + +// get the number of volatile call in a expression +pub fn count_volatile_calls(expr: &Expr) -> Result { + let mut volatile_visitor = VolatileFunctionCounter::new(); + expr.visit(&mut volatile_visitor)?; + Ok(volatile_visitor.get_count()) +} diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt index 1340fd490e06..5f023bb8f34e 100644 --- a/datafusion/sqllogictest/test_files/explain.slt +++ b/datafusion/sqllogictest/test_files/explain.slt @@ -411,3 +411,28 @@ logical_plan physical_plan 01)ProjectionExec: expr=[{c0:1,c1:2.3,c2:abc} as struct(Int64(1),Float64(2.3),Utf8("abc"))] 02)--PlaceholderRowExec + + + + +query TT +explain select * from VALUES (1), (2) where random() = 0 OR (column1 = 2 AND random() = 0); +---- +logical_plan +01)Filter: random() = Float64(0) OR column1 = Int64(2) AND random() = Float64(0) +02)--Values: (Int64(1)), (Int64(2)) +physical_plan +01)CoalesceBatchesExec: target_batch_size=8192 +02)--FilterExec: random() = 0 OR column1@0 = 2 AND random() = 0 +03)----ValuesExec + +query TT +explain select * from VALUES (1), (2) where column1 = 2 OR (column1 = 2 AND random() = 0); +---- +logical_plan +01)Filter: column1 = Int64(2) +02)--Values: (Int64(1)), (Int64(2)) +physical_plan +01)CoalesceBatchesExec: target_batch_size=8192 +02)--FilterExec: column1@0 = 2 +03)----ValuesExec