Skip to content

Commit 5901df5

Browse files
authored
feat: add bounds for unary math scalar functions (#11584)
* feat: unary udf function bounds * feat: add bounds for more types * feat: remove eprint * fix: add missing bounds file * tests: add tests for unary udf bounds * tests: test f32 and f64 * build: remove unrelated changes * refactor: better unbounded func name * tests: fix tests * refactor: use data_type method * refactor: add more useful intervals to Interval * refactor: use typed bounds for (-inf, inf) * refactor: inf to unbounded * refactor: add lower/upper pi bounds * refactor: consts to consts module * fix: add missing file * fix: docstring typo * refactor: remove unused signum bounds
1 parent 1356934 commit 5901df5

File tree

7 files changed

+595
-34
lines changed

7 files changed

+595
-34
lines changed
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
// Constants defined for scalar construction.
19+
20+
// PI ~ 3.1415927 in f32
21+
#[allow(clippy::approx_constant)]
22+
pub(super) const PI_UPPER_F32: f32 = 3.141593_f32;
23+
24+
// PI ~ 3.141592653589793 in f64
25+
pub(super) const PI_UPPER_F64: f64 = 3.141592653589794_f64;
26+
27+
// -PI ~ -3.1415927 in f32
28+
#[allow(clippy::approx_constant)]
29+
pub(super) const NEGATIVE_PI_LOWER_F32: f32 = -3.141593_f32;
30+
31+
// -PI ~ -3.141592653589793 in f64
32+
pub(super) const NEGATIVE_PI_LOWER_F64: f64 = -3.141592653589794_f64;
33+
34+
// PI / 2 ~ 1.5707964 in f32
35+
pub(super) const FRAC_PI_2_UPPER_F32: f32 = 1.5707965_f32;
36+
37+
// PI / 2 ~ 1.5707963267948966 in f64
38+
pub(super) const FRAC_PI_2_UPPER_F64: f64 = 1.5707963267948967_f64;
39+
40+
// -PI / 2 ~ -1.5707964 in f32
41+
pub(super) const NEGATIVE_FRAC_PI_2_LOWER_F32: f32 = -1.5707965_f32;
42+
43+
// -PI / 2 ~ -1.5707963267948966 in f64
44+
pub(super) const NEGATIVE_FRAC_PI_2_LOWER_F64: f64 = -1.5707963267948967_f64;

datafusion/common/src/scalar/mod.rs

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717

1818
//! [`ScalarValue`]: stores single values
1919
20+
mod consts;
2021
mod struct_builder;
22+
2123
use std::borrow::Borrow;
2224
use std::cmp::Ordering;
2325
use std::collections::{HashSet, VecDeque};
@@ -1007,6 +1009,123 @@ impl ScalarValue {
10071009
}
10081010
}
10091011

1012+
/// Returns a [`ScalarValue`] representing PI
1013+
pub fn new_pi(datatype: &DataType) -> Result<ScalarValue> {
1014+
match datatype {
1015+
DataType::Float32 => Ok(ScalarValue::from(std::f32::consts::PI)),
1016+
DataType::Float64 => Ok(ScalarValue::from(std::f64::consts::PI)),
1017+
_ => _internal_err!("PI is not supported for data type: {:?}", datatype),
1018+
}
1019+
}
1020+
1021+
/// Returns a [`ScalarValue`] representing PI's upper bound
1022+
pub fn new_pi_upper(datatype: &DataType) -> Result<ScalarValue> {
1023+
// TODO: replace the constants with next_up/next_down when
1024+
// they are stabilized: https://doc.rust-lang.org/std/primitive.f64.html#method.next_up
1025+
match datatype {
1026+
DataType::Float32 => Ok(ScalarValue::from(consts::PI_UPPER_F32)),
1027+
DataType::Float64 => Ok(ScalarValue::from(consts::PI_UPPER_F64)),
1028+
_ => {
1029+
_internal_err!("PI_UPPER is not supported for data type: {:?}", datatype)
1030+
}
1031+
}
1032+
}
1033+
1034+
/// Returns a [`ScalarValue`] representing -PI's lower bound
1035+
pub fn new_negative_pi_lower(datatype: &DataType) -> Result<ScalarValue> {
1036+
match datatype {
1037+
DataType::Float32 => Ok(ScalarValue::from(consts::NEGATIVE_PI_LOWER_F32)),
1038+
DataType::Float64 => Ok(ScalarValue::from(consts::NEGATIVE_PI_LOWER_F64)),
1039+
_ => {
1040+
_internal_err!("-PI_LOWER is not supported for data type: {:?}", datatype)
1041+
}
1042+
}
1043+
}
1044+
1045+
/// Returns a [`ScalarValue`] representing FRAC_PI_2's upper bound
1046+
pub fn new_frac_pi_2_upper(datatype: &DataType) -> Result<ScalarValue> {
1047+
match datatype {
1048+
DataType::Float32 => Ok(ScalarValue::from(consts::FRAC_PI_2_UPPER_F32)),
1049+
DataType::Float64 => Ok(ScalarValue::from(consts::FRAC_PI_2_UPPER_F64)),
1050+
_ => {
1051+
_internal_err!(
1052+
"PI_UPPER/2 is not supported for data type: {:?}",
1053+
datatype
1054+
)
1055+
}
1056+
}
1057+
}
1058+
1059+
// Returns a [`ScalarValue`] representing FRAC_PI_2's lower bound
1060+
pub fn new_neg_frac_pi_2_lower(datatype: &DataType) -> Result<ScalarValue> {
1061+
match datatype {
1062+
DataType::Float32 => {
1063+
Ok(ScalarValue::from(consts::NEGATIVE_FRAC_PI_2_LOWER_F32))
1064+
}
1065+
DataType::Float64 => {
1066+
Ok(ScalarValue::from(consts::NEGATIVE_FRAC_PI_2_LOWER_F64))
1067+
}
1068+
_ => {
1069+
_internal_err!(
1070+
"-PI/2_LOWER is not supported for data type: {:?}",
1071+
datatype
1072+
)
1073+
}
1074+
}
1075+
}
1076+
1077+
/// Returns a [`ScalarValue`] representing -PI
1078+
pub fn new_negative_pi(datatype: &DataType) -> Result<ScalarValue> {
1079+
match datatype {
1080+
DataType::Float32 => Ok(ScalarValue::from(-std::f32::consts::PI)),
1081+
DataType::Float64 => Ok(ScalarValue::from(-std::f64::consts::PI)),
1082+
_ => _internal_err!("-PI is not supported for data type: {:?}", datatype),
1083+
}
1084+
}
1085+
1086+
/// Returns a [`ScalarValue`] representing PI/2
1087+
pub fn new_frac_pi_2(datatype: &DataType) -> Result<ScalarValue> {
1088+
match datatype {
1089+
DataType::Float32 => Ok(ScalarValue::from(std::f32::consts::FRAC_PI_2)),
1090+
DataType::Float64 => Ok(ScalarValue::from(std::f64::consts::FRAC_PI_2)),
1091+
_ => _internal_err!("PI/2 is not supported for data type: {:?}", datatype),
1092+
}
1093+
}
1094+
1095+
/// Returns a [`ScalarValue`] representing -PI/2
1096+
pub fn new_neg_frac_pi_2(datatype: &DataType) -> Result<ScalarValue> {
1097+
match datatype {
1098+
DataType::Float32 => Ok(ScalarValue::from(-std::f32::consts::FRAC_PI_2)),
1099+
DataType::Float64 => Ok(ScalarValue::from(-std::f64::consts::FRAC_PI_2)),
1100+
_ => _internal_err!("-PI/2 is not supported for data type: {:?}", datatype),
1101+
}
1102+
}
1103+
1104+
/// Returns a [`ScalarValue`] representing infinity
1105+
pub fn new_infinity(datatype: &DataType) -> Result<ScalarValue> {
1106+
match datatype {
1107+
DataType::Float32 => Ok(ScalarValue::from(f32::INFINITY)),
1108+
DataType::Float64 => Ok(ScalarValue::from(f64::INFINITY)),
1109+
_ => {
1110+
_internal_err!("Infinity is not supported for data type: {:?}", datatype)
1111+
}
1112+
}
1113+
}
1114+
1115+
/// Returns a [`ScalarValue`] representing negative infinity
1116+
pub fn new_neg_infinity(datatype: &DataType) -> Result<ScalarValue> {
1117+
match datatype {
1118+
DataType::Float32 => Ok(ScalarValue::from(f32::NEG_INFINITY)),
1119+
DataType::Float64 => Ok(ScalarValue::from(f64::NEG_INFINITY)),
1120+
_ => {
1121+
_internal_err!(
1122+
"Negative Infinity is not supported for data type: {:?}",
1123+
datatype
1124+
)
1125+
}
1126+
}
1127+
}
1128+
10101129
/// Create a zero value in the given type.
10111130
pub fn new_zero(datatype: &DataType) -> Result<ScalarValue> {
10121131
Ok(match datatype {

datafusion/expr/src/interval_arithmetic.rs

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,38 @@ impl Interval {
332332
Ok(Self::new(unbounded_endpoint.clone(), unbounded_endpoint))
333333
}
334334

335+
/// Creates an interval between -1 to 1.
336+
pub fn make_symmetric_unit_interval(data_type: &DataType) -> Result<Self> {
337+
Self::try_new(
338+
ScalarValue::new_negative_one(data_type)?,
339+
ScalarValue::new_one(data_type)?,
340+
)
341+
}
342+
343+
/// Create an interval from -π to π.
344+
pub fn make_symmetric_pi_interval(data_type: &DataType) -> Result<Self> {
345+
Self::try_new(
346+
ScalarValue::new_negative_pi_lower(data_type)?,
347+
ScalarValue::new_pi_upper(data_type)?,
348+
)
349+
}
350+
351+
/// Create an interval from -π/2 to π/2.
352+
pub fn make_symmetric_half_pi_interval(data_type: &DataType) -> Result<Self> {
353+
Self::try_new(
354+
ScalarValue::new_neg_frac_pi_2_lower(data_type)?,
355+
ScalarValue::new_frac_pi_2_upper(data_type)?,
356+
)
357+
}
358+
359+
/// Create an interval from 0 to infinity.
360+
pub fn make_non_negative_infinity_interval(data_type: &DataType) -> Result<Self> {
361+
Self::try_new(
362+
ScalarValue::new_zero(data_type)?,
363+
ScalarValue::try_from(data_type)?,
364+
)
365+
}
366+
335367
/// Returns a reference to the lower bound.
336368
pub fn lower(&self) -> &ScalarValue {
337369
&self.lower

datafusion/functions/src/macros.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ macro_rules! downcast_arg {
162162
/// $UNARY_FUNC: the unary function to apply to the argument
163163
/// $OUTPUT_ORDERING: the output ordering calculation method of the function
164164
macro_rules! make_math_unary_udf {
165-
($UDF:ident, $GNAME:ident, $NAME:ident, $UNARY_FUNC:ident, $OUTPUT_ORDERING:expr) => {
165+
($UDF:ident, $GNAME:ident, $NAME:ident, $UNARY_FUNC:ident, $OUTPUT_ORDERING:expr, $EVALUATE_BOUNDS:expr) => {
166166
make_udf_function!($NAME::$UDF, $GNAME, $NAME);
167167

168168
mod $NAME {
@@ -172,6 +172,7 @@ macro_rules! make_math_unary_udf {
172172
use arrow::array::{ArrayRef, Float32Array, Float64Array};
173173
use arrow::datatypes::DataType;
174174
use datafusion_common::{exec_err, DataFusionError, Result};
175+
use datafusion_expr::interval_arithmetic::Interval;
175176
use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
176177
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
177178

@@ -222,6 +223,10 @@ macro_rules! make_math_unary_udf {
222223
$OUTPUT_ORDERING(input)
223224
}
224225

226+
fn evaluate_bounds(&self, inputs: &[&Interval]) -> Result<Interval> {
227+
$EVALUATE_BOUNDS(inputs)
228+
}
229+
225230
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
226231
let args = ColumnarValue::values_to_arrays(args)?;
227232

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use datafusion_common::ScalarValue;
19+
use datafusion_expr::interval_arithmetic::Interval;
20+
21+
pub(super) fn unbounded_bounds(input: &[&Interval]) -> crate::Result<Interval> {
22+
let data_type = input[0].data_type();
23+
24+
Interval::make_unbounded(&data_type)
25+
}
26+
27+
pub(super) fn sin_bounds(input: &[&Interval]) -> crate::Result<Interval> {
28+
// sin(x) is bounded by [-1, 1]
29+
let data_type = input[0].data_type();
30+
31+
Interval::make_symmetric_unit_interval(&data_type)
32+
}
33+
34+
pub(super) fn asin_bounds(input: &[&Interval]) -> crate::Result<Interval> {
35+
// asin(x) is bounded by [-π/2, π/2]
36+
let data_type = input[0].data_type();
37+
38+
Interval::make_symmetric_half_pi_interval(&data_type)
39+
}
40+
41+
pub(super) fn atan_bounds(input: &[&Interval]) -> crate::Result<Interval> {
42+
// atan(x) is bounded by [-π/2, π/2]
43+
let data_type = input[0].data_type();
44+
45+
Interval::make_symmetric_half_pi_interval(&data_type)
46+
}
47+
48+
pub(super) fn acos_bounds(input: &[&Interval]) -> crate::Result<Interval> {
49+
// acos(x) is bounded by [0, π]
50+
let data_type = input[0].data_type();
51+
52+
Interval::try_new(
53+
ScalarValue::new_zero(&data_type)?,
54+
ScalarValue::new_pi_upper(&data_type)?,
55+
)
56+
}
57+
58+
pub(super) fn acosh_bounds(input: &[&Interval]) -> crate::Result<Interval> {
59+
// acosh(x) is bounded by [0, ∞)
60+
let data_type = input[0].data_type();
61+
62+
Interval::make_non_negative_infinity_interval(&data_type)
63+
}
64+
65+
pub(super) fn cos_bounds(input: &[&Interval]) -> crate::Result<Interval> {
66+
// cos(x) is bounded by [-1, 1]
67+
let data_type = input[0].data_type();
68+
69+
Interval::make_symmetric_unit_interval(&data_type)
70+
}
71+
72+
pub(super) fn cosh_bounds(input: &[&Interval]) -> crate::Result<Interval> {
73+
// cosh(x) is bounded by [1, ∞)
74+
let data_type = input[0].data_type();
75+
76+
Interval::try_new(
77+
ScalarValue::new_one(&data_type)?,
78+
ScalarValue::try_from(&data_type)?,
79+
)
80+
}
81+
82+
pub(super) fn exp_bounds(input: &[&Interval]) -> crate::Result<Interval> {
83+
// exp(x) is bounded by [0, ∞)
84+
let data_type = input[0].data_type();
85+
86+
Interval::make_non_negative_infinity_interval(&data_type)
87+
}
88+
89+
pub(super) fn radians_bounds(input: &[&Interval]) -> crate::Result<Interval> {
90+
// radians(x) is bounded by (-π, π)
91+
let data_type = input[0].data_type();
92+
93+
Interval::make_symmetric_pi_interval(&data_type)
94+
}
95+
96+
pub(super) fn sqrt_bounds(input: &[&Interval]) -> crate::Result<Interval> {
97+
// sqrt(x) is bounded by [0, ∞)
98+
let data_type = input[0].data_type();
99+
100+
Interval::make_non_negative_infinity_interval(&data_type)
101+
}
102+
103+
pub(super) fn tanh_bounds(input: &[&Interval]) -> crate::Result<Interval> {
104+
// tanh(x) is bounded by (-1, 1)
105+
let data_type = input[0].data_type();
106+
107+
Interval::make_symmetric_unit_interval(&data_type)
108+
}

0 commit comments

Comments
 (0)