@@ -373,6 +373,24 @@ impl ScalarFunctionDefinition {
373
373
ScalarFunctionDefinition :: Name ( func_name) => func_name. as_ref ( ) ,
374
374
}
375
375
}
376
+
377
+ /// Whether this function is volatile, i.e. whether it can return different results
378
+ /// when evaluated multiple times with the same input.
379
+ pub fn is_volatile ( & self ) -> Result < bool > {
380
+ match self {
381
+ ScalarFunctionDefinition :: BuiltIn ( fun) => {
382
+ Ok ( fun. volatility ( ) == crate :: Volatility :: Volatile )
383
+ }
384
+ ScalarFunctionDefinition :: UDF ( udf) => {
385
+ Ok ( udf. signature ( ) . volatility == crate :: Volatility :: Volatile )
386
+ }
387
+ ScalarFunctionDefinition :: Name ( func) => {
388
+ internal_err ! (
389
+ "Cannot determine volatility of unresolved function: {func}"
390
+ )
391
+ }
392
+ }
393
+ }
376
394
}
377
395
378
396
impl ScalarFunction {
@@ -1692,14 +1710,28 @@ fn create_names(exprs: &[Expr]) -> Result<String> {
1692
1710
. join ( ", " ) )
1693
1711
}
1694
1712
1713
+ /// Whether the given expression is volatile, i.e. whether it can return different results
1714
+ /// when evaluated multiple times with the same input.
1715
+ pub fn is_volatile ( expr : & Expr ) -> Result < bool > {
1716
+ match expr {
1717
+ Expr :: ScalarFunction ( func) => func. func_def . is_volatile ( ) ,
1718
+ _ => Ok ( false ) ,
1719
+ }
1720
+ }
1721
+
1695
1722
#[ cfg( test) ]
1696
1723
mod test {
1697
1724
use crate :: expr:: Cast ;
1698
1725
use crate :: expr_fn:: col;
1699
- use crate :: { case, lit, Expr } ;
1726
+ use crate :: {
1727
+ case, lit, BuiltinScalarFunction , ColumnarValue , Expr , ReturnTypeFunction ,
1728
+ ScalarFunctionDefinition , ScalarFunctionImplementation , ScalarUDF , Signature ,
1729
+ Volatility ,
1730
+ } ;
1700
1731
use arrow:: datatypes:: DataType ;
1701
1732
use datafusion_common:: Column ;
1702
1733
use datafusion_common:: { Result , ScalarValue } ;
1734
+ use std:: sync:: Arc ;
1703
1735
1704
1736
#[ test]
1705
1737
fn format_case_when ( ) -> Result < ( ) > {
@@ -1800,4 +1832,45 @@ mod test {
1800
1832
"UInt32(1) OR UInt32(2)"
1801
1833
) ;
1802
1834
}
1835
+
1836
+ #[ test]
1837
+ fn test_is_volatile_scalar_func_definition ( ) {
1838
+ // BuiltIn
1839
+ assert ! (
1840
+ ScalarFunctionDefinition :: BuiltIn ( BuiltinScalarFunction :: Random )
1841
+ . is_volatile( )
1842
+ . unwrap( )
1843
+ ) ;
1844
+ assert ! (
1845
+ !ScalarFunctionDefinition :: BuiltIn ( BuiltinScalarFunction :: Abs )
1846
+ . is_volatile( )
1847
+ . unwrap( )
1848
+ ) ;
1849
+
1850
+ // UDF
1851
+ let return_type: ReturnTypeFunction =
1852
+ Arc :: new ( move |_| Ok ( Arc :: new ( DataType :: Utf8 ) ) ) ;
1853
+ let fun: ScalarFunctionImplementation =
1854
+ Arc :: new ( move |_| Ok ( ColumnarValue :: Scalar ( ScalarValue :: new_utf8 ( "a" ) ) ) ) ;
1855
+ let udf = Arc :: new ( ScalarUDF :: new (
1856
+ "TestScalarUDF" ,
1857
+ & Signature :: uniform ( 1 , vec ! [ DataType :: Float32 ] , Volatility :: Stable ) ,
1858
+ & return_type,
1859
+ & fun,
1860
+ ) ) ;
1861
+ assert ! ( !ScalarFunctionDefinition :: UDF ( udf) . is_volatile( ) . unwrap( ) ) ;
1862
+
1863
+ let udf = Arc :: new ( ScalarUDF :: new (
1864
+ "TestScalarUDF" ,
1865
+ & Signature :: uniform ( 1 , vec ! [ DataType :: Float32 ] , Volatility :: Volatile ) ,
1866
+ & return_type,
1867
+ & fun,
1868
+ ) ) ;
1869
+ assert ! ( ScalarFunctionDefinition :: UDF ( udf) . is_volatile( ) . unwrap( ) ) ;
1870
+
1871
+ // Unresolved function
1872
+ ScalarFunctionDefinition :: Name ( Arc :: from ( "UnresolvedFunc" ) )
1873
+ . is_volatile ( )
1874
+ . expect_err ( "Shouldn't determine volatility of unresolved function" ) ;
1875
+ }
1803
1876
}
0 commit comments