From 303d3031b92b928651e12ec8f83c0f8b5044d925 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 12 Dec 2023 10:25:27 -0800 Subject: [PATCH 1/5] fix: volatile expressions should not be target of common subexpt elimination --- datafusion/expr/src/expr.rs | 17 +++++++++++++++++ .../optimizer/src/common_subexpr_eliminate.rs | 10 +++++++--- .../sqllogictest/test_files/functions.slt | 5 +++++ 3 files changed, 29 insertions(+), 3 deletions(-) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 958f4f4a3456..3080acbfe88a 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -1692,6 +1692,23 @@ fn create_names(exprs: &[Expr]) -> Result { .join(", ")) } +/// Whether the given expression is volatile, i.e. whether it can return different results +/// when evaluated multiple times with the same input. +pub fn is_volatile(expr: &Expr) -> bool { + match expr { + Expr::ScalarFunction(func) => match func.func_def { + ScalarFunctionDefinition::BuiltIn(func) => match func { + BuiltinScalarFunction::Random => true, + _ => false, + }, + // TODO: Add volatile flag to UDFs + ScalarFunctionDefinition::UDF(_) => false, + ScalarFunctionDefinition::Name(_) => false, + }, + _ => false, + } +} + #[cfg(test)] mod test { use crate::expr::Cast; diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 1d21407a6985..b42b095a8c88 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -29,7 +29,7 @@ use datafusion_common::tree_node::{ use datafusion_common::{ internal_err, Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, }; -use datafusion_expr::expr::Alias; +use datafusion_expr::expr::{is_volatile, Alias}; use datafusion_expr::logical_plan::{ Aggregate, Filter, LogicalPlan, Projection, Sort, Window, }; @@ -113,6 +113,8 @@ impl CommonSubexprEliminate { let Projection { expr, input, .. } = projection; let input_schema = Arc::clone(input.schema()); let mut expr_set = ExprSet::new(); + + // Visit expr list and build expr identifier to occuring count map (`expr_set`). let arrays = to_arrays(expr, input_schema, &mut expr_set, ExprMask::Normal)?; let (mut new_expr, new_input) = @@ -527,11 +529,13 @@ impl ExprMask { | Expr::Wildcard { .. } ); + let is_volatile = is_volatile(expr); + let is_aggr = matches!(expr, Expr::AggregateFunction(..)); match self { - Self::Normal => is_normal_minus_aggregates || is_aggr, - Self::NormalAndAggregates => is_normal_minus_aggregates, + Self::Normal => is_volatile || is_normal_minus_aggregates || is_aggr, + Self::NormalAndAggregates => is_volatile || is_normal_minus_aggregates, } } } diff --git a/datafusion/sqllogictest/test_files/functions.slt b/datafusion/sqllogictest/test_files/functions.slt index 4f55ea316bb9..ad570b3735ae 100644 --- a/datafusion/sqllogictest/test_files/functions.slt +++ b/datafusion/sqllogictest/test_files/functions.slt @@ -995,3 +995,8 @@ query ? SELECT find_in_set(NULL, NULL) ---- NULL + +query B +SELECT r1 == r2 FROM (SELECT random() r1, random() r2) +---- +false From 1fa466501e1e47bb5620be752f7b07daea823720 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 12 Dec 2023 10:33:58 -0800 Subject: [PATCH 2/5] Fix clippy --- datafusion/expr/src/expr.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 3080acbfe88a..d52de58ebff1 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -1694,6 +1694,7 @@ fn create_names(exprs: &[Expr]) -> Result { /// Whether the given expression is volatile, i.e. whether it can return different results /// when evaluated multiple times with the same input. +#[allow(clippy::match_like_matches_macro)] pub fn is_volatile(expr: &Expr) -> bool { match expr { Expr::ScalarFunction(func) => match func.func_def { From ffe1756d27aa5fcb5f7b09cb3e2744919752f531 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 12 Dec 2023 14:38:27 -0800 Subject: [PATCH 3/5] For review --- datafusion/expr/src/expr.rs | 25 +++++++++++-------- .../sqllogictest/test_files/functions.slt | 3 ++- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index d52de58ebff1..32296a15ded2 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -373,6 +373,20 @@ impl ScalarFunctionDefinition { ScalarFunctionDefinition::Name(func_name) => func_name.as_ref(), } } + + /// Whether this function is volatile, i.e. whether it can return different results + /// when evaluated multiple times with the same input. + pub fn is_volatile(&self) -> bool { + match self { + ScalarFunctionDefinition::BuiltIn(fun) => { + fun.volatility() == crate::Volatility::Volatile + } + ScalarFunctionDefinition::UDF(udf) => { + udf.signature().volatility == crate::Volatility::Volatile + } + ScalarFunctionDefinition::Name(_) => false, + } + } } impl ScalarFunction { @@ -1694,18 +1708,9 @@ fn create_names(exprs: &[Expr]) -> Result { /// Whether the given expression is volatile, i.e. whether it can return different results /// when evaluated multiple times with the same input. -#[allow(clippy::match_like_matches_macro)] pub fn is_volatile(expr: &Expr) -> bool { match expr { - Expr::ScalarFunction(func) => match func.func_def { - ScalarFunctionDefinition::BuiltIn(func) => match func { - BuiltinScalarFunction::Random => true, - _ => false, - }, - // TODO: Add volatile flag to UDFs - ScalarFunctionDefinition::UDF(_) => false, - ScalarFunctionDefinition::Name(_) => false, - }, + Expr::ScalarFunction(func) => func.func_def.is_volatile(), _ => false, } } diff --git a/datafusion/sqllogictest/test_files/functions.slt b/datafusion/sqllogictest/test_files/functions.slt index ad570b3735ae..1903088b0748 100644 --- a/datafusion/sqllogictest/test_files/functions.slt +++ b/datafusion/sqllogictest/test_files/functions.slt @@ -996,7 +996,8 @@ SELECT find_in_set(NULL, NULL) ---- NULL +# Verify that multiple calls to volatile functions like `random()` are not combined / optimized away query B -SELECT r1 == r2 FROM (SELECT random() r1, random() r2) +SELECT r FROM (SELECT r1 == r2 r, r1, r2 FROM (SELECT random() r1, random() r2) WHERE r1 > 0 AND r2 > 0) ---- false From 66304c08781088fa599a1afe5fc2f6d58cc6ba73 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 13 Dec 2023 13:21:53 -0800 Subject: [PATCH 4/5] Return error for unresolved scalar function --- datafusion/expr/src/expr.rs | 62 ++++++++++++++++--- .../optimizer/src/common_subexpr_eliminate.rs | 10 +-- 2 files changed, 60 insertions(+), 12 deletions(-) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 32296a15ded2..76b258a8a9e9 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -376,15 +376,17 @@ impl ScalarFunctionDefinition { /// Whether this function is volatile, i.e. whether it can return different results /// when evaluated multiple times with the same input. - pub fn is_volatile(&self) -> bool { + pub fn is_volatile(&self) -> Result { match self { ScalarFunctionDefinition::BuiltIn(fun) => { - fun.volatility() == crate::Volatility::Volatile + Ok(fun.volatility() == crate::Volatility::Volatile) } ScalarFunctionDefinition::UDF(udf) => { - udf.signature().volatility == crate::Volatility::Volatile + Ok(udf.signature().volatility == crate::Volatility::Volatile) + } + ScalarFunctionDefinition::Name(_) => { + internal_err!("Cannot determine volatility of unresolved function") } - ScalarFunctionDefinition::Name(_) => false, } } } @@ -1708,10 +1710,10 @@ fn create_names(exprs: &[Expr]) -> Result { /// Whether the given expression is volatile, i.e. whether it can return different results /// when evaluated multiple times with the same input. -pub fn is_volatile(expr: &Expr) -> bool { +pub fn is_volatile(expr: &Expr) -> Result { match expr { Expr::ScalarFunction(func) => func.func_def.is_volatile(), - _ => false, + _ => Ok(false), } } @@ -1719,10 +1721,15 @@ pub fn is_volatile(expr: &Expr) -> bool { mod test { use crate::expr::Cast; use crate::expr_fn::col; - use crate::{case, lit, Expr}; + use crate::{ + case, lit, BuiltinScalarFunction, ColumnarValue, Expr, ReturnTypeFunction, + ScalarFunctionDefinition, ScalarFunctionImplementation, ScalarUDF, Signature, + Volatility, + }; use arrow::datatypes::DataType; use datafusion_common::Column; use datafusion_common::{Result, ScalarValue}; + use std::sync::Arc; #[test] fn format_case_when() -> Result<()> { @@ -1823,4 +1830,45 @@ mod test { "UInt32(1) OR UInt32(2)" ); } + + #[test] + fn test_is_volatile_scalar_func_definition() { + // BuiltIn + assert!( + ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Random) + .is_volatile() + .unwrap() + ); + assert!( + !ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Abs) + .is_volatile() + .unwrap() + ); + + // UDF + let return_type: ReturnTypeFunction = + Arc::new(move |_| Ok(Arc::new(DataType::Utf8))); + let fun: ScalarFunctionImplementation = + Arc::new(move |_| Ok(ColumnarValue::Scalar(ScalarValue::new_utf8("a")))); + let udf = Arc::new(ScalarUDF::new( + "TestScalarUDF", + &Signature::uniform(1, vec![DataType::Float32], Volatility::Stable), + &return_type, + &fun, + )); + assert!(!ScalarFunctionDefinition::UDF(udf).is_volatile().unwrap()); + + let udf = Arc::new(ScalarUDF::new( + "TestScalarUDF", + &Signature::uniform(1, vec![DataType::Float32], Volatility::Volatile), + &return_type, + &fun, + )); + assert!(ScalarFunctionDefinition::UDF(udf).is_volatile().unwrap()); + + // Unresolved function + ScalarFunctionDefinition::Name(Arc::from("UnresolvedFunc")) + .is_volatile() + .expect_err("Unresolved function should not be resolved"); + } } diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index b42b095a8c88..1e089257c61a 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -518,7 +518,7 @@ enum ExprMask { } impl ExprMask { - fn ignores(&self, expr: &Expr) -> bool { + fn ignores(&self, expr: &Expr) -> Result { let is_normal_minus_aggregates = matches!( expr, Expr::Literal(..) @@ -529,14 +529,14 @@ impl ExprMask { | Expr::Wildcard { .. } ); - let is_volatile = is_volatile(expr); + let is_volatile = is_volatile(expr)?; let is_aggr = matches!(expr, Expr::AggregateFunction(..)); - match self { + Ok(match self { Self::Normal => is_volatile || is_normal_minus_aggregates || is_aggr, Self::NormalAndAggregates => is_volatile || is_normal_minus_aggregates, - } + }) } } @@ -628,7 +628,7 @@ impl TreeNodeVisitor for ExprIdentifierVisitor<'_> { let (idx, sub_expr_desc) = self.pop_enter_mark(); // skip exprs should not be recognize. - if self.expr_mask.ignores(expr) { + if self.expr_mask.ignores(expr)? { self.id_array[idx].0 = self.series_number; let desc = Self::desc_expr(expr); self.visit_stack.push(VisitRecord::ExprItem(desc)); From b245ad61528090ccad1f30a8932a93f5ebd1349a Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 13 Dec 2023 13:58:18 -0800 Subject: [PATCH 5/5] Improve error message --- datafusion/expr/src/expr.rs | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 76b258a8a9e9..f0aab95b8f0d 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -384,8 +384,10 @@ impl ScalarFunctionDefinition { ScalarFunctionDefinition::UDF(udf) => { Ok(udf.signature().volatility == crate::Volatility::Volatile) } - ScalarFunctionDefinition::Name(_) => { - internal_err!("Cannot determine volatility of unresolved function") + ScalarFunctionDefinition::Name(func) => { + internal_err!( + "Cannot determine volatility of unresolved function: {func}" + ) } } } @@ -1869,6 +1871,6 @@ mod test { // Unresolved function ScalarFunctionDefinition::Name(Arc::from("UnresolvedFunc")) .is_volatile() - .expect_err("Unresolved function should not be resolved"); + .expect_err("Shouldn't determine volatility of unresolved function"); } }