Skip to content

Commit 08d3b65

Browse files
authored
Simplify type signatures using TypeSignatureClass for mixed type function signature (#13372)
* add type sig class Signed-off-by: jayzhan211 <[email protected]> * timestamp Signed-off-by: jayzhan211 <[email protected]> * date part Signed-off-by: jayzhan211 <[email protected]> * fmt Signed-off-by: jayzhan211 <[email protected]> * taplo format Signed-off-by: jayzhan211 <[email protected]> * tpch test Signed-off-by: jayzhan211 <[email protected]> * msrc issue Signed-off-by: jayzhan211 <[email protected]> * msrc issue Signed-off-by: jayzhan211 <[email protected]> * explicit hash Signed-off-by: jayzhan211 <[email protected]> * Enhance type coercion and function signatures - Added logic to prevent unnecessary casting of string types in `native.rs`. - Introduced `Comparable` variant in `TypeSignature` to define coercion rules for comparisons. - Updated imports in `functions.rs` and `signature.rs` for better organization. - Modified `date_part.rs` to improve handling of timestamp extraction and fixed query tests in `expr.slt`. - Added `datafusion-macros` dependency in `Cargo.toml` and `Cargo.lock`. These changes improve type handling and ensure more accurate function behavior in SQL expressions. * fix comment Signed-off-by: Jay Zhan <[email protected]> * fix signature Signed-off-by: Jay Zhan <[email protected]> * fix test Signed-off-by: Jay Zhan <[email protected]> * Enhance type coercion for timestamps to allow implicit casting from strings. Update SQL logic tests to reflect changes in timestamp handling, including expected outputs for queries involving nanoseconds and seconds. * Refactor type coercion logic for timestamps to improve readability and maintainability. Update the `TypeSignatureClass` documentation to clarify its purpose in function signatures, particularly regarding coercible types. This change enhances the handling of implicit casting from strings to timestamps. * Fix SQL logic tests to correct query error handling for timestamp functions. Updated expected outputs for `date_part` and `extract` functions to reflect proper behavior with nanoseconds and seconds. This change improves the accuracy of test cases in the `expr.slt` file. * Enhance timestamp handling in TypeSignature to support timezone specification. Updated the logic to include an additional DataType for timestamps with a timezone wildcard, improving flexibility in timestamp operations. * Refactor date_part function: remove redundant imports and add missing not_impl_err import for better error handling --------- Signed-off-by: jayzhan211 <[email protected]> Signed-off-by: Jay Zhan <[email protected]>
1 parent 68ead28 commit 08d3b65

File tree

8 files changed

+187
-108
lines changed

8 files changed

+187
-108
lines changed

datafusion-cli/Cargo.lock

+1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion/common/src/types/native.rs

+27
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,8 @@ impl LogicalType for NativeType {
245245
(Self::FixedSizeBinary(size), _) => FixedSizeBinary(*size),
246246
(Self::String, LargeBinary) => LargeUtf8,
247247
(Self::String, BinaryView) => Utf8View,
248+
// We don't cast to another kind of string type if the origin one is already a string type
249+
(Self::String, Utf8 | LargeUtf8 | Utf8View) => origin.to_owned(),
248250
(Self::String, data_type) if can_cast_types(data_type, &Utf8View) => Utf8View,
249251
(Self::String, data_type) if can_cast_types(data_type, &LargeUtf8) => {
250252
LargeUtf8
@@ -433,4 +435,29 @@ impl NativeType {
433435
UInt8 | UInt16 | UInt32 | UInt64 | Int8 | Int16 | Int32 | Int64
434436
)
435437
}
438+
439+
#[inline]
440+
pub fn is_timestamp(&self) -> bool {
441+
matches!(self, NativeType::Timestamp(_, _))
442+
}
443+
444+
#[inline]
445+
pub fn is_date(&self) -> bool {
446+
matches!(self, NativeType::Date)
447+
}
448+
449+
#[inline]
450+
pub fn is_time(&self) -> bool {
451+
matches!(self, NativeType::Time(_))
452+
}
453+
454+
#[inline]
455+
pub fn is_interval(&self) -> bool {
456+
matches!(self, NativeType::Interval(_))
457+
}
458+
459+
#[inline]
460+
pub fn is_duration(&self) -> bool {
461+
matches!(self, NativeType::Duration(_))
462+
}
436463
}

datafusion/expr-common/src/signature.rs

+65-8
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@
1818
//! Signature module contains foundational types that are used to represent signatures, types,
1919
//! and return types of functions in DataFusion.
2020
21+
use std::fmt::Display;
22+
2123
use crate::type_coercion::aggregates::NUMERICS;
22-
use arrow::datatypes::DataType;
24+
use arrow::datatypes::{DataType, IntervalUnit, TimeUnit};
2325
use datafusion_common::types::{LogicalTypeRef, NativeType};
2426
use itertools::Itertools;
2527

@@ -112,7 +114,7 @@ pub enum TypeSignature {
112114
/// For example, `Coercible(vec![logical_float64()])` accepts
113115
/// arguments like `vec![DataType::Int32]` or `vec![DataType::Float32]`
114116
/// since i32 and f32 can be casted to f64
115-
Coercible(Vec<LogicalTypeRef>),
117+
Coercible(Vec<TypeSignatureClass>),
116118
/// The arguments will be coerced to a single type based on the comparison rules.
117119
/// For example, i32 and i64 has coerced type Int64.
118120
///
@@ -154,6 +156,33 @@ impl TypeSignature {
154156
}
155157
}
156158

159+
/// Represents the class of types that can be used in a function signature.
160+
///
161+
/// This is used to specify what types are valid for function arguments in a more flexible way than
162+
/// just listing specific DataTypes. For example, TypeSignatureClass::Timestamp matches any timestamp
163+
/// type regardless of timezone or precision.
164+
///
165+
/// Used primarily with TypeSignature::Coercible to define function signatures that can accept
166+
/// arguments that can be coerced to a particular class of types.
167+
#[derive(Debug, Clone, Eq, PartialEq, PartialOrd, Hash)]
168+
pub enum TypeSignatureClass {
169+
Timestamp,
170+
Date,
171+
Time,
172+
Interval,
173+
Duration,
174+
Native(LogicalTypeRef),
175+
// TODO:
176+
// Numeric
177+
// Integer
178+
}
179+
180+
impl Display for TypeSignatureClass {
181+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
182+
write!(f, "TypeSignatureClass::{self:?}")
183+
}
184+
}
185+
157186
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
158187
pub enum ArrayFunctionSignature {
159188
/// Specialized Signature for ArrayAppend and similar functions
@@ -180,7 +209,7 @@ pub enum ArrayFunctionSignature {
180209
MapArray,
181210
}
182211

183-
impl std::fmt::Display for ArrayFunctionSignature {
212+
impl Display for ArrayFunctionSignature {
184213
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
185214
match self {
186215
ArrayFunctionSignature::ArrayAndElement => {
@@ -255,7 +284,7 @@ impl TypeSignature {
255284
}
256285

257286
/// Helper function to join types with specified delimiter.
258-
pub fn join_types<T: std::fmt::Display>(types: &[T], delimiter: &str) -> String {
287+
pub fn join_types<T: Display>(types: &[T], delimiter: &str) -> String {
259288
types
260289
.iter()
261290
.map(|t| t.to_string())
@@ -290,7 +319,30 @@ impl TypeSignature {
290319
.collect(),
291320
TypeSignature::Coercible(types) => types
292321
.iter()
293-
.map(|logical_type| get_data_types(logical_type.native()))
322+
.map(|logical_type| match logical_type {
323+
TypeSignatureClass::Native(l) => get_data_types(l.native()),
324+
TypeSignatureClass::Timestamp => {
325+
vec![
326+
DataType::Timestamp(TimeUnit::Nanosecond, None),
327+
DataType::Timestamp(
328+
TimeUnit::Nanosecond,
329+
Some(TIMEZONE_WILDCARD.into()),
330+
),
331+
]
332+
}
333+
TypeSignatureClass::Date => {
334+
vec![DataType::Date64]
335+
}
336+
TypeSignatureClass::Time => {
337+
vec![DataType::Time64(TimeUnit::Nanosecond)]
338+
}
339+
TypeSignatureClass::Interval => {
340+
vec![DataType::Interval(IntervalUnit::DayTime)]
341+
}
342+
TypeSignatureClass::Duration => {
343+
vec![DataType::Duration(TimeUnit::Nanosecond)]
344+
}
345+
})
294346
.multi_cartesian_product()
295347
.collect(),
296348
TypeSignature::Variadic(types) => types
@@ -424,7 +476,10 @@ impl Signature {
424476
}
425477
}
426478
/// Target coerce types in order
427-
pub fn coercible(target_types: Vec<LogicalTypeRef>, volatility: Volatility) -> Self {
479+
pub fn coercible(
480+
target_types: Vec<TypeSignatureClass>,
481+
volatility: Volatility,
482+
) -> Self {
428483
Self {
429484
type_signature: TypeSignature::Coercible(target_types),
430485
volatility,
@@ -618,8 +673,10 @@ mod tests {
618673
]
619674
);
620675

621-
let type_signature =
622-
TypeSignature::Coercible(vec![logical_string(), logical_int64()]);
676+
let type_signature = TypeSignature::Coercible(vec![
677+
TypeSignatureClass::Native(logical_string()),
678+
TypeSignatureClass::Native(logical_int64()),
679+
]);
623680
let possible_types = type_signature.get_possible_types();
624681
assert_eq!(
625682
possible_types,

datafusion/expr/src/type_coercion/functions.rs

+58-24
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,18 @@ use arrow::{
2222
datatypes::{DataType, TimeUnit},
2323
};
2424
use datafusion_common::{
25-
exec_err, internal_datafusion_err, internal_err, plan_err,
25+
exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_err,
2626
types::{LogicalType, NativeType},
2727
utils::{coerced_fixed_size_list_to_list, list_ndims},
2828
Result,
2929
};
3030
use datafusion_expr_common::{
31-
signature::{ArrayFunctionSignature, FIXED_SIZE_LIST_WILDCARD, TIMEZONE_WILDCARD},
32-
type_coercion::binary::{comparison_coercion_numeric, string_coercion},
31+
signature::{
32+
ArrayFunctionSignature, TypeSignatureClass, FIXED_SIZE_LIST_WILDCARD,
33+
TIMEZONE_WILDCARD,
34+
},
35+
type_coercion::binary::comparison_coercion_numeric,
36+
type_coercion::binary::string_coercion,
3337
};
3438
use std::sync::Arc;
3539

@@ -568,35 +572,65 @@ fn get_valid_types(
568572
// Make sure the corresponding test is covered
569573
// If this function becomes COMPLEX, create another new signature!
570574
fn can_coerce_to(
571-
logical_type: &NativeType,
572-
target_type: &NativeType,
573-
) -> bool {
574-
if logical_type == target_type {
575-
return true;
576-
}
575+
current_type: &DataType,
576+
target_type_class: &TypeSignatureClass,
577+
) -> Result<DataType> {
578+
let logical_type: NativeType = current_type.into();
577579

578-
if logical_type == &NativeType::Null {
579-
return true;
580-
}
580+
match target_type_class {
581+
TypeSignatureClass::Native(native_type) => {
582+
let target_type = native_type.native();
583+
if &logical_type == target_type {
584+
return target_type.default_cast_for(current_type);
585+
}
581586

582-
if target_type.is_integer() && logical_type.is_integer() {
583-
return true;
584-
}
587+
if logical_type == NativeType::Null {
588+
return target_type.default_cast_for(current_type);
589+
}
590+
591+
if target_type.is_integer() && logical_type.is_integer() {
592+
return target_type.default_cast_for(current_type);
593+
}
585594

586-
false
595+
internal_err!(
596+
"Expect {} but received {}",
597+
target_type_class,
598+
current_type
599+
)
600+
}
601+
// Not consistent with Postgres and DuckDB but to avoid regression we implicit cast string to timestamp
602+
TypeSignatureClass::Timestamp
603+
if logical_type == NativeType::String =>
604+
{
605+
Ok(DataType::Timestamp(TimeUnit::Nanosecond, None))
606+
}
607+
TypeSignatureClass::Timestamp if logical_type.is_timestamp() => {
608+
Ok(current_type.to_owned())
609+
}
610+
TypeSignatureClass::Date if logical_type.is_date() => {
611+
Ok(current_type.to_owned())
612+
}
613+
TypeSignatureClass::Time if logical_type.is_time() => {
614+
Ok(current_type.to_owned())
615+
}
616+
TypeSignatureClass::Interval if logical_type.is_interval() => {
617+
Ok(current_type.to_owned())
618+
}
619+
TypeSignatureClass::Duration if logical_type.is_duration() => {
620+
Ok(current_type.to_owned())
621+
}
622+
_ => {
623+
not_impl_err!("Got logical_type: {logical_type} with target_type_class: {target_type_class}")
624+
}
625+
}
587626
}
588627

589628
let mut new_types = Vec::with_capacity(current_types.len());
590-
for (current_type, target_type) in
629+
for (current_type, target_type_class) in
591630
current_types.iter().zip(target_types.iter())
592631
{
593-
let logical_type: NativeType = current_type.into();
594-
let target_logical_type = target_type.native();
595-
if can_coerce_to(&logical_type, target_logical_type) {
596-
let target_type =
597-
target_logical_type.default_cast_for(current_type)?;
598-
new_types.push(target_type);
599-
}
632+
let target_type = can_coerce_to(current_type, target_type_class)?;
633+
new_types.push(target_type);
600634
}
601635

602636
vec![new_types]

datafusion/functions/Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ datafusion-common = { workspace = true }
7575
datafusion-doc = { workspace = true }
7676
datafusion-execution = { workspace = true }
7777
datafusion-expr = { workspace = true }
78+
datafusion-expr-common = { workspace = true }
7879
datafusion-macros = { workspace = true }
7980
hashbrown = { workspace = true, optional = true }
8081
hex = { version = "0.4", optional = true }

0 commit comments

Comments
 (0)