Skip to content

Commit 1d94e5c

Browse files
committed
use user defined rules instead
1 parent 25bcb58 commit 1d94e5c

File tree

9 files changed

+268
-71
lines changed

9 files changed

+268
-71
lines changed

datafusion/common/src/types/native.rs

+6
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,12 @@ impl NativeType {
436436
)
437437
}
438438

439+
#[inline]
440+
pub fn is_binary(&self) -> bool {
441+
use NativeType::*;
442+
matches!(self, Binary | FixedSizeBinary(_))
443+
}
444+
439445
#[inline]
440446
pub fn is_timestamp(&self) -> bool {
441447
matches!(self, NativeType::Timestamp(_, _))

datafusion/expr-common/src/signature.rs

-9
Original file line numberDiff line numberDiff line change
@@ -126,13 +126,6 @@ pub enum TypeSignature {
126126
Exact(Vec<DataType>),
127127
/// One or more arguments belonging to the [`TypeSignatureClass`], in order.
128128
///
129-
/// `Coercible(vec![TypeSignatureClass::AnyNative(...)])` accepts any type castable to the
130-
/// target `NativeType` through the explicit set of type conversion rules defined in
131-
/// `NativeType::default_cast_for`.
132-
///
133-
/// For example, `Coercible(vec![TypeSignatureClass::AnyNative(logical_float64())])` accepts
134-
/// arguments like `vec![Int32]` or `vec![Float32]` since i32 and f32 can be cast to f64.
135-
///
136129
/// `Coercible(vec![TypeSignatureClass::Native(...)])` is designed to cast between the same
137130
/// logical type.
138131
///
@@ -228,7 +221,6 @@ pub enum TypeSignatureClass {
228221
Interval,
229222
Duration,
230223
Native(LogicalTypeRef),
231-
AnyNative(LogicalTypeRef),
232224
Numeric(LogicalTypeRef),
233225
Integer(LogicalTypeRef),
234226
}
@@ -392,7 +384,6 @@ impl TypeSignature {
392384
.iter()
393385
.map(|logical_type| match logical_type {
394386
TypeSignatureClass::Native(l)
395-
| TypeSignatureClass::AnyNative(l)
396387
| TypeSignatureClass::Numeric(l)
397388
| TypeSignatureClass::Integer(l) => get_data_types(l.native()),
398389
TypeSignatureClass::Timestamp => {

datafusion/expr/src/type_coercion/functions.rs

+1-5
Original file line numberDiff line numberDiff line change
@@ -552,10 +552,6 @@ fn get_valid_types(
552552
)
553553
}
554554
}
555-
TypeSignatureClass::AnyNative(native_type) => {
556-
let target_type = native_type.native();
557-
target_type.default_cast_for(current_type)
558-
}
559555
TypeSignatureClass::Numeric(native_type) => {
560556
let target_type = native_type.native();
561557
if target_type.is_numeric() && logical_type.is_numeric() {
@@ -614,7 +610,7 @@ fn get_valid_types(
614610
// Following the behavior of `TypeSignature::String`, we find the common string type.
615611
let string_indices: Vec<_> = target_types.iter().enumerate()
616612
.filter(|(_, t)| {
617-
matches!(t, TypeSignatureClass::Native(n) | TypeSignatureClass::AnyNative(n) if n.native() == &NativeType::String)
613+
matches!(t, TypeSignatureClass::Native(n) if n.native() == &NativeType::String)
618614
})
619615
.map(|(i, _)| i)
620616
.collect();

datafusion/functions/src/string/ascii.rs

+31-8
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,10 @@ use crate::utils::make_scalar_function;
1919
use arrow::array::{ArrayAccessor, ArrayIter, ArrayRef, AsArray, Int32Array};
2020
use arrow::datatypes::DataType;
2121
use arrow::error::ArrowError;
22-
use datafusion_common::types::logical_string;
23-
use datafusion_common::{internal_err, Result};
22+
use datafusion_common::types::{LogicalType, NativeType};
23+
use datafusion_common::{internal_err, plan_err, Result};
2424
use datafusion_expr::{
25-
ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignatureClass,
26-
Volatility,
25+
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
2726
};
2827
use datafusion_macros::user_doc;
2928
use std::any::Any;
@@ -64,10 +63,7 @@ impl Default for AsciiFunc {
6463
impl AsciiFunc {
6564
pub fn new() -> Self {
6665
Self {
67-
signature: Signature::coercible(
68-
vec![TypeSignatureClass::AnyNative(logical_string())],
69-
Volatility::Immutable,
70-
),
66+
signature: Signature::user_defined(Volatility::Immutable),
7167
}
7268
}
7369
}
@@ -99,6 +95,33 @@ impl ScalarUDFImpl for AsciiFunc {
9995
make_scalar_function(ascii, vec![])(args)
10096
}
10197

98+
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
99+
if arg_types.len() != 1 {
100+
return plan_err!(
101+
"The {} function requires 1 argument, but got {}.",
102+
self.name(),
103+
arg_types.len()
104+
);
105+
}
106+
107+
let arg_type = &arg_types[0];
108+
let current_native_type: NativeType = arg_type.into();
109+
let target_native_type = NativeType::String;
110+
if current_native_type.is_integer()
111+
|| current_native_type.is_binary()
112+
|| current_native_type == NativeType::String
113+
|| current_native_type == NativeType::Null
114+
{
115+
Ok(vec![target_native_type.default_cast_for(arg_type)?])
116+
} else {
117+
plan_err!(
118+
"The first argument of the {} function can only be a string, integer, or binary but got {:?}.",
119+
self.name(),
120+
arg_type
121+
)
122+
}
123+
}
124+
102125
fn documentation(&self) -> Option<&Documentation> {
103126
self.doc()
104127
}

datafusion/functions/src/string/bit_length.rs

+31-8
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,10 @@ use arrow::datatypes::DataType;
2020
use std::any::Any;
2121

2222
use crate::utils::utf8_to_int_type;
23-
use datafusion_common::types::logical_string;
24-
use datafusion_common::{exec_err, Result, ScalarValue};
23+
use datafusion_common::types::{LogicalType, NativeType};
24+
use datafusion_common::{exec_err, plan_err, Result, ScalarValue};
2525
use datafusion_expr::{
26-
ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignatureClass,
27-
Volatility,
26+
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
2827
};
2928
use datafusion_macros::user_doc;
3029

@@ -58,10 +57,7 @@ impl Default for BitLengthFunc {
5857
impl BitLengthFunc {
5958
pub fn new() -> Self {
6059
Self {
61-
signature: Signature::coercible(
62-
vec![TypeSignatureClass::AnyNative(logical_string())],
63-
Volatility::Immutable,
64-
),
60+
signature: Signature::user_defined(Volatility::Immutable),
6561
}
6662
}
6763
}
@@ -112,6 +108,33 @@ impl ScalarUDFImpl for BitLengthFunc {
112108
}
113109
}
114110

111+
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
112+
if arg_types.len() != 1 {
113+
return plan_err!(
114+
"The {} function requires 1 argument, but got {}.",
115+
self.name(),
116+
arg_types.len()
117+
);
118+
}
119+
120+
let arg_type = &arg_types[0];
121+
let current_native_type: NativeType = arg_type.into();
122+
let target_native_type = NativeType::String;
123+
if current_native_type.is_integer()
124+
|| current_native_type.is_binary()
125+
|| current_native_type == NativeType::String
126+
|| current_native_type == NativeType::Null
127+
{
128+
Ok(vec![target_native_type.default_cast_for(arg_type)?])
129+
} else {
130+
plan_err!(
131+
"The first argument of the {} function can only be a string, integer, or binary but got {:?}.",
132+
self.name(),
133+
arg_type
134+
)
135+
}
136+
}
137+
115138
fn documentation(&self) -> Option<&Documentation> {
116139
self.doc()
117140
}

datafusion/functions/src/string/contains.rs

+56-11
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,14 @@ use arrow::array::{Array, ArrayRef, AsArray};
2020
use arrow::compute::contains as arrow_contains;
2121
use arrow::datatypes::DataType;
2222
use arrow::datatypes::DataType::{Boolean, LargeUtf8, Utf8, Utf8View};
23-
use datafusion_common::exec_err;
24-
use datafusion_common::types::logical_string;
23+
use datafusion_common::types::{LogicalType, NativeType};
2524
use datafusion_common::DataFusionError;
2625
use datafusion_common::Result;
26+
use datafusion_common::{exec_err, plan_err};
2727
use datafusion_expr::{
28-
ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignatureClass,
29-
Volatility,
28+
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
3029
};
30+
use datafusion_expr_common::type_coercion::binary::string_coercion;
3131
use datafusion_macros::user_doc;
3232
use std::any::Any;
3333
use std::sync::Arc;
@@ -64,13 +64,7 @@ impl Default for ContainsFunc {
6464
impl ContainsFunc {
6565
pub fn new() -> Self {
6666
Self {
67-
signature: Signature::coercible(
68-
vec![
69-
TypeSignatureClass::AnyNative(logical_string()),
70-
TypeSignatureClass::AnyNative(logical_string()),
71-
],
72-
Volatility::Immutable,
73-
),
67+
signature: Signature::user_defined(Volatility::Immutable),
7468
}
7569
}
7670
}
@@ -100,6 +94,57 @@ impl ScalarUDFImpl for ContainsFunc {
10094
make_scalar_function(contains, vec![])(args)
10195
}
10296

97+
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
98+
if arg_types.len() != 2 {
99+
return plan_err!(
100+
"The {} function requires 2 arguments, but got {}.",
101+
self.name(),
102+
arg_types.len()
103+
);
104+
}
105+
106+
let first_arg_type = &arg_types[0];
107+
let first_native_type: NativeType = first_arg_type.into();
108+
let second_arg_type = &arg_types[1];
109+
let second_native_type: NativeType = second_arg_type.into();
110+
let target_native_type = NativeType::String;
111+
112+
let first_data_type = if first_native_type.is_integer()
113+
|| first_native_type.is_binary()
114+
|| first_native_type == NativeType::String
115+
|| first_native_type == NativeType::Null
116+
{
117+
target_native_type.default_cast_for(first_arg_type)
118+
} else {
119+
plan_err!(
120+
"The first argument of the {} function can only be a string, integer, or binary but got {:?}.",
121+
self.name(),
122+
first_arg_type
123+
)
124+
}?;
125+
let second_data_type = if second_native_type.is_integer()
126+
|| second_native_type.is_binary()
127+
|| second_native_type == NativeType::String
128+
|| second_native_type == NativeType::Null
129+
{
130+
target_native_type.default_cast_for(second_arg_type)
131+
} else {
132+
plan_err!(
133+
"The second argument of the {} function can only be a string, integer, or binary but got {:?}.",
134+
self.name(),
135+
second_arg_type
136+
)
137+
}?;
138+
139+
if let Some(coerced_type) = string_coercion(&first_data_type, &second_data_type) {
140+
Ok(vec![coerced_type.clone(), coerced_type])
141+
} else {
142+
plan_err!(
143+
"{first_data_type} and {second_data_type} are not coercible to a common string type"
144+
)
145+
}
146+
}
147+
103148
fn documentation(&self) -> Option<&Documentation> {
104149
self.doc()
105150
}

datafusion/functions/src/string/ends_with.rs

+56-11
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,12 @@ use arrow::array::ArrayRef;
2222
use arrow::datatypes::DataType;
2323

2424
use crate::utils::make_scalar_function;
25-
use datafusion_common::types::logical_string;
26-
use datafusion_common::{internal_err, Result};
25+
use datafusion_common::types::{LogicalType, NativeType};
26+
use datafusion_common::{internal_err, plan_err, Result};
2727
use datafusion_expr::{
28-
ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignatureClass,
29-
Volatility,
28+
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
3029
};
30+
use datafusion_expr_common::type_coercion::binary::string_coercion;
3131
use datafusion_macros::user_doc;
3232

3333
#[user_doc(
@@ -65,13 +65,7 @@ impl Default for EndsWithFunc {
6565
impl EndsWithFunc {
6666
pub fn new() -> Self {
6767
Self {
68-
signature: Signature::coercible(
69-
vec![
70-
TypeSignatureClass::AnyNative(logical_string()),
71-
TypeSignatureClass::AnyNative(logical_string()),
72-
],
73-
Volatility::Immutable,
74-
),
68+
signature: Signature::user_defined(Volatility::Immutable),
7569
}
7670
}
7771
}
@@ -108,6 +102,57 @@ impl ScalarUDFImpl for EndsWithFunc {
108102
}
109103
}
110104

105+
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
106+
if arg_types.len() != 2 {
107+
return plan_err!(
108+
"The {} function requires 2 arguments, but got {}.",
109+
self.name(),
110+
arg_types.len()
111+
);
112+
}
113+
114+
let first_arg_type = &arg_types[0];
115+
let first_native_type: NativeType = first_arg_type.into();
116+
let second_arg_type = &arg_types[1];
117+
let second_native_type: NativeType = second_arg_type.into();
118+
let target_native_type = NativeType::String;
119+
120+
let first_data_type = if first_native_type.is_integer()
121+
|| first_native_type.is_binary()
122+
|| first_native_type == NativeType::String
123+
|| first_native_type == NativeType::Null
124+
{
125+
target_native_type.default_cast_for(first_arg_type)
126+
} else {
127+
plan_err!(
128+
"The first argument of the {} function can only be a string, integer, or binary but got {:?}.",
129+
self.name(),
130+
first_arg_type
131+
)
132+
}?;
133+
let second_data_type = if second_native_type.is_integer()
134+
|| second_native_type.is_binary()
135+
|| second_native_type == NativeType::String
136+
|| second_native_type == NativeType::Null
137+
{
138+
target_native_type.default_cast_for(second_arg_type)
139+
} else {
140+
plan_err!(
141+
"The second argument of the {} function can only be a string, integer, or binary but got {:?}.",
142+
self.name(),
143+
second_arg_type
144+
)
145+
}?;
146+
147+
if let Some(coerced_type) = string_coercion(&first_data_type, &second_data_type) {
148+
Ok(vec![coerced_type.clone(), coerced_type])
149+
} else {
150+
plan_err!(
151+
"{first_data_type} and {second_data_type} are not coercible to a common string type"
152+
)
153+
}
154+
}
155+
111156
fn documentation(&self) -> Option<&Documentation> {
112157
self.doc()
113158
}

0 commit comments

Comments
 (0)