diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 27b2d71b1f42..d8b829f27e7d 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -92,7 +92,7 @@ pub use table_source::{TableProviderFilterPushDown, TableSource, TableType}; pub use udaf::{ aggregate_doc_sections, AggregateUDF, AggregateUDFImpl, ReversedUDAF, StatisticsArgs, }; -pub use udf::{scalar_doc_sections, ScalarUDF, ScalarUDFImpl}; +pub use udf::{scalar_doc_sections, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl}; pub use udf_docs::{DocSection, Documentation, DocumentationBuilder}; pub use udwf::{window_doc_sections, ReversedUDWF, WindowUDF, WindowUDFImpl}; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 1a5d50477b1c..8ba83d1712d2 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -203,9 +203,6 @@ impl ScalarUDF { self.inner.simplify(args, info) } - /// Invoke the function on `args`, returning the appropriate result. - /// - /// See [`ScalarUDFImpl::invoke`] for more details. #[deprecated(since = "42.1.0", note = "Use `invoke_batch` instead")] pub fn invoke(&self, args: &[ColumnarValue]) -> Result { #[allow(deprecated)] @@ -216,17 +213,23 @@ impl ScalarUDF { self.inner.is_nullable(args, schema) } - /// Invoke the function with `args` and number of rows, returning the appropriate result. - /// - /// See [`ScalarUDFImpl::invoke_batch`] for more details. + #[deprecated(since = "43.0.0", note = "Use `invoke_batch` instead")] pub fn invoke_batch( &self, args: &[ColumnarValue], number_rows: usize, ) -> Result { + #[allow(deprecated)] self.inner.invoke_batch(args, number_rows) } + /// Invoke the function on `args`, returning the appropriate result. + /// + /// See [`ScalarUDFImpl::invoke_with_args`] for more details. + pub fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + self.inner.invoke_with_args(args) + } + /// Invoke the function without `args` but number of rows, returning the appropriate result. /// /// See [`ScalarUDFImpl::invoke_no_args`] for more details. @@ -324,6 +327,18 @@ where } } +/// Arguments passed to [`ScalarUDFImpl::invoke_with_args`] when invoking a +/// scalar function. +pub struct ScalarFunctionArgs<'a> { + /// The evaluated arguments to the function + pub args: Vec, + /// The number of rows in record batch being evaluated + pub number_rows: usize, + /// The return type of the scalar function returned (from `return_type` or `return_type_from_exprs`) + /// when creating the physical expression from the logical expression + pub return_type: &'a DataType, +} + /// Trait for implementing [`ScalarUDF`]. /// /// This trait exposes the full API for implementing user defined functions and @@ -356,7 +371,7 @@ where /// } /// } /// } -/// +/// /// static DOCUMENTATION: OnceLock = OnceLock::new(); /// /// fn get_doc() -> &'static Documentation { @@ -518,6 +533,7 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { /// /// [`ColumnarValue::values_to_arrays`] can be used to convert the arguments /// to arrays, which will likely be simpler code, but be slower. + #[deprecated(since = "43.0.0", note = "Use `invoke_with_args` instead")] fn invoke_batch( &self, args: &[ColumnarValue], @@ -537,6 +553,23 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { } } + /// Invoke the function returning the appropriate result. + /// + /// The function will be invoked with a struct `ScalarFunctionArgs` + /// + /// # Performance Notes + /// + /// For the best performance, the implementations should handle the common case + /// when one or more of their arguments are constant values (aka + /// [`ColumnarValue::Scalar`]). + /// + /// [`ColumnarValue::values_to_arrays`] can be used to convert the arguments + /// to arrays, which will likely be simpler code, but be slower. + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + #[allow(deprecated)] + self.invoke_batch(&args.args, args.number_rows) + } + /// Invoke the function without `args`, instead the number of rows are provided, /// returning the appropriate result. #[deprecated(since = "42.1.0", note = "Use `invoke_batch` instead")] @@ -767,6 +800,7 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl { args: &[ColumnarValue], number_rows: usize, ) -> Result { + #[allow(deprecated)] self.inner.invoke_batch(args, number_rows) } diff --git a/datafusion/functions/benches/random.rs b/datafusion/functions/benches/random.rs index 5df5d9c7dee2..bc20e0ff11c1 100644 --- a/datafusion/functions/benches/random.rs +++ b/datafusion/functions/benches/random.rs @@ -29,6 +29,7 @@ fn criterion_benchmark(c: &mut Criterion) { c.bench_function("random_1M_rows_batch_8192", |b| { b.iter(|| { for _ in 0..iterations { + #[allow(deprecated)] // TODO: migrate to invoke_with_args black_box(random_func.invoke_batch(&[], 8192).unwrap()); } }) @@ -39,6 +40,7 @@ fn criterion_benchmark(c: &mut Criterion) { c.bench_function("random_1M_rows_batch_128", |b| { b.iter(|| { for _ in 0..iterations_128 { + #[allow(deprecated)] // TODO: migrate to invoke_with_args black_box(random_func.invoke_batch(&[], 128).unwrap()); } }) diff --git a/datafusion/functions/src/core/version.rs b/datafusion/functions/src/core/version.rs index 36cf07e9e5da..eac0aa38f058 100644 --- a/datafusion/functions/src/core/version.rs +++ b/datafusion/functions/src/core/version.rs @@ -121,6 +121,7 @@ mod test { #[tokio::test] async fn test_version_udf() { let version_udf = ScalarUDF::from(VersionFunc::new()); + #[allow(deprecated)] // TODO: migrate to invoke_with_args let version = version_udf.invoke_batch(&[], 1).unwrap(); if let ColumnarValue::Scalar(ScalarValue::Utf8(Some(version))) = version { diff --git a/datafusion/functions/src/datetime/to_local_time.rs b/datafusion/functions/src/datetime/to_local_time.rs index fef1eb9a60c8..5048b8fd47ec 100644 --- a/datafusion/functions/src/datetime/to_local_time.rs +++ b/datafusion/functions/src/datetime/to_local_time.rs @@ -431,7 +431,7 @@ mod tests { use arrow::datatypes::{DataType, TimeUnit}; use chrono::NaiveDateTime; use datafusion_common::ScalarValue; - use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; use super::{adjust_to_local_time, ToLocalTimeFunc}; @@ -558,7 +558,11 @@ mod tests { fn test_to_local_time_helper(input: ScalarValue, expected: ScalarValue) { let res = ToLocalTimeFunc::new() - .invoke_batch(&[ColumnarValue::Scalar(input)], 1) + .invoke_with_args(ScalarFunctionArgs { + args: &[ColumnarValue::Scalar(input)], + number_rows: 1, + return_type: &expected.data_type(), + }) .unwrap(); match res { ColumnarValue::Scalar(res) => { @@ -617,6 +621,7 @@ mod tests { .map(|s| Some(string_to_timestamp_nanos(s).unwrap())) .collect::(); let batch_size = input.len(); + #[allow(deprecated)] // TODO: migrate to invoke_with_args let result = ToLocalTimeFunc::new() .invoke_batch(&[ColumnarValue::Array(Arc::new(input))], batch_size) .unwrap(); diff --git a/datafusion/functions/src/datetime/to_timestamp.rs b/datafusion/functions/src/datetime/to_timestamp.rs index f15fad701c55..78a7bf505dac 100644 --- a/datafusion/functions/src/datetime/to_timestamp.rs +++ b/datafusion/functions/src/datetime/to_timestamp.rs @@ -1008,7 +1008,7 @@ mod tests { for array in arrays { let rt = udf.return_type(&[array.data_type()]).unwrap(); assert!(matches!(rt, Timestamp(_, Some(_)))); - + #[allow(deprecated)] // TODO: migrate to invoke_with_args let res = udf .invoke_batch(&[array.clone()], 1) .expect("that to_timestamp parsed values without error"); @@ -1051,7 +1051,7 @@ mod tests { for array in arrays { let rt = udf.return_type(&[array.data_type()]).unwrap(); assert!(matches!(rt, Timestamp(_, None))); - + #[allow(deprecated)] // TODO: migrate to invoke_with_args let res = udf .invoke_batch(&[array.clone()], 1) .expect("that to_timestamp parsed values without error"); diff --git a/datafusion/functions/src/datetime/to_unixtime.rs b/datafusion/functions/src/datetime/to_unixtime.rs index dd90ce6a6c96..c291596c2520 100644 --- a/datafusion/functions/src/datetime/to_unixtime.rs +++ b/datafusion/functions/src/datetime/to_unixtime.rs @@ -83,6 +83,7 @@ impl ScalarUDFImpl for ToUnixtimeFunc { DataType::Date64 | DataType::Date32 | DataType::Timestamp(_, None) => args[0] .cast_to(&DataType::Timestamp(TimeUnit::Second, None), None)? .cast_to(&DataType::Int64, None), + #[allow(deprecated)] // TODO: migrate to invoke_with_args DataType::Utf8 => ToTimestampSecondsFunc::new() .invoke_batch(args, batch_size)? .cast_to(&DataType::Int64, None), diff --git a/datafusion/functions/src/math/log.rs b/datafusion/functions/src/math/log.rs index 9110f9f532d8..14b6dc3e054e 100644 --- a/datafusion/functions/src/math/log.rs +++ b/datafusion/functions/src/math/log.rs @@ -277,7 +277,7 @@ mod tests { ]))), // num ColumnarValue::Array(Arc::new(Int64Array::from(vec![5, 10, 15, 20]))), ]; - + #[allow(deprecated)] // TODO: migrate to invoke_with_args let _ = LogFunc::new().invoke_batch(&args, 4); } @@ -286,7 +286,7 @@ mod tests { let args = [ ColumnarValue::Array(Arc::new(Int64Array::from(vec![10]))), // num ]; - + #[allow(deprecated)] // TODO: migrate to invoke_with_args let result = LogFunc::new().invoke_batch(&args, 1); result.expect_err("expected error"); } @@ -296,7 +296,7 @@ mod tests { let args = [ ColumnarValue::Scalar(ScalarValue::Float32(Some(10.0))), // num ]; - + #[allow(deprecated)] // TODO: migrate to invoke_with_args let result = LogFunc::new() .invoke_batch(&args, 1) .expect("failed to initialize function log"); @@ -320,7 +320,7 @@ mod tests { let args = [ ColumnarValue::Scalar(ScalarValue::Float64(Some(10.0))), // num ]; - + #[allow(deprecated)] // TODO: migrate to invoke_with_args let result = LogFunc::new() .invoke_batch(&args, 1) .expect("failed to initialize function log"); @@ -345,7 +345,7 @@ mod tests { ColumnarValue::Scalar(ScalarValue::Float32(Some(2.0))), // num ColumnarValue::Scalar(ScalarValue::Float32(Some(32.0))), // num ]; - + #[allow(deprecated)] // TODO: migrate to invoke_with_args let result = LogFunc::new() .invoke_batch(&args, 1) .expect("failed to initialize function log"); @@ -370,7 +370,7 @@ mod tests { ColumnarValue::Scalar(ScalarValue::Float64(Some(2.0))), // num ColumnarValue::Scalar(ScalarValue::Float64(Some(64.0))), // num ]; - + #[allow(deprecated)] // TODO: migrate to invoke_with_args let result = LogFunc::new() .invoke_batch(&args, 1) .expect("failed to initialize function log"); @@ -396,7 +396,7 @@ mod tests { 10.0, 100.0, 1000.0, 10000.0, ]))), // num ]; - + #[allow(deprecated)] // TODO: migrate to invoke_with_args let result = LogFunc::new() .invoke_batch(&args, 4) .expect("failed to initialize function log"); @@ -425,7 +425,7 @@ mod tests { 10.0, 100.0, 1000.0, 10000.0, ]))), // num ]; - + #[allow(deprecated)] // TODO: migrate to invoke_with_args let result = LogFunc::new() .invoke_batch(&args, 4) .expect("failed to initialize function log"); @@ -455,7 +455,7 @@ mod tests { 8.0, 4.0, 81.0, 625.0, ]))), // num ]; - + #[allow(deprecated)] // TODO: migrate to invoke_with_args let result = LogFunc::new() .invoke_batch(&args, 4) .expect("failed to initialize function log"); @@ -485,7 +485,7 @@ mod tests { 8.0, 4.0, 81.0, 625.0, ]))), // num ]; - + #[allow(deprecated)] // TODO: migrate to invoke_with_args let result = LogFunc::new() .invoke_batch(&args, 4) .expect("failed to initialize function log"); diff --git a/datafusion/functions/src/math/power.rs b/datafusion/functions/src/math/power.rs index a24c613f5259..acf5f84df92b 100644 --- a/datafusion/functions/src/math/power.rs +++ b/datafusion/functions/src/math/power.rs @@ -205,7 +205,7 @@ mod tests { ColumnarValue::Array(Arc::new(Float64Array::from(vec![2.0, 2.0, 3.0, 5.0]))), // base ColumnarValue::Array(Arc::new(Float64Array::from(vec![3.0, 2.0, 4.0, 4.0]))), // exponent ]; - + #[allow(deprecated)] // TODO: migrate to invoke_with_args let result = PowerFunc::new() .invoke_batch(&args, 4) .expect("failed to initialize function power"); @@ -232,7 +232,7 @@ mod tests { ColumnarValue::Array(Arc::new(Int64Array::from(vec![2, 2, 3, 5]))), // base ColumnarValue::Array(Arc::new(Int64Array::from(vec![3, 2, 4, 4]))), // exponent ]; - + #[allow(deprecated)] // TODO: migrate to invoke_with_args let result = PowerFunc::new() .invoke_batch(&args, 4) .expect("failed to initialize function power"); diff --git a/datafusion/functions/src/math/signum.rs b/datafusion/functions/src/math/signum.rs index 7f21297712c7..33ff630f309f 100644 --- a/datafusion/functions/src/math/signum.rs +++ b/datafusion/functions/src/math/signum.rs @@ -167,6 +167,7 @@ mod test { f32::NEG_INFINITY, ])); let batch_size = array.len(); + #[allow(deprecated)] // TODO: migrate to invoke_with_args let result = SignumFunc::new() .invoke_batch(&[ColumnarValue::Array(array)], batch_size) .expect("failed to initialize function signum"); @@ -207,6 +208,7 @@ mod test { f64::NEG_INFINITY, ])); let batch_size = array.len(); + #[allow(deprecated)] // TODO: migrate to invoke_with_args let result = SignumFunc::new() .invoke_batch(&[ColumnarValue::Array(array)], batch_size) .expect("failed to initialize function signum"); diff --git a/datafusion/functions/src/regex/regexpcount.rs b/datafusion/functions/src/regex/regexpcount.rs index 8da154430fc5..819463795b7f 100644 --- a/datafusion/functions/src/regex/regexpcount.rs +++ b/datafusion/functions/src/regex/regexpcount.rs @@ -655,7 +655,7 @@ mod tests { let v_sv = ScalarValue::Utf8(Some(v.to_string())); let regex_sv = ScalarValue::Utf8(Some(regex.to_string())); let expected = expected.get(pos).cloned(); - + #[allow(deprecated)] // TODO: migrate to invoke_with_args let re = RegexpCountFunc::new().invoke_batch( &[ColumnarValue::Scalar(v_sv), ColumnarValue::Scalar(regex_sv)], 1, @@ -670,7 +670,7 @@ mod tests { // largeutf8 let v_sv = ScalarValue::LargeUtf8(Some(v.to_string())); let regex_sv = ScalarValue::LargeUtf8(Some(regex.to_string())); - + #[allow(deprecated)] // TODO: migrate to invoke_with_args let re = RegexpCountFunc::new().invoke_batch( &[ColumnarValue::Scalar(v_sv), ColumnarValue::Scalar(regex_sv)], 1, @@ -685,7 +685,7 @@ mod tests { // utf8view let v_sv = ScalarValue::Utf8View(Some(v.to_string())); let regex_sv = ScalarValue::Utf8View(Some(regex.to_string())); - + #[allow(deprecated)] // TODO: migrate to invoke_with_args let re = RegexpCountFunc::new().invoke_batch( &[ColumnarValue::Scalar(v_sv), ColumnarValue::Scalar(regex_sv)], 1, @@ -711,7 +711,7 @@ mod tests { let regex_sv = ScalarValue::Utf8(Some(regex.to_string())); let start_sv = ScalarValue::Int64(Some(start)); let expected = expected.get(pos).cloned(); - + #[allow(deprecated)] // TODO: migrate to invoke_with_args let re = RegexpCountFunc::new().invoke_batch( &[ ColumnarValue::Scalar(v_sv), @@ -730,7 +730,7 @@ mod tests { // largeutf8 let v_sv = ScalarValue::LargeUtf8(Some(v.to_string())); let regex_sv = ScalarValue::LargeUtf8(Some(regex.to_string())); - + #[allow(deprecated)] // TODO: migrate to invoke_with_args let re = RegexpCountFunc::new().invoke_batch( &[ ColumnarValue::Scalar(v_sv), @@ -749,7 +749,7 @@ mod tests { // utf8view let v_sv = ScalarValue::Utf8View(Some(v.to_string())); let regex_sv = ScalarValue::Utf8View(Some(regex.to_string())); - + #[allow(deprecated)] // TODO: migrate to invoke_with_args let re = RegexpCountFunc::new().invoke_batch( &[ ColumnarValue::Scalar(v_sv), @@ -781,7 +781,7 @@ mod tests { let start_sv = ScalarValue::Int64(Some(start)); let flags_sv = ScalarValue::Utf8(Some(flags.to_string())); let expected = expected.get(pos).cloned(); - + #[allow(deprecated)] // TODO: migrate to invoke_with_args let re = RegexpCountFunc::new().invoke_batch( &[ ColumnarValue::Scalar(v_sv), @@ -802,7 +802,7 @@ mod tests { let v_sv = ScalarValue::LargeUtf8(Some(v.to_string())); let regex_sv = ScalarValue::LargeUtf8(Some(regex.to_string())); let flags_sv = ScalarValue::LargeUtf8(Some(flags.to_string())); - + #[allow(deprecated)] // TODO: migrate to invoke_with_args let re = RegexpCountFunc::new().invoke_batch( &[ ColumnarValue::Scalar(v_sv), @@ -823,7 +823,7 @@ mod tests { let v_sv = ScalarValue::Utf8View(Some(v.to_string())); let regex_sv = ScalarValue::Utf8View(Some(regex.to_string())); let flags_sv = ScalarValue::Utf8View(Some(flags.to_string())); - + #[allow(deprecated)] // TODO: migrate to invoke_with_args let re = RegexpCountFunc::new().invoke_batch( &[ ColumnarValue::Scalar(v_sv), @@ -905,7 +905,7 @@ mod tests { let start_sv = ScalarValue::Int64(Some(start)); let flags_sv = ScalarValue::Utf8(flags.get(pos).map(|f| f.to_string())); let expected = expected.get(pos).cloned(); - + #[allow(deprecated)] // TODO: migrate to invoke_with_args let re = RegexpCountFunc::new().invoke_batch( &[ ColumnarValue::Scalar(v_sv), @@ -926,7 +926,7 @@ mod tests { let v_sv = ScalarValue::LargeUtf8(Some(v.to_string())); let regex_sv = ScalarValue::LargeUtf8(regex.get(pos).map(|s| s.to_string())); let flags_sv = ScalarValue::LargeUtf8(flags.get(pos).map(|f| f.to_string())); - + #[allow(deprecated)] // TODO: migrate to invoke_with_args let re = RegexpCountFunc::new().invoke_batch( &[ ColumnarValue::Scalar(v_sv), @@ -947,7 +947,7 @@ mod tests { let v_sv = ScalarValue::Utf8View(Some(v.to_string())); let regex_sv = ScalarValue::Utf8View(regex.get(pos).map(|s| s.to_string())); let flags_sv = ScalarValue::Utf8View(flags.get(pos).map(|f| f.to_string())); - + #[allow(deprecated)] // TODO: migrate to invoke_with_args let re = RegexpCountFunc::new().invoke_batch( &[ ColumnarValue::Scalar(v_sv), diff --git a/datafusion/functions/src/string/ascii.rs b/datafusion/functions/src/string/ascii.rs index b76d70d7e9d2..e3cd89b43257 100644 --- a/datafusion/functions/src/string/ascii.rs +++ b/datafusion/functions/src/string/ascii.rs @@ -155,7 +155,7 @@ mod tests { ($INPUT:expr, $EXPECTED:expr) => { test_function!( AsciiFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8($INPUT))], + [ColumnarValue::Scalar(ScalarValue::Utf8($INPUT))], $EXPECTED, i32, Int32, @@ -164,7 +164,7 @@ mod tests { test_function!( AsciiFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT))], + [ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT))], $EXPECTED, i32, Int32, @@ -173,7 +173,7 @@ mod tests { test_function!( AsciiFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT))], + [ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT))], $EXPECTED, i32, Int32, diff --git a/datafusion/functions/src/utils.rs b/datafusion/functions/src/utils.rs index 87180cb77de7..cf549e5efbf0 100644 --- a/datafusion/functions/src/utils.rs +++ b/datafusion/functions/src/utils.rs @@ -146,9 +146,11 @@ pub mod test { match expected { Ok(expected) => { assert_eq!(return_type.is_ok(), true); - assert_eq!(return_type.unwrap(), $EXPECTED_DATA_TYPE); + let return_type = return_type.unwrap(); + assert_eq!(return_type, $EXPECTED_DATA_TYPE); - let result = func.invoke_batch($ARGS, cardinality); + let args: Vec<_> = $ARGS.into_iter().collect(); + let result = func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{args, number_rows: cardinality, return_type: &return_type}); assert_eq!(result.is_ok(), true, "function returned an error: {}", result.unwrap_err()); let result = result.unwrap().clone().into_array(cardinality).expect("Failed to convert to array"); @@ -169,7 +171,7 @@ pub mod test { } else { // invoke is expected error - cannot use .expect_err() due to Debug not being implemented - match func.invoke_batch($ARGS, cardinality) { + match func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{args: $ARGS, number_rows: cardinality, return_type: &return_type.unwrap()}) { Ok(_) => assert!(false, "expected error"), Err(error) => { assert!(expected_error.strip_backtrace().starts_with(&error.strip_backtrace())); diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index 9bf168e8a199..3f6aa2d90aae 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -43,7 +43,7 @@ use datafusion_common::{internal_err, DFSchema, Result, ScalarValue}; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::ExprProperties; use datafusion_expr::type_coercion::functions::data_types_with_scalar_udf; -use datafusion_expr::{expr_vec_fmt, ColumnarValue, Expr, ScalarUDF}; +use datafusion_expr::{expr_vec_fmt, ColumnarValue, Expr, ScalarFunctionArgs, ScalarUDF}; /// Physical expression of a scalar function #[derive(Eq, PartialEq, Hash)] @@ -140,18 +140,23 @@ impl PhysicalExpr for ScalarFunctionExpr { .map(|e| e.evaluate(batch)) .collect::>>()?; + let inputs_all_scalars = !inputs.is_empty() && inputs + .iter() + .all(|arg| matches!(arg, ColumnarValue::Scalar(_))); + // evaluate the function - let output = self.fun.invoke_batch(&inputs, batch.num_rows())?; + let output = self.fun.invoke_with_args(ScalarFunctionArgs { + args: inputs, + number_rows: batch.num_rows(), + return_type: &self.return_type, + })?; + if let ColumnarValue::Array(array) = &output { if array.len() != batch.num_rows() { - // If the arguments are a non-empty slice of scalar values, we can assume that + // If the arguments all non-empty slice of scalar values, we can assume that // returning a one-element array is equivalent to returning a scalar. - let preserve_scalar = array.len() == 1 - && !inputs.is_empty() - && inputs - .iter() - .all(|arg| matches!(arg, ColumnarValue::Scalar(_))); + let preserve_scalar = array.len() == 1 && inputs_all_scalars; return if preserve_scalar { ScalarValue::try_from_array(array, 0).map(ColumnarValue::Scalar) } else {