Skip to content

Commit 6dfbdba

Browse files
authored
chore: Migrate Core Functions to invoke_with_args (#14725)
* migrate version * migrate more * also migrate invoke
1 parent 8f2f537 commit 6dfbdba

File tree

10 files changed

+44
-64
lines changed

10 files changed

+44
-64
lines changed

datafusion/functions/src/core/arrow_cast.rs

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ use std::any::Any;
2929

3030
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
3131
use datafusion_expr::{
32-
ColumnarValue, Documentation, Expr, ReturnInfo, ReturnTypeArgs, ScalarUDFImpl,
33-
Signature, Volatility,
32+
ColumnarValue, Documentation, Expr, ReturnInfo, ReturnTypeArgs, ScalarFunctionArgs,
33+
ScalarUDFImpl, Signature, Volatility,
3434
};
3535
use datafusion_macros::user_doc;
3636

@@ -138,11 +138,7 @@ impl ScalarUDFImpl for ArrowCastFunc {
138138
)
139139
}
140140

141-
fn invoke_batch(
142-
&self,
143-
_args: &[ColumnarValue],
144-
_number_rows: usize,
145-
) -> Result<ColumnarValue> {
141+
fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
146142
internal_err!("arrow_cast should have been simplified to cast")
147143
}
148144

datafusion/functions/src/core/arrowtypeof.rs

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
use arrow::datatypes::DataType;
1919
use datafusion_common::{utils::take_function_args, Result, ScalarValue};
20-
use datafusion_expr::{ColumnarValue, Documentation};
20+
use datafusion_expr::{ColumnarValue, Documentation, ScalarFunctionArgs};
2121
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
2222
use datafusion_macros::user_doc;
2323
use std::any::Any;
@@ -75,12 +75,8 @@ impl ScalarUDFImpl for ArrowTypeOfFunc {
7575
Ok(DataType::Utf8)
7676
}
7777

78-
fn invoke_batch(
79-
&self,
80-
args: &[ColumnarValue],
81-
_number_rows: usize,
82-
) -> Result<ColumnarValue> {
83-
let [arg] = take_function_args(self.name(), args)?;
78+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
79+
let [arg] = take_function_args(self.name(), args.args)?;
8480
let input_data_type = arg.data_type();
8581
Ok(ColumnarValue::Scalar(ScalarValue::from(format!(
8682
"{input_data_type}"

datafusion/functions/src/core/coalesce.rs

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@ use arrow::compute::{and, is_not_null, is_null};
2121
use arrow::datatypes::DataType;
2222
use datafusion_common::{exec_err, internal_err, Result};
2323
use datafusion_expr::binary::try_type_union_resolution;
24-
use datafusion_expr::{ColumnarValue, Documentation, ReturnInfo, ReturnTypeArgs};
24+
use datafusion_expr::{
25+
ColumnarValue, Documentation, ReturnInfo, ReturnTypeArgs, ScalarFunctionArgs,
26+
};
2527
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
2628
use datafusion_macros::user_doc;
2729
use itertools::Itertools;
@@ -93,11 +95,8 @@ impl ScalarUDFImpl for CoalesceFunc {
9395
}
9496

9597
/// coalesce evaluates to the first value which is not NULL
96-
fn invoke_batch(
97-
&self,
98-
args: &[ColumnarValue],
99-
_number_rows: usize,
100-
) -> Result<ColumnarValue> {
98+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
99+
let args = args.args;
101100
// do not accept 0 arguments.
102101
if args.is_empty() {
103102
return exec_err!(

datafusion/functions/src/core/getfield.rs

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@ use datafusion_common::{
2424
exec_err, internal_err, plan_datafusion_err, utils::take_function_args, Result,
2525
ScalarValue,
2626
};
27-
use datafusion_expr::{ColumnarValue, Documentation, Expr, ReturnInfo, ReturnTypeArgs};
27+
use datafusion_expr::{
28+
ColumnarValue, Documentation, Expr, ReturnInfo, ReturnTypeArgs, ScalarFunctionArgs,
29+
};
2830
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
2931
use datafusion_macros::user_doc;
3032
use std::any::Any;
@@ -170,12 +172,8 @@ impl ScalarUDFImpl for GetFieldFunc {
170172
}
171173
}
172174

173-
fn invoke_batch(
174-
&self,
175-
args: &[ColumnarValue],
176-
_number_rows: usize,
177-
) -> Result<ColumnarValue> {
178-
let [base, field_name] = take_function_args(self.name(), args)?;
175+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
176+
let [base, field_name] = take_function_args(self.name(), args.args)?;
179177

180178
if base.data_type().is_null() {
181179
return Ok(ColumnarValue::Scalar(ScalarValue::Null));
@@ -229,7 +227,7 @@ impl ScalarUDFImpl for GetFieldFunc {
229227
}
230228
(DataType::Struct(_), ScalarValue::Utf8(Some(k))) => {
231229
let as_struct_array = as_struct_array(&array)?;
232-
match as_struct_array.column_by_name(k) {
230+
match as_struct_array.column_by_name(&k) {
233231
None => exec_err!("get indexed field {k} not found in struct"),
234232
Some(col) => Ok(ColumnarValue::Array(Arc::clone(col))),
235233
}

datafusion/functions/src/core/greatest.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ use arrow::compute::SortOptions;
2323
use arrow::datatypes::DataType;
2424
use datafusion_common::{internal_err, Result, ScalarValue};
2525
use datafusion_doc::Documentation;
26-
use datafusion_expr::ColumnarValue;
26+
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs};
2727
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
2828
use datafusion_macros::user_doc;
2929
use std::any::Any;
@@ -143,8 +143,8 @@ impl ScalarUDFImpl for GreatestFunc {
143143
Ok(arg_types[0].clone())
144144
}
145145

146-
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
147-
super::greatest_least_utils::execute_conditional::<Self>(args)
146+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
147+
super::greatest_least_utils::execute_conditional::<Self>(&args.args)
148148
}
149149

150150
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {

datafusion/functions/src/core/least.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ use arrow::compute::SortOptions;
2323
use arrow::datatypes::DataType;
2424
use datafusion_common::{internal_err, Result, ScalarValue};
2525
use datafusion_doc::Documentation;
26-
use datafusion_expr::ColumnarValue;
26+
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs};
2727
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
2828
use datafusion_macros::user_doc;
2929
use std::any::Any;
@@ -156,8 +156,8 @@ impl ScalarUDFImpl for LeastFunc {
156156
Ok(arg_types[0].clone())
157157
}
158158

159-
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
160-
super::greatest_least_utils::execute_conditional::<Self>(args)
159+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
160+
super::greatest_least_utils::execute_conditional::<Self>(&args.args)
161161
}
162162

163163
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {

datafusion/functions/src/core/nullif.rs

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
// under the License.
1717

1818
use arrow::datatypes::DataType;
19-
use datafusion_expr::{ColumnarValue, Documentation};
19+
use datafusion_expr::{ColumnarValue, Documentation, ScalarFunctionArgs};
2020

2121
use arrow::compute::kernels::cmp::eq;
2222
use arrow::compute::kernels::nullif::nullif;
@@ -101,12 +101,8 @@ impl ScalarUDFImpl for NullIfFunc {
101101
Ok(arg_types[0].to_owned())
102102
}
103103

104-
fn invoke_batch(
105-
&self,
106-
args: &[ColumnarValue],
107-
_number_rows: usize,
108-
) -> Result<ColumnarValue> {
109-
nullif_func(args)
104+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
105+
nullif_func(&args.args)
110106
}
111107

112108
fn documentation(&self) -> Option<&Documentation> {

datafusion/functions/src/core/nvl.rs

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ use arrow::compute::kernels::zip::zip;
2121
use arrow::datatypes::DataType;
2222
use datafusion_common::{utils::take_function_args, Result};
2323
use datafusion_expr::{
24-
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
24+
ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
25+
Volatility,
2526
};
2627
use datafusion_macros::user_doc;
2728
use std::sync::Arc;
@@ -116,12 +117,8 @@ impl ScalarUDFImpl for NVLFunc {
116117
Ok(arg_types[0].clone())
117118
}
118119

119-
fn invoke_batch(
120-
&self,
121-
args: &[ColumnarValue],
122-
_number_rows: usize,
123-
) -> Result<ColumnarValue> {
124-
nvl_func(args)
120+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
121+
nvl_func(&args.args)
125122
}
126123

127124
fn aliases(&self) -> &[String] {

datafusion/functions/src/core/nvl2.rs

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ use arrow::datatypes::DataType;
2222
use datafusion_common::{internal_err, utils::take_function_args, Result};
2323
use datafusion_expr::{
2424
type_coercion::binary::comparison_coercion, ColumnarValue, Documentation,
25-
ScalarUDFImpl, Signature, Volatility,
25+
ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
2626
};
2727
use datafusion_macros::user_doc;
2828
use std::sync::Arc;
@@ -95,12 +95,8 @@ impl ScalarUDFImpl for NVL2Func {
9595
Ok(arg_types[1].clone())
9696
}
9797

98-
fn invoke_batch(
99-
&self,
100-
args: &[ColumnarValue],
101-
_number_rows: usize,
102-
) -> Result<ColumnarValue> {
103-
nvl2_func(args)
98+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
99+
nvl2_func(&args.args)
104100
}
105101

106102
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {

datafusion/functions/src/core/version.rs

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
use arrow::datatypes::DataType;
2121
use datafusion_common::{utils::take_function_args, Result, ScalarValue};
2222
use datafusion_expr::{
23-
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
23+
ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
24+
Volatility,
2425
};
2526
use datafusion_macros::user_doc;
2627
use std::any::Any;
@@ -75,12 +76,8 @@ impl ScalarUDFImpl for VersionFunc {
7576
Ok(DataType::Utf8)
7677
}
7778

78-
fn invoke_batch(
79-
&self,
80-
args: &[ColumnarValue],
81-
_number_rows: usize,
82-
) -> Result<ColumnarValue> {
83-
let [] = take_function_args(self.name(), args)?;
79+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
80+
let [] = take_function_args(self.name(), args.args)?;
8481
// TODO it would be great to add rust version and arrow version,
8582
// but that requires a `build.rs` script and/or adding a version const to arrow-rs
8683
let version = format!(
@@ -105,8 +102,13 @@ mod test {
105102
#[tokio::test]
106103
async fn test_version_udf() {
107104
let version_udf = ScalarUDF::from(VersionFunc::new());
108-
#[allow(deprecated)] // TODO: migrate to invoke_with_args
109-
let version = version_udf.invoke_batch(&[], 1).unwrap();
105+
let version = version_udf
106+
.invoke_with_args(ScalarFunctionArgs {
107+
args: vec![],
108+
number_rows: 0,
109+
return_type: &DataType::Utf8,
110+
})
111+
.unwrap();
110112

111113
if let ColumnarValue::Scalar(ScalarValue::Utf8(Some(version))) = version {
112114
assert!(version.starts_with("Apache DataFusion"));

0 commit comments

Comments
 (0)