Skip to content

Commit 85f7a8e

Browse files
authored
Move abs to datafusion_functions (#9313)
* feat: move abs to datafusion_functions * fix proto * fix proto * fix CI vendored code * Fix proto * add support type * fix signature * fix typo * fix test cases * disable a test case * remove old code from math_expressions * feat: add test * fix clippy * use unknown for proto * fix unknown proto
1 parent b55d0ed commit 85f7a8e

File tree

12 files changed

+198
-123
lines changed

12 files changed

+198
-123
lines changed

datafusion/expr/src/built_in_function.rs

-7
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,6 @@ use strum_macros::EnumIter;
4242
#[derive(Debug, Clone, PartialEq, Eq, Hash, EnumIter, Copy)]
4343
pub enum BuiltinScalarFunction {
4444
// math functions
45-
/// abs
46-
Abs,
4745
/// acos
4846
Acos,
4947
/// asin
@@ -364,7 +362,6 @@ impl BuiltinScalarFunction {
364362
pub fn volatility(&self) -> Volatility {
365363
match self {
366364
// Immutable scalar builtins
367-
BuiltinScalarFunction::Abs => Volatility::Immutable,
368365
BuiltinScalarFunction::Acos => Volatility::Immutable,
369366
BuiltinScalarFunction::Asin => Volatility::Immutable,
370367
BuiltinScalarFunction::Atan => Volatility::Immutable,
@@ -868,8 +865,6 @@ impl BuiltinScalarFunction {
868865

869866
BuiltinScalarFunction::ArrowTypeof => Ok(Utf8),
870867

871-
BuiltinScalarFunction::Abs => Ok(input_expr_types[0].clone()),
872-
873868
BuiltinScalarFunction::OverLay => {
874869
utf8_to_str_type(&input_expr_types[0], "overlay")
875870
}
@@ -1338,7 +1333,6 @@ impl BuiltinScalarFunction {
13381333
Signature::uniform(2, vec![Int64], self.volatility())
13391334
}
13401335
BuiltinScalarFunction::ArrowTypeof => Signature::any(1, self.volatility()),
1341-
BuiltinScalarFunction::Abs => Signature::any(1, self.volatility()),
13421336
BuiltinScalarFunction::OverLay => Signature::one_of(
13431337
vec![
13441338
Exact(vec![Utf8, Utf8, Int64, Int64]),
@@ -1444,7 +1438,6 @@ impl BuiltinScalarFunction {
14441438
/// Returns all names that can be used to call this function
14451439
pub fn aliases(&self) -> &'static [&'static str] {
14461440
match self {
1447-
BuiltinScalarFunction::Abs => &["abs"],
14481441
BuiltinScalarFunction::Acos => &["acos"],
14491442
BuiltinScalarFunction::Acosh => &["acosh"],
14501443
BuiltinScalarFunction::Asin => &["asin"],

datafusion/expr/src/expr.rs

-5
Original file line numberDiff line numberDiff line change
@@ -2033,11 +2033,6 @@ mod test {
20332033
.is_volatile()
20342034
.unwrap()
20352035
);
2036-
assert!(
2037-
!ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Abs)
2038-
.is_volatile()
2039-
.unwrap()
2040-
);
20412036

20422037
// UDF
20432038
#[derive(Debug)]

datafusion/expr/src/expr_fn.rs

-2
Original file line numberDiff line numberDiff line change
@@ -557,7 +557,6 @@ nary_scalar_expr!(
557557
trunc,
558558
"truncate toward zero, with optional precision"
559559
);
560-
scalar_expr!(Abs, abs, num, "absolute value");
561560
scalar_expr!(Signum, signum, num, "sign of the argument (-1, 0, +1) ");
562561
scalar_expr!(Exp, exp, num, "exponential");
563562
scalar_expr!(Gcd, gcd, arg_1 arg_2, "greatest common divisor");
@@ -1354,7 +1353,6 @@ mod test {
13541353
test_nary_scalar_expr!(Round, round, input, decimal_places);
13551354
test_nary_scalar_expr!(Trunc, trunc, num);
13561355
test_nary_scalar_expr!(Trunc, trunc, num, precision);
1357-
test_unary_scalar_expr!(Abs, abs);
13581356
test_unary_scalar_expr!(Signum, signum);
13591357
test_unary_scalar_expr!(Exp, exp);
13601358
test_unary_scalar_expr!(Log2, log2);

datafusion/functions/src/math/abs.rs

+177
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! math expressions
19+
20+
use arrow::array::Decimal128Array;
21+
use arrow::array::Decimal256Array;
22+
use arrow::array::Int16Array;
23+
use arrow::array::Int32Array;
24+
use arrow::array::Int64Array;
25+
use arrow::array::Int8Array;
26+
use arrow::datatypes::DataType;
27+
use datafusion_common::not_impl_err;
28+
use datafusion_common::plan_datafusion_err;
29+
use datafusion_common::{internal_err, Result, DataFusionError};
30+
use datafusion_expr::utils;
31+
use datafusion_expr::ColumnarValue;
32+
33+
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
34+
use std::any::Any;
35+
use std::sync::Arc;
36+
use arrow::array::{ArrayRef, Float32Array, Float64Array};
37+
use arrow::error::ArrowError;
38+
39+
type MathArrayFunction = fn(&Vec<ArrayRef>) -> Result<ArrayRef>;
40+
41+
macro_rules! make_abs_function {
42+
($ARRAY_TYPE:ident) => {{
43+
|args: &Vec<ArrayRef>| {
44+
let array = downcast_arg!(&args[0], "abs arg", $ARRAY_TYPE);
45+
let res: $ARRAY_TYPE = array.unary(|x| x.abs());
46+
Ok(Arc::new(res) as ArrayRef)
47+
}
48+
}};
49+
}
50+
51+
macro_rules! make_try_abs_function {
52+
($ARRAY_TYPE:ident) => {{
53+
|args: &Vec<ArrayRef>| {
54+
let array = downcast_arg!(&args[0], "abs arg", $ARRAY_TYPE);
55+
let res: $ARRAY_TYPE = array.try_unary(|x| {
56+
x.checked_abs().ok_or_else(|| {
57+
ArrowError::ComputeError(format!(
58+
"{} overflow on abs({})",
59+
stringify!($ARRAY_TYPE),
60+
x
61+
))
62+
})
63+
})?;
64+
Ok(Arc::new(res) as ArrayRef)
65+
}
66+
}};
67+
}
68+
69+
macro_rules! make_decimal_abs_function {
70+
($ARRAY_TYPE:ident) => {{
71+
|args: &Vec<ArrayRef>| {
72+
let array = downcast_arg!(&args[0], "abs arg", $ARRAY_TYPE);
73+
let res: $ARRAY_TYPE = array
74+
.unary(|x| x.wrapping_abs())
75+
.with_data_type(args[0].data_type().clone());
76+
Ok(Arc::new(res) as ArrayRef)
77+
}
78+
}};
79+
}
80+
81+
/// Abs SQL function
82+
/// Return different implementations based on input datatype to reduce branches during execution
83+
fn create_abs_function(
84+
input_data_type: &DataType,
85+
) -> Result<MathArrayFunction> {
86+
match input_data_type {
87+
DataType::Float32 => Ok(make_abs_function!(Float32Array)),
88+
DataType::Float64 => Ok(make_abs_function!(Float64Array)),
89+
90+
// Types that may overflow, such as abs(-128_i8).
91+
DataType::Int8 => Ok(make_try_abs_function!(Int8Array)),
92+
DataType::Int16 => Ok(make_try_abs_function!(Int16Array)),
93+
DataType::Int32 => Ok(make_try_abs_function!(Int32Array)),
94+
DataType::Int64 => Ok(make_try_abs_function!(Int64Array)),
95+
96+
// Types of results are the same as the input.
97+
DataType::Null
98+
| DataType::UInt8
99+
| DataType::UInt16
100+
| DataType::UInt32
101+
| DataType::UInt64 => Ok(|args: &Vec<ArrayRef>| Ok(args[0].clone())),
102+
103+
// Decimal types
104+
DataType::Decimal128(_, _) => Ok(make_decimal_abs_function!(Decimal128Array)),
105+
DataType::Decimal256(_, _) => Ok(make_decimal_abs_function!(Decimal256Array)),
106+
107+
other => not_impl_err!("Unsupported data type {other:?} for function abs"),
108+
}
109+
}
110+
#[derive(Debug)]
111+
pub(super) struct AbsFunc {
112+
signature: Signature,
113+
}
114+
115+
impl AbsFunc {
116+
pub fn new() -> Self {
117+
Self {
118+
signature: Signature::any(1, Volatility::Immutable)
119+
}
120+
}
121+
}
122+
123+
impl ScalarUDFImpl for AbsFunc {
124+
fn as_any(&self) -> &dyn Any {
125+
self
126+
}
127+
fn name(&self) -> &str {
128+
"abs"
129+
}
130+
131+
fn signature(&self) -> &Signature {
132+
&self.signature
133+
}
134+
135+
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
136+
if arg_types.len() != 1 {
137+
return Err(plan_datafusion_err!(
138+
"{}",
139+
utils::generate_signature_error_msg(
140+
self.name(),
141+
self.signature().clone(),
142+
arg_types,
143+
)
144+
));
145+
}
146+
match arg_types[0] {
147+
DataType::Float32 => Ok(DataType::Float32),
148+
DataType::Float64 => Ok(DataType::Float64),
149+
DataType::Int8 => Ok(DataType::Int8),
150+
DataType::Int16 => Ok(DataType::Int16),
151+
DataType::Int32 => Ok(DataType::Int32),
152+
DataType::Int64 => Ok(DataType::Int64),
153+
DataType::Null => Ok(DataType::Null),
154+
DataType::UInt8 => Ok(DataType::UInt8),
155+
DataType::UInt16 => Ok(DataType::UInt16),
156+
DataType::UInt32 => Ok(DataType::UInt32),
157+
DataType::UInt64 => Ok(DataType::UInt64),
158+
DataType::Decimal128(precision, scale) => Ok(DataType::Decimal128(precision, scale)),
159+
DataType::Decimal256(precision, scale) => Ok(DataType::Decimal256(precision, scale)),
160+
_ => not_impl_err!("Unsupported data type {} for function abs", arg_types[0].to_string()),
161+
}
162+
}
163+
164+
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
165+
let args = ColumnarValue::values_to_arrays(args)?;
166+
167+
if args.len() != 1 {
168+
return internal_err!("abs function requires 1 argument, got {}", args.len());
169+
}
170+
171+
let input_data_type = args[0].data_type();
172+
let abs_fun = create_abs_function(input_data_type)?;
173+
174+
let arr = abs_fun(&args)?;
175+
Ok(ColumnarValue::Array(arr))
176+
}
177+
}

datafusion/functions/src/math/mod.rs

+5-3
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,14 @@
1818
//! "math" DataFusion functions
1919
2020
mod nans;
21+
mod abs;
2122

2223
// create UDFs
2324
make_udf_function!(nans::IsNanFunc, ISNAN, isnan);
25+
make_udf_function!(abs::AbsFunc, ABS, abs);
2426

2527
// Export the functions out of this package, both as expr_fn as well as a list of functions
2628
export_functions!(
27-
(isnan, num, "returns true if a given number is +NaN or -NaN otherwise returns false")
28-
);
29-
29+
(isnan, num, "returns true if a given number is +NaN or -NaN otherwise returns false"),
30+
(abs, num, "returns the absolute value of a given number")
31+
);

datafusion/physical-expr/src/functions.rs

-4
Original file line numberDiff line numberDiff line change
@@ -260,9 +260,6 @@ pub fn create_physical_fun(
260260
) -> Result<ScalarFunctionImplementation> {
261261
Ok(match fun {
262262
// math functions
263-
BuiltinScalarFunction::Abs => Arc::new(|args| {
264-
make_scalar_function_inner(math_expressions::abs_invoke)(args)
265-
}),
266263
BuiltinScalarFunction::Acos => Arc::new(math_expressions::acos),
267264
BuiltinScalarFunction::Asin => Arc::new(math_expressions::asin),
268265
BuiltinScalarFunction::Atan => Arc::new(math_expressions::atan),
@@ -3075,7 +3072,6 @@ mod tests {
30753072
let funs = [
30763073
BuiltinScalarFunction::Concat,
30773074
BuiltinScalarFunction::ToTimestamp,
3078-
BuiltinScalarFunction::Abs,
30793075
BuiltinScalarFunction::Repeat,
30803076
];
30813077

0 commit comments

Comments
 (0)