18
18
use std:: any:: Any ;
19
19
use std:: sync:: Arc ;
20
20
21
- use arrow:: array:: { ArrayRef , GenericStringArray , OffsetSizeTrait } ;
21
+ use arrow:: array:: { ArrayRef , GenericStringArray , OffsetSizeTrait , StringArray } ;
22
22
use arrow:: datatypes:: DataType ;
23
23
24
- use datafusion_common:: cast:: { as_generic_string_array, as_int64_array} ;
24
+ use datafusion_common:: cast:: { as_generic_string_array, as_int64_array, as_string_view_array } ;
25
25
use datafusion_common:: { exec_err, Result } ;
26
26
use datafusion_expr:: TypeSignature :: * ;
27
27
use datafusion_expr:: { ColumnarValue , Volatility } ;
@@ -45,7 +45,14 @@ impl RepeatFunc {
45
45
use DataType :: * ;
46
46
Self {
47
47
signature : Signature :: one_of (
48
- vec ! [ Exact ( vec![ Utf8 , Int64 ] ) , Exact ( vec![ LargeUtf8 , Int64 ] ) ] ,
48
+ vec ! [
49
+ // Planner attempts coercion to the target type starting with the most preferred candidate.
50
+ // For example, given input `(Utf8View, Int64)`, it first tries coercing to `(Utf8View, Int64)`.
51
+ // If that fails, it proceeds to `(Utf8, Int64)`.
52
+ Exact ( vec![ Utf8View , Int64 ] ) ,
53
+ Exact ( vec![ Utf8 , Int64 ] ) ,
54
+ Exact ( vec![ LargeUtf8 , Int64 ] ) ,
55
+ ] ,
49
56
Volatility :: Immutable ,
50
57
) ,
51
58
}
@@ -71,9 +78,10 @@ impl ScalarUDFImpl for RepeatFunc {
71
78
72
79
fn invoke ( & self , args : & [ ColumnarValue ] ) -> Result < ColumnarValue > {
73
80
match args[ 0 ] . data_type ( ) {
81
+ DataType :: Utf8View => make_scalar_function ( repeat_utf8view, vec ! [ ] ) ( args) ,
74
82
DataType :: Utf8 => make_scalar_function ( repeat :: < i32 > , vec ! [ ] ) ( args) ,
75
83
DataType :: LargeUtf8 => make_scalar_function ( repeat :: < i64 > , vec ! [ ] ) ( args) ,
76
- other => exec_err ! ( "Unsupported data type {other:?} for function repeat" ) ,
84
+ other => exec_err ! ( "Unsupported data type {other:?} for function repeat. Expected Utf8, Utf8View or LargeUtf8 " ) ,
77
85
}
78
86
}
79
87
}
@@ -87,18 +95,35 @@ fn repeat<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
87
95
let result = string_array
88
96
. iter ( )
89
97
. zip ( number_array. iter ( ) )
90
- . map ( |( string, number) | match ( string, number) {
91
- ( Some ( string) , Some ( number) ) if number >= 0 => {
92
- Some ( string. repeat ( number as usize ) )
93
- }
94
- ( Some ( _) , Some ( _) ) => Some ( "" . to_string ( ) ) ,
95
- _ => None ,
96
- } )
98
+ . map ( |( string, number) | repeat_common ( string, number) )
97
99
. collect :: < GenericStringArray < T > > ( ) ;
98
100
99
101
Ok ( Arc :: new ( result) as ArrayRef )
100
102
}
101
103
104
+ fn repeat_utf8view ( args : & [ ArrayRef ] ) -> Result < ArrayRef > {
105
+ let string_view_array = as_string_view_array ( & args[ 0 ] ) ?;
106
+ let number_array = as_int64_array ( & args[ 1 ] ) ?;
107
+
108
+ let result = string_view_array
109
+ . iter ( )
110
+ . zip ( number_array. iter ( ) )
111
+ . map ( |( string, number) | repeat_common ( string, number) )
112
+ . collect :: < StringArray > ( ) ;
113
+
114
+ Ok ( Arc :: new ( result) as ArrayRef )
115
+ }
116
+
117
+ fn repeat_common ( string : Option < & str > , number : Option < i64 > ) -> Option < String > {
118
+ match ( string, number) {
119
+ ( Some ( string) , Some ( number) ) if number >= 0 => {
120
+ Some ( string. repeat ( number as usize ) )
121
+ }
122
+ ( Some ( _) , Some ( _) ) => Some ( "" . to_string ( ) ) ,
123
+ _ => None ,
124
+ }
125
+ }
126
+
102
127
#[ cfg( test) ]
103
128
mod tests {
104
129
use arrow:: array:: { Array , StringArray } ;
@@ -124,7 +149,6 @@ mod tests {
124
149
Utf8 ,
125
150
StringArray
126
151
) ;
127
-
128
152
test_function ! (
129
153
RepeatFunc :: new( ) ,
130
154
& [
@@ -148,6 +172,40 @@ mod tests {
148
172
StringArray
149
173
) ;
150
174
175
+ test_function ! (
176
+ RepeatFunc :: new( ) ,
177
+ & [
178
+ ColumnarValue :: Scalar ( ScalarValue :: Utf8View ( Some ( String :: from( "Pg" ) ) ) ) ,
179
+ ColumnarValue :: Scalar ( ScalarValue :: Int64 ( Some ( 4 ) ) ) ,
180
+ ] ,
181
+ Ok ( Some ( "PgPgPgPg" ) ) ,
182
+ & str ,
183
+ Utf8 ,
184
+ StringArray
185
+ ) ;
186
+ test_function ! (
187
+ RepeatFunc :: new( ) ,
188
+ & [
189
+ ColumnarValue :: Scalar ( ScalarValue :: Utf8View ( None ) ) ,
190
+ ColumnarValue :: Scalar ( ScalarValue :: Int64 ( Some ( 4 ) ) ) ,
191
+ ] ,
192
+ Ok ( None ) ,
193
+ & str ,
194
+ Utf8 ,
195
+ StringArray
196
+ ) ;
197
+ test_function ! (
198
+ RepeatFunc :: new( ) ,
199
+ & [
200
+ ColumnarValue :: Scalar ( ScalarValue :: Utf8View ( Some ( String :: from( "Pg" ) ) ) ) ,
201
+ ColumnarValue :: Scalar ( ScalarValue :: Int64 ( None ) ) ,
202
+ ] ,
203
+ Ok ( None ) ,
204
+ & str ,
205
+ Utf8 ,
206
+ StringArray
207
+ ) ;
208
+
151
209
Ok ( ( ) )
152
210
}
153
211
}
0 commit comments