Skip to content

Commit 85ceb9d

Browse files
authored
ScalarUDF with zero arguments should be provided with one null array as parameter (#9031)
* Fix ScalaUDF with zero arguments * Fix test * Fix clippy * Fix * Exclude built-in scalar functions * For review
1 parent efd2fd2 commit 85ceb9d

File tree

8 files changed

+146
-32
lines changed

8 files changed

+146
-32
lines changed

datafusion/core/src/physical_optimizer/projection_pushdown.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1270,6 +1270,7 @@ mod tests {
12701270
],
12711271
DataType::Int32,
12721272
None,
1273+
false,
12731274
)),
12741275
Arc::new(CaseExpr::try_new(
12751276
Some(Arc::new(Column::new("d", 2))),
@@ -1336,6 +1337,7 @@ mod tests {
13361337
],
13371338
DataType::Int32,
13381339
None,
1340+
false,
13391341
)),
13401342
Arc::new(CaseExpr::try_new(
13411343
Some(Arc::new(Column::new("d", 3))),
@@ -1405,6 +1407,7 @@ mod tests {
14051407
],
14061408
DataType::Int32,
14071409
None,
1410+
false,
14081411
)),
14091412
Arc::new(CaseExpr::try_new(
14101413
Some(Arc::new(Column::new("d", 2))),
@@ -1471,6 +1474,7 @@ mod tests {
14711474
],
14721475
DataType::Int32,
14731476
None,
1477+
false,
14741478
)),
14751479
Arc::new(CaseExpr::try_new(
14761480
Some(Arc::new(Column::new("d_new", 3))),

datafusion/core/tests/user_defined/user_defined_scalar_functions.rs

Lines changed: 108 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,19 @@
1616
// under the License.
1717

1818
use arrow::compute::kernels::numeric::add;
19-
use arrow_array::{ArrayRef, Float64Array, Int32Array, RecordBatch};
19+
use arrow_array::{Array, ArrayRef, Float64Array, Int32Array, RecordBatch, UInt8Array};
20+
use arrow_schema::DataType::Float64;
2021
use arrow_schema::{DataType, Field, Schema};
2122
use datafusion::prelude::*;
2223
use datafusion::{execution::registry::FunctionRegistry, test_util};
2324
use datafusion_common::cast::as_float64_array;
2425
use datafusion_common::{assert_batches_eq, cast::as_int32_array, Result, ScalarValue};
2526
use datafusion_expr::{
26-
create_udaf, create_udf, Accumulator, ColumnarValue, LogicalPlanBuilder, Volatility,
27+
create_udaf, create_udf, Accumulator, ColumnarValue, LogicalPlanBuilder, ScalarUDF,
28+
ScalarUDFImpl, Signature, Volatility,
2729
};
30+
use rand::{thread_rng, Rng};
31+
use std::iter;
2832
use std::sync::Arc;
2933

3034
/// test that casting happens on udfs.
@@ -166,10 +170,7 @@ async fn scalar_udf_zero_params() -> Result<()> {
166170

167171
ctx.register_batch("t", batch)?;
168172
// create function just returns 100 regardless of inp
169-
let myfunc = Arc::new(|args: &[ColumnarValue]| {
170-
let ColumnarValue::Scalar(_) = &args[0] else {
171-
panic!("expect scalar")
172-
};
173+
let myfunc = Arc::new(|_args: &[ColumnarValue]| {
173174
Ok(ColumnarValue::Array(
174175
Arc::new((0..1).map(|_| 100).collect::<Int32Array>()) as ArrayRef,
175176
))
@@ -392,6 +393,107 @@ async fn test_user_defined_functions_with_alias() -> Result<()> {
392393
Ok(())
393394
}
394395

396+
#[derive(Debug)]
397+
pub struct RandomUDF {
398+
signature: Signature,
399+
}
400+
401+
impl RandomUDF {
402+
pub fn new() -> Self {
403+
Self {
404+
signature: Signature::any(0, Volatility::Volatile),
405+
}
406+
}
407+
}
408+
409+
impl ScalarUDFImpl for RandomUDF {
410+
fn as_any(&self) -> &dyn std::any::Any {
411+
self
412+
}
413+
414+
fn name(&self) -> &str {
415+
"random_udf"
416+
}
417+
418+
fn signature(&self) -> &Signature {
419+
&self.signature
420+
}
421+
422+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
423+
Ok(Float64)
424+
}
425+
426+
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
427+
let len: usize = match &args[0] {
428+
// This udf is always invoked with zero argument so its argument
429+
// is a null array indicating the batch size.
430+
ColumnarValue::Array(array) if array.data_type().is_null() => array.len(),
431+
_ => {
432+
return Err(datafusion::error::DataFusionError::Internal(
433+
"Invalid argument type".to_string(),
434+
))
435+
}
436+
};
437+
let mut rng = thread_rng();
438+
let values = iter::repeat_with(|| rng.gen_range(0.1..1.0)).take(len);
439+
let array = Float64Array::from_iter_values(values);
440+
Ok(ColumnarValue::Array(Arc::new(array)))
441+
}
442+
}
443+
444+
/// Ensure that a user defined function with zero argument will be invoked
445+
/// with a null array indicating the batch size.
446+
#[tokio::test]
447+
async fn test_user_defined_functions_zero_argument() -> Result<()> {
448+
let ctx = SessionContext::new();
449+
450+
let schema = Arc::new(Schema::new(vec![Field::new(
451+
"index",
452+
DataType::UInt8,
453+
false,
454+
)]));
455+
456+
let batch = RecordBatch::try_new(
457+
schema,
458+
vec![Arc::new(UInt8Array::from_iter_values([1, 2, 3]))],
459+
)?;
460+
461+
ctx.register_batch("data_table", batch)?;
462+
463+
let random_normal_udf = ScalarUDF::from(RandomUDF::new());
464+
ctx.register_udf(random_normal_udf);
465+
466+
let result = plan_and_collect(
467+
&ctx,
468+
"SELECT random_udf() AS random_udf, random() AS native_random FROM data_table",
469+
)
470+
.await?;
471+
472+
assert_eq!(result.len(), 1);
473+
let batch = &result[0];
474+
let random_udf = batch
475+
.column(0)
476+
.as_any()
477+
.downcast_ref::<Float64Array>()
478+
.unwrap();
479+
let native_random = batch
480+
.column(1)
481+
.as_any()
482+
.downcast_ref::<Float64Array>()
483+
.unwrap();
484+
485+
assert_eq!(random_udf.len(), native_random.len());
486+
487+
let mut previous = -1.0;
488+
for i in 0..random_udf.len() {
489+
assert!(random_udf.value(i) >= 0.0 && random_udf.value(i) < 1.0);
490+
assert!(random_udf.value(i) != previous);
491+
previous = random_udf.value(i);
492+
}
493+
494+
Ok(())
495+
}
496+
395497
fn create_udf_context() -> SessionContext {
396498
let ctx = SessionContext::new();
397499
// register a custom UDF

datafusion/physical-expr/src/functions.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ pub fn create_physical_expr(
8181
input_phy_exprs.to_vec(),
8282
data_type,
8383
monotonicity,
84+
fun.signature().type_signature.supports_zero_argument(),
8485
)))
8586
}
8687

datafusion/physical-expr/src/planner.rs

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ pub fn create_physical_expr(
259259
}
260260

261261
Expr::ScalarFunction(ScalarFunction { func_def, args }) => {
262-
let mut physical_args = args
262+
let physical_args = args
263263
.iter()
264264
.map(|e| create_physical_expr(e, input_dfschema, execution_props))
265265
.collect::<Result<Vec<_>>>()?;
@@ -272,17 +272,11 @@ pub fn create_physical_expr(
272272
execution_props,
273273
)
274274
}
275-
ScalarFunctionDefinition::UDF(fun) => {
276-
// udfs with zero params expect null array as input
277-
if args.is_empty() {
278-
physical_args.push(Arc::new(Literal::new(ScalarValue::Null)));
279-
}
280-
udf::create_physical_expr(
281-
fun.clone().as_ref(),
282-
&physical_args,
283-
input_schema,
284-
)
285-
}
275+
ScalarFunctionDefinition::UDF(fun) => udf::create_physical_expr(
276+
fun.clone().as_ref(),
277+
&physical_args,
278+
input_schema,
279+
),
286280
ScalarFunctionDefinition::Name(_) => {
287281
internal_err!("Function `Expr` with name should be resolved.")
288282
}

datafusion/physical-expr/src/scalar_function.rs

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ pub struct ScalarFunctionExpr {
5858
// and it specifies the effect of an increase or decrease in
5959
// the corresponding `arg` to the function value.
6060
monotonicity: Option<FuncMonotonicity>,
61+
// Whether this function can be invoked with zero arguments
62+
supports_zero_argument: bool,
6163
}
6264

6365
impl Debug for ScalarFunctionExpr {
@@ -79,13 +81,15 @@ impl ScalarFunctionExpr {
7981
args: Vec<Arc<dyn PhysicalExpr>>,
8082
return_type: DataType,
8183
monotonicity: Option<FuncMonotonicity>,
84+
supports_zero_argument: bool,
8285
) -> Self {
8386
Self {
8487
fun,
8588
name: name.to_owned(),
8689
args,
8790
return_type,
8891
monotonicity,
92+
supports_zero_argument,
8993
}
9094
}
9195

@@ -138,9 +142,12 @@ impl PhysicalExpr for ScalarFunctionExpr {
138142
fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
139143
// evaluate the arguments, if there are no arguments we'll instead pass in a null array
140144
// indicating the batch size (as a convention)
141-
let inputs = match (self.args.len(), self.name.parse::<BuiltinScalarFunction>()) {
145+
let inputs = match (
146+
self.args.is_empty(),
147+
self.name.parse::<BuiltinScalarFunction>(),
148+
) {
142149
// MakeArray support zero argument but has the different behavior from the array with one null.
143-
(0, Ok(scalar_fun))
150+
(true, Ok(scalar_fun))
144151
if scalar_fun
145152
.signature()
146153
.type_signature
@@ -149,6 +156,11 @@ impl PhysicalExpr for ScalarFunctionExpr {
149156
{
150157
vec![ColumnarValue::create_null_array(batch.num_rows())]
151158
}
159+
// If the function supports zero argument, we pass in a null array indicating the batch size.
160+
// This is for user-defined functions.
161+
(true, Err(_)) if self.supports_zero_argument => {
162+
vec![ColumnarValue::create_null_array(batch.num_rows())]
163+
}
152164
_ => self
153165
.args
154166
.iter()
@@ -175,6 +187,7 @@ impl PhysicalExpr for ScalarFunctionExpr {
175187
children,
176188
self.return_type().clone(),
177189
self.monotonicity.clone(),
190+
self.supports_zero_argument,
178191
)))
179192
}
180193

datafusion/physical-expr/src/udf.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ pub fn create_physical_expr(
4040
input_phy_exprs.to_vec(),
4141
fun.return_type(&input_exprs_types)?,
4242
fun.monotonicity()?,
43+
fun.signature().type_signature.supports_zero_argument(),
4344
)))
4445
}
4546

datafusion/proto/src/physical_plan/from_proto.rs

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -340,21 +340,17 @@ pub fn parse_physical_expr(
340340
// TODO Do not create new the ExecutionProps
341341
let execution_props = ExecutionProps::new();
342342

343-
let fun_expr = functions::create_physical_fun(
343+
functions::create_physical_expr(
344344
&(&scalar_function).into(),
345+
&args,
346+
input_schema,
345347
&execution_props,
346-
)?;
347-
348-
Arc::new(ScalarFunctionExpr::new(
349-
&e.name,
350-
fun_expr,
351-
args,
352-
convert_required!(e.return_type)?,
353-
None,
354-
))
348+
)?
355349
}
356350
ExprType::ScalarUdf(e) => {
357-
let scalar_fun = registry.udf(e.name.as_str())?.fun().clone();
351+
let udf = registry.udf(e.name.as_str())?;
352+
let signature = udf.signature();
353+
let scalar_fun = udf.fun().clone();
358354

359355
let args = e
360356
.args
@@ -368,6 +364,7 @@ pub fn parse_physical_expr(
368364
args,
369365
convert_required!(e.return_type)?,
370366
None,
367+
signature.type_signature.supports_zero_argument(),
371368
))
372369
}
373370
ExprType::LikeExpr(like_expr) => Arc::new(LikeExpr::new(

datafusion/proto/tests/cases/roundtrip_physical_plan.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -578,8 +578,9 @@ fn roundtrip_builtin_scalar_function() -> Result<()> {
578578
"acos",
579579
fun_expr,
580580
vec![col("a", &schema)?],
581-
DataType::Int64,
581+
DataType::Float64,
582582
None,
583+
false,
583584
);
584585

585586
let project =
@@ -617,6 +618,7 @@ fn roundtrip_scalar_udf() -> Result<()> {
617618
vec![col("a", &schema)?],
618619
DataType::Int64,
619620
None,
621+
false,
620622
);
621623

622624
let project =

0 commit comments

Comments
 (0)