Skip to content

Commit 7dd209f

Browse files
viiryaappletreeisyellow
authored andcommitted
fix: volatile expressions should not be target of common subexpt elimination (apache#8520)
* fix: volatile expressions should not be target of common subexpt elimination * Fix clippy * For review * Return error for unresolved scalar function * Improve error message
1 parent c0c9e88 commit 7dd209f

File tree

3 files changed

+91
-8
lines changed

3 files changed

+91
-8
lines changed

datafusion/expr/src/expr.rs

+74-1
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,24 @@ impl ScalarFunctionDefinition {
373373
ScalarFunctionDefinition::Name(func_name) => func_name.as_ref(),
374374
}
375375
}
376+
377+
/// Whether this function is volatile, i.e. whether it can return different results
378+
/// when evaluated multiple times with the same input.
379+
pub fn is_volatile(&self) -> Result<bool> {
380+
match self {
381+
ScalarFunctionDefinition::BuiltIn(fun) => {
382+
Ok(fun.volatility() == crate::Volatility::Volatile)
383+
}
384+
ScalarFunctionDefinition::UDF(udf) => {
385+
Ok(udf.signature().volatility == crate::Volatility::Volatile)
386+
}
387+
ScalarFunctionDefinition::Name(func) => {
388+
internal_err!(
389+
"Cannot determine volatility of unresolved function: {func}"
390+
)
391+
}
392+
}
393+
}
376394
}
377395

378396
impl ScalarFunction {
@@ -1692,14 +1710,28 @@ fn create_names(exprs: &[Expr]) -> Result<String> {
16921710
.join(", "))
16931711
}
16941712

1713+
/// Whether the given expression is volatile, i.e. whether it can return different results
1714+
/// when evaluated multiple times with the same input.
1715+
pub fn is_volatile(expr: &Expr) -> Result<bool> {
1716+
match expr {
1717+
Expr::ScalarFunction(func) => func.func_def.is_volatile(),
1718+
_ => Ok(false),
1719+
}
1720+
}
1721+
16951722
#[cfg(test)]
16961723
mod test {
16971724
use crate::expr::Cast;
16981725
use crate::expr_fn::col;
1699-
use crate::{case, lit, Expr};
1726+
use crate::{
1727+
case, lit, BuiltinScalarFunction, ColumnarValue, Expr, ReturnTypeFunction,
1728+
ScalarFunctionDefinition, ScalarFunctionImplementation, ScalarUDF, Signature,
1729+
Volatility,
1730+
};
17001731
use arrow::datatypes::DataType;
17011732
use datafusion_common::Column;
17021733
use datafusion_common::{Result, ScalarValue};
1734+
use std::sync::Arc;
17031735

17041736
#[test]
17051737
fn format_case_when() -> Result<()> {
@@ -1800,4 +1832,45 @@ mod test {
18001832
"UInt32(1) OR UInt32(2)"
18011833
);
18021834
}
1835+
1836+
#[test]
1837+
fn test_is_volatile_scalar_func_definition() {
1838+
// BuiltIn
1839+
assert!(
1840+
ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Random)
1841+
.is_volatile()
1842+
.unwrap()
1843+
);
1844+
assert!(
1845+
!ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Abs)
1846+
.is_volatile()
1847+
.unwrap()
1848+
);
1849+
1850+
// UDF
1851+
let return_type: ReturnTypeFunction =
1852+
Arc::new(move |_| Ok(Arc::new(DataType::Utf8)));
1853+
let fun: ScalarFunctionImplementation =
1854+
Arc::new(move |_| Ok(ColumnarValue::Scalar(ScalarValue::new_utf8("a"))));
1855+
let udf = Arc::new(ScalarUDF::new(
1856+
"TestScalarUDF",
1857+
&Signature::uniform(1, vec![DataType::Float32], Volatility::Stable),
1858+
&return_type,
1859+
&fun,
1860+
));
1861+
assert!(!ScalarFunctionDefinition::UDF(udf).is_volatile().unwrap());
1862+
1863+
let udf = Arc::new(ScalarUDF::new(
1864+
"TestScalarUDF",
1865+
&Signature::uniform(1, vec![DataType::Float32], Volatility::Volatile),
1866+
&return_type,
1867+
&fun,
1868+
));
1869+
assert!(ScalarFunctionDefinition::UDF(udf).is_volatile().unwrap());
1870+
1871+
// Unresolved function
1872+
ScalarFunctionDefinition::Name(Arc::from("UnresolvedFunc"))
1873+
.is_volatile()
1874+
.expect_err("Shouldn't determine volatility of unresolved function");
1875+
}
18031876
}

datafusion/optimizer/src/common_subexpr_eliminate.rs

+11-7
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ use datafusion_common::tree_node::{
2929
use datafusion_common::{
3030
internal_err, Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result,
3131
};
32-
use datafusion_expr::expr::Alias;
32+
use datafusion_expr::expr::{is_volatile, Alias};
3333
use datafusion_expr::logical_plan::{
3434
Aggregate, Filter, LogicalPlan, Projection, Sort, Window,
3535
};
@@ -113,6 +113,8 @@ impl CommonSubexprEliminate {
113113
let Projection { expr, input, .. } = projection;
114114
let input_schema = Arc::clone(input.schema());
115115
let mut expr_set = ExprSet::new();
116+
117+
// Visit expr list and build expr identifier to occuring count map (`expr_set`).
116118
let arrays = to_arrays(expr, input_schema, &mut expr_set, ExprMask::Normal)?;
117119

118120
let (mut new_expr, new_input) =
@@ -516,7 +518,7 @@ enum ExprMask {
516518
}
517519

518520
impl ExprMask {
519-
fn ignores(&self, expr: &Expr) -> bool {
521+
fn ignores(&self, expr: &Expr) -> Result<bool> {
520522
let is_normal_minus_aggregates = matches!(
521523
expr,
522524
Expr::Literal(..)
@@ -527,12 +529,14 @@ impl ExprMask {
527529
| Expr::Wildcard { .. }
528530
);
529531

532+
let is_volatile = is_volatile(expr)?;
533+
530534
let is_aggr = matches!(expr, Expr::AggregateFunction(..));
531535

532-
match self {
533-
Self::Normal => is_normal_minus_aggregates || is_aggr,
534-
Self::NormalAndAggregates => is_normal_minus_aggregates,
535-
}
536+
Ok(match self {
537+
Self::Normal => is_volatile || is_normal_minus_aggregates || is_aggr,
538+
Self::NormalAndAggregates => is_volatile || is_normal_minus_aggregates,
539+
})
536540
}
537541
}
538542

@@ -624,7 +628,7 @@ impl TreeNodeVisitor for ExprIdentifierVisitor<'_> {
624628

625629
let (idx, sub_expr_desc) = self.pop_enter_mark();
626630
// skip exprs should not be recognize.
627-
if self.expr_mask.ignores(expr) {
631+
if self.expr_mask.ignores(expr)? {
628632
self.id_array[idx].0 = self.series_number;
629633
let desc = Self::desc_expr(expr);
630634
self.visit_stack.push(VisitRecord::ExprItem(desc));

datafusion/sqllogictest/test_files/functions.slt

+6
Original file line numberDiff line numberDiff line change
@@ -995,3 +995,9 @@ query ?
995995
SELECT find_in_set(NULL, NULL)
996996
----
997997
NULL
998+
999+
# Verify that multiple calls to volatile functions like `random()` are not combined / optimized away
1000+
query B
1001+
SELECT r FROM (SELECT r1 == r2 r, r1, r2 FROM (SELECT random() r1, random() r2) WHERE r1 > 0 AND r2 > 0)
1002+
----
1003+
false

0 commit comments

Comments
 (0)