Skip to content

Commit edbd93a

Browse files
joseph-isaacsalamb
andauthored
Add ScalarUDFImpl::invoke_with_args to support passing the return type created for the udf instance (#13290)
* Added support for `ScalarUDFImpl::invoke_with_return_type` where the invoke is passed the return type created for the udf instance * Do not yet deprecate invoke_batch, add docs to invoke_with_args * add ticket reference --------- Co-authored-by: Andrew Lamb <[email protected]>
1 parent e7d9504 commit edbd93a

File tree

13 files changed

+107
-75
lines changed

13 files changed

+107
-75
lines changed

datafusion/expr/src/lib.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ pub use table_source::{TableProviderFilterPushDown, TableSource, TableType};
9292
pub use udaf::{
9393
aggregate_doc_sections, AggregateUDF, AggregateUDFImpl, ReversedUDAF, StatisticsArgs,
9494
};
95-
pub use udf::{scalar_doc_sections, ScalarUDF, ScalarUDFImpl};
95+
pub use udf::{scalar_doc_sections, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl};
9696
pub use udf_docs::{DocSection, Documentation, DocumentationBuilder};
9797
pub use udwf::{window_doc_sections, ReversedUDWF, WindowUDF, WindowUDFImpl};
9898
pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits};

datafusion/expr/src/udf.rs

+57-41
Original file line numberDiff line numberDiff line change
@@ -203,10 +203,7 @@ impl ScalarUDF {
203203
self.inner.simplify(args, info)
204204
}
205205

206-
/// Invoke the function on `args`, returning the appropriate result.
207-
///
208-
/// See [`ScalarUDFImpl::invoke`] for more details.
209-
#[deprecated(since = "42.1.0", note = "Use `invoke_batch` instead")]
206+
#[deprecated(since = "42.1.0", note = "Use `invoke_with_args` instead")]
210207
pub fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
211208
#[allow(deprecated)]
212209
self.inner.invoke(args)
@@ -216,20 +213,27 @@ impl ScalarUDF {
216213
self.inner.is_nullable(args, schema)
217214
}
218215

219-
/// Invoke the function with `args` and number of rows, returning the appropriate result.
220-
///
221-
/// See [`ScalarUDFImpl::invoke_batch`] for more details.
216+
#[deprecated(since = "43.0.0", note = "Use `invoke_with_args` instead")]
222217
pub fn invoke_batch(
223218
&self,
224219
args: &[ColumnarValue],
225220
number_rows: usize,
226221
) -> Result<ColumnarValue> {
222+
#[allow(deprecated)]
227223
self.inner.invoke_batch(args, number_rows)
228224
}
229225

226+
/// Invoke the function on `args`, returning the appropriate result.
227+
///
228+
/// See [`ScalarUDFImpl::invoke_with_args`] for details.
229+
pub fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
230+
self.inner.invoke_with_args(args)
231+
}
232+
230233
/// Invoke the function without `args` but number of rows, returning the appropriate result.
231234
///
232-
/// See [`ScalarUDFImpl::invoke_no_args`] for more details.
235+
/// Note: This method is deprecated and will be removed in future releases.
236+
/// User defined functions should implement [`Self::invoke_with_args`] instead.
233237
#[deprecated(since = "42.1.0", note = "Use `invoke_batch` instead")]
234238
pub fn invoke_no_args(&self, number_rows: usize) -> Result<ColumnarValue> {
235239
#[allow(deprecated)]
@@ -324,26 +328,37 @@ where
324328
}
325329
}
326330

327-
/// Trait for implementing [`ScalarUDF`].
331+
pub struct ScalarFunctionArgs<'a> {
332+
// The evaluated arguments to the function
333+
pub args: &'a [ColumnarValue],
334+
// The number of rows in record batch being evaluated
335+
pub number_rows: usize,
336+
// The return type of the scalar function returned (from `return_type` or `return_type_from_exprs`)
337+
// when creating the physical expression from the logical expression
338+
pub return_type: &'a DataType,
339+
}
340+
341+
/// Trait for implementing user defined scalar functions.
328342
///
329343
/// This trait exposes the full API for implementing user defined functions and
330344
/// can be used to implement any function.
331345
///
332346
/// See [`advanced_udf.rs`] for a full example with complete implementation and
333347
/// [`ScalarUDF`] for other available options.
334348
///
335-
///
336349
/// [`advanced_udf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_udf.rs
350+
///
337351
/// # Basic Example
338352
/// ```
339353
/// # use std::any::Any;
340354
/// # use std::sync::OnceLock;
341355
/// # use arrow::datatypes::DataType;
342356
/// # use datafusion_common::{DataFusionError, plan_err, Result};
343-
/// # use datafusion_expr::{col, ColumnarValue, Documentation, Signature, Volatility};
357+
/// # use datafusion_expr::{col, ColumnarValue, Documentation, ScalarFunctionArgs, Signature, Volatility};
344358
/// # use datafusion_expr::{ScalarUDFImpl, ScalarUDF};
345359
/// # use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH;
346360
///
361+
/// /// This struct for a simple UDF that adds one to an int32
347362
/// #[derive(Debug)]
348363
/// struct AddOne {
349364
/// signature: Signature,
@@ -356,7 +371,7 @@ where
356371
/// }
357372
/// }
358373
/// }
359-
///
374+
///
360375
/// static DOCUMENTATION: OnceLock<Documentation> = OnceLock::new();
361376
///
362377
/// fn get_doc() -> &'static Documentation {
@@ -383,7 +398,9 @@ where
383398
/// Ok(DataType::Int32)
384399
/// }
385400
/// // The actual implementation would add one to the argument
386-
/// fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> { unimplemented!() }
401+
/// fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
402+
/// unimplemented!()
403+
/// }
387404
/// fn documentation(&self) -> Option<&Documentation> {
388405
/// Some(get_doc())
389406
/// }
@@ -479,24 +496,9 @@ pub trait ScalarUDFImpl: Debug + Send + Sync {
479496

480497
/// Invoke the function on `args`, returning the appropriate result
481498
///
482-
/// The function will be invoked passed with the slice of [`ColumnarValue`]
483-
/// (either scalar or array).
484-
///
485-
/// If the function does not take any arguments, please use [invoke_no_args]
486-
/// instead and return [not_impl_err] for this function.
487-
///
488-
///
489-
/// # Performance
490-
///
491-
/// For the best performance, the implementations of `invoke` should handle
492-
/// the common case when one or more of their arguments are constant values
493-
/// (aka [`ColumnarValue::Scalar`]).
494-
///
495-
/// [`ColumnarValue::values_to_arrays`] can be used to convert the arguments
496-
/// to arrays, which will likely be simpler code, but be slower.
497-
///
498-
/// [invoke_no_args]: ScalarUDFImpl::invoke_no_args
499-
#[deprecated(since = "42.1.0", note = "Use `invoke_batch` instead")]
499+
/// Note: This method is deprecated and will be removed in future releases.
500+
/// User defined functions should implement [`Self::invoke_with_args`] instead.
501+
#[deprecated(since = "42.1.0", note = "Use `invoke_with_args` instead")]
500502
fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> {
501503
not_impl_err!(
502504
"Function {} does not implement invoke but called",
@@ -507,17 +509,12 @@ pub trait ScalarUDFImpl: Debug + Send + Sync {
507509
/// Invoke the function with `args` and the number of rows,
508510
/// returning the appropriate result.
509511
///
510-
/// The function will be invoked with the slice of [`ColumnarValue`]
511-
/// (either scalar or array).
512-
///
513-
/// # Performance
512+
/// Note: See notes on [`Self::invoke_with_args`]
514513
///
515-
/// For the best performance, the implementations should handle the common case
516-
/// when one or more of their arguments are constant values (aka
517-
/// [`ColumnarValue::Scalar`]).
514+
/// Note: This method is deprecated and will be removed in future releases.
515+
/// User defined functions should implement [`Self::invoke_with_args`] instead.
518516
///
519-
/// [`ColumnarValue::values_to_arrays`] can be used to convert the arguments
520-
/// to arrays, which will likely be simpler code, but be slower.
517+
/// See <https://github.com/apache/datafusion/issues/13515> for more details.
521518
fn invoke_batch(
522519
&self,
523520
args: &[ColumnarValue],
@@ -537,9 +534,27 @@ pub trait ScalarUDFImpl: Debug + Send + Sync {
537534
}
538535
}
539536

537+
/// Invoke the function returning the appropriate result.
538+
///
539+
/// # Performance
540+
///
541+
/// For the best performance, the implementations should handle the common case
542+
/// when one or more of their arguments are constant values (aka
543+
/// [`ColumnarValue::Scalar`]).
544+
///
545+
/// [`ColumnarValue::values_to_arrays`] can be used to convert the arguments
546+
/// to arrays, which will likely be simpler code, but be slower.
547+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
548+
#[allow(deprecated)]
549+
self.invoke_batch(args.args, args.number_rows)
550+
}
551+
540552
/// Invoke the function without `args`, instead the number of rows are provided,
541553
/// returning the appropriate result.
542-
#[deprecated(since = "42.1.0", note = "Use `invoke_batch` instead")]
554+
///
555+
/// Note: This method is deprecated and will be removed in future releases.
556+
/// User defined functions should implement [`Self::invoke_with_args`] instead.
557+
#[deprecated(since = "42.1.0", note = "Use `invoke_with_args` instead")]
543558
fn invoke_no_args(&self, _number_rows: usize) -> Result<ColumnarValue> {
544559
not_impl_err!(
545560
"Function {} does not implement invoke_no_args but called",
@@ -767,6 +782,7 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl {
767782
args: &[ColumnarValue],
768783
number_rows: usize,
769784
) -> Result<ColumnarValue> {
785+
#[allow(deprecated)]
770786
self.inner.invoke_batch(args, number_rows)
771787
}
772788

datafusion/functions/benches/random.rs

+2
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ fn criterion_benchmark(c: &mut Criterion) {
2929
c.bench_function("random_1M_rows_batch_8192", |b| {
3030
b.iter(|| {
3131
for _ in 0..iterations {
32+
#[allow(deprecated)] // TODO: migrate to invoke_with_args
3233
black_box(random_func.invoke_batch(&[], 8192).unwrap());
3334
}
3435
})
@@ -39,6 +40,7 @@ fn criterion_benchmark(c: &mut Criterion) {
3940
c.bench_function("random_1M_rows_batch_128", |b| {
4041
b.iter(|| {
4142
for _ in 0..iterations_128 {
43+
#[allow(deprecated)] // TODO: migrate to invoke_with_args
4244
black_box(random_func.invoke_batch(&[], 128).unwrap());
4345
}
4446
})

datafusion/functions/src/core/version.rs

+1
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ mod test {
121121
#[tokio::test]
122122
async fn test_version_udf() {
123123
let version_udf = ScalarUDF::from(VersionFunc::new());
124+
#[allow(deprecated)] // TODO: migrate to invoke_with_args
124125
let version = version_udf.invoke_batch(&[], 1).unwrap();
125126

126127
if let ColumnarValue::Scalar(ScalarValue::Utf8(Some(version))) = version {

datafusion/functions/src/datetime/to_local_time.rs

+7-2
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,7 @@ mod tests {
431431
use arrow::datatypes::{DataType, TimeUnit};
432432
use chrono::NaiveDateTime;
433433
use datafusion_common::ScalarValue;
434-
use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
434+
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl};
435435

436436
use super::{adjust_to_local_time, ToLocalTimeFunc};
437437

@@ -558,7 +558,11 @@ mod tests {
558558

559559
fn test_to_local_time_helper(input: ScalarValue, expected: ScalarValue) {
560560
let res = ToLocalTimeFunc::new()
561-
.invoke_batch(&[ColumnarValue::Scalar(input)], 1)
561+
.invoke_with_args(ScalarFunctionArgs {
562+
args: &[ColumnarValue::Scalar(input)],
563+
number_rows: 1,
564+
return_type: &expected.data_type(),
565+
})
562566
.unwrap();
563567
match res {
564568
ColumnarValue::Scalar(res) => {
@@ -617,6 +621,7 @@ mod tests {
617621
.map(|s| Some(string_to_timestamp_nanos(s).unwrap()))
618622
.collect::<TimestampNanosecondArray>();
619623
let batch_size = input.len();
624+
#[allow(deprecated)] // TODO: migrate to invoke_with_args
620625
let result = ToLocalTimeFunc::new()
621626
.invoke_batch(&[ColumnarValue::Array(Arc::new(input))], batch_size)
622627
.unwrap();

datafusion/functions/src/datetime/to_timestamp.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -1008,7 +1008,7 @@ mod tests {
10081008
for array in arrays {
10091009
let rt = udf.return_type(&[array.data_type()]).unwrap();
10101010
assert!(matches!(rt, Timestamp(_, Some(_))));
1011-
1011+
#[allow(deprecated)] // TODO: migrate to invoke_with_args
10121012
let res = udf
10131013
.invoke_batch(&[array.clone()], 1)
10141014
.expect("that to_timestamp parsed values without error");
@@ -1051,7 +1051,7 @@ mod tests {
10511051
for array in arrays {
10521052
let rt = udf.return_type(&[array.data_type()]).unwrap();
10531053
assert!(matches!(rt, Timestamp(_, None)));
1054-
1054+
#[allow(deprecated)] // TODO: migrate to invoke_with_args
10551055
let res = udf
10561056
.invoke_batch(&[array.clone()], 1)
10571057
.expect("that to_timestamp parsed values without error");

datafusion/functions/src/datetime/to_unixtime.rs

+1
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ impl ScalarUDFImpl for ToUnixtimeFunc {
8383
DataType::Date64 | DataType::Date32 | DataType::Timestamp(_, None) => args[0]
8484
.cast_to(&DataType::Timestamp(TimeUnit::Second, None), None)?
8585
.cast_to(&DataType::Int64, None),
86+
#[allow(deprecated)] // TODO: migrate to invoke_with_args
8687
DataType::Utf8 => ToTimestampSecondsFunc::new()
8788
.invoke_batch(args, batch_size)?
8889
.cast_to(&DataType::Int64, None),

datafusion/functions/src/math/log.rs

+10-10
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ mod tests {
277277
]))), // num
278278
ColumnarValue::Array(Arc::new(Int64Array::from(vec![5, 10, 15, 20]))),
279279
];
280-
280+
#[allow(deprecated)] // TODO: migrate to invoke_with_args
281281
let _ = LogFunc::new().invoke_batch(&args, 4);
282282
}
283283

@@ -286,7 +286,7 @@ mod tests {
286286
let args = [
287287
ColumnarValue::Array(Arc::new(Int64Array::from(vec![10]))), // num
288288
];
289-
289+
#[allow(deprecated)] // TODO: migrate to invoke_with_args
290290
let result = LogFunc::new().invoke_batch(&args, 1);
291291
result.expect_err("expected error");
292292
}
@@ -296,7 +296,7 @@ mod tests {
296296
let args = [
297297
ColumnarValue::Scalar(ScalarValue::Float32(Some(10.0))), // num
298298
];
299-
299+
#[allow(deprecated)] // TODO: migrate to invoke_with_args
300300
let result = LogFunc::new()
301301
.invoke_batch(&args, 1)
302302
.expect("failed to initialize function log");
@@ -320,7 +320,7 @@ mod tests {
320320
let args = [
321321
ColumnarValue::Scalar(ScalarValue::Float64(Some(10.0))), // num
322322
];
323-
323+
#[allow(deprecated)] // TODO: migrate to invoke_with_args
324324
let result = LogFunc::new()
325325
.invoke_batch(&args, 1)
326326
.expect("failed to initialize function log");
@@ -345,7 +345,7 @@ mod tests {
345345
ColumnarValue::Scalar(ScalarValue::Float32(Some(2.0))), // num
346346
ColumnarValue::Scalar(ScalarValue::Float32(Some(32.0))), // num
347347
];
348-
348+
#[allow(deprecated)] // TODO: migrate to invoke_with_args
349349
let result = LogFunc::new()
350350
.invoke_batch(&args, 1)
351351
.expect("failed to initialize function log");
@@ -370,7 +370,7 @@ mod tests {
370370
ColumnarValue::Scalar(ScalarValue::Float64(Some(2.0))), // num
371371
ColumnarValue::Scalar(ScalarValue::Float64(Some(64.0))), // num
372372
];
373-
373+
#[allow(deprecated)] // TODO: migrate to invoke_with_args
374374
let result = LogFunc::new()
375375
.invoke_batch(&args, 1)
376376
.expect("failed to initialize function log");
@@ -396,7 +396,7 @@ mod tests {
396396
10.0, 100.0, 1000.0, 10000.0,
397397
]))), // num
398398
];
399-
399+
#[allow(deprecated)] // TODO: migrate to invoke_with_args
400400
let result = LogFunc::new()
401401
.invoke_batch(&args, 4)
402402
.expect("failed to initialize function log");
@@ -425,7 +425,7 @@ mod tests {
425425
10.0, 100.0, 1000.0, 10000.0,
426426
]))), // num
427427
];
428-
428+
#[allow(deprecated)] // TODO: migrate to invoke_with_args
429429
let result = LogFunc::new()
430430
.invoke_batch(&args, 4)
431431
.expect("failed to initialize function log");
@@ -455,7 +455,7 @@ mod tests {
455455
8.0, 4.0, 81.0, 625.0,
456456
]))), // num
457457
];
458-
458+
#[allow(deprecated)] // TODO: migrate to invoke_with_args
459459
let result = LogFunc::new()
460460
.invoke_batch(&args, 4)
461461
.expect("failed to initialize function log");
@@ -485,7 +485,7 @@ mod tests {
485485
8.0, 4.0, 81.0, 625.0,
486486
]))), // num
487487
];
488-
488+
#[allow(deprecated)] // TODO: migrate to invoke_with_args
489489
let result = LogFunc::new()
490490
.invoke_batch(&args, 4)
491491
.expect("failed to initialize function log");

datafusion/functions/src/math/power.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ mod tests {
205205
ColumnarValue::Array(Arc::new(Float64Array::from(vec![2.0, 2.0, 3.0, 5.0]))), // base
206206
ColumnarValue::Array(Arc::new(Float64Array::from(vec![3.0, 2.0, 4.0, 4.0]))), // exponent
207207
];
208-
208+
#[allow(deprecated)] // TODO: migrate to invoke_with_args
209209
let result = PowerFunc::new()
210210
.invoke_batch(&args, 4)
211211
.expect("failed to initialize function power");
@@ -232,7 +232,7 @@ mod tests {
232232
ColumnarValue::Array(Arc::new(Int64Array::from(vec![2, 2, 3, 5]))), // base
233233
ColumnarValue::Array(Arc::new(Int64Array::from(vec![3, 2, 4, 4]))), // exponent
234234
];
235-
235+
#[allow(deprecated)] // TODO: migrate to invoke_with_args
236236
let result = PowerFunc::new()
237237
.invoke_batch(&args, 4)
238238
.expect("failed to initialize function power");

0 commit comments

Comments
 (0)