Skip to content

Commit 7089c64

Browse files
authored
Minor: Unify downcast_arg method (#13865)
1 parent 4118c43 commit 7089c64

File tree

10 files changed

+55
-46
lines changed

10 files changed

+55
-46
lines changed

datafusion/functions-nested/src/distance.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,24 +17,23 @@
1717

1818
//! [ScalarUDFImpl] definitions for array_distance function.
1919
20-
use crate::utils::{downcast_arg, make_scalar_function};
20+
use crate::utils::make_scalar_function;
2121
use arrow_array::{
2222
Array, ArrayRef, Float64Array, LargeListArray, ListArray, OffsetSizeTrait,
2323
};
2424
use arrow_schema::DataType;
2525
use arrow_schema::DataType::{FixedSizeList, Float64, LargeList, List};
26-
use core::any::type_name;
2726
use datafusion_common::cast::{
2827
as_float32_array, as_float64_array, as_generic_list_array, as_int32_array,
2928
as_int64_array,
3029
};
3130
use datafusion_common::utils::coerced_fixed_size_list_to_list;
32-
use datafusion_common::DataFusionError;
33-
use datafusion_common::{exec_err, Result};
31+
use datafusion_common::{exec_err, internal_datafusion_err, Result};
3432
use datafusion_expr::scalar_doc_sections::DOC_SECTION_ARRAY;
3533
use datafusion_expr::{
3634
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
3735
};
36+
use datafusion_functions::{downcast_arg, downcast_named_arg};
3837
use std::any::Any;
3938
use std::sync::{Arc, OnceLock};
4039

datafusion/functions-nested/src/length.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,19 @@
1717

1818
//! [`ScalarUDFImpl`] definitions for array_length function.
1919
20-
use crate::utils::{downcast_arg, make_scalar_function};
20+
use crate::utils::make_scalar_function;
2121
use arrow_array::{
2222
Array, ArrayRef, Int64Array, LargeListArray, ListArray, OffsetSizeTrait, UInt64Array,
2323
};
2424
use arrow_schema::DataType;
2525
use arrow_schema::DataType::{FixedSizeList, LargeList, List, UInt64};
26-
use core::any::type_name;
2726
use datafusion_common::cast::{as_generic_list_array, as_int64_array};
28-
use datafusion_common::DataFusionError;
29-
use datafusion_common::{exec_err, plan_err, Result};
27+
use datafusion_common::{exec_err, internal_datafusion_err, plan_err, Result};
3028
use datafusion_expr::scalar_doc_sections::DOC_SECTION_ARRAY;
3129
use datafusion_expr::{
3230
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
3331
};
32+
use datafusion_functions::{downcast_arg, downcast_named_arg};
3433
use std::any::Any;
3534
use std::sync::{Arc, OnceLock};
3635

datafusion/functions-nested/src/string.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,13 @@ use arrow::array::{
2626
use arrow::datatypes::{DataType, Field};
2727
use datafusion_expr::TypeSignature;
2828

29-
use datafusion_common::{not_impl_err, plan_err, DataFusionError, Result};
29+
use datafusion_common::{
30+
internal_datafusion_err, not_impl_err, plan_err, DataFusionError, Result,
31+
};
3032

31-
use std::any::{type_name, Any};
33+
use std::any::Any;
3234

33-
use crate::utils::{downcast_arg, make_scalar_function};
35+
use crate::utils::make_scalar_function;
3436
use arrow::compute::cast;
3537
use arrow_array::builder::{ArrayBuilder, LargeStringBuilder, StringViewBuilder};
3638
use arrow_array::cast::AsArray;
@@ -45,6 +47,7 @@ use datafusion_expr::{
4547
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
4648
};
4749
use datafusion_functions::strings::StringArrayType;
50+
use datafusion_functions::{downcast_arg, downcast_named_arg};
4851
use std::sync::{Arc, OnceLock};
4952

5053
macro_rules! call_array_function {

datafusion/functions-nested/src/utils.rs

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -28,23 +28,12 @@ use arrow_array::{
2828
use arrow_buffer::OffsetBuffer;
2929
use arrow_schema::{Field, Fields};
3030
use datafusion_common::cast::{as_large_list_array, as_list_array};
31-
use datafusion_common::{exec_err, internal_err, plan_err, Result, ScalarValue};
31+
use datafusion_common::{
32+
exec_err, internal_datafusion_err, internal_err, plan_err, Result, ScalarValue,
33+
};
3234

33-
use core::any::type_name;
34-
use datafusion_common::DataFusionError;
3535
use datafusion_expr::ColumnarValue;
36-
37-
macro_rules! downcast_arg {
38-
($ARG:expr, $ARRAY_TYPE:ident) => {{
39-
$ARG.as_any().downcast_ref::<$ARRAY_TYPE>().ok_or_else(|| {
40-
DataFusionError::Internal(format!(
41-
"could not cast to {}",
42-
type_name::<$ARRAY_TYPE>()
43-
))
44-
})?
45-
}};
46-
}
47-
pub(crate) use downcast_arg;
36+
use datafusion_functions::{downcast_arg, downcast_named_arg};
4837

4938
pub(crate) fn check_datatypes(name: &str, args: &[&ArrayRef]) -> Result<()> {
5039
let data_type = args[0].data_type();

datafusion/functions/src/macros.rs

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,24 +109,37 @@ macro_rules! make_stub_package {
109109
};
110110
}
111111

112-
/// Downcast an argument to a specific array type, returning an internal error
112+
/// Downcast a named argument to a specific array type, returning an internal error
113113
/// if the cast fails
114114
///
115115
/// $ARG: ArrayRef
116116
/// $NAME: name of the argument (for error messages)
117117
/// $ARRAY_TYPE: the type of array to cast the argument to
118-
macro_rules! downcast_arg {
118+
#[macro_export]
119+
macro_rules! downcast_named_arg {
119120
($ARG:expr, $NAME:expr, $ARRAY_TYPE:ident) => {{
120121
$ARG.as_any().downcast_ref::<$ARRAY_TYPE>().ok_or_else(|| {
121-
DataFusionError::Internal(format!(
122+
internal_datafusion_err!(
122123
"could not cast {} to {}",
123124
$NAME,
124125
std::any::type_name::<$ARRAY_TYPE>()
125-
))
126+
)
126127
})?
127128
}};
128129
}
129130

131+
/// Downcast an argument to a specific array type, returning an internal error
132+
/// if the cast fails
133+
///
134+
/// $ARG: ArrayRef
135+
/// $ARRAY_TYPE: the type of array to cast the argument to
136+
#[macro_export]
137+
macro_rules! downcast_arg {
138+
($ARG:expr, $ARRAY_TYPE:ident) => {{
139+
downcast_named_arg!($ARG, "", $ARRAY_TYPE)
140+
}};
141+
}
142+
130143
/// Macro to create a unary math UDF.
131144
///
132145
/// A unary math function takes an argument of type Float32 or Float64,

datafusion/functions/src/math/abs.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ use arrow::array::{
2626
};
2727
use arrow::datatypes::DataType;
2828
use arrow::error::ArrowError;
29-
use datafusion_common::{exec_err, not_impl_err, DataFusionError, Result};
29+
use datafusion_common::{exec_err, internal_datafusion_err, not_impl_err, Result};
3030
use datafusion_expr::interval_arithmetic::Interval;
3131
use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
3232
use datafusion_expr::{
@@ -39,7 +39,7 @@ type MathArrayFunction = fn(&Vec<ArrayRef>) -> Result<ArrayRef>;
3939
macro_rules! make_abs_function {
4040
($ARRAY_TYPE:ident) => {{
4141
|args: &Vec<ArrayRef>| {
42-
let array = downcast_arg!(&args[0], "abs arg", $ARRAY_TYPE);
42+
let array = downcast_named_arg!(&args[0], "abs arg", $ARRAY_TYPE);
4343
let res: $ARRAY_TYPE = array.unary(|x| x.abs());
4444
Ok(Arc::new(res) as ArrayRef)
4545
}
@@ -49,7 +49,7 @@ macro_rules! make_abs_function {
4949
macro_rules! make_try_abs_function {
5050
($ARRAY_TYPE:ident) => {{
5151
|args: &Vec<ArrayRef>| {
52-
let array = downcast_arg!(&args[0], "abs arg", $ARRAY_TYPE);
52+
let array = downcast_named_arg!(&args[0], "abs arg", $ARRAY_TYPE);
5353
let res: $ARRAY_TYPE = array.try_unary(|x| {
5454
x.checked_abs().ok_or_else(|| {
5555
ArrowError::ComputeError(format!(
@@ -67,7 +67,7 @@ macro_rules! make_try_abs_function {
6767
macro_rules! make_decimal_abs_function {
6868
($ARRAY_TYPE:ident) => {{
6969
|args: &Vec<ArrayRef>| {
70-
let array = downcast_arg!(&args[0], "abs arg", $ARRAY_TYPE);
70+
let array = downcast_named_arg!(&args[0], "abs arg", $ARRAY_TYPE);
7171
let res: $ARRAY_TYPE = array
7272
.unary(|x| x.wrapping_abs())
7373
.with_data_type(args[0].data_type().clone());

datafusion/functions/src/math/factorial.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@ use arrow::datatypes::DataType;
2626
use arrow::datatypes::DataType::Int64;
2727

2828
use crate::utils::make_scalar_function;
29-
use datafusion_common::{arrow_datafusion_err, exec_err, DataFusionError, Result};
29+
use datafusion_common::{
30+
arrow_datafusion_err, exec_err, internal_datafusion_err, DataFusionError, Result,
31+
};
3032
use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH;
3133
use datafusion_expr::{
3234
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
@@ -99,7 +101,7 @@ fn get_factorial_doc() -> &'static Documentation {
99101
fn factorial(args: &[ArrayRef]) -> Result<ArrayRef> {
100102
match args[0].data_type() {
101103
Int64 => {
102-
let arg = downcast_arg!((&args[0]), "value", Int64Array);
104+
let arg = downcast_named_arg!((&args[0]), "value", Int64Array);
103105
Ok(arg
104106
.iter()
105107
.map(|a| match a {

datafusion/functions/src/math/gcd.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@ use arrow::datatypes::DataType;
2525
use arrow::datatypes::DataType::Int64;
2626

2727
use crate::utils::make_scalar_function;
28-
use datafusion_common::{arrow_datafusion_err, exec_err, DataFusionError, Result};
28+
use datafusion_common::{
29+
arrow_datafusion_err, exec_err, internal_datafusion_err, DataFusionError, Result,
30+
};
2931
use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH;
3032
use datafusion_expr::{
3133
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
@@ -100,8 +102,8 @@ fn get_gcd_doc() -> &'static Documentation {
100102
fn gcd(args: &[ArrayRef]) -> Result<ArrayRef> {
101103
match args[0].data_type() {
102104
Int64 => {
103-
let arg1 = downcast_arg!(&args[0], "x", Int64Array);
104-
let arg2 = downcast_arg!(&args[1], "y", Int64Array);
105+
let arg1 = downcast_named_arg!(&args[0], "x", Int64Array);
106+
let arg2 = downcast_named_arg!(&args[1], "y", Int64Array);
105107

106108
Ok(arg1
107109
.iter()

datafusion/functions/src/math/lcm.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@ use arrow::datatypes::DataType;
2323
use arrow::datatypes::DataType::Int64;
2424

2525
use arrow::error::ArrowError;
26-
use datafusion_common::{arrow_datafusion_err, exec_err, DataFusionError, Result};
26+
use datafusion_common::{
27+
arrow_datafusion_err, exec_err, internal_datafusion_err, DataFusionError, Result,
28+
};
2729
use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH;
2830
use datafusion_expr::{
2931
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
@@ -121,8 +123,8 @@ fn lcm(args: &[ArrayRef]) -> Result<ArrayRef> {
121123

122124
match args[0].data_type() {
123125
Int64 => {
124-
let arg1 = downcast_arg!(&args[0], "x", Int64Array);
125-
let arg2 = downcast_arg!(&args[1], "y", Int64Array);
126+
let arg1 = downcast_named_arg!(&args[0], "x", Int64Array);
127+
let arg2 = downcast_named_arg!(&args[1], "y", Int64Array);
126128

127129
Ok(arg1
128130
.iter()

datafusion/functions/src/math/power.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ use super::log::LogFunc;
2424
use arrow::array::{ArrayRef, AsArray, Int64Array};
2525
use arrow::datatypes::{ArrowNativeTypeOp, DataType, Float64Type};
2626
use datafusion_common::{
27-
arrow_datafusion_err, exec_datafusion_err, exec_err, plan_datafusion_err,
28-
DataFusionError, Result, ScalarValue,
27+
arrow_datafusion_err, exec_datafusion_err, exec_err, internal_datafusion_err,
28+
plan_datafusion_err, DataFusionError, Result, ScalarValue,
2929
};
3030
use datafusion_expr::expr::ScalarFunction;
3131
use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH;
@@ -103,8 +103,8 @@ impl ScalarUDFImpl for PowerFunc {
103103
Arc::new(result) as _
104104
}
105105
DataType::Int64 => {
106-
let bases = downcast_arg!(&args[0], "base", Int64Array);
107-
let exponents = downcast_arg!(&args[1], "exponent", Int64Array);
106+
let bases = downcast_named_arg!(&args[0], "base", Int64Array);
107+
let exponents = downcast_named_arg!(&args[1], "exponent", Int64Array);
108108
bases
109109
.iter()
110110
.zip(exponents.iter())

0 commit comments

Comments
 (0)