From cf152af6515f0808d840e1fe9c63b02802595826 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Wed, 16 Aug 2023 17:28:35 +0900 Subject: [PATCH] Add isnan and iszero (#7274) * Add isnan and iszero. * Modified doc. * f64 doesn't need high priority. --------- Co-authored-by: Andrew Lamb --- datafusion/core/tests/sql/expr.rs | 2 + datafusion/expr/src/built_in_function.rs | 16 ++ datafusion/expr/src/expr_fn.rs | 14 ++ datafusion/physical-expr/src/functions.rs | 6 + .../physical-expr/src/math_expressions.rs | 141 +++++++++++++++++- datafusion/proto/proto/datafusion.proto | 2 + datafusion/proto/src/generated/pbjson.rs | 6 + datafusion/proto/src/generated/prost.rs | 6 + .../proto/src/logical_plan/from_proto.rs | 6 +- datafusion/proto/src/logical_plan/to_proto.rs | 2 + datafusion/sqllogictest/test_files/math.slt | 12 ++ datafusion/sqllogictest/test_files/scalar.slt | 46 ++++++ docs/source/user-guide/expressions.md | 2 + .../source/user-guide/sql/scalar_functions.md | 28 ++++ 14 files changed, 286 insertions(+), 3 deletions(-) diff --git a/datafusion/core/tests/sql/expr.rs b/datafusion/core/tests/sql/expr.rs index 36786bd0798f..b88a0d8df128 100644 --- a/datafusion/core/tests/sql/expr.rs +++ b/datafusion/core/tests/sql/expr.rs @@ -63,6 +63,8 @@ async fn test_mathematical_expressions_with_null() -> Result<()> { test_expression!("nanvl(NULL, NULL)", "NULL"); test_expression!("nanvl(1, NULL)", "NULL"); test_expression!("nanvl(NULL, 1)", "NULL"); + test_expression!("isnan(NULL)", "NULL"); + test_expression!("iszero(NULL)", "NULL"); Ok(()) } diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 6cc7fd540251..a0c7a839e7a2 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -84,6 +84,10 @@ pub enum BuiltinScalarFunction { Gcd, /// lcm, Least common multiple Lcm, + /// isnan + Isnan, + /// iszero + Iszero, /// ln, Natural logarithm Ln, /// log, same as log10 @@ -334,6 +338,8 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Factorial => Volatility::Immutable, BuiltinScalarFunction::Floor => Volatility::Immutable, BuiltinScalarFunction::Gcd => Volatility::Immutable, + BuiltinScalarFunction::Isnan => Volatility::Immutable, + BuiltinScalarFunction::Iszero => Volatility::Immutable, BuiltinScalarFunction::Lcm => Volatility::Immutable, BuiltinScalarFunction::Ln => Volatility::Immutable, BuiltinScalarFunction::Log => Volatility::Immutable, @@ -774,6 +780,8 @@ impl BuiltinScalarFunction { _ => Ok(Float64), }, + BuiltinScalarFunction::Isnan | BuiltinScalarFunction::Iszero => Ok(Boolean), + BuiltinScalarFunction::ArrowTypeof => Ok(Utf8), BuiltinScalarFunction::Abs @@ -1184,6 +1192,12 @@ impl BuiltinScalarFunction { | BuiltinScalarFunction::CurrentTime => { Signature::uniform(0, vec![], self.volatility()) } + BuiltinScalarFunction::Isnan | BuiltinScalarFunction::Iszero => { + Signature::one_of( + vec![Exact(vec![Float32]), Exact(vec![Float64])], + self.volatility(), + ) + } } } } @@ -1208,6 +1222,8 @@ fn aliases(func: &BuiltinScalarFunction) -> &'static [&'static str] { BuiltinScalarFunction::Factorial => &["factorial"], BuiltinScalarFunction::Floor => &["floor"], BuiltinScalarFunction::Gcd => &["gcd"], + BuiltinScalarFunction::Isnan => &["isnan"], + BuiltinScalarFunction::Iszero => &["iszero"], BuiltinScalarFunction::Lcm => &["lcm"], BuiltinScalarFunction::Ln => &["ln"], BuiltinScalarFunction::Log => &["log"], diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 47767c23b363..3ca6aa5d6e95 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -811,6 +811,18 @@ scalar_expr!(CurrentDate, current_date, ,"returns current UTC date as a [`DataTy scalar_expr!(Now, now, ,"returns current timestamp in nanoseconds, using the same value for all instances of now() in same statement"); scalar_expr!(CurrentTime, current_time, , "returns current UTC time as a [`DataType::Time64`] value"); scalar_expr!(Nanvl, nanvl, x y, "returns x if x is not NaN otherwise returns y"); +scalar_expr!( + Isnan, + isnan, + num, + "returns true if a given number is +NaN or -NaN otherwise returns false" +); +scalar_expr!( + Iszero, + iszero, + num, + "returns true if a given number is +0.0 or -0.0 otherwise returns false" +); scalar_expr!(ArrowTypeof, arrow_typeof, val, "data type"); @@ -1003,6 +1015,8 @@ mod test { test_unary_scalar_expr!(Ln, ln); test_scalar_expr!(Atan2, atan2, y, x); test_scalar_expr!(Nanvl, nanvl, x, y); + test_scalar_expr!(Isnan, isnan, input); + test_scalar_expr!(Iszero, iszero, input); test_scalar_expr!(Ascii, ascii, input); test_scalar_expr!(BitLength, bit_length, string); diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index d1a5119ee8a3..81c673ea834b 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -374,6 +374,12 @@ pub fn create_physical_fun( BuiltinScalarFunction::Gcd => { Arc::new(|args| make_scalar_function(math_expressions::gcd)(args)) } + BuiltinScalarFunction::Isnan => { + Arc::new(|args| make_scalar_function(math_expressions::isnan)(args)) + } + BuiltinScalarFunction::Iszero => { + Arc::new(|args| make_scalar_function(math_expressions::iszero)(args)) + } BuiltinScalarFunction::Lcm => { Arc::new(|args| make_scalar_function(math_expressions::lcm)(args)) } diff --git a/datafusion/physical-expr/src/math_expressions.rs b/datafusion/physical-expr/src/math_expressions.rs index 03e0bb64551b..e62d24497bed 100644 --- a/datafusion/physical-expr/src/math_expressions.rs +++ b/datafusion/physical-expr/src/math_expressions.rs @@ -18,7 +18,7 @@ //! Math expressions use arrow::array::ArrayRef; -use arrow::array::{Float32Array, Float64Array, Int64Array}; +use arrow::array::{BooleanArray, Float32Array, Float64Array, Int64Array}; use arrow::datatypes::DataType; use datafusion_common::ScalarValue; use datafusion_common::ScalarValue::{Float32, Int64}; @@ -142,6 +142,19 @@ macro_rules! make_function_inputs2 { }}; } +macro_rules! make_function_scalar_inputs_return_type { + ($ARG: expr, $NAME:expr, $ARGS_TYPE:ident, $RETURN_TYPE:ident, $FUNC: block) => {{ + let arg = downcast_arg!($ARG, $NAME, $ARGS_TYPE); + + arg.iter() + .map(|a| match a { + Some(a) => Some($FUNC(a)), + _ => None, + }) + .collect::<$RETURN_TYPE>() + }}; +} + math_unary_function!("sqrt", sqrt); math_unary_function!("cbrt", cbrt); math_unary_function!("sin", sin); @@ -306,6 +319,56 @@ pub fn nanvl(args: &[ArrayRef]) -> Result { } } +/// Isnan SQL function +pub fn isnan(args: &[ArrayRef]) -> Result { + match args[0].data_type() { + DataType::Float64 => Ok(Arc::new(make_function_scalar_inputs_return_type!( + &args[0], + "x", + Float64Array, + BooleanArray, + { f64::is_nan } + )) as ArrayRef), + + DataType::Float32 => Ok(Arc::new(make_function_scalar_inputs_return_type!( + &args[0], + "x", + Float32Array, + BooleanArray, + { f32::is_nan } + )) as ArrayRef), + + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {other:?} for function isnan" + ))), + } +} + +/// Iszero SQL function +pub fn iszero(args: &[ArrayRef]) -> Result { + match args[0].data_type() { + DataType::Float64 => Ok(Arc::new(make_function_scalar_inputs_return_type!( + &args[0], + "x", + Float64Array, + BooleanArray, + { |x: f64| { x == 0_f64 } } + )) as ArrayRef), + + DataType::Float32 => Ok(Arc::new(make_function_scalar_inputs_return_type!( + &args[0], + "x", + Float32Array, + BooleanArray, + { |x: f32| { x == 0_f32 } } + )) as ArrayRef), + + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {other:?} for function iszero" + ))), + } +} + /// Pi SQL function pub fn pi(args: &[ColumnarValue]) -> Result { if !matches!(&args[0], ColumnarValue::Array(_)) { @@ -650,7 +713,9 @@ mod tests { use super::*; use arrow::array::{Float64Array, NullArray}; - use datafusion_common::cast::{as_float32_array, as_float64_array, as_int64_array}; + use datafusion_common::cast::{ + as_boolean_array, as_float32_array, as_float64_array, as_int64_array, + }; #[test] fn test_random_expression() { @@ -1041,4 +1106,76 @@ mod tests { assert_eq!(floats.value(2), 3.0); assert!(floats.value(3).is_nan()); } + + #[test] + fn test_isnan_f64() { + let args: Vec = vec![Arc::new(Float64Array::from(vec![ + 1.0, + f64::NAN, + 3.0, + -f64::NAN, + ]))]; + + let result = isnan(&args).expect("failed to initialize function isnan"); + let booleans = + as_boolean_array(&result).expect("failed to initialize function isnan"); + + assert_eq!(booleans.len(), 4); + assert!(!booleans.value(0)); + assert!(booleans.value(1)); + assert!(!booleans.value(2)); + assert!(booleans.value(3)); + } + + #[test] + fn test_isnan_f32() { + let args: Vec = vec![Arc::new(Float32Array::from(vec![ + 1.0, + f32::NAN, + 3.0, + f32::NAN, + ]))]; + + let result = isnan(&args).expect("failed to initialize function isnan"); + let booleans = + as_boolean_array(&result).expect("failed to initialize function isnan"); + + assert_eq!(booleans.len(), 4); + assert!(!booleans.value(0)); + assert!(booleans.value(1)); + assert!(!booleans.value(2)); + assert!(booleans.value(3)); + } + + #[test] + fn test_iszero_f64() { + let args: Vec = + vec![Arc::new(Float64Array::from(vec![1.0, 0.0, 3.0, -0.0]))]; + + let result = iszero(&args).expect("failed to initialize function iszero"); + let booleans = + as_boolean_array(&result).expect("failed to initialize function iszero"); + + assert_eq!(booleans.len(), 4); + assert!(!booleans.value(0)); + assert!(booleans.value(1)); + assert!(!booleans.value(2)); + assert!(booleans.value(3)); + } + + #[test] + fn test_iszero_f32() { + let args: Vec = + vec![Arc::new(Float32Array::from(vec![1.0, 0.0, 3.0, -0.0]))]; + + let result = iszero(&args).expect("failed to initialize function iszero"); + let booleans = + as_boolean_array(&result).expect("failed to initialize function iszero"); + + assert_eq!(booleans.len(), 4); + assert!(!booleans.value(0)); + assert!(booleans.value(1)); + assert!(!booleans.value(2)); + assert!(booleans.value(3)); + } } diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 1a8ad093b965..e4ef7b1bd448 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -595,6 +595,8 @@ enum ScalarFunction { ArrayReplaceAll = 110; Nanvl = 111; Flatten = 112; + Isnan = 113; + Iszero = 114; } message ScalarFunctionNode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index a33d80be9ddb..f1a9e9c7bb74 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -18945,6 +18945,8 @@ impl serde::Serialize for ScalarFunction { Self::ArrayReplaceAll => "ArrayReplaceAll", Self::Nanvl => "Nanvl", Self::Flatten => "Flatten", + Self::Isnan => "Isnan", + Self::Iszero => "Iszero", }; serializer.serialize_str(variant) } @@ -19069,6 +19071,8 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ArrayReplaceAll", "Nanvl", "Flatten", + "Isnan", + "Iszero", ]; struct GeneratedVisitor; @@ -19224,6 +19228,8 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ArrayReplaceAll" => Ok(ScalarFunction::ArrayReplaceAll), "Nanvl" => Ok(ScalarFunction::Nanvl), "Flatten" => Ok(ScalarFunction::Flatten), + "Isnan" => Ok(ScalarFunction::Isnan), + "Iszero" => Ok(ScalarFunction::Iszero), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 519cd002df3f..6cf402fe66e9 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2375,6 +2375,8 @@ pub enum ScalarFunction { ArrayReplaceAll = 110, Nanvl = 111, Flatten = 112, + Isnan = 113, + Iszero = 114, } impl ScalarFunction { /// String value of the enum field names used in the ProtoBuf definition. @@ -2496,6 +2498,8 @@ impl ScalarFunction { ScalarFunction::ArrayReplaceAll => "ArrayReplaceAll", ScalarFunction::Nanvl => "Nanvl", ScalarFunction::Flatten => "Flatten", + ScalarFunction::Isnan => "Isnan", + ScalarFunction::Iszero => "Iszero", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -2614,6 +2618,8 @@ impl ScalarFunction { "ArrayReplaceAll" => Some(Self::ArrayReplaceAll), "Nanvl" => Some(Self::Nanvl), "Flatten" => Some(Self::Flatten), + "Isnan" => Some(Self::Isnan), + "Iszero" => Some(Self::Iszero), _ => None, } } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index d2e037aa4b1b..43f1d44b7ded 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -45,7 +45,7 @@ use datafusion_expr::{ concat_ws_expr, cos, cosh, cot, current_date, current_time, date_bin, date_part, date_trunc, degrees, digest, exp, expr::{self, InList, Sort, WindowFunction}, - factorial, floor, from_unixtime, gcd, lcm, left, ln, log, log10, log2, + factorial, floor, from_unixtime, gcd, isnan, iszero, lcm, left, ln, log, log10, log2, logical_plan::{PlanType, StringifiedPlan}, lower, lpad, ltrim, md5, nanvl, now, nullif, octet_length, pi, power, radians, random, regexp_match, regexp_replace, repeat, replace, reverse, right, round, rpad, @@ -525,6 +525,8 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::FromUnixtime => Self::FromUnixtime, ScalarFunction::Atan2 => Self::Atan2, ScalarFunction::Nanvl => Self::Nanvl, + ScalarFunction::Isnan => Self::Isnan, + ScalarFunction::Iszero => Self::Iszero, ScalarFunction::ArrowTypeof => Self::ArrowTypeof, } } @@ -1577,6 +1579,8 @@ pub fn parse_expr( parse_expr(&args[0], registry)?, parse_expr(&args[1], registry)?, )), + ScalarFunction::Isnan => Ok(isnan(parse_expr(&args[0], registry)?)), + ScalarFunction::Iszero => Ok(iszero(parse_expr(&args[0], registry)?)), _ => Err(proto_error( "Protobuf deserialization error: Unsupported scalar function", )), diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index cdb9b008036c..cb3296438165 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1524,6 +1524,8 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::FromUnixtime => Self::FromUnixtime, BuiltinScalarFunction::Atan2 => Self::Atan2, BuiltinScalarFunction::Nanvl => Self::Nanvl, + BuiltinScalarFunction::Isnan => Self::Isnan, + BuiltinScalarFunction::Iszero => Self::Iszero, BuiltinScalarFunction::ArrowTypeof => Self::ArrowTypeof, }; diff --git a/datafusion/sqllogictest/test_files/math.slt b/datafusion/sqllogictest/test_files/math.slt index 9965821e52bd..cd55e018e99c 100644 --- a/datafusion/sqllogictest/test_files/math.slt +++ b/datafusion/sqllogictest/test_files/math.slt @@ -99,3 +99,15 @@ query RRR SELECT nanvl(asin(10), 1.0), nanvl(1.0, 2.0), nanvl(asin(10), asin(10)) ---- 1 1 NaN + +# isnan +query BBBB +SELECT isnan(1.0), isnan('NaN'::DOUBLE), isnan(-'NaN'::DOUBLE), isnan(NULL) +---- +false true true NULL + +# iszero +query BBBB +SELECT iszero(1.0), iszero(0.0), iszero(-0.0), iszero(NULL) +---- +false true true NULL diff --git a/datafusion/sqllogictest/test_files/scalar.slt b/datafusion/sqllogictest/test_files/scalar.slt index 1f6d926aea2d..49782164c026 100644 --- a/datafusion/sqllogictest/test_files/scalar.slt +++ b/datafusion/sqllogictest/test_files/scalar.slt @@ -695,6 +695,52 @@ select round(nanvl(asin(f + a), 2), 5), round(nanvl(asin(b + c), 3), 5), round(n 2 -1.11977 4 NULL NULL NULL +## isnan + +# isnan scalar function +query BBB +select isnan(10.0), isnan('NaN'::DOUBLE), isnan(-'NaN'::DOUBLE) +---- +false true true + +# isnan scalar nulls +query B +select isnan(NULL) +---- +NULL + +# isnan with columns +query BBBB +select isnan(asin(a + b + c)), isnan(-asin(a + b + c)), isnan(asin(d + e + f)), isnan(-asin(d + e + f)) from small_floats; +---- +true true false false +false false true true +true true false false +NULL NULL NULL NULL + +## iszero + +# iszero scalar function +query BBB +select iszero(10.0), iszero(0.0), iszero(-0.0) +---- +false true true + +# iszero scalar nulls +query B +select iszero(NULL) +---- +NULL + +# iszero with columns +query BBBB +select iszero(floor(a + b + c)), iszero(-floor(a + b + c)), iszero(floor(d + e + f)), iszero(-floor(d + e + f)) from small_floats; +---- +false false false false +true true false false +false false true true +NULL NULL NULL NULL + ## pi # pi scalar function diff --git a/docs/source/user-guide/expressions.md b/docs/source/user-guide/expressions.md index 88a5a73a6df3..630b158092d8 100644 --- a/docs/source/user-guide/expressions.md +++ b/docs/source/user-guide/expressions.md @@ -94,6 +94,8 @@ expressions such as `col("a") + col("b")` to be used. | factorial(x) | factorial | | floor(x) | nearest integer less than or equal to argument | | gcd(x, y) | greatest common divisor | +| isnan(x) | predicate determining whether NaN/-NaN or not | +| iszero(x) | predicate determining whether 0.0/-0.0 or not | | lcm(x, y) | least common multiple | | ln(x) | natural logarithm | | log(base, x) | logarithm of x for a particular base | diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index 9bcf2ae0b09b..6dbe5c05f6d1 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -38,6 +38,8 @@ - [factorial](#factorial) - [floor](#floor) - [gcd](#gcd) +- [isnan](#isnan) +- [iszero](#iszero) - [lcm](#lcm) - [ln](#ln) - [log](#log) @@ -283,6 +285,32 @@ gcd(expression_x, expression_y) - **expression_y**: Second numeric expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators. +### `isnan` + +Returns true if a given number is +NaN or -NaN otherwise returns false. + +``` +isnan(numeric_expression) +``` + +#### Arguments + +- **numeric_expression**: Numeric expression to operate on. + Can be a constant, column, or function, and any combination of arithmetic operators. + +### `iszero` + +Returns true if a given number is +0.0 or -0.0 otherwise returns false. + +``` +iszero(numeric_expression) +``` + +#### Arguments + +- **numeric_expression**: Numeric expression to operate on. + Can be a constant, column, or function, and any combination of arithmetic operators. + ### `lcm` Returns the least common multiple of `expression_x` and `expression_y`. Returns 0 if either input is zero.