18
18
use std:: any:: Any ;
19
19
use std:: sync:: Arc ;
20
20
21
- use arrow:: array:: { ArrayRef , OffsetSizeTrait , StringBuilder } ;
22
- use arrow:: datatypes:: DataType ;
21
+ use arrow:: array:: {
22
+ ArrayAccessor , ArrayIter , ArrayRef , ArrowPrimitiveType , AsArray , OffsetSizeTrait ,
23
+ PrimitiveArray , StringBuilder ,
24
+ } ;
25
+ use arrow:: datatypes:: { DataType , Int32Type , Int64Type } ;
23
26
24
- use datafusion_common:: cast:: { as_generic_string_array, as_int64_array} ;
25
27
use datafusion_common:: { exec_err, Result } ;
26
28
use datafusion_expr:: TypeSignature :: Exact ;
27
29
use datafusion_expr:: { ColumnarValue , ScalarUDFImpl , Signature , Volatility } ;
@@ -46,6 +48,7 @@ impl SubstrIndexFunc {
46
48
Self {
47
49
signature : Signature :: one_of (
48
50
vec ! [
51
+ Exact ( vec![ Utf8View , Utf8View , Int64 ] ) ,
49
52
Exact ( vec![ Utf8 , Utf8 , Int64 ] ) ,
50
53
Exact ( vec![ LargeUtf8 , LargeUtf8 , Int64 ] ) ,
51
54
] ,
@@ -74,15 +77,7 @@ impl ScalarUDFImpl for SubstrIndexFunc {
74
77
}
75
78
76
79
fn invoke ( & self , args : & [ ColumnarValue ] ) -> Result < ColumnarValue > {
77
- match args[ 0 ] . data_type ( ) {
78
- DataType :: Utf8 => make_scalar_function ( substr_index :: < i32 > , vec ! [ ] ) ( args) ,
79
- DataType :: LargeUtf8 => {
80
- make_scalar_function ( substr_index :: < i64 > , vec ! [ ] ) ( args)
81
- }
82
- other => {
83
- exec_err ! ( "Unsupported data type {other:?} for function substr_index" )
84
- }
85
- }
80
+ make_scalar_function ( substr_index, vec ! [ ] ) ( args)
86
81
}
87
82
88
83
fn aliases ( & self ) -> & [ String ] {
@@ -95,23 +90,71 @@ impl ScalarUDFImpl for SubstrIndexFunc {
95
90
/// SUBSTRING_INDEX('www.apache.org', '.', 2) = www.apache
96
91
/// SUBSTRING_INDEX('www.apache.org', '.', -2) = apache.org
97
92
/// SUBSTRING_INDEX('www.apache.org', '.', -1) = org
98
- pub fn substr_index < T : OffsetSizeTrait > ( args : & [ ArrayRef ] ) -> Result < ArrayRef > {
93
+ fn substr_index ( args : & [ ArrayRef ] ) -> Result < ArrayRef > {
99
94
if args. len ( ) != 3 {
100
95
return exec_err ! (
101
96
"substr_index was called with {} arguments. It requires 3." ,
102
97
args. len( )
103
98
) ;
104
99
}
105
100
106
- let string_array = as_generic_string_array :: < T > ( & args[ 0 ] ) ?;
107
- let delimiter_array = as_generic_string_array :: < T > ( & args[ 1 ] ) ?;
108
- let count_array = as_int64_array ( & args[ 2 ] ) ?;
101
+ match args[ 0 ] . data_type ( ) {
102
+ DataType :: Utf8 => {
103
+ let string_array = args[ 0 ] . as_string :: < i32 > ( ) ;
104
+ let delimiter_array = args[ 1 ] . as_string :: < i32 > ( ) ;
105
+ let count_array: & PrimitiveArray < Int64Type > = args[ 2 ] . as_primitive ( ) ;
106
+ substr_index_general :: < Int32Type , _ , _ > (
107
+ string_array,
108
+ delimiter_array,
109
+ count_array,
110
+ )
111
+ }
112
+ DataType :: LargeUtf8 => {
113
+ let string_array = args[ 0 ] . as_string :: < i64 > ( ) ;
114
+ let delimiter_array = args[ 1 ] . as_string :: < i64 > ( ) ;
115
+ let count_array: & PrimitiveArray < Int64Type > = args[ 2 ] . as_primitive ( ) ;
116
+ substr_index_general :: < Int64Type , _ , _ > (
117
+ string_array,
118
+ delimiter_array,
119
+ count_array,
120
+ )
121
+ }
122
+ DataType :: Utf8View => {
123
+ let string_array = args[ 0 ] . as_string_view ( ) ;
124
+ let delimiter_array = args[ 1 ] . as_string_view ( ) ;
125
+ let count_array: & PrimitiveArray < Int64Type > = args[ 2 ] . as_primitive ( ) ;
126
+ substr_index_general :: < Int32Type , _ , _ > (
127
+ string_array,
128
+ delimiter_array,
129
+ count_array,
130
+ )
131
+ }
132
+ other => {
133
+ exec_err ! ( "Unsupported data type {other:?} for function substr_index" )
134
+ }
135
+ }
136
+ }
109
137
138
+ pub fn substr_index_general <
139
+ ' a ,
140
+ T : ArrowPrimitiveType ,
141
+ V : ArrayAccessor < Item = & ' a str > ,
142
+ P : ArrayAccessor < Item = i64 > ,
143
+ > (
144
+ string_array : V ,
145
+ delimiter_array : V ,
146
+ count_array : P ,
147
+ ) -> Result < ArrayRef >
148
+ where
149
+ T :: Native : OffsetSizeTrait ,
150
+ {
110
151
let mut builder = StringBuilder :: new ( ) ;
111
- string_array
112
- . iter ( )
113
- . zip ( delimiter_array. iter ( ) )
114
- . zip ( count_array. iter ( ) )
152
+ let string_iter = ArrayIter :: new ( string_array) ;
153
+ let delimiter_array_iter = ArrayIter :: new ( delimiter_array) ;
154
+ let count_array_iter = ArrayIter :: new ( count_array) ;
155
+ string_iter
156
+ . zip ( delimiter_array_iter)
157
+ . zip ( count_array_iter)
115
158
. for_each ( |( ( string, delimiter) , n) | match ( string, delimiter, n) {
116
159
( Some ( string) , Some ( delimiter) , Some ( n) ) => {
117
160
// In MySQL, these cases will return an empty string.
0 commit comments