diff --git a/datafusion/functions/benches/concat.rs b/datafusion/functions/benches/concat.rs index 0f287ab36dad..45ca076e754f 100644 --- a/datafusion/functions/benches/concat.rs +++ b/datafusion/functions/benches/concat.rs @@ -16,10 +16,11 @@ // under the License. use arrow::array::ArrayRef; +use arrow::datatypes::DataType; use arrow::util::bench_util::create_string_array_with_len; use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; use datafusion_common::ScalarValue; -use datafusion_expr::ColumnarValue; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::string::concat; use std::sync::Arc; @@ -39,8 +40,16 @@ fn criterion_benchmark(c: &mut Criterion) { let mut group = c.benchmark_group("concat function"); group.bench_function(BenchmarkId::new("concat", size), |b| { b.iter(|| { - // TODO use invoke_with_args - criterion::black_box(concat().invoke_batch(&args, size).unwrap()) + let args_cloned = args.clone(); + criterion::black_box( + concat() + .invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + number_rows: size, + return_type: &DataType::Utf8, + }) + .unwrap(), + ) }) }); group.finish(); diff --git a/datafusion/functions/benches/lower.rs b/datafusion/functions/benches/lower.rs index 114ac4a16fe5..534e5739225d 100644 --- a/datafusion/functions/benches/lower.rs +++ b/datafusion/functions/benches/lower.rs @@ -18,11 +18,12 @@ extern crate criterion; use arrow::array::{ArrayRef, StringArray, StringViewBuilder}; +use arrow::datatypes::DataType; use arrow::util::bench_util::{ create_string_array_with_len, create_string_view_array_with_len, }; use criterion::{black_box, criterion_group, criterion_main, Criterion}; -use datafusion_expr::ColumnarValue; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::string; use std::sync::Arc; @@ -125,8 +126,12 @@ fn criterion_benchmark(c: &mut Criterion) { let args = create_args1(size, 32); c.bench_function(&format!("lower_all_values_are_ascii: {}", size), |b| { b.iter(|| { - // TODO use invoke_with_args - black_box(lower.invoke_batch(&args, size)) + let args_cloned = args.clone(); + black_box(lower.invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + number_rows: size, + return_type: &DataType::Utf8, + })) }) }); @@ -135,8 +140,12 @@ fn criterion_benchmark(c: &mut Criterion) { &format!("lower_the_first_value_is_nonascii: {}", size), |b| { b.iter(|| { - // TODO use invoke_with_args - black_box(lower.invoke_batch(&args, size)) + let args_cloned = args.clone(); + black_box(lower.invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + number_rows: size, + return_type: &DataType::Utf8, + })) }) }, ); @@ -146,8 +155,12 @@ fn criterion_benchmark(c: &mut Criterion) { &format!("lower_the_middle_value_is_nonascii: {}", size), |b| { b.iter(|| { - // TODO use invoke_with_args - black_box(lower.invoke_batch(&args, size)) + let args_cloned = args.clone(); + black_box(lower.invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + number_rows: size, + return_type: &DataType::Utf8, + })) }) }, ); @@ -167,8 +180,12 @@ fn criterion_benchmark(c: &mut Criterion) { &format!("lower_all_values_are_ascii_string_views: size: {}, str_len: {}, null_density: {}, mixed: {}", size, str_len, null_density, mixed), |b| b.iter(|| { - // TODO use invoke_with_args - black_box(lower.invoke_batch(&args, size)) + let args_cloned = args.clone(); + black_box(lower.invoke_with_args(ScalarFunctionArgs{ + args: args_cloned, + number_rows: size, + return_type: &DataType::Utf8, + })) }), ); @@ -177,8 +194,12 @@ fn criterion_benchmark(c: &mut Criterion) { &format!("lower_all_values_are_ascii_string_views: size: {}, str_len: {}, null_density: {}, mixed: {}", size, str_len, null_density, mixed), |b| b.iter(|| { - // TODO use invoke_with_args - black_box(lower.invoke_batch(&args, size)) + let args_cloned = args.clone(); + black_box(lower.invoke_with_args(ScalarFunctionArgs{ + args: args_cloned, + number_rows: size, + return_type: &DataType::Utf8, + })) }), ); @@ -187,8 +208,12 @@ fn criterion_benchmark(c: &mut Criterion) { &format!("lower_some_values_are_nonascii_string_views: size: {}, str_len: {}, non_ascii_density: {}, null_density: {}, mixed: {}", size, str_len, 0.1, null_density, mixed), |b| b.iter(|| { - // TODO use invoke_with_args - black_box(lower.invoke_batch(&args, size)) + let args_cloned = args.clone(); + black_box(lower.invoke_with_args(ScalarFunctionArgs{ + args: args_cloned, + number_rows: size, + return_type: &DataType::Utf8, + })) }), ); } diff --git a/datafusion/functions/benches/ltrim.rs b/datafusion/functions/benches/ltrim.rs index fed455eeac91..457fb499f5a1 100644 --- a/datafusion/functions/benches/ltrim.rs +++ b/datafusion/functions/benches/ltrim.rs @@ -18,12 +18,13 @@ extern crate criterion; use arrow::array::{ArrayRef, LargeStringArray, StringArray, StringViewArray}; +use arrow::datatypes::DataType; use criterion::{ black_box, criterion_group, criterion_main, measurement::Measurement, BenchmarkGroup, Criterion, SamplingMode, }; use datafusion_common::ScalarValue; -use datafusion_expr::{ColumnarValue, ScalarUDF}; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDF}; use datafusion_functions::string; use rand::{distributions::Alphanumeric, rngs::StdRng, Rng, SeedableRng}; use std::{fmt, sync::Arc}; @@ -141,8 +142,12 @@ fn run_with_string_type( ), |b| { b.iter(|| { - // TODO use invoke_with_args - black_box(ltrim.invoke_batch(&args, size)) + let args_cloned = args.clone(); + black_box(ltrim.invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + number_rows: size, + return_type: &DataType::Utf8, + })) }) }, ); diff --git a/datafusion/functions/benches/repeat.rs b/datafusion/functions/benches/repeat.rs index 71207a0548fa..5cc6a177d9d9 100644 --- a/datafusion/functions/benches/repeat.rs +++ b/datafusion/functions/benches/repeat.rs @@ -18,11 +18,12 @@ extern crate criterion; use arrow::array::{ArrayRef, Int64Array, OffsetSizeTrait}; +use arrow::datatypes::DataType; use arrow::util::bench_util::{ create_string_array_with_len, create_string_view_array_with_len, }; use criterion::{black_box, criterion_group, criterion_main, Criterion, SamplingMode}; -use datafusion_expr::ColumnarValue; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::string; use std::sync::Arc; use std::time::Duration; @@ -73,8 +74,12 @@ fn criterion_benchmark(c: &mut Criterion) { ), |b| { b.iter(|| { - // TODO use invoke_with_args - black_box(repeat.invoke_batch(&args, repeat_times as usize)) + let args_cloned = args.clone(); + black_box(repeat.invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + number_rows: repeat_times as usize, + return_type: &DataType::Utf8, + })) }) }, ); @@ -87,8 +92,12 @@ fn criterion_benchmark(c: &mut Criterion) { ), |b| { b.iter(|| { - // TODO use invoke_with_args - black_box(repeat.invoke_batch(&args, repeat_times as usize)) + let args_cloned = args.clone(); + black_box(repeat.invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + number_rows: repeat_times as usize, + return_type: &DataType::Utf8, + })) }) }, ); @@ -101,8 +110,12 @@ fn criterion_benchmark(c: &mut Criterion) { ), |b| { b.iter(|| { - // TODO use invoke_with_args - black_box(repeat.invoke_batch(&args, repeat_times as usize)) + let args_cloned = args.clone(); + black_box(repeat.invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + number_rows: repeat_times as usize, + return_type: &DataType::Utf8, + })) }) }, ); @@ -124,8 +137,12 @@ fn criterion_benchmark(c: &mut Criterion) { ), |b| { b.iter(|| { - // TODO use invoke_with_args - black_box(repeat.invoke_batch(&args, repeat_times as usize)) + let args_cloned = args.clone(); + black_box(repeat.invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + number_rows: repeat_times as usize, + return_type: &DataType::Utf8, + })) }) }, ); @@ -138,8 +155,12 @@ fn criterion_benchmark(c: &mut Criterion) { ), |b| { b.iter(|| { - // TODO use invoke_with_args - black_box(repeat.invoke_batch(&args, size)) + let args_cloned = args.clone(); + black_box(repeat.invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + number_rows: repeat_times as usize, + return_type: &DataType::Utf8, + })) }) }, ); @@ -152,8 +173,12 @@ fn criterion_benchmark(c: &mut Criterion) { ), |b| { b.iter(|| { - // TODO use invoke_with_args - black_box(repeat.invoke_batch(&args, repeat_times as usize)) + let args_cloned = args.clone(); + black_box(repeat.invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + number_rows: repeat_times as usize, + return_type: &DataType::Utf8, + })) }) }, ); @@ -175,8 +200,12 @@ fn criterion_benchmark(c: &mut Criterion) { ), |b| { b.iter(|| { - // TODO use invoke_with_args - black_box(repeat.invoke_batch(&args, size)) + let args_cloned = args.clone(); + black_box(repeat.invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + number_rows: repeat_times as usize, + return_type: &DataType::Utf8, + })) }) }, ); diff --git a/datafusion/functions/benches/to_hex.rs b/datafusion/functions/benches/to_hex.rs index ce3767cc4839..a45d936c0a52 100644 --- a/datafusion/functions/benches/to_hex.rs +++ b/datafusion/functions/benches/to_hex.rs @@ -17,12 +17,10 @@ extern crate criterion; -use arrow::{ - datatypes::{Int32Type, Int64Type}, - util::bench_util::create_primitive_array, -}; +use arrow::datatypes::{DataType, Int32Type, Int64Type}; +use arrow::util::bench_util::create_primitive_array; use criterion::{black_box, criterion_group, criterion_main, Criterion}; -use datafusion_expr::ColumnarValue; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::string; use std::sync::Arc; @@ -33,13 +31,33 @@ fn criterion_benchmark(c: &mut Criterion) { let batch_len = i32_array.len(); let i32_args = vec![ColumnarValue::Array(i32_array)]; c.bench_function(&format!("to_hex i32 array: {}", size), |b| { - b.iter(|| black_box(hex.invoke_batch(&i32_args, batch_len).unwrap())) + b.iter(|| { + let args_cloned = i32_args.clone(); + black_box( + hex.invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + number_rows: batch_len, + return_type: &DataType::Utf8, + }) + .unwrap(), + ) + }) }); let i64_array = Arc::new(create_primitive_array::(size, 0.2)); let batch_len = i64_array.len(); let i64_args = vec![ColumnarValue::Array(i64_array)]; c.bench_function(&format!("to_hex i64 array: {}", size), |b| { - b.iter(|| black_box(hex.invoke_batch(&i64_args, batch_len).unwrap())) + b.iter(|| { + let args_cloned = i64_args.clone(); + black_box( + hex.invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + number_rows: batch_len, + return_type: &DataType::Utf8, + }) + .unwrap(), + ) + }) }); } diff --git a/datafusion/functions/benches/upper.rs b/datafusion/functions/benches/upper.rs index 9b41a15b11c7..f0bee89c7d37 100644 --- a/datafusion/functions/benches/upper.rs +++ b/datafusion/functions/benches/upper.rs @@ -17,9 +17,10 @@ extern crate criterion; +use arrow::datatypes::DataType; use arrow::util::bench_util::create_string_array_with_len; use criterion::{black_box, criterion_group, criterion_main, Criterion}; -use datafusion_expr::ColumnarValue; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::string; use std::sync::Arc; @@ -38,8 +39,12 @@ fn criterion_benchmark(c: &mut Criterion) { let args = create_args(size, 32); c.bench_function("upper_all_values_are_ascii", |b| { b.iter(|| { - // TODO use invoke_with_args - black_box(upper.invoke_batch(&args, size)) + let args_cloned = args.clone(); + black_box(upper.invoke_with_args(ScalarFunctionArgs { + args: args_cloned, + number_rows: size, + return_type: &DataType::Utf8, + })) }) }); } diff --git a/datafusion/functions/benches/uuid.rs b/datafusion/functions/benches/uuid.rs index 95cf77de3190..7b8d156fec21 100644 --- a/datafusion/functions/benches/uuid.rs +++ b/datafusion/functions/benches/uuid.rs @@ -17,13 +17,21 @@ extern crate criterion; +use arrow::datatypes::DataType; use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_expr::ScalarFunctionArgs; use datafusion_functions::string; fn criterion_benchmark(c: &mut Criterion) { let uuid = string::uuid(); c.bench_function("uuid", |b| { - b.iter(|| black_box(uuid.invoke_batch(&[], 1024))) + b.iter(|| { + black_box(uuid.invoke_with_args(ScalarFunctionArgs { + args: vec![], + number_rows: 1024, + return_type: &DataType::Utf8, + })) + }) }); } diff --git a/datafusion/functions/src/string/ascii.rs b/datafusion/functions/src/string/ascii.rs index 3832ad2a341d..006492a0e07a 100644 --- a/datafusion/functions/src/string/ascii.rs +++ b/datafusion/functions/src/string/ascii.rs @@ -22,7 +22,7 @@ use arrow::error::ArrowError; use datafusion_common::types::logical_string; use datafusion_common::{internal_err, Result}; use datafusion_expr::{ColumnarValue, Documentation, TypeSignatureClass}; -use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; use datafusion_expr_common::signature::Coercion; use datafusion_macros::user_doc; use std::any::Any; @@ -92,12 +92,8 @@ impl ScalarUDFImpl for AsciiFunc { Ok(Int32) } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - make_scalar_function(ascii, vec![])(args) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(ascii, vec![])(&args.args) } fn documentation(&self) -> Option<&Documentation> { diff --git a/datafusion/functions/src/string/bit_length.rs b/datafusion/functions/src/string/bit_length.rs index f7e9fce960fe..2a782c59963e 100644 --- a/datafusion/functions/src/string/bit_length.rs +++ b/datafusion/functions/src/string/bit_length.rs @@ -22,7 +22,7 @@ use std::any::Any; use crate::utils::utf8_to_int_type; use datafusion_common::{utils::take_function_args, Result, ScalarValue}; use datafusion_expr::{ColumnarValue, Documentation, Volatility}; -use datafusion_expr::{ScalarUDFImpl, Signature}; +use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature}; use datafusion_macros::user_doc; #[user_doc( @@ -77,12 +77,8 @@ impl ScalarUDFImpl for BitLengthFunc { utf8_to_int_type(&arg_types[0], "bit_length") } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - let [array] = take_function_args(self.name(), args)?; + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let [array] = take_function_args(self.name(), &args.args)?; match array { ColumnarValue::Array(v) => Ok(ColumnarValue::Array(bit_length(v.as_ref())?)), diff --git a/datafusion/functions/src/string/btrim.rs b/datafusion/functions/src/string/btrim.rs index 05a2f646e969..89bffa25698e 100644 --- a/datafusion/functions/src/string/btrim.rs +++ b/datafusion/functions/src/string/btrim.rs @@ -22,7 +22,8 @@ use arrow::datatypes::DataType; use datafusion_common::{exec_err, Result}; use datafusion_expr::function::Hint; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignature, Volatility, + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + TypeSignature, Volatility, }; use datafusion_macros::user_doc; use std::any::Any; @@ -101,20 +102,16 @@ impl ScalarUDFImpl for BTrimFunc { } } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - match args[0].data_type() { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + match args.args[0].data_type() { DataType::Utf8 | DataType::Utf8View => make_scalar_function( btrim::, vec![Hint::Pad, Hint::AcceptsSingular], - )(args), + )(&args.args), DataType::LargeUtf8 => make_scalar_function( btrim::, vec![Hint::Pad, Hint::AcceptsSingular], - )(args), + )(&args.args), other => exec_err!( "Unsupported data type {other:?} for function btrim,\ expected Utf8, LargeUtf8 or Utf8View." diff --git a/datafusion/functions/src/string/chr.rs b/datafusion/functions/src/string/chr.rs index 3530e3f22c0f..58aa7ede74c4 100644 --- a/datafusion/functions/src/string/chr.rs +++ b/datafusion/functions/src/string/chr.rs @@ -28,7 +28,7 @@ use crate::utils::make_scalar_function; use datafusion_common::cast::as_int64_array; use datafusion_common::{exec_err, Result}; use datafusion_expr::{ColumnarValue, Documentation, Volatility}; -use datafusion_expr::{ScalarUDFImpl, Signature}; +use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature}; use datafusion_macros::user_doc; /// Returns the character with the given code. chr(0) is disallowed because text data types cannot store that character. @@ -111,12 +111,8 @@ impl ScalarUDFImpl for ChrFunc { Ok(Utf8) } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - make_scalar_function(chr, vec![])(args) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(chr, vec![])(&args.args) } fn documentation(&self) -> Option<&Documentation> { diff --git a/datafusion/functions/src/string/concat.rs b/datafusion/functions/src/string/concat.rs index 9ce732efa0c7..c47d08d579e4 100644 --- a/datafusion/functions/src/string/concat.rs +++ b/datafusion/functions/src/string/concat.rs @@ -30,7 +30,7 @@ use datafusion_common::{internal_err, plan_err, Result, ScalarValue}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::{lit, ColumnarValue, Documentation, Expr, Volatility}; -use datafusion_expr::{ScalarUDFImpl, Signature}; +use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature}; use datafusion_macros::user_doc; #[user_doc( @@ -105,11 +105,9 @@ impl ScalarUDFImpl for ConcatFunc { /// Concatenates the text representations of all the arguments. NULL arguments are ignored. /// concat('abcde', 2, NULL, 22) = 'abcde222' - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let ScalarFunctionArgs { args, .. } = args; + let mut return_datatype = DataType::Utf8; args.iter().for_each(|col| { if col.data_type() == DataType::Utf8View { @@ -169,7 +167,7 @@ impl ScalarUDFImpl for ConcatFunc { let mut data_size = 0; let mut columns = Vec::with_capacity(args.len()); - for arg in args { + for arg in &args { match arg { ColumnarValue::Scalar(ScalarValue::Utf8(maybe_value)) | ColumnarValue::Scalar(ScalarValue::LargeUtf8(maybe_value)) @@ -470,10 +468,14 @@ mod tests { None, Some("b"), ]))); - let args = &[c0, c1, c2, c3, c4]; - #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch - let result = ConcatFunc::new().invoke_batch(args, 3)?; + let args = ScalarFunctionArgs { + args: vec![c0, c1, c2, c3, c4], + number_rows: 3, + return_type: &Utf8, + }; + + let result = ConcatFunc::new().invoke_with_args(args)?; let expected = Arc::new(StringViewArray::from(vec!["foo,x,a", "bar,,", "baz,z,b"])) as ArrayRef; diff --git a/datafusion/functions/src/string/concat_ws.rs b/datafusion/functions/src/string/concat_ws.rs index 026d167cccd5..c2bad206db15 100644 --- a/datafusion/functions/src/string/concat_ws.rs +++ b/datafusion/functions/src/string/concat_ws.rs @@ -30,7 +30,7 @@ use datafusion_common::{exec_err, internal_err, plan_err, Result, ScalarValue}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::{lit, ColumnarValue, Documentation, Expr, Volatility}; -use datafusion_expr::{ScalarUDFImpl, Signature}; +use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature}; use datafusion_macros::user_doc; #[user_doc( @@ -102,11 +102,9 @@ impl ScalarUDFImpl for ConcatWsFunc { /// Concatenates all but the first argument, with separators. The first argument is used as the separator string, and should not be NULL. Other NULL arguments are ignored. /// concat_ws(',', 'abcde', 2, NULL, 22) = 'abcde,2,22' - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let ScalarFunctionArgs { args, .. } = args; + // do not accept 0 arguments. if args.len() < 2 { return exec_err!( @@ -411,7 +409,7 @@ mod tests { use crate::string::concat_ws::ConcatWsFunc; use datafusion_common::Result; use datafusion_common::ScalarValue; - use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; use crate::utils::test::test_function; @@ -482,10 +480,14 @@ mod tests { None, Some("z"), ]))); - let args = &[c0, c1, c2]; - #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch - let result = ConcatWsFunc::new().invoke_batch(args, 3)?; + let args = ScalarFunctionArgs { + args: vec![c0, c1, c2], + number_rows: 3, + return_type: &Utf8, + }; + + let result = ConcatWsFunc::new().invoke_with_args(args)?; let expected = Arc::new(StringArray::from(vec!["foo,x", "bar", "baz,z"])) as ArrayRef; match &result { @@ -508,10 +510,14 @@ mod tests { Some("y"), Some("z"), ]))); - let args = &[c0, c1, c2]; - #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch - let result = ConcatWsFunc::new().invoke_batch(args, 3)?; + let args = ScalarFunctionArgs { + args: vec![c0, c1, c2], + number_rows: 3, + return_type: &Utf8, + }; + + let result = ConcatWsFunc::new().invoke_with_args(args)?; let expected = Arc::new(StringArray::from(vec![Some("foo,x"), None, Some("baz+z")])) as ArrayRef; diff --git a/datafusion/functions/src/string/contains.rs b/datafusion/functions/src/string/contains.rs index 36871f0c3282..77774cdb5e1d 100644 --- a/datafusion/functions/src/string/contains.rs +++ b/datafusion/functions/src/string/contains.rs @@ -24,7 +24,8 @@ use datafusion_common::exec_err; use datafusion_common::DataFusionError; use datafusion_common::Result; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, }; use datafusion_macros::user_doc; use std::any::Any; @@ -81,12 +82,8 @@ impl ScalarUDFImpl for ContainsFunc { Ok(Boolean) } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - make_scalar_function(contains, vec![])(args) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(contains, vec![])(&args.args) } fn documentation(&self) -> Option<&Documentation> { @@ -125,8 +122,9 @@ pub fn contains(args: &[ArrayRef]) -> Result { mod test { use super::ContainsFunc; use arrow::array::{BooleanArray, StringArray}; + use arrow::datatypes::DataType; use datafusion_common::ScalarValue; - use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; use std::sync::Arc; #[test] @@ -137,8 +135,14 @@ mod test { Some("yyy?()"), ]))); let scalar = ColumnarValue::Scalar(ScalarValue::Utf8(Some("x?(".to_string()))); - #[allow(deprecated)] // TODO migrate UDF to invoke - let actual = udf.invoke_batch(&[array, scalar], 2).unwrap(); + + let args = ScalarFunctionArgs { + args: vec![array, scalar], + number_rows: 2, + return_type: &DataType::Boolean, + }; + + let actual = udf.invoke_with_args(args).unwrap(); let expect = ColumnarValue::Array(Arc::new(BooleanArray::from(vec![ Some(true), Some(false), diff --git a/datafusion/functions/src/string/ends_with.rs b/datafusion/functions/src/string/ends_with.rs index 0a77ec9ebd2c..5cca79de14ff 100644 --- a/datafusion/functions/src/string/ends_with.rs +++ b/datafusion/functions/src/string/ends_with.rs @@ -24,7 +24,7 @@ use arrow::datatypes::DataType; use crate::utils::make_scalar_function; use datafusion_common::{internal_err, Result}; use datafusion_expr::{ColumnarValue, Documentation, Volatility}; -use datafusion_expr::{ScalarUDFImpl, Signature}; +use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature}; use datafusion_macros::user_doc; #[user_doc( @@ -84,14 +84,10 @@ impl ScalarUDFImpl for EndsWithFunc { Ok(DataType::Boolean) } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - match args[0].data_type() { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + match args.args[0].data_type() { DataType::Utf8View | DataType::Utf8 | DataType::LargeUtf8 => { - make_scalar_function(ends_with, vec![])(args) + make_scalar_function(ends_with, vec![])(&args.args) } other => { internal_err!("Unsupported data type {other:?} for function ends_with. Expected Utf8, LargeUtf8 or Utf8View")? diff --git a/datafusion/functions/src/string/levenshtein.rs b/datafusion/functions/src/string/levenshtein.rs index c2e5dc52f82f..a19fcc5b476c 100644 --- a/datafusion/functions/src/string/levenshtein.rs +++ b/datafusion/functions/src/string/levenshtein.rs @@ -26,7 +26,7 @@ use datafusion_common::cast::{as_generic_string_array, as_string_view_array}; use datafusion_common::utils::datafusion_strsim; use datafusion_common::{exec_err, utils::take_function_args, Result}; use datafusion_expr::{ColumnarValue, Documentation}; -use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; #[user_doc( @@ -86,16 +86,14 @@ impl ScalarUDFImpl for LevenshteinFunc { utf8_to_int_type(&arg_types[0], "levenshtein") } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - match args[0].data_type() { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + match args.args[0].data_type() { DataType::Utf8View | DataType::Utf8 => { - make_scalar_function(levenshtein::, vec![])(args) + make_scalar_function(levenshtein::, vec![])(&args.args) + } + DataType::LargeUtf8 => { + make_scalar_function(levenshtein::, vec![])(&args.args) } - DataType::LargeUtf8 => make_scalar_function(levenshtein::, vec![])(args), other => { exec_err!("Unsupported data type {other:?} for function levenshtein") } diff --git a/datafusion/functions/src/string/lower.rs b/datafusion/functions/src/string/lower.rs index e90c3804b1ee..375717e23d6d 100644 --- a/datafusion/functions/src/string/lower.rs +++ b/datafusion/functions/src/string/lower.rs @@ -22,7 +22,7 @@ use crate::string::common::to_lower; use crate::utils::utf8_to_str_type; use datafusion_common::Result; use datafusion_expr::{ColumnarValue, Documentation}; -use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; #[user_doc( @@ -77,12 +77,8 @@ impl ScalarUDFImpl for LowerFunc { utf8_to_str_type(&arg_types[0], "lower") } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - to_lower(args, "lower") + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + to_lower(&args.args, "lower") } fn documentation(&self) -> Option<&Documentation> { @@ -98,10 +94,14 @@ mod tests { fn to_lower(input: ArrayRef, expected: ArrayRef) -> Result<()> { let func = LowerFunc::new(); - let batch_len = input.len(); - let args = vec![ColumnarValue::Array(input)]; - #[allow(deprecated)] // TODO migrate UDF to invoke - let result = match func.invoke_batch(&args, batch_len)? { + + let args = ScalarFunctionArgs { + number_rows: input.len(), + args: vec![ColumnarValue::Array(input)], + return_type: &DataType::Utf8, + }; + + let result = match func.invoke_with_args(args)? { ColumnarValue::Array(result) => result, _ => unreachable!("lower"), }; diff --git a/datafusion/functions/src/string/ltrim.rs b/datafusion/functions/src/string/ltrim.rs index 0bc62ee5000d..75c4ff25b7df 100644 --- a/datafusion/functions/src/string/ltrim.rs +++ b/datafusion/functions/src/string/ltrim.rs @@ -24,7 +24,7 @@ use crate::utils::{make_scalar_function, utf8_to_str_type}; use datafusion_common::{exec_err, Result}; use datafusion_expr::function::Hint; use datafusion_expr::{ColumnarValue, Documentation, TypeSignature, Volatility}; -use datafusion_expr::{ScalarUDFImpl, Signature}; +use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature}; use datafusion_macros::user_doc; /// Returns the longest string with leading characters removed. If the characters are not specified, whitespace is removed. @@ -104,20 +104,16 @@ impl ScalarUDFImpl for LtrimFunc { } } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - match args[0].data_type() { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + match args.args[0].data_type() { DataType::Utf8 | DataType::Utf8View => make_scalar_function( ltrim::, vec![Hint::Pad, Hint::AcceptsSingular], - )(args), + )(&args.args), DataType::LargeUtf8 => make_scalar_function( ltrim::, vec![Hint::Pad, Hint::AcceptsSingular], - )(args), + )(&args.args), other => exec_err!( "Unsupported data type {other:?} for function ltrim,\ expected Utf8, LargeUtf8 or Utf8View." diff --git a/datafusion/functions/src/string/octet_length.rs b/datafusion/functions/src/string/octet_length.rs index 7e0187c0b1be..46175c96cdc6 100644 --- a/datafusion/functions/src/string/octet_length.rs +++ b/datafusion/functions/src/string/octet_length.rs @@ -22,7 +22,7 @@ use std::any::Any; use crate::utils::utf8_to_int_type; use datafusion_common::{utils::take_function_args, Result, ScalarValue}; use datafusion_expr::{ColumnarValue, Documentation, Volatility}; -use datafusion_expr::{ScalarUDFImpl, Signature}; +use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature}; use datafusion_macros::user_doc; #[user_doc( @@ -77,12 +77,8 @@ impl ScalarUDFImpl for OctetLengthFunc { utf8_to_int_type(&arg_types[0], "octet_length") } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - let [array] = take_function_args(self.name(), args)?; + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let [array] = take_function_args(self.name(), &args.args)?; match array { ColumnarValue::Array(v) => Ok(ColumnarValue::Array(length(v.as_ref())?)), diff --git a/datafusion/functions/src/string/overlay.rs b/datafusion/functions/src/string/overlay.rs index 3389da0968f7..0ea5359e9621 100644 --- a/datafusion/functions/src/string/overlay.rs +++ b/datafusion/functions/src/string/overlay.rs @@ -27,7 +27,7 @@ use datafusion_common::cast::{ }; use datafusion_common::{exec_err, Result}; use datafusion_expr::{ColumnarValue, Documentation, TypeSignature, Volatility}; -use datafusion_expr::{ScalarUDFImpl, Signature}; +use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature}; use datafusion_macros::user_doc; #[user_doc( @@ -100,16 +100,14 @@ impl ScalarUDFImpl for OverlayFunc { utf8_to_str_type(&arg_types[0], "overlay") } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - match args[0].data_type() { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + match args.args[0].data_type() { DataType::Utf8View | DataType::Utf8 => { - make_scalar_function(overlay::, vec![])(args) + make_scalar_function(overlay::, vec![])(&args.args) + } + DataType::LargeUtf8 => { + make_scalar_function(overlay::, vec![])(&args.args) } - DataType::LargeUtf8 => make_scalar_function(overlay::, vec![])(args), other => exec_err!("Unsupported data type {other:?} for function overlay"), } } diff --git a/datafusion/functions/src/string/repeat.rs b/datafusion/functions/src/string/repeat.rs index 8fdbc3dd296f..2d36cb8356a0 100644 --- a/datafusion/functions/src/string/repeat.rs +++ b/datafusion/functions/src/string/repeat.rs @@ -29,7 +29,7 @@ use datafusion_common::cast::as_int64_array; use datafusion_common::types::{logical_int64, logical_string, NativeType}; use datafusion_common::{exec_err, DataFusionError, Result}; use datafusion_expr::{ColumnarValue, Documentation, Volatility}; -use datafusion_expr::{ScalarUDFImpl, Signature}; +use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature}; use datafusion_expr_common::signature::{Coercion, TypeSignatureClass}; use datafusion_macros::user_doc; @@ -98,12 +98,8 @@ impl ScalarUDFImpl for RepeatFunc { utf8_to_str_type(&arg_types[0], "repeat") } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - make_scalar_function(repeat, vec![])(args) + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(repeat, vec![])(&args.args) } fn documentation(&self) -> Option<&Documentation> { diff --git a/datafusion/functions/src/string/replace.rs b/datafusion/functions/src/string/replace.rs index 9b6afc546994..a3488b561fd2 100644 --- a/datafusion/functions/src/string/replace.rs +++ b/datafusion/functions/src/string/replace.rs @@ -25,7 +25,7 @@ use crate::utils::{make_scalar_function, utf8_to_str_type}; use datafusion_common::cast::{as_generic_string_array, as_string_view_array}; use datafusion_common::{exec_err, Result}; use datafusion_expr::{ColumnarValue, Documentation, Volatility}; -use datafusion_expr::{ScalarUDFImpl, Signature}; +use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature}; use datafusion_macros::user_doc; #[user_doc( doc_section(label = "String Functions"), @@ -82,15 +82,13 @@ impl ScalarUDFImpl for ReplaceFunc { utf8_to_str_type(&arg_types[0], "replace") } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - match args[0].data_type() { - DataType::Utf8 => make_scalar_function(replace::, vec![])(args), - DataType::LargeUtf8 => make_scalar_function(replace::, vec![])(args), - DataType::Utf8View => make_scalar_function(replace_view, vec![])(args), + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + match args.args[0].data_type() { + DataType::Utf8 => make_scalar_function(replace::, vec![])(&args.args), + DataType::LargeUtf8 => { + make_scalar_function(replace::, vec![])(&args.args) + } + DataType::Utf8View => make_scalar_function(replace_view, vec![])(&args.args), other => { exec_err!("Unsupported data type {other:?} for function replace") } diff --git a/datafusion/functions/src/string/rtrim.rs b/datafusion/functions/src/string/rtrim.rs index 3fb208bb7198..71c4286150e5 100644 --- a/datafusion/functions/src/string/rtrim.rs +++ b/datafusion/functions/src/string/rtrim.rs @@ -24,7 +24,7 @@ use crate::utils::{make_scalar_function, utf8_to_str_type}; use datafusion_common::{exec_err, Result}; use datafusion_expr::function::Hint; use datafusion_expr::{ColumnarValue, Documentation, TypeSignature, Volatility}; -use datafusion_expr::{ScalarUDFImpl, Signature}; +use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature}; use datafusion_macros::user_doc; /// Returns the longest string with trailing characters removed. If the characters are not specified, whitespace is removed. @@ -104,20 +104,16 @@ impl ScalarUDFImpl for RtrimFunc { } } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - match args[0].data_type() { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + match args.args[0].data_type() { DataType::Utf8 | DataType::Utf8View => make_scalar_function( rtrim::, vec![Hint::Pad, Hint::AcceptsSingular], - )(args), + )(&args.args), DataType::LargeUtf8 => make_scalar_function( rtrim::, vec![Hint::Pad, Hint::AcceptsSingular], - )(args), + )(&args.args), other => exec_err!( "Unsupported data type {other:?} for function rtrim,\ expected Utf8, LargeUtf8 or Utf8View." diff --git a/datafusion/functions/src/string/split_part.rs b/datafusion/functions/src/string/split_part.rs index a597e1be5d02..724d9c278cca 100644 --- a/datafusion/functions/src/string/split_part.rs +++ b/datafusion/functions/src/string/split_part.rs @@ -26,7 +26,7 @@ use datafusion_common::cast::as_int64_array; use datafusion_common::ScalarValue; use datafusion_common::{exec_err, DataFusionError, Result}; use datafusion_expr::{ColumnarValue, Documentation, TypeSignature, Volatility}; -use datafusion_expr::{ScalarUDFImpl, Signature}; +use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature}; use datafusion_macros::user_doc; use std::any::Any; use std::sync::Arc; @@ -97,11 +97,9 @@ impl ScalarUDFImpl for SplitPartFunc { utf8_to_str_type(&arg_types[0], "split_part") } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let ScalarFunctionArgs { args, .. } = args; + // First, determine if any of the arguments is an Array let len = args.iter().find_map(|arg| match arg { ColumnarValue::Array(a) => Some(a.len()), diff --git a/datafusion/functions/src/string/starts_with.rs b/datafusion/functions/src/string/starts_with.rs index 74d0fbdc4033..f1344780eb4c 100644 --- a/datafusion/functions/src/string/starts_with.rs +++ b/datafusion/functions/src/string/starts_with.rs @@ -25,7 +25,7 @@ use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use crate::utils::make_scalar_function; use datafusion_common::{internal_err, Result, ScalarValue}; use datafusion_expr::{ColumnarValue, Documentation, Expr, Like}; -use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; /// Returns true if string starts with prefix. @@ -86,14 +86,10 @@ impl ScalarUDFImpl for StartsWithFunc { Ok(DataType::Boolean) } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - match args[0].data_type() { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + match args.args[0].data_type() { DataType::Utf8View | DataType::Utf8 | DataType::LargeUtf8 => { - make_scalar_function(starts_with, vec![])(args) + make_scalar_function(starts_with, vec![])(&args.args) } _ => internal_err!("Unsupported data types for starts_with. Expected Utf8, LargeUtf8 or Utf8View")?, } diff --git a/datafusion/functions/src/string/to_hex.rs b/datafusion/functions/src/string/to_hex.rs index 5c7c92cc34ed..a3a1acfcf1f0 100644 --- a/datafusion/functions/src/string/to_hex.rs +++ b/datafusion/functions/src/string/to_hex.rs @@ -30,7 +30,7 @@ use datafusion_common::Result; use datafusion_common::{exec_err, plan_err}; use datafusion_expr::{ColumnarValue, Documentation}; -use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; /// Converts the number to its equivalent hexadecimal representation. @@ -127,14 +127,14 @@ impl ScalarUDFImpl for ToHexFunc { }) } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - match args[0].data_type() { - DataType::Int32 => make_scalar_function(to_hex::, vec![])(args), - DataType::Int64 => make_scalar_function(to_hex::, vec![])(args), + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + match args.args[0].data_type() { + DataType::Int32 => { + make_scalar_function(to_hex::, vec![])(&args.args) + } + DataType::Int64 => { + make_scalar_function(to_hex::, vec![])(&args.args) + } other => exec_err!("Unsupported data type {other:?} for function to_hex"), } } diff --git a/datafusion/functions/src/string/upper.rs b/datafusion/functions/src/string/upper.rs index 7bab33e68a4d..d27b54d29bc6 100644 --- a/datafusion/functions/src/string/upper.rs +++ b/datafusion/functions/src/string/upper.rs @@ -20,7 +20,7 @@ use crate::utils::utf8_to_str_type; use arrow::datatypes::DataType; use datafusion_common::Result; use datafusion_expr::{ColumnarValue, Documentation}; -use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; use std::any::Any; @@ -76,12 +76,8 @@ impl ScalarUDFImpl for UpperFunc { utf8_to_str_type(&arg_types[0], "upper") } - fn invoke_batch( - &self, - args: &[ColumnarValue], - _number_rows: usize, - ) -> Result { - to_upper(args, "upper") + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + to_upper(&args.args, "upper") } fn documentation(&self) -> Option<&Documentation> { @@ -97,10 +93,14 @@ mod tests { fn to_upper(input: ArrayRef, expected: ArrayRef) -> Result<()> { let func = UpperFunc::new(); - let batch_len = input.len(); - let args = vec![ColumnarValue::Array(input)]; - #[allow(deprecated)] // TODO migrate UDF to invoke - let result = match func.invoke_batch(&args, batch_len)? { + + let args = ScalarFunctionArgs { + number_rows: input.len(), + args: vec![ColumnarValue::Array(input)], + return_type: &DataType::Utf8, + }; + + let result = match func.invoke_with_args(args)? { ColumnarValue::Array(result) => result, _ => unreachable!("upper"), }; diff --git a/datafusion/functions/src/string/uuid.rs b/datafusion/functions/src/string/uuid.rs index 64065c26b7d4..d1f43d548066 100644 --- a/datafusion/functions/src/string/uuid.rs +++ b/datafusion/functions/src/string/uuid.rs @@ -26,7 +26,7 @@ use uuid::Uuid; use datafusion_common::{internal_err, Result}; use datafusion_expr::{ColumnarValue, Documentation, Volatility}; -use datafusion_expr::{ScalarUDFImpl, Signature}; +use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature}; use datafusion_macros::user_doc; #[user_doc( @@ -80,22 +80,20 @@ impl ScalarUDFImpl for UuidFunc { /// Prints random (v4) uuid values per row /// uuid() = 'a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11' - fn invoke_batch( - &self, - args: &[ColumnarValue], - num_rows: usize, - ) -> Result { - if !args.is_empty() { + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + if !args.args.is_empty() { return internal_err!("{} function does not accept arguments", self.name()); } // Generate random u128 values let mut rng = rand::thread_rng(); - let mut randoms = vec![0u128; num_rows]; + let mut randoms = vec![0u128; args.number_rows]; rng.fill(&mut randoms[..]); - let mut builder = - GenericStringBuilder::::with_capacity(num_rows, num_rows * 36); + let mut builder = GenericStringBuilder::::with_capacity( + args.number_rows, + args.number_rows * 36, + ); let mut buffer = [0u8; 36]; for x in &mut randoms {