@@ -19,10 +19,12 @@ use std::any::Any;
19
19
use std:: cmp:: max;
20
20
use std:: sync:: Arc ;
21
21
22
- use arrow:: array:: { ArrayRef , GenericStringArray , OffsetSizeTrait } ;
22
+ use arrow:: array:: {
23
+ ArrayAccessor , ArrayIter , ArrayRef , AsArray , GenericStringArray , OffsetSizeTrait ,
24
+ } ;
23
25
use arrow:: datatypes:: DataType ;
24
26
25
- use datafusion_common:: cast:: { as_generic_string_array , as_int64_array} ;
27
+ use datafusion_common:: cast:: as_int64_array;
26
28
use datafusion_common:: { exec_err, Result } ;
27
29
use datafusion_expr:: TypeSignature :: Exact ;
28
30
use datafusion_expr:: { ColumnarValue , ScalarUDFImpl , Signature , Volatility } ;
@@ -51,6 +53,8 @@ impl SubstrFunc {
51
53
Exact ( vec![ LargeUtf8 , Int64 ] ) ,
52
54
Exact ( vec![ Utf8 , Int64 , Int64 ] ) ,
53
55
Exact ( vec![ LargeUtf8 , Int64 , Int64 ] ) ,
56
+ Exact ( vec![ Utf8View , Int64 ] ) ,
57
+ Exact ( vec![ Utf8View , Int64 , Int64 ] ) ,
54
58
] ,
55
59
Volatility :: Immutable ,
56
60
) ,
@@ -77,30 +81,47 @@ impl ScalarUDFImpl for SubstrFunc {
77
81
}
78
82
79
83
fn invoke ( & self , args : & [ ColumnarValue ] ) -> Result < ColumnarValue > {
80
- match args[ 0 ] . data_type ( ) {
81
- DataType :: Utf8 => make_scalar_function ( substr :: < i32 > , vec ! [ ] ) ( args) ,
82
- DataType :: LargeUtf8 => make_scalar_function ( substr :: < i64 > , vec ! [ ] ) ( args) ,
83
- other => exec_err ! ( "Unsupported data type {other:?} for function substr" ) ,
84
- }
84
+ make_scalar_function ( substr, vec ! [ ] ) ( args)
85
85
}
86
86
87
87
fn aliases ( & self ) -> & [ String ] {
88
88
& self . aliases
89
89
}
90
90
}
91
91
92
+ pub fn substr ( args : & [ ArrayRef ] ) -> Result < ArrayRef > {
93
+ match args[ 0 ] . data_type ( ) {
94
+ DataType :: Utf8 => {
95
+ let string_array = args[ 0 ] . as_string :: < i32 > ( ) ;
96
+ calculate_substr :: < _ , i32 > ( string_array, & args[ 1 ..] )
97
+ }
98
+ DataType :: LargeUtf8 => {
99
+ let string_array = args[ 0 ] . as_string :: < i64 > ( ) ;
100
+ calculate_substr :: < _ , i64 > ( string_array, & args[ 1 ..] )
101
+ }
102
+ DataType :: Utf8View => {
103
+ let string_array = args[ 0 ] . as_string_view ( ) ;
104
+ calculate_substr :: < _ , i32 > ( string_array, & args[ 1 ..] )
105
+ }
106
+ other => exec_err ! ( "Unsupported data type {other:?} for function substr" ) ,
107
+ }
108
+ }
109
+
92
110
/// Extracts the substring of string starting at the start'th character, and extending for count characters if that is specified. (Same as substring(string from start for count).)
93
111
/// substr('alphabet', 3) = 'phabet'
94
112
/// substr('alphabet', 3, 2) = 'ph'
95
113
/// The implementation uses UTF-8 code points as characters
96
- pub fn substr < T : OffsetSizeTrait > ( args : & [ ArrayRef ] ) -> Result < ArrayRef > {
114
+ fn calculate_substr < ' a , V , T > ( string_array : V , args : & [ ArrayRef ] ) -> Result < ArrayRef >
115
+ where
116
+ V : ArrayAccessor < Item = & ' a str > ,
117
+ T : OffsetSizeTrait ,
118
+ {
97
119
match args. len ( ) {
98
- 2 => {
99
- let string_array = as_generic_string_array :: < T > ( & args [ 0 ] ) ? ;
100
- let start_array = as_int64_array ( & args[ 1 ] ) ?;
120
+ 1 => {
121
+ let iter = ArrayIter :: new ( string_array ) ;
122
+ let start_array = as_int64_array ( & args[ 0 ] ) ?;
101
123
102
- let result = string_array
103
- . iter ( )
124
+ let result = iter
104
125
. zip ( start_array. iter ( ) )
105
126
. map ( |( string, start) | match ( string, start) {
106
127
( Some ( string) , Some ( start) ) => {
@@ -113,16 +134,14 @@ pub fn substr<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
113
134
_ => None ,
114
135
} )
115
136
. collect :: < GenericStringArray < T > > ( ) ;
116
-
117
137
Ok ( Arc :: new ( result) as ArrayRef )
118
138
}
119
- 3 => {
120
- let string_array = as_generic_string_array :: < T > ( & args [ 0 ] ) ? ;
121
- let start_array = as_int64_array ( & args[ 1 ] ) ?;
122
- let count_array = as_int64_array ( & args[ 2 ] ) ?;
139
+ 2 => {
140
+ let iter = ArrayIter :: new ( string_array ) ;
141
+ let start_array = as_int64_array ( & args[ 0 ] ) ?;
142
+ let count_array = as_int64_array ( & args[ 1 ] ) ?;
123
143
124
- let result = string_array
125
- . iter ( )
144
+ let result = iter
126
145
. zip ( start_array. iter ( ) )
127
146
. zip ( count_array. iter ( ) )
128
147
. map ( |( ( string, start) , count) | match ( string, start, count) {
@@ -162,6 +181,71 @@ mod tests {
162
181
163
182
#[ test]
164
183
fn test_functions ( ) -> Result < ( ) > {
184
+ test_function ! (
185
+ SubstrFunc :: new( ) ,
186
+ & [
187
+ ColumnarValue :: Scalar ( ScalarValue :: Utf8View ( None ) ) ,
188
+ ColumnarValue :: Scalar ( ScalarValue :: from( 1i64 ) ) ,
189
+ ] ,
190
+ Ok ( None ) ,
191
+ & str ,
192
+ Utf8 ,
193
+ StringArray
194
+ ) ;
195
+ test_function ! (
196
+ SubstrFunc :: new( ) ,
197
+ & [
198
+ ColumnarValue :: Scalar ( ScalarValue :: Utf8View ( Some ( String :: from(
199
+ "alphabet"
200
+ ) ) ) ) ,
201
+ ColumnarValue :: Scalar ( ScalarValue :: from( 0i64 ) ) ,
202
+ ] ,
203
+ Ok ( Some ( "alphabet" ) ) ,
204
+ & str ,
205
+ Utf8 ,
206
+ StringArray
207
+ ) ;
208
+ test_function ! (
209
+ SubstrFunc :: new( ) ,
210
+ & [
211
+ ColumnarValue :: Scalar ( ScalarValue :: Utf8View ( Some ( String :: from(
212
+ "joséésoj"
213
+ ) ) ) ) ,
214
+ ColumnarValue :: Scalar ( ScalarValue :: from( 5i64 ) ) ,
215
+ ] ,
216
+ Ok ( Some ( "ésoj" ) ) ,
217
+ & str ,
218
+ Utf8 ,
219
+ StringArray
220
+ ) ;
221
+ test_function ! (
222
+ SubstrFunc :: new( ) ,
223
+ & [
224
+ ColumnarValue :: Scalar ( ScalarValue :: Utf8View ( Some ( String :: from(
225
+ "alphabet"
226
+ ) ) ) ) ,
227
+ ColumnarValue :: Scalar ( ScalarValue :: from( 3i64 ) ) ,
228
+ ColumnarValue :: Scalar ( ScalarValue :: from( 2i64 ) ) ,
229
+ ] ,
230
+ Ok ( Some ( "ph" ) ) ,
231
+ & str ,
232
+ Utf8 ,
233
+ StringArray
234
+ ) ;
235
+ test_function ! (
236
+ SubstrFunc :: new( ) ,
237
+ & [
238
+ ColumnarValue :: Scalar ( ScalarValue :: Utf8View ( Some ( String :: from(
239
+ "alphabet"
240
+ ) ) ) ) ,
241
+ ColumnarValue :: Scalar ( ScalarValue :: from( 3i64 ) ) ,
242
+ ColumnarValue :: Scalar ( ScalarValue :: from( 20i64 ) ) ,
243
+ ] ,
244
+ Ok ( Some ( "phabet" ) ) ,
245
+ & str ,
246
+ Utf8 ,
247
+ StringArray
248
+ ) ;
165
249
test_function ! (
166
250
SubstrFunc :: new( ) ,
167
251
& [
0 commit comments