Skip to content

Commit 7c0d77f

Browse files
viiryaappletreeisyellow
authored andcommitted
fix: volatile expressions should not be target of common subexpt elimination (apache#8520)
* fix: volatile expressions should not be target of common subexpt elimination * Fix clippy * For review * Return error for unresolved scalar function * Improve error message
1 parent 06bbe12 commit 7c0d77f

File tree

3 files changed

+91
-9
lines changed

3 files changed

+91
-9
lines changed

datafusion/expr/src/expr.rs

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,24 @@ impl ScalarFunctionDefinition {
378378
ScalarFunctionDefinition::Name(func_name) => func_name.as_ref(),
379379
}
380380
}
381+
382+
/// Whether this function is volatile, i.e. whether it can return different results
383+
/// when evaluated multiple times with the same input.
384+
pub fn is_volatile(&self) -> Result<bool> {
385+
match self {
386+
ScalarFunctionDefinition::BuiltIn(fun) => {
387+
Ok(fun.volatility() == crate::Volatility::Volatile)
388+
}
389+
ScalarFunctionDefinition::UDF(udf) => {
390+
Ok(udf.signature().volatility == crate::Volatility::Volatile)
391+
}
392+
ScalarFunctionDefinition::Name(func) => {
393+
internal_err!(
394+
"Cannot determine volatility of unresolved function: {func}"
395+
)
396+
}
397+
}
398+
}
381399
}
382400

383401
impl ScalarFunction {
@@ -1678,14 +1696,28 @@ fn create_names(exprs: &[Expr]) -> Result<String> {
16781696
.join(", "))
16791697
}
16801698

1699+
/// Whether the given expression is volatile, i.e. whether it can return different results
1700+
/// when evaluated multiple times with the same input.
1701+
pub fn is_volatile(expr: &Expr) -> Result<bool> {
1702+
match expr {
1703+
Expr::ScalarFunction(func) => func.func_def.is_volatile(),
1704+
_ => Ok(false),
1705+
}
1706+
}
1707+
16811708
#[cfg(test)]
16821709
mod test {
16831710
use crate::expr::Cast;
16841711
use crate::expr_fn::col;
1685-
use crate::{case, lit, Expr};
1712+
use crate::{
1713+
case, lit, BuiltinScalarFunction, ColumnarValue, Expr, ReturnTypeFunction,
1714+
ScalarFunctionDefinition, ScalarFunctionImplementation, ScalarUDF, Signature,
1715+
Volatility,
1716+
};
16861717
use arrow::datatypes::DataType;
16871718
use datafusion_common::Column;
16881719
use datafusion_common::{Result, ScalarValue};
1720+
use std::sync::Arc;
16891721

16901722
#[test]
16911723
fn format_case_when() -> Result<()> {
@@ -1786,4 +1818,45 @@ mod test {
17861818
"UInt32(1) OR UInt32(2)"
17871819
);
17881820
}
1821+
1822+
#[test]
1823+
fn test_is_volatile_scalar_func_definition() {
1824+
// BuiltIn
1825+
assert!(
1826+
ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Random)
1827+
.is_volatile()
1828+
.unwrap()
1829+
);
1830+
assert!(
1831+
!ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Abs)
1832+
.is_volatile()
1833+
.unwrap()
1834+
);
1835+
1836+
// UDF
1837+
let return_type: ReturnTypeFunction =
1838+
Arc::new(move |_| Ok(Arc::new(DataType::Utf8)));
1839+
let fun: ScalarFunctionImplementation =
1840+
Arc::new(move |_| Ok(ColumnarValue::Scalar(ScalarValue::new_utf8("a"))));
1841+
let udf = Arc::new(ScalarUDF::new(
1842+
"TestScalarUDF",
1843+
&Signature::uniform(1, vec![DataType::Float32], Volatility::Stable),
1844+
&return_type,
1845+
&fun,
1846+
));
1847+
assert!(!ScalarFunctionDefinition::UDF(udf).is_volatile().unwrap());
1848+
1849+
let udf = Arc::new(ScalarUDF::new(
1850+
"TestScalarUDF",
1851+
&Signature::uniform(1, vec![DataType::Float32], Volatility::Volatile),
1852+
&return_type,
1853+
&fun,
1854+
));
1855+
assert!(ScalarFunctionDefinition::UDF(udf).is_volatile().unwrap());
1856+
1857+
// Unresolved function
1858+
ScalarFunctionDefinition::Name(Arc::from("UnresolvedFunc"))
1859+
.is_volatile()
1860+
.expect_err("Shouldn't determine volatility of unresolved function");
1861+
}
17891862
}

datafusion/optimizer/src/common_subexpr_eliminate.rs

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ use datafusion_common::tree_node::{
2929
use datafusion_common::{
3030
internal_err, Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result,
3131
};
32-
use datafusion_expr::expr::Alias;
32+
use datafusion_expr::expr::{is_volatile, Alias};
3333
use datafusion_expr::logical_plan::{
3434
Aggregate, Filter, LogicalPlan, Projection, Sort, Window,
3535
};
@@ -113,6 +113,8 @@ impl CommonSubexprEliminate {
113113
let Projection { expr, input, .. } = projection;
114114
let input_schema = Arc::clone(input.schema());
115115
let mut expr_set = ExprSet::new();
116+
117+
// Visit expr list and build expr identifier to occuring count map (`expr_set`).
116118
let arrays = to_arrays(expr, input_schema, &mut expr_set, ExprMask::Normal)?;
117119

118120
let (mut new_expr, new_input) =
@@ -517,7 +519,7 @@ enum ExprMask {
517519
}
518520

519521
impl ExprMask {
520-
fn ignores(&self, expr: &Expr) -> bool {
522+
fn ignores(&self, expr: &Expr) -> Result<bool> {
521523
let is_normal_minus_aggregates = matches!(
522524
expr,
523525
Expr::Literal(..)
@@ -528,15 +530,16 @@ impl ExprMask {
528530
| Expr::Wildcard { .. }
529531
);
530532

533+
let is_volatile = is_volatile(expr)?;
534+
531535
let is_aggr = matches!(
532536
expr,
533537
Expr::AggregateFunction(..) | Expr::AggregateUDF { .. }
534538
);
535-
536-
match self {
537-
Self::Normal => is_normal_minus_aggregates || is_aggr,
538-
Self::NormalAndAggregates => is_normal_minus_aggregates,
539-
}
539+
Ok(match self {
540+
Self::Normal => is_volatile || is_normal_minus_aggregates || is_aggr,
541+
Self::NormalAndAggregates => is_volatile || is_normal_minus_aggregates,
542+
})
540543
}
541544
}
542545

@@ -628,7 +631,7 @@ impl TreeNodeVisitor for ExprIdentifierVisitor<'_> {
628631

629632
let (idx, sub_expr_desc) = self.pop_enter_mark();
630633
// skip exprs should not be recognize.
631-
if self.expr_mask.ignores(expr) {
634+
if self.expr_mask.ignores(expr)? {
632635
self.id_array[idx].0 = self.series_number;
633636
let desc = Self::desc_expr(expr);
634637
self.visit_stack.push(VisitRecord::ExprItem(desc));

datafusion/sqllogictest/test_files/functions.slt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -952,3 +952,9 @@ query ?
952952
SELECT substr_index(NULL, NULL, NULL)
953953
----
954954
NULL
955+
956+
# Verify that multiple calls to volatile functions like `random()` are not combined / optimized away
957+
query B
958+
SELECT r FROM (SELECT r1 == r2 r, r1, r2 FROM (SELECT random() r1, random() r2) WHERE r1 > 0 AND r2 > 0)
959+
----
960+
false

0 commit comments

Comments
 (0)