Skip to content

Commit 2f150f6

Browse files
authored
Introduce TypeSignature::Comparable and update NullIf signature (#13356)
* simplify signature for nullif Signed-off-by: Jay Zhan <[email protected]> * add possible types Signed-off-by: Jay Zhan <[email protected]> * typo Signed-off-by: jayzhan211 <[email protected]> * numeric string Signed-off-by: jayzhan211 <[email protected]> * add doc for signature Signed-off-by: Jay Zhan <[email protected]> * add doc for signature Signed-off-by: Jay Zhan <[email protected]> --------- Signed-off-by: Jay Zhan <[email protected]> Signed-off-by: jayzhan211 <[email protected]>
1 parent a53b974 commit 2f150f6

File tree

5 files changed

+153
-44
lines changed

5 files changed

+153
-44
lines changed

datafusion/expr-common/src/signature.rs

+33-6
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
//! Signature module contains foundational types that are used to represent signatures, types,
1919
//! and return types of functions in DataFusion.
2020
21-
use crate::type_coercion::aggregates::{NUMERICS, STRINGS};
21+
use crate::type_coercion::aggregates::NUMERICS;
2222
use arrow::datatypes::DataType;
2323
use datafusion_common::types::{LogicalTypeRef, NativeType};
2424
use itertools::Itertools;
@@ -113,6 +113,15 @@ pub enum TypeSignature {
113113
/// arguments like `vec![DataType::Int32]` or `vec![DataType::Float32]`
114114
/// since i32 and f32 can be casted to f64
115115
Coercible(Vec<LogicalTypeRef>),
116+
/// The arguments will be coerced to a single type based on the comparison rules.
117+
/// For example, i32 and i64 has coerced type Int64.
118+
///
119+
/// Note:
120+
/// - If compares with numeric and string, numeric is preferred for numeric string cases. For example, nullif('2', 1) has coerced types Int64.
121+
/// - If the result is Null, it will be coerced to String (Utf8View).
122+
///
123+
/// See `comparison_coercion_numeric` for more details.
124+
Comparable(usize),
116125
/// Fixed number of arguments of arbitrary types, number should be larger than 0
117126
Any(usize),
118127
/// Matches exactly one of a list of [`TypeSignature`]s. Coercion is attempted to match
@@ -138,6 +147,13 @@ pub enum TypeSignature {
138147
NullAry,
139148
}
140149

150+
impl TypeSignature {
151+
#[inline]
152+
pub fn is_one_of(&self) -> bool {
153+
matches!(self, TypeSignature::OneOf(_))
154+
}
155+
}
156+
141157
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
142158
pub enum ArrayFunctionSignature {
143159
/// Specialized Signature for ArrayAppend and similar functions
@@ -210,6 +226,9 @@ impl TypeSignature {
210226
TypeSignature::Numeric(num) => {
211227
vec![format!("Numeric({num})")]
212228
}
229+
TypeSignature::Comparable(num) => {
230+
vec![format!("Comparable({num})")]
231+
}
213232
TypeSignature::Coercible(types) => {
214233
vec![Self::join_types(types, ", ")]
215234
}
@@ -284,13 +303,13 @@ impl TypeSignature {
284303
.cloned()
285304
.map(|numeric_type| vec![numeric_type; *arg_count])
286305
.collect(),
287-
TypeSignature::String(arg_count) => STRINGS
288-
.iter()
289-
.cloned()
290-
.map(|string_type| vec![string_type; *arg_count])
291-
.collect(),
306+
TypeSignature::String(arg_count) => get_data_types(&NativeType::String)
307+
.into_iter()
308+
.map(|dt| vec![dt; *arg_count])
309+
.collect::<Vec<_>>(),
292310
// TODO: Implement for other types
293311
TypeSignature::Any(_)
312+
| TypeSignature::Comparable(_)
294313
| TypeSignature::NullAry
295314
| TypeSignature::VariadicAny
296315
| TypeSignature::ArraySignature(_)
@@ -412,6 +431,14 @@ impl Signature {
412431
}
413432
}
414433

434+
/// Used for function that expects comparable data types, it will try to coerced all the types into single final one.
435+
pub fn comparable(arg_count: usize, volatility: Volatility) -> Self {
436+
Self {
437+
type_signature: TypeSignature::Comparable(arg_count),
438+
volatility,
439+
}
440+
}
441+
415442
pub fn nullary(volatility: Volatility) -> Self {
416443
Signature {
417444
type_signature: TypeSignature::NullAry,

datafusion/expr-common/src/type_coercion/binary.rs

+34
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ use arrow::datatypes::{
2828
DataType, Field, FieldRef, Fields, TimeUnit, DECIMAL128_MAX_PRECISION,
2929
DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE,
3030
};
31+
use datafusion_common::types::NativeType;
3132
use datafusion_common::{
3233
exec_datafusion_err, exec_err, internal_err, plan_datafusion_err, plan_err, Result,
3334
};
@@ -641,6 +642,21 @@ pub fn comparison_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<D
641642
.or_else(|| struct_coercion(lhs_type, rhs_type))
642643
}
643644

645+
// Similar to comparison_coercion but prefer numeric if compares with numeric and string
646+
pub fn comparison_coercion_numeric(
647+
lhs_type: &DataType,
648+
rhs_type: &DataType,
649+
) -> Option<DataType> {
650+
if lhs_type == rhs_type {
651+
// same type => equality is possible
652+
return Some(lhs_type.clone());
653+
}
654+
binary_numeric_coercion(lhs_type, rhs_type)
655+
.or_else(|| string_coercion(lhs_type, rhs_type))
656+
.or_else(|| null_coercion(lhs_type, rhs_type))
657+
.or_else(|| string_numeric_coercion_as_numeric(lhs_type, rhs_type))
658+
}
659+
644660
/// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a comparison operation
645661
/// where one is numeric and one is `Utf8`/`LargeUtf8`.
646662
fn string_numeric_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
@@ -654,6 +670,24 @@ fn string_numeric_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<D
654670
}
655671
}
656672

673+
/// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a comparison operation
674+
/// where one is numeric and one is `Utf8`/`LargeUtf8`.
675+
fn string_numeric_coercion_as_numeric(
676+
lhs_type: &DataType,
677+
rhs_type: &DataType,
678+
) -> Option<DataType> {
679+
let lhs_logical_type = NativeType::from(lhs_type);
680+
let rhs_logical_type = NativeType::from(rhs_type);
681+
if lhs_logical_type.is_numeric() && rhs_logical_type == NativeType::String {
682+
return Some(lhs_type.to_owned());
683+
}
684+
if rhs_logical_type.is_numeric() && lhs_logical_type == NativeType::String {
685+
return Some(rhs_type.to_owned());
686+
}
687+
688+
None
689+
}
690+
657691
/// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a comparison operation
658692
/// where one is temporal and one is `Utf8View`/`Utf8`/`LargeUtf8`.
659693
///

datafusion/expr/src/type_coercion/functions.rs

+26-3
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ use datafusion_common::{
2929
};
3030
use datafusion_expr_common::{
3131
signature::{ArrayFunctionSignature, FIXED_SIZE_LIST_WILDCARD, TIMEZONE_WILDCARD},
32-
type_coercion::binary::string_coercion,
32+
type_coercion::binary::{comparison_coercion_numeric, string_coercion},
3333
};
3434
use std::sync::Arc;
3535

@@ -182,6 +182,7 @@ fn is_well_supported_signature(type_signature: &TypeSignature) -> bool {
182182
| TypeSignature::Coercible(_)
183183
| TypeSignature::Any(_)
184184
| TypeSignature::NullAry
185+
| TypeSignature::Comparable(_)
185186
)
186187
}
187188

@@ -194,13 +195,18 @@ fn try_coerce_types(
194195

195196
// Well-supported signature that returns exact valid types.
196197
if !valid_types.is_empty() && is_well_supported_signature(type_signature) {
197-
// exact valid types
198-
assert_eq!(valid_types.len(), 1);
198+
// There may be many valid types if valid signature is OneOf
199+
// Otherwise, there should be only one valid type
200+
if !type_signature.is_one_of() {
201+
assert_eq!(valid_types.len(), 1);
202+
}
203+
199204
let valid_types = valid_types.swap_remove(0);
200205
if let Some(t) = maybe_data_types_without_coercion(&valid_types, current_types) {
201206
return Ok(t);
202207
}
203208
} else {
209+
// TODO: Deprecate this branch after all signatures are well-supported (aka coercion has happened already)
204210
// Try and coerce the argument types to match the signature, returning the
205211
// coerced types from the first matching signature.
206212
for valid_types in valid_types {
@@ -515,6 +521,23 @@ fn get_valid_types(
515521

516522
vec![vec![valid_type; *number]]
517523
}
524+
TypeSignature::Comparable(num) => {
525+
function_length_check(current_types.len(), *num)?;
526+
let mut target_type = current_types[0].to_owned();
527+
for data_type in current_types.iter().skip(1) {
528+
if let Some(dt) = comparison_coercion_numeric(&target_type, data_type) {
529+
target_type = dt;
530+
} else {
531+
return plan_err!("{target_type} and {data_type} is not comparable");
532+
}
533+
}
534+
// Convert null to String type.
535+
if target_type.is_null() {
536+
vec![vec![DataType::Utf8View; *num]]
537+
} else {
538+
vec![vec![target_type; *num]]
539+
}
540+
}
518541
TypeSignature::Coercible(target_types) => {
519542
function_length_check(current_types.len(), target_types.len())?;
520543

datafusion/functions/src/core/nullif.rs

+15-33
Original file line numberDiff line numberDiff line change
@@ -32,26 +32,6 @@ pub struct NullIfFunc {
3232
signature: Signature,
3333
}
3434

35-
/// Currently supported types by the nullif function.
36-
/// The order of these types correspond to the order on which coercion applies
37-
/// This should thus be from least informative to most informative
38-
static SUPPORTED_NULLIF_TYPES: &[DataType] = &[
39-
DataType::Boolean,
40-
DataType::UInt8,
41-
DataType::UInt16,
42-
DataType::UInt32,
43-
DataType::UInt64,
44-
DataType::Int8,
45-
DataType::Int16,
46-
DataType::Int32,
47-
DataType::Int64,
48-
DataType::Float32,
49-
DataType::Float64,
50-
DataType::Utf8View,
51-
DataType::Utf8,
52-
DataType::LargeUtf8,
53-
];
54-
5535
impl Default for NullIfFunc {
5636
fn default() -> Self {
5737
Self::new()
@@ -61,11 +41,20 @@ impl Default for NullIfFunc {
6141
impl NullIfFunc {
6242
pub fn new() -> Self {
6343
Self {
64-
signature: Signature::uniform(
65-
2,
66-
SUPPORTED_NULLIF_TYPES.to_vec(),
67-
Volatility::Immutable,
68-
),
44+
// Documentation mentioned in Postgres,
45+
// The result has the same type as the first argument — but there is a subtlety.
46+
// What is actually returned is the first argument of the implied = operator,
47+
// and in some cases that will have been promoted to match the second argument's type.
48+
// For example, NULLIF(1, 2.2) yields numeric, because there is no integer = numeric operator, only numeric = numeric
49+
//
50+
// We don't strictly follow Postgres or DuckDB for **simplicity**.
51+
// In this function, we will coerce arguments to the same data type for comparison need. Unlike DuckDB
52+
// we don't return the **original** first argument type but return the final coerced type.
53+
//
54+
// In Postgres, nullif('2', 2) returns Null but nullif('2::varchar', 2) returns error.
55+
// While in DuckDB both query returns Null. We follow DuckDB in this case since I think they are equivalent thing and should
56+
// have the same result as well.
57+
signature: Signature::comparable(2, Volatility::Immutable),
6958
}
7059
}
7160
}
@@ -83,14 +72,7 @@ impl ScalarUDFImpl for NullIfFunc {
8372
}
8473

8574
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
86-
// NULLIF has two args and they might get coerced, get a preview of this
87-
let coerced_types = datafusion_expr::type_coercion::functions::data_types(
88-
arg_types,
89-
&self.signature,
90-
);
91-
coerced_types
92-
.map(|typs| typs[0].clone())
93-
.map_err(|e| e.context("Failed to coerce arguments for NULLIF"))
75+
Ok(arg_types[0].to_owned())
9476
}
9577

9678
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {

datafusion/sqllogictest/test_files/nullif.slt

+45-2
Original file line numberDiff line numberDiff line change
@@ -97,11 +97,54 @@ SELECT NULLIF(1, 3);
9797
----
9898
1
9999

100-
query I
100+
query T
101101
SELECT NULLIF(NULL, NULL);
102102
----
103103
NULL
104104

105+
query R
106+
select nullif(1, 1.2);
107+
----
108+
1
109+
110+
query R
111+
select nullif(1.0, 2);
112+
----
113+
1
114+
115+
query error DataFusion error: Arrow error: Cast error: Cannot cast string 'a' to value of Int64 type
116+
select nullif(2, 'a');
117+
118+
query T
119+
select nullif('2', '3');
120+
----
121+
2
122+
123+
query I
124+
select nullif(2, '1');
125+
----
126+
2
127+
128+
query I
129+
select nullif('2', 2);
130+
----
131+
NULL
132+
133+
query I
134+
select nullif('1', 2);
135+
----
136+
1
137+
138+
statement ok
139+
create table t(a varchar, b int) as values ('1', 2), ('2', 2), ('3', 2);
140+
141+
query I
142+
select nullif(a, b) from t;
143+
----
144+
1
145+
NULL
146+
3
147+
105148
query T
106149
SELECT NULLIF(arrow_cast('a', 'Utf8View'), 'a');
107150
----
@@ -130,4 +173,4 @@ NULL
130173
query T
131174
SELECT NULLIF(arrow_cast('a', 'Utf8View'), null);
132175
----
133-
a
176+
a

0 commit comments

Comments
 (0)