Skip to content

Commit 8945462

Browse files
authored
Fix : signum function bug when 0.0 input (#11580)
* add signum unit test * fix: signum function implementation - input zero output zero * fix: run cargo fmt * fix: not specified return type is float64 * fix: sqllogictest
1 parent 1e06b91 commit 8945462

File tree

4 files changed

+218
-7
lines changed

4 files changed

+218
-7
lines changed

datafusion/functions/src/math/mod.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ pub mod pi;
3535
pub mod power;
3636
pub mod random;
3737
pub mod round;
38+
pub mod signum;
3839
pub mod trunc;
3940

4041
// Create UDFs
@@ -81,7 +82,7 @@ make_math_unary_udf!(
8182
);
8283
make_udf_function!(random::RandomFunc, RANDOM, random);
8384
make_udf_function!(round::RoundFunc, ROUND, round);
84-
make_math_unary_udf!(SignumFunc, SIGNUM, signum, signum, super::signum_order);
85+
make_udf_function!(signum::SignumFunc, SIGNUM, signum);
8586
make_math_unary_udf!(SinFunc, SIN, sin, sin, super::sin_order);
8687
make_math_unary_udf!(SinhFunc, SINH, sinh, sinh, super::sinh_order);
8788
make_math_unary_udf!(SqrtFunc, SQRT, sqrt, sqrt, super::sqrt_order);

datafusion/functions/src/math/monotonicity.rs

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -197,11 +197,6 @@ pub fn radians_order(input: &[ExprProperties]) -> Result<SortProperties> {
197197
Ok(input[0].sort_properties)
198198
}
199199

200-
/// Non-decreasing for all real numbers x.
201-
pub fn signum_order(input: &[ExprProperties]) -> Result<SortProperties> {
202-
Ok(input[0].sort_properties)
203-
}
204-
205200
/// Non-decreasing on \[0, π\] and then non-increasing on \[π, 2π\].
206201
/// This pattern repeats periodically with a period of 2π.
207202
// TODO: Implement ordering rule of the SIN function.
Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use std::any::Any;
19+
use std::sync::Arc;
20+
21+
use arrow::array::{ArrayRef, Float32Array, Float64Array};
22+
use arrow::datatypes::DataType;
23+
use arrow::datatypes::DataType::{Float32, Float64};
24+
25+
use datafusion_common::{exec_err, DataFusionError, Result};
26+
use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
27+
use datafusion_expr::ColumnarValue;
28+
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
29+
30+
use crate::utils::make_scalar_function;
31+
32+
#[derive(Debug)]
33+
pub struct SignumFunc {
34+
signature: Signature,
35+
}
36+
37+
impl Default for SignumFunc {
38+
fn default() -> Self {
39+
SignumFunc::new()
40+
}
41+
}
42+
43+
impl SignumFunc {
44+
pub fn new() -> Self {
45+
use DataType::*;
46+
Self {
47+
signature: Signature::uniform(
48+
1,
49+
vec![Float64, Float32],
50+
Volatility::Immutable,
51+
),
52+
}
53+
}
54+
}
55+
56+
impl ScalarUDFImpl for SignumFunc {
57+
fn as_any(&self) -> &dyn Any {
58+
self
59+
}
60+
61+
fn name(&self) -> &str {
62+
"signum"
63+
}
64+
65+
fn signature(&self) -> &Signature {
66+
&self.signature
67+
}
68+
69+
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
70+
match &arg_types[0] {
71+
Float32 => Ok(Float32),
72+
_ => Ok(Float64),
73+
}
74+
}
75+
76+
fn output_ordering(&self, input: &[ExprProperties]) -> Result<SortProperties> {
77+
// Non-decreasing for all real numbers x.
78+
Ok(input[0].sort_properties)
79+
}
80+
81+
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
82+
make_scalar_function(signum, vec![])(args)
83+
}
84+
}
85+
86+
/// signum SQL function
87+
pub fn signum(args: &[ArrayRef]) -> Result<ArrayRef> {
88+
match args[0].data_type() {
89+
Float64 => Ok(Arc::new(make_function_scalar_inputs_return_type!(
90+
&args[0],
91+
"signum",
92+
Float64Array,
93+
Float64Array,
94+
{
95+
|x: f64| {
96+
if x == 0_f64 {
97+
0_f64
98+
} else {
99+
x.signum()
100+
}
101+
}
102+
}
103+
)) as ArrayRef),
104+
105+
Float32 => Ok(Arc::new(make_function_scalar_inputs_return_type!(
106+
&args[0],
107+
"signum",
108+
Float32Array,
109+
Float32Array,
110+
{
111+
|x: f32| {
112+
if x == 0_f32 {
113+
0_f32
114+
} else {
115+
x.signum()
116+
}
117+
}
118+
}
119+
)) as ArrayRef),
120+
121+
other => exec_err!("Unsupported data type {other:?} for function signum"),
122+
}
123+
}
124+
125+
#[cfg(test)]
126+
mod test {
127+
use std::sync::Arc;
128+
129+
use arrow::array::{Float32Array, Float64Array};
130+
131+
use datafusion_common::cast::{as_float32_array, as_float64_array};
132+
use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
133+
134+
use crate::math::signum::SignumFunc;
135+
136+
#[test]
137+
fn test_signum_f32() {
138+
let args = [ColumnarValue::Array(Arc::new(Float32Array::from(vec![
139+
-1.0,
140+
-0.0,
141+
0.0,
142+
1.0,
143+
-0.01,
144+
0.01,
145+
f32::NAN,
146+
f32::INFINITY,
147+
f32::NEG_INFINITY,
148+
])))];
149+
150+
let result = SignumFunc::new()
151+
.invoke(&args)
152+
.expect("failed to initialize function signum");
153+
154+
match result {
155+
ColumnarValue::Array(arr) => {
156+
let floats = as_float32_array(&arr)
157+
.expect("failed to convert result to a Float32Array");
158+
159+
assert_eq!(floats.len(), 9);
160+
assert_eq!(floats.value(0), -1.0);
161+
assert_eq!(floats.value(1), 0.0);
162+
assert_eq!(floats.value(2), 0.0);
163+
assert_eq!(floats.value(3), 1.0);
164+
assert_eq!(floats.value(4), -1.0);
165+
assert_eq!(floats.value(5), 1.0);
166+
assert!(floats.value(6).is_nan());
167+
assert_eq!(floats.value(7), 1.0);
168+
assert_eq!(floats.value(8), -1.0);
169+
}
170+
ColumnarValue::Scalar(_) => {
171+
panic!("Expected an array value")
172+
}
173+
}
174+
}
175+
176+
#[test]
177+
fn test_signum_f64() {
178+
let args = [ColumnarValue::Array(Arc::new(Float64Array::from(vec![
179+
-1.0,
180+
-0.0,
181+
0.0,
182+
1.0,
183+
-0.01,
184+
0.01,
185+
f64::NAN,
186+
f64::INFINITY,
187+
f64::NEG_INFINITY,
188+
])))];
189+
190+
let result = SignumFunc::new()
191+
.invoke(&args)
192+
.expect("failed to initialize function signum");
193+
194+
match result {
195+
ColumnarValue::Array(arr) => {
196+
let floats = as_float64_array(&arr)
197+
.expect("failed to convert result to a Float32Array");
198+
199+
assert_eq!(floats.len(), 9);
200+
assert_eq!(floats.value(0), -1.0);
201+
assert_eq!(floats.value(1), 0.0);
202+
assert_eq!(floats.value(2), 0.0);
203+
assert_eq!(floats.value(3), 1.0);
204+
assert_eq!(floats.value(4), -1.0);
205+
assert_eq!(floats.value(5), 1.0);
206+
assert!(floats.value(6).is_nan());
207+
assert_eq!(floats.value(7), 1.0);
208+
assert_eq!(floats.value(8), -1.0);
209+
}
210+
ColumnarValue::Scalar(_) => {
211+
panic!("Expected an array value")
212+
}
213+
}
214+
}
215+
}

datafusion/sqllogictest/test_files/scalar.slt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -794,7 +794,7 @@ select round(column1, column2) from values (3.14, 2), (3.14, 3), (3.14, 21474836
794794
query RRR rowsort
795795
select signum(-2), signum(0), signum(2);
796796
----
797-
-1 1 1
797+
-1 0 1
798798

799799
# signum scalar nulls
800800
query R rowsort

0 commit comments

Comments
 (0)