Skip to content

Commit 66c8ba2

Browse files
authored
move Atan2, Atan, Acosh, Asinh, Atanh to datafusion-function (#9872)
* Refactor math functions in datafusion code * fic ci * fix: avoid regression * refactor: move atan2 function * chore: move atan2 test
1 parent 2cb6f73 commit 66c8ba2

File tree

13 files changed

+201
-184
lines changed

13 files changed

+201
-184
lines changed

datafusion/expr/src/built_in_function.rs

Lines changed: 4 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -37,16 +37,6 @@ use strum_macros::EnumIter;
3737
#[derive(Debug, Clone, PartialEq, Eq, Hash, EnumIter, Copy)]
3838
pub enum BuiltinScalarFunction {
3939
// math functions
40-
/// atan
41-
Atan,
42-
/// atan2
43-
Atan2,
44-
/// acosh
45-
Acosh,
46-
/// asinh
47-
Asinh,
48-
/// atanh
49-
Atanh,
5040
/// cbrt
5141
Cbrt,
5242
/// ceil
@@ -159,11 +149,6 @@ impl BuiltinScalarFunction {
159149
pub fn volatility(&self) -> Volatility {
160150
match self {
161151
// Immutable scalar builtins
162-
BuiltinScalarFunction::Atan => Volatility::Immutable,
163-
BuiltinScalarFunction::Atan2 => Volatility::Immutable,
164-
BuiltinScalarFunction::Acosh => Volatility::Immutable,
165-
BuiltinScalarFunction::Asinh => Volatility::Immutable,
166-
BuiltinScalarFunction::Atanh => Volatility::Immutable,
167152
BuiltinScalarFunction::Ceil => Volatility::Immutable,
168153
BuiltinScalarFunction::Coalesce => Volatility::Immutable,
169154
BuiltinScalarFunction::Cos => Volatility::Immutable,
@@ -238,11 +223,6 @@ impl BuiltinScalarFunction {
238223
_ => Ok(Float64),
239224
},
240225

241-
BuiltinScalarFunction::Atan2 => match &input_expr_types[0] {
242-
Float32 => Ok(Float32),
243-
_ => Ok(Float64),
244-
},
245-
246226
BuiltinScalarFunction::Log => match &input_expr_types[0] {
247227
Float32 => Ok(Float32),
248228
_ => Ok(Float64),
@@ -255,11 +235,7 @@ impl BuiltinScalarFunction {
255235

256236
BuiltinScalarFunction::Iszero => Ok(Boolean),
257237

258-
BuiltinScalarFunction::Atan
259-
| BuiltinScalarFunction::Acosh
260-
| BuiltinScalarFunction::Asinh
261-
| BuiltinScalarFunction::Atanh
262-
| BuiltinScalarFunction::Ceil
238+
BuiltinScalarFunction::Ceil
263239
| BuiltinScalarFunction::Cos
264240
| BuiltinScalarFunction::Cosh
265241
| BuiltinScalarFunction::Degrees
@@ -332,10 +308,7 @@ impl BuiltinScalarFunction {
332308
],
333309
self.volatility(),
334310
),
335-
BuiltinScalarFunction::Atan2 => Signature::one_of(
336-
vec![Exact(vec![Float32, Float32]), Exact(vec![Float64, Float64])],
337-
self.volatility(),
338-
),
311+
339312
BuiltinScalarFunction::Log => Signature::one_of(
340313
vec![
341314
Exact(vec![Float32]),
@@ -355,11 +328,7 @@ impl BuiltinScalarFunction {
355328
BuiltinScalarFunction::Gcd | BuiltinScalarFunction::Lcm => {
356329
Signature::uniform(2, vec![Int64], self.volatility())
357330
}
358-
BuiltinScalarFunction::Atan
359-
| BuiltinScalarFunction::Acosh
360-
| BuiltinScalarFunction::Asinh
361-
| BuiltinScalarFunction::Atanh
362-
| BuiltinScalarFunction::Cbrt
331+
BuiltinScalarFunction::Cbrt
363332
| BuiltinScalarFunction::Ceil
364333
| BuiltinScalarFunction::Cos
365334
| BuiltinScalarFunction::Cosh
@@ -392,11 +361,7 @@ impl BuiltinScalarFunction {
392361
pub fn monotonicity(&self) -> Option<FuncMonotonicity> {
393362
if matches!(
394363
&self,
395-
BuiltinScalarFunction::Atan
396-
| BuiltinScalarFunction::Acosh
397-
| BuiltinScalarFunction::Asinh
398-
| BuiltinScalarFunction::Atanh
399-
| BuiltinScalarFunction::Ceil
364+
BuiltinScalarFunction::Ceil
400365
| BuiltinScalarFunction::Degrees
401366
| BuiltinScalarFunction::Exp
402367
| BuiltinScalarFunction::Factorial
@@ -421,11 +386,6 @@ impl BuiltinScalarFunction {
421386
/// Returns all names that can be used to call this function
422387
pub fn aliases(&self) -> &'static [&'static str] {
423388
match self {
424-
BuiltinScalarFunction::Acosh => &["acosh"],
425-
BuiltinScalarFunction::Asinh => &["asinh"],
426-
BuiltinScalarFunction::Atan => &["atan"],
427-
BuiltinScalarFunction::Atanh => &["atanh"],
428-
BuiltinScalarFunction::Atan2 => &["atan2"],
429389
BuiltinScalarFunction::Cbrt => &["cbrt"],
430390
BuiltinScalarFunction::Ceil => &["ceil"],
431391
BuiltinScalarFunction::Cos => &["cos"],

datafusion/expr/src/expr_fn.rs

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -541,10 +541,6 @@ scalar_expr!(Cos, cos, num, "cosine");
541541
scalar_expr!(Cot, cot, num, "cotangent");
542542
scalar_expr!(Sinh, sinh, num, "hyperbolic sine");
543543
scalar_expr!(Cosh, cosh, num, "hyperbolic cosine");
544-
scalar_expr!(Atan, atan, num, "inverse tangent");
545-
scalar_expr!(Asinh, asinh, num, "inverse hyperbolic sine");
546-
scalar_expr!(Acosh, acosh, num, "inverse hyperbolic cosine");
547-
scalar_expr!(Atanh, atanh, num, "inverse hyperbolic tangent");
548544
scalar_expr!(Factorial, factorial, num, "factorial");
549545
scalar_expr!(
550546
Floor,
@@ -571,7 +567,6 @@ scalar_expr!(Exp, exp, num, "exponential");
571567
scalar_expr!(Gcd, gcd, arg_1 arg_2, "greatest common divisor");
572568
scalar_expr!(Lcm, lcm, arg_1 arg_2, "least common multiple");
573569
scalar_expr!(Power, power, base exponent, "`base` raised to the power of `exponent`");
574-
scalar_expr!(Atan2, atan2, y x, "inverse tangent of a division given in the argument");
575570
scalar_expr!(Log, log, base x, "logarithm of a `x` for a particular `base`");
576571

577572
scalar_expr!(InitCap, initcap, string, "converts the first letter of each word in `string` in uppercase and the remaining characters in lowercase");
@@ -979,10 +974,6 @@ mod test {
979974
test_unary_scalar_expr!(Cot, cot);
980975
test_unary_scalar_expr!(Sinh, sinh);
981976
test_unary_scalar_expr!(Cosh, cosh);
982-
test_unary_scalar_expr!(Atan, atan);
983-
test_unary_scalar_expr!(Asinh, asinh);
984-
test_unary_scalar_expr!(Acosh, acosh);
985-
test_unary_scalar_expr!(Atanh, atanh);
986977
test_unary_scalar_expr!(Factorial, factorial);
987978
test_unary_scalar_expr!(Floor, floor);
988979
test_unary_scalar_expr!(Ceil, ceil);
@@ -994,7 +985,6 @@ mod test {
994985
test_nary_scalar_expr!(Trunc, trunc, num, precision);
995986
test_unary_scalar_expr!(Signum, signum);
996987
test_unary_scalar_expr!(Exp, exp);
997-
test_scalar_expr!(Atan2, atan2, y, x);
998988
test_scalar_expr!(Nanvl, nanvl, x, y);
999989
test_scalar_expr!(Iszero, iszero, input);
1000990

datafusion/functions/src/macros.rs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ macro_rules! downcast_arg {
156156
/// $GNAME: a singleton instance of the UDF
157157
/// $NAME: the name of the function
158158
/// $UNARY_FUNC: the unary function to apply to the argument
159+
/// $MONOTONIC_FUNC: the monotonicity of the function
159160
macro_rules! make_math_unary_udf {
160161
($UDF:ident, $GNAME:ident, $NAME:ident, $UNARY_FUNC:ident, $MONOTONICITY:expr) => {
161162
make_udf_function!($NAME::$UDF, $GNAME, $NAME);
@@ -249,3 +250,31 @@ macro_rules! make_math_unary_udf {
249250
}
250251
};
251252
}
253+
254+
#[macro_export]
255+
macro_rules! make_function_inputs2 {
256+
($ARG1: expr, $ARG2: expr, $NAME1:expr, $NAME2: expr, $ARRAY_TYPE:ident, $FUNC: block) => {{
257+
let arg1 = downcast_arg!($ARG1, $NAME1, $ARRAY_TYPE);
258+
let arg2 = downcast_arg!($ARG2, $NAME2, $ARRAY_TYPE);
259+
260+
arg1.iter()
261+
.zip(arg2.iter())
262+
.map(|(a1, a2)| match (a1, a2) {
263+
(Some(a1), Some(a2)) => Some($FUNC(a1, a2.try_into().ok()?)),
264+
_ => None,
265+
})
266+
.collect::<$ARRAY_TYPE>()
267+
}};
268+
($ARG1: expr, $ARG2: expr, $NAME1:expr, $NAME2: expr, $ARRAY_TYPE1:ident, $ARRAY_TYPE2:ident, $FUNC: block) => {{
269+
let arg1 = downcast_arg!($ARG1, $NAME1, $ARRAY_TYPE1);
270+
let arg2 = downcast_arg!($ARG2, $NAME2, $ARRAY_TYPE2);
271+
272+
arg1.iter()
273+
.zip(arg2.iter())
274+
.map(|(a1, a2)| match (a1, a2) {
275+
(Some(a1), Some(a2)) => Some($FUNC(a1, a2.try_into().ok()?)),
276+
_ => None,
277+
})
278+
.collect::<$ARRAY_TYPE1>()
279+
}};
280+
}
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! Math function: `atan2()`.
19+
20+
use arrow::array::{ArrayRef, Float32Array, Float64Array};
21+
use arrow::datatypes::DataType;
22+
use datafusion_common::DataFusionError;
23+
use datafusion_common::{exec_err, Result};
24+
use datafusion_expr::ColumnarValue;
25+
use datafusion_expr::TypeSignature::*;
26+
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
27+
use std::any::Any;
28+
use std::sync::Arc;
29+
30+
use crate::make_function_inputs2;
31+
use crate::utils::make_scalar_function;
32+
33+
#[derive(Debug)]
34+
pub(super) struct Atan2 {
35+
signature: Signature,
36+
}
37+
38+
impl Atan2 {
39+
pub fn new() -> Self {
40+
use DataType::*;
41+
Self {
42+
signature: Signature::one_of(
43+
vec![Exact(vec![Float32, Float32]), Exact(vec![Float64, Float64])],
44+
Volatility::Immutable,
45+
),
46+
}
47+
}
48+
}
49+
50+
impl ScalarUDFImpl for Atan2 {
51+
fn as_any(&self) -> &dyn Any {
52+
self
53+
}
54+
fn name(&self) -> &str {
55+
"atan2"
56+
}
57+
58+
fn signature(&self) -> &Signature {
59+
&self.signature
60+
}
61+
62+
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
63+
use self::DataType::*;
64+
match &arg_types[0] {
65+
Float32 => Ok(Float32),
66+
_ => Ok(Float64),
67+
}
68+
}
69+
70+
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
71+
make_scalar_function(atan2, vec![])(args)
72+
}
73+
}
74+
75+
/// Atan2 SQL function
76+
pub fn atan2(args: &[ArrayRef]) -> Result<ArrayRef> {
77+
match args[0].data_type() {
78+
DataType::Float64 => Ok(Arc::new(make_function_inputs2!(
79+
&args[0],
80+
&args[1],
81+
"y",
82+
"x",
83+
Float64Array,
84+
{ f64::atan2 }
85+
)) as ArrayRef),
86+
87+
DataType::Float32 => Ok(Arc::new(make_function_inputs2!(
88+
&args[0],
89+
&args[1],
90+
"y",
91+
"x",
92+
Float32Array,
93+
{ f32::atan2 }
94+
)) as ArrayRef),
95+
96+
other => exec_err!("Unsupported data type {other:?} for function atan2"),
97+
}
98+
}
99+
100+
#[cfg(test)]
101+
mod test {
102+
use super::*;
103+
use datafusion_common::cast::{as_float32_array, as_float64_array};
104+
105+
#[test]
106+
fn test_atan2_f64() {
107+
let args: Vec<ArrayRef> = vec![
108+
Arc::new(Float64Array::from(vec![2.0, -3.0, 4.0, -5.0])), // y
109+
Arc::new(Float64Array::from(vec![1.0, 2.0, -3.0, -4.0])), // x
110+
];
111+
112+
let result = atan2(&args).expect("failed to initialize function atan2");
113+
let floats =
114+
as_float64_array(&result).expect("failed to initialize function atan2");
115+
116+
assert_eq!(floats.len(), 4);
117+
assert_eq!(floats.value(0), (2.0_f64).atan2(1.0));
118+
assert_eq!(floats.value(1), (-3.0_f64).atan2(2.0));
119+
assert_eq!(floats.value(2), (4.0_f64).atan2(-3.0));
120+
assert_eq!(floats.value(3), (-5.0_f64).atan2(-4.0));
121+
}
122+
123+
#[test]
124+
fn test_atan2_f32() {
125+
let args: Vec<ArrayRef> = vec![
126+
Arc::new(Float32Array::from(vec![2.0, -3.0, 4.0, -5.0])), // y
127+
Arc::new(Float32Array::from(vec![1.0, 2.0, -3.0, -4.0])), // x
128+
];
129+
130+
let result = atan2(&args).expect("failed to initialize function atan2");
131+
let floats =
132+
as_float32_array(&result).expect("failed to initialize function atan2");
133+
134+
assert_eq!(floats.len(), 4);
135+
assert_eq!(floats.value(0), (2.0_f32).atan2(1.0));
136+
assert_eq!(floats.value(1), (-3.0_f32).atan2(2.0));
137+
assert_eq!(floats.value(2), (4.0_f32).atan2(-3.0));
138+
assert_eq!(floats.value(3), (-5.0_f32).atan2(-4.0));
139+
}
140+
}

datafusion/functions/src/math/mod.rs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,13 @@
1818
//! "math" DataFusion functions
1919
2020
mod abs;
21+
mod atan2;
2122
mod nans;
2223

2324
// Create UDFs
2425
make_udf_function!(nans::IsNanFunc, ISNAN, isnan);
2526
make_udf_function!(abs::AbsFunc, ABS, abs);
27+
make_udf_function!(atan2::Atan2, ATAN2, atan2);
2628

2729
make_math_unary_udf!(Log2Func, LOG2, log2, log2, Some(vec![Some(true)]));
2830
make_math_unary_udf!(Log10Func, LOG10, log10, log10, Some(vec![Some(true)]));
@@ -33,6 +35,11 @@ make_math_unary_udf!(AcosFunc, ACOS, acos, acos, None);
3335
make_math_unary_udf!(AsinFunc, ASIN, asin, asin, None);
3436
make_math_unary_udf!(TanFunc, TAN, tan, tan, None);
3537

38+
make_math_unary_udf!(AtanhFunc, ATANH, atanh, atanh, Some(vec![Some(true)]));
39+
make_math_unary_udf!(AsinhFunc, ASINH, asinh, asinh, Some(vec![Some(true)]));
40+
make_math_unary_udf!(AcoshFunc, ACOSH, acosh, acosh, Some(vec![Some(true)]));
41+
make_math_unary_udf!(AtanFunc, ATAN, atan, atan, Some(vec![Some(true)]));
42+
3643
// Export the functions out of this package, both as expr_fn as well as a list of functions
3744
export_functions!(
3845
(
@@ -55,5 +62,10 @@ export_functions!(
5562
"returns the arc sine or inverse sine of a number"
5663
),
5764
(tan, num, "returns the tangent of a number"),
58-
(tanh, num, "returns the hyperbolic tangent of a number")
65+
(tanh, num, "returns the hyperbolic tangent of a number"),
66+
(atanh, num, "returns inverse hyperbolic tangent"),
67+
(asinh, num, "returns inverse hyperbolic sine"),
68+
(acosh, num, "returns inverse hyperbolic cosine"),
69+
(atan, num, "returns inverse tangent"),
70+
(atan2, y x, "returns inverse tangent of a division given in the argument")
5971
);

datafusion/functions/src/utils.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ get_optimal_return_type!(utf8_to_str_type, DataType::LargeUtf8, DataType::Utf8);
6868
// `utf8_to_int_type`: returns either a Int32 or Int64 based on the input type size.
6969
get_optimal_return_type!(utf8_to_int_type, DataType::Int64, DataType::Int32);
7070

71+
/// Creates a scalar function implementation for the given function.
72+
/// * `inner` - the function to be executed
73+
/// * `hints` - hints to be used when expanding scalars to arrays
7174
pub(super) fn make_scalar_function<F>(
7275
inner: F,
7376
hints: Vec<Hint>,

datafusion/physical-expr/src/functions.rs

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -179,10 +179,6 @@ pub fn create_physical_fun(
179179
) -> Result<ScalarFunctionImplementation> {
180180
Ok(match fun {
181181
// math functions
182-
BuiltinScalarFunction::Atan => Arc::new(math_expressions::atan),
183-
BuiltinScalarFunction::Acosh => Arc::new(math_expressions::acosh),
184-
BuiltinScalarFunction::Asinh => Arc::new(math_expressions::asinh),
185-
BuiltinScalarFunction::Atanh => Arc::new(math_expressions::atanh),
186182
BuiltinScalarFunction::Ceil => Arc::new(math_expressions::ceil),
187183
BuiltinScalarFunction::Cos => Arc::new(math_expressions::cos),
188184
BuiltinScalarFunction::Cosh => Arc::new(math_expressions::cosh),
@@ -221,9 +217,6 @@ pub fn create_physical_fun(
221217
BuiltinScalarFunction::Power => {
222218
Arc::new(|args| make_scalar_function_inner(math_expressions::power)(args))
223219
}
224-
BuiltinScalarFunction::Atan2 => {
225-
Arc::new(|args| make_scalar_function_inner(math_expressions::atan2)(args))
226-
}
227220
BuiltinScalarFunction::Log => {
228221
Arc::new(|args| make_scalar_function_inner(math_expressions::log)(args))
229222
}

0 commit comments

Comments
 (0)