@@ -20,14 +20,14 @@ use arrow::array::{Array, ArrayRef, AsArray};
20
20
use arrow:: compute:: contains as arrow_contains;
21
21
use arrow:: datatypes:: DataType ;
22
22
use arrow:: datatypes:: DataType :: { Boolean , LargeUtf8 , Utf8 , Utf8View } ;
23
- use datafusion_common:: exec_err;
24
- use datafusion_common:: types:: logical_string;
23
+ use datafusion_common:: types:: { LogicalType , NativeType } ;
25
24
use datafusion_common:: DataFusionError ;
26
25
use datafusion_common:: Result ;
26
+ use datafusion_common:: { exec_err, plan_err} ;
27
27
use datafusion_expr:: {
28
- ColumnarValue , Documentation , ScalarUDFImpl , Signature , TypeSignatureClass ,
29
- Volatility ,
28
+ ColumnarValue , Documentation , ScalarUDFImpl , Signature , Volatility ,
30
29
} ;
30
+ use datafusion_expr_common:: type_coercion:: binary:: string_coercion;
31
31
use datafusion_macros:: user_doc;
32
32
use std:: any:: Any ;
33
33
use std:: sync:: Arc ;
@@ -64,13 +64,7 @@ impl Default for ContainsFunc {
64
64
impl ContainsFunc {
65
65
pub fn new ( ) -> Self {
66
66
Self {
67
- signature : Signature :: coercible (
68
- vec ! [
69
- TypeSignatureClass :: AnyNative ( logical_string( ) ) ,
70
- TypeSignatureClass :: AnyNative ( logical_string( ) ) ,
71
- ] ,
72
- Volatility :: Immutable ,
73
- ) ,
67
+ signature : Signature :: user_defined ( Volatility :: Immutable ) ,
74
68
}
75
69
}
76
70
}
@@ -100,6 +94,57 @@ impl ScalarUDFImpl for ContainsFunc {
100
94
make_scalar_function ( contains, vec ! [ ] ) ( args)
101
95
}
102
96
97
+ fn coerce_types ( & self , arg_types : & [ DataType ] ) -> Result < Vec < DataType > > {
98
+ if arg_types. len ( ) != 2 {
99
+ return plan_err ! (
100
+ "The {} function requires 2 arguments, but got {}." ,
101
+ self . name( ) ,
102
+ arg_types. len( )
103
+ ) ;
104
+ }
105
+
106
+ let first_arg_type = & arg_types[ 0 ] ;
107
+ let first_native_type: NativeType = first_arg_type. into ( ) ;
108
+ let second_arg_type = & arg_types[ 1 ] ;
109
+ let second_native_type: NativeType = second_arg_type. into ( ) ;
110
+ let target_native_type = NativeType :: String ;
111
+
112
+ let first_data_type = if first_native_type. is_integer ( )
113
+ || first_native_type. is_binary ( )
114
+ || first_native_type == NativeType :: String
115
+ || first_native_type == NativeType :: Null
116
+ {
117
+ target_native_type. default_cast_for ( first_arg_type)
118
+ } else {
119
+ plan_err ! (
120
+ "The first argument of the {} function can only be a string, integer, or binary but got {:?}." ,
121
+ self . name( ) ,
122
+ first_arg_type
123
+ )
124
+ } ?;
125
+ let second_data_type = if second_native_type. is_integer ( )
126
+ || second_native_type. is_binary ( )
127
+ || second_native_type == NativeType :: String
128
+ || second_native_type == NativeType :: Null
129
+ {
130
+ target_native_type. default_cast_for ( second_arg_type)
131
+ } else {
132
+ plan_err ! (
133
+ "The second argument of the {} function can only be a string, integer, or binary but got {:?}." ,
134
+ self . name( ) ,
135
+ second_arg_type
136
+ )
137
+ } ?;
138
+
139
+ if let Some ( coerced_type) = string_coercion ( & first_data_type, & second_data_type) {
140
+ Ok ( vec ! [ coerced_type. clone( ) , coerced_type] )
141
+ } else {
142
+ plan_err ! (
143
+ "{first_data_type} and {second_data_type} are not coercible to a common string type"
144
+ )
145
+ }
146
+ }
147
+
103
148
fn documentation ( & self ) -> Option < & Documentation > {
104
149
self . doc ( )
105
150
}
0 commit comments