Skip to content

ScalarUDFImpl invoke improvements #13507

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion datafusion/expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ pub use table_source::{TableProviderFilterPushDown, TableSource, TableType};
pub use udaf::{
aggregate_doc_sections, AggregateUDF, AggregateUDFImpl, ReversedUDAF, StatisticsArgs,
};
pub use udf::{scalar_doc_sections, ScalarUDF, ScalarUDFImpl};
pub use udf::{scalar_doc_sections, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl};
pub use udf_docs::{DocSection, Documentation, DocumentationBuilder};
pub use udwf::{window_doc_sections, ReversedUDWF, WindowUDF, WindowUDFImpl};
pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits};
Expand Down
48 changes: 41 additions & 7 deletions datafusion/expr/src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,6 @@ impl ScalarUDF {
self.inner.simplify(args, info)
}

/// Invoke the function on `args`, returning the appropriate result.
///
/// See [`ScalarUDFImpl::invoke`] for more details.
#[deprecated(since = "42.1.0", note = "Use `invoke_batch` instead")]
pub fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
#[allow(deprecated)]
Expand All @@ -216,17 +213,23 @@ impl ScalarUDF {
self.inner.is_nullable(args, schema)
}

/// Invoke the function with `args` and number of rows, returning the appropriate result.
///
/// See [`ScalarUDFImpl::invoke_batch`] for more details.
#[deprecated(since = "43.0.0", note = "Use `invoke_batch` instead")]
pub fn invoke_batch(
&self,
args: &[ColumnarValue],
number_rows: usize,
) -> Result<ColumnarValue> {
#[allow(deprecated)]
self.inner.invoke_batch(args, number_rows)
}

/// Invoke the function on `args`, returning the appropriate result.
///
/// See [`ScalarUDFImpl::invoke_with_args`] for more details.
pub fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
self.inner.invoke_with_args(args)
}

/// Invoke the function without `args` but number of rows, returning the appropriate result.
///
/// See [`ScalarUDFImpl::invoke_no_args`] for more details.
Expand Down Expand Up @@ -324,6 +327,18 @@ where
}
}

/// Arguments passed to [`ScalarUDFImpl::invoke_with_args`] when invoking a
/// scalar function.
pub struct ScalarFunctionArgs<'a> {
/// The evaluated arguments to the function
pub args: Vec<ColumnarValue>,
/// The number of rows in record batch being evaluated
pub number_rows: usize,
/// The return type of the scalar function returned (from `return_type` or `return_type_from_exprs`)
/// when creating the physical expression from the logical expression
pub return_type: &'a DataType,
}

/// Trait for implementing [`ScalarUDF`].
///
/// This trait exposes the full API for implementing user defined functions and
Expand Down Expand Up @@ -356,7 +371,7 @@ where
/// }
/// }
/// }
///
///
/// static DOCUMENTATION: OnceLock<Documentation> = OnceLock::new();
///
/// fn get_doc() -> &'static Documentation {
Expand Down Expand Up @@ -518,6 +533,7 @@ pub trait ScalarUDFImpl: Debug + Send + Sync {
///
/// [`ColumnarValue::values_to_arrays`] can be used to convert the arguments
/// to arrays, which will likely be simpler code, but be slower.
#[deprecated(since = "43.0.0", note = "Use `invoke_with_args` instead")]
fn invoke_batch(
&self,
args: &[ColumnarValue],
Expand All @@ -537,6 +553,23 @@ pub trait ScalarUDFImpl: Debug + Send + Sync {
}
}

/// Invoke the function returning the appropriate result.
///
/// The function will be invoked with a struct `ScalarFunctionArgs`
///
/// # Performance Notes
///
/// For the best performance, the implementations should handle the common case
/// when one or more of their arguments are constant values (aka
/// [`ColumnarValue::Scalar`]).
///
/// [`ColumnarValue::values_to_arrays`] can be used to convert the arguments
/// to arrays, which will likely be simpler code, but be slower.
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
#[allow(deprecated)]
self.invoke_batch(&args.args, args.number_rows)
}

/// Invoke the function without `args`, instead the number of rows are provided,
/// returning the appropriate result.
#[deprecated(since = "42.1.0", note = "Use `invoke_batch` instead")]
Expand Down Expand Up @@ -767,6 +800,7 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl {
args: &[ColumnarValue],
number_rows: usize,
) -> Result<ColumnarValue> {
#[allow(deprecated)]
self.inner.invoke_batch(args, number_rows)
}

Expand Down
2 changes: 2 additions & 0 deletions datafusion/functions/benches/random.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ fn criterion_benchmark(c: &mut Criterion) {
c.bench_function("random_1M_rows_batch_8192", |b| {
b.iter(|| {
for _ in 0..iterations {
#[allow(deprecated)] // TODO: migrate to invoke_with_args
black_box(random_func.invoke_batch(&[], 8192).unwrap());
}
})
Expand All @@ -39,6 +40,7 @@ fn criterion_benchmark(c: &mut Criterion) {
c.bench_function("random_1M_rows_batch_128", |b| {
b.iter(|| {
for _ in 0..iterations_128 {
#[allow(deprecated)] // TODO: migrate to invoke_with_args
black_box(random_func.invoke_batch(&[], 128).unwrap());
}
})
Expand Down
1 change: 1 addition & 0 deletions datafusion/functions/src/core/version.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ mod test {
#[tokio::test]
async fn test_version_udf() {
let version_udf = ScalarUDF::from(VersionFunc::new());
#[allow(deprecated)] // TODO: migrate to invoke_with_args
let version = version_udf.invoke_batch(&[], 1).unwrap();

if let ColumnarValue::Scalar(ScalarValue::Utf8(Some(version))) = version {
Expand Down
9 changes: 7 additions & 2 deletions datafusion/functions/src/datetime/to_local_time.rs
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ mod tests {
use arrow::datatypes::{DataType, TimeUnit};
use chrono::NaiveDateTime;
use datafusion_common::ScalarValue;
use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl};

use super::{adjust_to_local_time, ToLocalTimeFunc};

Expand Down Expand Up @@ -558,7 +558,11 @@ mod tests {

fn test_to_local_time_helper(input: ScalarValue, expected: ScalarValue) {
let res = ToLocalTimeFunc::new()
.invoke_batch(&[ColumnarValue::Scalar(input)], 1)
.invoke_with_args(ScalarFunctionArgs {
args: &[ColumnarValue::Scalar(input)],
number_rows: 1,
return_type: &expected.data_type(),
})
.unwrap();
match res {
ColumnarValue::Scalar(res) => {
Expand Down Expand Up @@ -617,6 +621,7 @@ mod tests {
.map(|s| Some(string_to_timestamp_nanos(s).unwrap()))
.collect::<TimestampNanosecondArray>();
let batch_size = input.len();
#[allow(deprecated)] // TODO: migrate to invoke_with_args
let result = ToLocalTimeFunc::new()
.invoke_batch(&[ColumnarValue::Array(Arc::new(input))], batch_size)
.unwrap();
Expand Down
4 changes: 2 additions & 2 deletions datafusion/functions/src/datetime/to_timestamp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1008,7 +1008,7 @@ mod tests {
for array in arrays {
let rt = udf.return_type(&[array.data_type()]).unwrap();
assert!(matches!(rt, Timestamp(_, Some(_))));

#[allow(deprecated)] // TODO: migrate to invoke_with_args
let res = udf
.invoke_batch(&[array.clone()], 1)
.expect("that to_timestamp parsed values without error");
Expand Down Expand Up @@ -1051,7 +1051,7 @@ mod tests {
for array in arrays {
let rt = udf.return_type(&[array.data_type()]).unwrap();
assert!(matches!(rt, Timestamp(_, None)));

#[allow(deprecated)] // TODO: migrate to invoke_with_args
let res = udf
.invoke_batch(&[array.clone()], 1)
.expect("that to_timestamp parsed values without error");
Expand Down
1 change: 1 addition & 0 deletions datafusion/functions/src/datetime/to_unixtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ impl ScalarUDFImpl for ToUnixtimeFunc {
DataType::Date64 | DataType::Date32 | DataType::Timestamp(_, None) => args[0]
.cast_to(&DataType::Timestamp(TimeUnit::Second, None), None)?
.cast_to(&DataType::Int64, None),
#[allow(deprecated)] // TODO: migrate to invoke_with_args
DataType::Utf8 => ToTimestampSecondsFunc::new()
.invoke_batch(args, batch_size)?
.cast_to(&DataType::Int64, None),
Expand Down
20 changes: 10 additions & 10 deletions datafusion/functions/src/math/log.rs
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ mod tests {
]))), // num
ColumnarValue::Array(Arc::new(Int64Array::from(vec![5, 10, 15, 20]))),
];

#[allow(deprecated)] // TODO: migrate to invoke_with_args
let _ = LogFunc::new().invoke_batch(&args, 4);
}

Expand All @@ -286,7 +286,7 @@ mod tests {
let args = [
ColumnarValue::Array(Arc::new(Int64Array::from(vec![10]))), // num
];

#[allow(deprecated)] // TODO: migrate to invoke_with_args
let result = LogFunc::new().invoke_batch(&args, 1);
result.expect_err("expected error");
}
Expand All @@ -296,7 +296,7 @@ mod tests {
let args = [
ColumnarValue::Scalar(ScalarValue::Float32(Some(10.0))), // num
];

#[allow(deprecated)] // TODO: migrate to invoke_with_args
let result = LogFunc::new()
.invoke_batch(&args, 1)
.expect("failed to initialize function log");
Expand All @@ -320,7 +320,7 @@ mod tests {
let args = [
ColumnarValue::Scalar(ScalarValue::Float64(Some(10.0))), // num
];

#[allow(deprecated)] // TODO: migrate to invoke_with_args
let result = LogFunc::new()
.invoke_batch(&args, 1)
.expect("failed to initialize function log");
Expand All @@ -345,7 +345,7 @@ mod tests {
ColumnarValue::Scalar(ScalarValue::Float32(Some(2.0))), // num
ColumnarValue::Scalar(ScalarValue::Float32(Some(32.0))), // num
];

#[allow(deprecated)] // TODO: migrate to invoke_with_args
let result = LogFunc::new()
.invoke_batch(&args, 1)
.expect("failed to initialize function log");
Expand All @@ -370,7 +370,7 @@ mod tests {
ColumnarValue::Scalar(ScalarValue::Float64(Some(2.0))), // num
ColumnarValue::Scalar(ScalarValue::Float64(Some(64.0))), // num
];

#[allow(deprecated)] // TODO: migrate to invoke_with_args
let result = LogFunc::new()
.invoke_batch(&args, 1)
.expect("failed to initialize function log");
Expand All @@ -396,7 +396,7 @@ mod tests {
10.0, 100.0, 1000.0, 10000.0,
]))), // num
];

#[allow(deprecated)] // TODO: migrate to invoke_with_args
let result = LogFunc::new()
.invoke_batch(&args, 4)
.expect("failed to initialize function log");
Expand Down Expand Up @@ -425,7 +425,7 @@ mod tests {
10.0, 100.0, 1000.0, 10000.0,
]))), // num
];

#[allow(deprecated)] // TODO: migrate to invoke_with_args
let result = LogFunc::new()
.invoke_batch(&args, 4)
.expect("failed to initialize function log");
Expand Down Expand Up @@ -455,7 +455,7 @@ mod tests {
8.0, 4.0, 81.0, 625.0,
]))), // num
];

#[allow(deprecated)] // TODO: migrate to invoke_with_args
let result = LogFunc::new()
.invoke_batch(&args, 4)
.expect("failed to initialize function log");
Expand Down Expand Up @@ -485,7 +485,7 @@ mod tests {
8.0, 4.0, 81.0, 625.0,
]))), // num
];

#[allow(deprecated)] // TODO: migrate to invoke_with_args
let result = LogFunc::new()
.invoke_batch(&args, 4)
.expect("failed to initialize function log");
Expand Down
4 changes: 2 additions & 2 deletions datafusion/functions/src/math/power.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ mod tests {
ColumnarValue::Array(Arc::new(Float64Array::from(vec![2.0, 2.0, 3.0, 5.0]))), // base
ColumnarValue::Array(Arc::new(Float64Array::from(vec![3.0, 2.0, 4.0, 4.0]))), // exponent
];

#[allow(deprecated)] // TODO: migrate to invoke_with_args
let result = PowerFunc::new()
.invoke_batch(&args, 4)
.expect("failed to initialize function power");
Expand All @@ -232,7 +232,7 @@ mod tests {
ColumnarValue::Array(Arc::new(Int64Array::from(vec![2, 2, 3, 5]))), // base
ColumnarValue::Array(Arc::new(Int64Array::from(vec![3, 2, 4, 4]))), // exponent
];

#[allow(deprecated)] // TODO: migrate to invoke_with_args
let result = PowerFunc::new()
.invoke_batch(&args, 4)
.expect("failed to initialize function power");
Expand Down
2 changes: 2 additions & 0 deletions datafusion/functions/src/math/signum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ mod test {
f32::NEG_INFINITY,
]));
let batch_size = array.len();
#[allow(deprecated)] // TODO: migrate to invoke_with_args
let result = SignumFunc::new()
.invoke_batch(&[ColumnarValue::Array(array)], batch_size)
.expect("failed to initialize function signum");
Expand Down Expand Up @@ -207,6 +208,7 @@ mod test {
f64::NEG_INFINITY,
]));
let batch_size = array.len();
#[allow(deprecated)] // TODO: migrate to invoke_with_args
let result = SignumFunc::new()
.invoke_batch(&[ColumnarValue::Array(array)], batch_size)
.expect("failed to initialize function signum");
Expand Down
Loading
Loading