15
15
// specific language governing permissions and limitations
16
16
// under the License.
17
17
18
+ use crate :: regexp_common:: regexp_is_match_utf8;
18
19
use crate :: utils:: make_scalar_function;
19
- use arrow:: array:: { ArrayRef , OffsetSizeTrait } ;
20
+
21
+ use arrow:: array:: { Array , ArrayRef , AsArray , GenericStringArray , StringViewArray } ;
20
22
use arrow:: datatypes:: DataType ;
21
- use arrow:: datatypes:: DataType :: Boolean ;
22
- use datafusion_common:: cast :: as_generic_string_array ;
23
+ use arrow:: datatypes:: DataType :: { Boolean , LargeUtf8 , Utf8 , Utf8View } ;
24
+ use datafusion_common:: exec_err ;
23
25
use datafusion_common:: DataFusionError ;
24
26
use datafusion_common:: Result ;
25
- use datafusion_common:: { arrow_datafusion_err, exec_err} ;
26
27
use datafusion_expr:: ScalarUDFImpl ;
27
28
use datafusion_expr:: TypeSignature :: Exact ;
28
29
use datafusion_expr:: { ColumnarValue , Signature , Volatility } ;
30
+
29
31
use std:: any:: Any ;
30
32
use std:: sync:: Arc ;
33
+
31
34
#[ derive( Debug ) ]
32
35
pub struct ContainsFunc {
33
36
signature : Signature ,
@@ -44,7 +47,17 @@ impl ContainsFunc {
44
47
use DataType :: * ;
45
48
Self {
46
49
signature : Signature :: one_of (
47
- vec ! [ Exact ( vec![ Utf8 , Utf8 ] ) , Exact ( vec![ LargeUtf8 , LargeUtf8 ] ) ] ,
50
+ vec ! [
51
+ Exact ( vec![ Utf8View , Utf8View ] ) ,
52
+ Exact ( vec![ Utf8View , Utf8 ] ) ,
53
+ Exact ( vec![ Utf8View , LargeUtf8 ] ) ,
54
+ Exact ( vec![ Utf8 , Utf8View ] ) ,
55
+ Exact ( vec![ Utf8 , Utf8 ] ) ,
56
+ Exact ( vec![ Utf8 , LargeUtf8 ] ) ,
57
+ Exact ( vec![ LargeUtf8 , Utf8View ] ) ,
58
+ Exact ( vec![ LargeUtf8 , Utf8 ] ) ,
59
+ Exact ( vec![ LargeUtf8 , LargeUtf8 ] ) ,
60
+ ] ,
48
61
Volatility :: Immutable ,
49
62
) ,
50
63
}
@@ -69,28 +82,116 @@ impl ScalarUDFImpl for ContainsFunc {
69
82
}
70
83
71
84
fn invoke ( & self , args : & [ ColumnarValue ] ) -> Result < ColumnarValue > {
72
- match args[ 0 ] . data_type ( ) {
73
- DataType :: Utf8 => make_scalar_function ( contains :: < i32 > , vec ! [ ] ) ( args) ,
74
- DataType :: LargeUtf8 => make_scalar_function ( contains :: < i64 > , vec ! [ ] ) ( args) ,
75
- other => {
76
- exec_err ! ( "unsupported data type {other:?} for function contains" )
77
- }
78
- }
85
+ make_scalar_function ( contains, vec ! [ ] ) ( args)
79
86
}
80
87
}
81
88
82
89
/// use regexp_is_match_utf8_scalar to do the calculation for contains
83
- pub fn contains < T : OffsetSizeTrait > (
84
- args : & [ ArrayRef ] ,
85
- ) -> Result < ArrayRef , DataFusionError > {
86
- let mod_str = as_generic_string_array :: < T > ( & args[ 0 ] ) ?;
87
- let match_str = as_generic_string_array :: < T > ( & args[ 1 ] ) ?;
88
- let res = arrow:: compute:: kernels:: comparison:: regexp_is_match_utf8 (
89
- mod_str, match_str, None ,
90
- )
91
- . map_err ( |e| arrow_datafusion_err ! ( e) ) ?;
92
-
93
- Ok ( Arc :: new ( res) as ArrayRef )
90
+ pub fn contains ( args : & [ ArrayRef ] ) -> Result < ArrayRef , DataFusionError > {
91
+ match ( args[ 0 ] . data_type ( ) , args[ 1 ] . data_type ( ) ) {
92
+ ( Utf8View , Utf8View ) => {
93
+ let mod_str = args[ 0 ] . as_string_view ( ) ;
94
+ let match_str = args[ 1 ] . as_string_view ( ) ;
95
+ let res = regexp_is_match_utf8 :: <
96
+ StringViewArray ,
97
+ StringViewArray ,
98
+ GenericStringArray < i32 > ,
99
+ > ( mod_str, match_str, None ) ?;
100
+
101
+ Ok ( Arc :: new ( res) as ArrayRef )
102
+ }
103
+ ( Utf8View , Utf8 ) => {
104
+ let mod_str = args[ 0 ] . as_string_view ( ) ;
105
+ let match_str = args[ 1 ] . as_string :: < i32 > ( ) ;
106
+ let res = regexp_is_match_utf8 :: <
107
+ StringViewArray ,
108
+ GenericStringArray < i32 > ,
109
+ GenericStringArray < i32 > ,
110
+ > ( mod_str, match_str, None ) ?;
111
+
112
+ Ok ( Arc :: new ( res) as ArrayRef )
113
+ }
114
+ ( Utf8View , LargeUtf8 ) => {
115
+ let mod_str = args[ 0 ] . as_string_view ( ) ;
116
+ let match_str = args[ 1 ] . as_string :: < i64 > ( ) ;
117
+ let res = regexp_is_match_utf8 :: <
118
+ StringViewArray ,
119
+ GenericStringArray < i64 > ,
120
+ GenericStringArray < i32 > ,
121
+ > ( mod_str, match_str, None ) ?;
122
+
123
+ Ok ( Arc :: new ( res) as ArrayRef )
124
+ }
125
+ ( Utf8 , Utf8View ) => {
126
+ let mod_str = args[ 0 ] . as_string :: < i32 > ( ) ;
127
+ let match_str = args[ 1 ] . as_string_view ( ) ;
128
+ let res = regexp_is_match_utf8 :: <
129
+ GenericStringArray < i32 > ,
130
+ StringViewArray ,
131
+ GenericStringArray < i32 > ,
132
+ > ( mod_str, match_str, None ) ?;
133
+
134
+ Ok ( Arc :: new ( res) as ArrayRef )
135
+ }
136
+ ( Utf8 , Utf8 ) => {
137
+ let mod_str = args[ 0 ] . as_string :: < i32 > ( ) ;
138
+ let match_str = args[ 1 ] . as_string :: < i32 > ( ) ;
139
+ let res = regexp_is_match_utf8 :: <
140
+ GenericStringArray < i32 > ,
141
+ GenericStringArray < i32 > ,
142
+ GenericStringArray < i32 > ,
143
+ > ( mod_str, match_str, None ) ?;
144
+
145
+ Ok ( Arc :: new ( res) as ArrayRef )
146
+ }
147
+ ( Utf8 , LargeUtf8 ) => {
148
+ let mod_str = args[ 0 ] . as_string :: < i32 > ( ) ;
149
+ let match_str = args[ 1 ] . as_string :: < i64 > ( ) ;
150
+ let res = regexp_is_match_utf8 :: <
151
+ GenericStringArray < i32 > ,
152
+ GenericStringArray < i64 > ,
153
+ GenericStringArray < i32 > ,
154
+ > ( mod_str, match_str, None ) ?;
155
+
156
+ Ok ( Arc :: new ( res) as ArrayRef )
157
+ }
158
+ ( LargeUtf8 , Utf8View ) => {
159
+ let mod_str = args[ 0 ] . as_string :: < i64 > ( ) ;
160
+ let match_str = args[ 1 ] . as_string_view ( ) ;
161
+ let res = regexp_is_match_utf8 :: <
162
+ GenericStringArray < i64 > ,
163
+ StringViewArray ,
164
+ GenericStringArray < i32 > ,
165
+ > ( mod_str, match_str, None ) ?;
166
+
167
+ Ok ( Arc :: new ( res) as ArrayRef )
168
+ }
169
+ ( LargeUtf8 , Utf8 ) => {
170
+ let mod_str = args[ 0 ] . as_string :: < i64 > ( ) ;
171
+ let match_str = args[ 1 ] . as_string :: < i32 > ( ) ;
172
+ let res = regexp_is_match_utf8 :: <
173
+ GenericStringArray < i64 > ,
174
+ GenericStringArray < i32 > ,
175
+ GenericStringArray < i32 > ,
176
+ > ( mod_str, match_str, None ) ?;
177
+
178
+ Ok ( Arc :: new ( res) as ArrayRef )
179
+ }
180
+ ( LargeUtf8 , LargeUtf8 ) => {
181
+ let mod_str = args[ 0 ] . as_string :: < i64 > ( ) ;
182
+ let match_str = args[ 1 ] . as_string :: < i64 > ( ) ;
183
+ let res = regexp_is_match_utf8 :: <
184
+ GenericStringArray < i64 > ,
185
+ GenericStringArray < i64 > ,
186
+ GenericStringArray < i32 > ,
187
+ > ( mod_str, match_str, None ) ?;
188
+
189
+ Ok ( Arc :: new ( res) as ArrayRef )
190
+ }
191
+ other => {
192
+ exec_err ! ( "Unsupported data type {other:?} for function `contains`." )
193
+ }
194
+ }
94
195
}
95
196
96
197
#[ cfg( test) ]
@@ -138,6 +239,49 @@ mod tests {
138
239
Boolean ,
139
240
BooleanArray
140
241
) ;
242
+
243
+ test_function ! (
244
+ ContainsFunc :: new( ) ,
245
+ & [
246
+ ColumnarValue :: Scalar ( ScalarValue :: Utf8View ( Some ( String :: from(
247
+ "Apache"
248
+ ) ) ) ) ,
249
+ ColumnarValue :: Scalar ( ScalarValue :: Utf8View ( Some ( String :: from( "pac" ) ) ) ) ,
250
+ ] ,
251
+ Ok ( Some ( true ) ) ,
252
+ bool ,
253
+ Boolean ,
254
+ BooleanArray
255
+ ) ;
256
+ test_function ! (
257
+ ContainsFunc :: new( ) ,
258
+ & [
259
+ ColumnarValue :: Scalar ( ScalarValue :: Utf8View ( Some ( String :: from(
260
+ "Apache"
261
+ ) ) ) ) ,
262
+ ColumnarValue :: Scalar ( ScalarValue :: Utf8 ( Some ( String :: from( "ap" ) ) ) ) ,
263
+ ] ,
264
+ Ok ( Some ( false ) ) ,
265
+ bool ,
266
+ Boolean ,
267
+ BooleanArray
268
+ ) ;
269
+ test_function ! (
270
+ ContainsFunc :: new( ) ,
271
+ & [
272
+ ColumnarValue :: Scalar ( ScalarValue :: Utf8View ( Some ( String :: from(
273
+ "Apache"
274
+ ) ) ) ) ,
275
+ ColumnarValue :: Scalar ( ScalarValue :: LargeUtf8 ( Some ( String :: from(
276
+ "DataFusion"
277
+ ) ) ) ) ,
278
+ ] ,
279
+ Ok ( Some ( false ) ) ,
280
+ bool ,
281
+ Boolean ,
282
+ BooleanArray
283
+ ) ;
284
+
141
285
Ok ( ( ) )
142
286
}
143
287
}
0 commit comments