Skip to content

Commit 3c1f832

Browse files
committed
Fix ScalaUDF with zero arguments
1 parent d594e62 commit 3c1f832

File tree

8 files changed

+159
-27
lines changed

8 files changed

+159
-27
lines changed

datafusion/core/src/physical_optimizer/projection_pushdown.rs

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1225,7 +1225,7 @@ mod tests {
12251225
use datafusion_common::{JoinSide, JoinType, Result, ScalarValue, Statistics};
12261226
use datafusion_execution::object_store::ObjectStoreUrl;
12271227
use datafusion_execution::{SendableRecordBatchStream, TaskContext};
1228-
use datafusion_expr::{ColumnarValue, Operator};
1228+
use datafusion_expr::{ColumnarValue, Operator, Signature, Volatility};
12291229
use datafusion_physical_expr::expressions::{
12301230
BinaryExpr, CaseExpr, CastExpr, Column, Literal, NegativeExpr,
12311231
};
@@ -1270,6 +1270,10 @@ mod tests {
12701270
],
12711271
DataType::Int32,
12721272
None,
1273+
Signature::exact(
1274+
vec![DataType::Float32, DataType::Float32],
1275+
Volatility::Immutable,
1276+
),
12731277
)),
12741278
Arc::new(CaseExpr::try_new(
12751279
Some(Arc::new(Column::new("d", 2))),
@@ -1336,6 +1340,10 @@ mod tests {
13361340
],
13371341
DataType::Int32,
13381342
None,
1343+
Signature::exact(
1344+
vec![DataType::Float32, DataType::Float32],
1345+
Volatility::Immutable,
1346+
),
13391347
)),
13401348
Arc::new(CaseExpr::try_new(
13411349
Some(Arc::new(Column::new("d", 3))),
@@ -1405,6 +1413,10 @@ mod tests {
14051413
],
14061414
DataType::Int32,
14071415
None,
1416+
Signature::exact(
1417+
vec![DataType::Float32, DataType::Float32],
1418+
Volatility::Immutable,
1419+
),
14081420
)),
14091421
Arc::new(CaseExpr::try_new(
14101422
Some(Arc::new(Column::new("d", 2))),
@@ -1471,6 +1483,10 @@ mod tests {
14711483
],
14721484
DataType::Int32,
14731485
None,
1486+
Signature::exact(
1487+
vec![DataType::Float32, DataType::Float32],
1488+
Volatility::Immutable,
1489+
),
14741490
)),
14751491
Arc::new(CaseExpr::try_new(
14761492
Some(Arc::new(Column::new("d_new", 3))),

datafusion/core/tests/user_defined/user_defined_scalar_functions.rs

Lines changed: 115 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,22 @@
1616
// under the License.
1717

1818
use arrow::compute::kernels::numeric::add;
19-
use arrow_array::{ArrayRef, Float64Array, Int32Array, RecordBatch};
19+
use arrow_array::{
20+
ArrayRef, Float64Array, Int32Array, Int64Array, RecordBatch, UInt64Array, UInt8Array,
21+
};
22+
use arrow_schema::DataType::Float64;
2023
use arrow_schema::{DataType, Field, Schema};
2124
use datafusion::prelude::*;
2225
use datafusion::{execution::registry::FunctionRegistry, test_util};
2326
use datafusion_common::cast::as_float64_array;
2427
use datafusion_common::{assert_batches_eq, cast::as_int32_array, Result, ScalarValue};
28+
use datafusion_expr::TypeSignature::{Any, Variadic};
2529
use datafusion_expr::{
26-
create_udaf, create_udf, Accumulator, ColumnarValue, LogicalPlanBuilder, Volatility,
30+
create_udaf, create_udf, Accumulator, ColumnarValue, LogicalPlanBuilder, ScalarUDF,
31+
ScalarUDFImpl, Signature, Volatility,
2732
};
33+
use rand::{thread_rng, Rng};
34+
use std::iter;
2835
use std::sync::Arc;
2936

3037
/// test that casting happens on udfs.
@@ -392,6 +399,112 @@ async fn test_user_defined_functions_with_alias() -> Result<()> {
392399
Ok(())
393400
}
394401

402+
#[derive(Debug)]
403+
pub struct RandomUDF {
404+
signature: Signature,
405+
}
406+
407+
impl RandomUDF {
408+
pub fn new() -> Self {
409+
Self {
410+
signature: Signature::one_of(
411+
vec![Any(0), Variadic(vec![Float64])],
412+
Volatility::Volatile,
413+
),
414+
}
415+
}
416+
}
417+
418+
impl ScalarUDFImpl for RandomUDF {
419+
fn as_any(&self) -> &dyn std::any::Any {
420+
self
421+
}
422+
423+
fn name(&self) -> &str {
424+
"random_udf"
425+
}
426+
427+
fn signature(&self) -> &Signature {
428+
&self.signature
429+
}
430+
431+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
432+
Ok(Float64)
433+
}
434+
435+
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
436+
let len: usize = match &args[0] {
437+
ColumnarValue::Array(array) => array.len(),
438+
_ => {
439+
return Err(datafusion::error::DataFusionError::Internal(
440+
"Invalid argument type".to_string(),
441+
))
442+
}
443+
};
444+
let mut rng = thread_rng();
445+
let values = iter::repeat_with(|| rng.gen_range(0.1..1.0)).take(len);
446+
let array = Float64Array::from_iter_values(values);
447+
Ok(ColumnarValue::Array(Arc::new(array)))
448+
}
449+
}
450+
451+
#[tokio::test]
452+
async fn test_user_defined_functions_zero_argument() -> Result<()> {
453+
let ctx = SessionContext::new();
454+
455+
let schema = Arc::new(Schema::new(vec![
456+
Field::new("index", DataType::UInt8, false),
457+
Field::new("uint", DataType::UInt64, true),
458+
Field::new("int", DataType::Int64, true),
459+
Field::new("float", DataType::Float64, true),
460+
]));
461+
462+
let batch = RecordBatch::try_new(
463+
schema,
464+
vec![
465+
Arc::new(UInt8Array::from_iter_values([1, 2, 3])),
466+
Arc::new(UInt64Array::from(vec![Some(2), Some(3), None])),
467+
Arc::new(Int64Array::from(vec![Some(-2), Some(3), None])),
468+
Arc::new(Float64Array::from(vec![Some(1.0), Some(3.3), None])),
469+
],
470+
)?;
471+
472+
ctx.register_batch("data_table", batch)?;
473+
474+
let random_normal_udf = ScalarUDF::from(RandomUDF::new());
475+
ctx.register_udf(random_normal_udf);
476+
477+
let result = plan_and_collect(
478+
&ctx,
479+
"SELECT random_udf() AS random_udf, random() AS native_random FROM data_table",
480+
)
481+
.await?;
482+
483+
assert_eq!(result.len(), 1);
484+
let batch = &result[0];
485+
let random_udf = batch
486+
.column(0)
487+
.as_any()
488+
.downcast_ref::<Float64Array>()
489+
.unwrap();
490+
let native_random = batch
491+
.column(1)
492+
.as_any()
493+
.downcast_ref::<Float64Array>()
494+
.unwrap();
495+
496+
assert_eq!(random_udf.len(), native_random.len());
497+
498+
let mut previous = 1.0;
499+
for i in 0..random_udf.len() {
500+
assert!(random_udf.value(i) >= 0.0 && random_udf.value(i) < 1.0);
501+
assert!(random_udf.value(i) != previous);
502+
previous = random_udf.value(i);
503+
}
504+
505+
Ok(())
506+
}
507+
395508
fn create_udf_context() -> SessionContext {
396509
let ctx = SessionContext::new();
397510
// register a custom UDF

datafusion/physical-expr/src/functions.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ pub fn create_physical_expr(
8181
input_phy_exprs.to_vec(),
8282
data_type,
8383
monotonicity,
84+
fun.signature().clone(),
8485
)))
8586
}
8687

datafusion/physical-expr/src/planner.rs

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ pub fn create_physical_expr(
259259
}
260260

261261
Expr::ScalarFunction(ScalarFunction { func_def, args }) => {
262-
let mut physical_args = args
262+
let physical_args = args
263263
.iter()
264264
.map(|e| create_physical_expr(e, input_dfschema, execution_props))
265265
.collect::<Result<Vec<_>>>()?;
@@ -272,17 +272,11 @@ pub fn create_physical_expr(
272272
execution_props,
273273
)
274274
}
275-
ScalarFunctionDefinition::UDF(fun) => {
276-
// udfs with zero params expect null array as input
277-
if args.is_empty() {
278-
physical_args.push(Arc::new(Literal::new(ScalarValue::Null)));
279-
}
280-
udf::create_physical_expr(
281-
fun.clone().as_ref(),
282-
&physical_args,
283-
input_schema,
284-
)
285-
}
275+
ScalarFunctionDefinition::UDF(fun) => udf::create_physical_expr(
276+
fun.clone().as_ref(),
277+
&physical_args,
278+
input_schema,
279+
),
286280
ScalarFunctionDefinition::Name(_) => {
287281
internal_err!("Function `Expr` with name should be resolved.")
288282
}

datafusion/physical-expr/src/scalar_function.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ use arrow::record_batch::RecordBatch;
4444
use datafusion_common::Result;
4545
use datafusion_expr::{
4646
expr_vec_fmt, BuiltinScalarFunction, ColumnarValue, FuncMonotonicity,
47-
ScalarFunctionImplementation,
47+
ScalarFunctionImplementation, Signature,
4848
};
4949

5050
/// Physical expression of a scalar function
@@ -58,6 +58,8 @@ pub struct ScalarFunctionExpr {
5858
// and it specifies the effect of an increase or decrease in
5959
// the corresponding `arg` to the function value.
6060
monotonicity: Option<FuncMonotonicity>,
61+
// Signature of the function
62+
signature: Signature,
6163
}
6264

6365
impl Debug for ScalarFunctionExpr {
@@ -79,13 +81,15 @@ impl ScalarFunctionExpr {
7981
args: Vec<Arc<dyn PhysicalExpr>>,
8082
return_type: DataType,
8183
monotonicity: Option<FuncMonotonicity>,
84+
signature: Signature,
8285
) -> Self {
8386
Self {
8487
fun,
8588
name: name.to_owned(),
8689
args,
8790
return_type,
8891
monotonicity,
92+
signature,
8993
}
9094
}
9195

@@ -149,6 +153,9 @@ impl PhysicalExpr for ScalarFunctionExpr {
149153
{
150154
vec![ColumnarValue::create_null_array(batch.num_rows())]
151155
}
156+
(0, _) if self.signature.type_signature.supports_zero_argument() => {
157+
vec![ColumnarValue::create_null_array(batch.num_rows())]
158+
}
152159
_ => self
153160
.args
154161
.iter()
@@ -175,6 +182,7 @@ impl PhysicalExpr for ScalarFunctionExpr {
175182
children,
176183
self.return_type().clone(),
177184
self.monotonicity.clone(),
185+
self.signature.clone(),
178186
)))
179187
}
180188

datafusion/physical-expr/src/udf.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ pub fn create_physical_expr(
4040
input_phy_exprs.to_vec(),
4141
fun.return_type(&input_exprs_types)?,
4242
fun.monotonicity()?,
43+
fun.signature().clone(),
4344
)))
4445
}
4546

datafusion/proto/src/physical_plan/from_proto.rs

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -340,21 +340,17 @@ pub fn parse_physical_expr(
340340
// TODO Do not create new the ExecutionProps
341341
let execution_props = ExecutionProps::new();
342342

343-
let fun_expr = functions::create_physical_fun(
343+
functions::create_physical_expr(
344344
&(&scalar_function).into(),
345+
&args,
346+
input_schema,
345347
&execution_props,
346-
)?;
347-
348-
Arc::new(ScalarFunctionExpr::new(
349-
&e.name,
350-
fun_expr,
351-
args,
352-
convert_required!(e.return_type)?,
353-
None,
354-
))
348+
)?
355349
}
356350
ExprType::ScalarUdf(e) => {
357-
let scalar_fun = registry.udf(e.name.as_str())?.fun().clone();
351+
let udf = registry.udf(e.name.as_str())?;
352+
let signature = udf.signature();
353+
let scalar_fun = udf.fun().clone();
358354

359355
let args = e
360356
.args
@@ -368,6 +364,7 @@ pub fn parse_physical_expr(
368364
args,
369365
convert_required!(e.return_type)?,
370366
None,
367+
signature.clone(),
371368
))
372369
}
373370
ExprType::LikeExpr(like_expr) => Arc::new(LikeExpr::new(

datafusion/proto/tests/cases/roundtrip_physical_plan.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -580,6 +580,7 @@ fn roundtrip_builtin_scalar_function() -> Result<()> {
580580
vec![col("a", &schema)?],
581581
DataType::Int64,
582582
None,
583+
Signature::exact(vec![DataType::Int64], Volatility::Immutable),
583584
);
584585

585586
let project =
@@ -617,6 +618,7 @@ fn roundtrip_scalar_udf() -> Result<()> {
617618
vec![col("a", &schema)?],
618619
DataType::Int64,
619620
None,
621+
Signature::exact(vec![DataType::Int64], Volatility::Immutable),
620622
);
621623

622624
let project =

0 commit comments

Comments
 (0)