Skip to content

Commit 02bfefe

Browse files
authored
Update RPAD scalar function to support Utf8View (#11942)
* Update RPAD scalar function to support Utf8View * adding more test coverage * optimize macro
1 parent f98f8a9 commit 02bfefe

File tree

3 files changed

+203
-80
lines changed

3 files changed

+203
-80
lines changed

datafusion/functions/src/unicode/rpad.rs

Lines changed: 156 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@ use std::sync::Arc;
2020

2121
use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait};
2222
use arrow::datatypes::DataType;
23-
use datafusion_common::cast::{as_generic_string_array, as_int64_array};
23+
use datafusion_common::cast::{
24+
as_generic_string_array, as_int64_array, as_string_view_array,
25+
};
2426
use unicode_segmentation::UnicodeSegmentation;
2527

2628
use crate::utils::{make_scalar_function, utf8_to_str_type};
@@ -45,11 +47,17 @@ impl RPadFunc {
4547
Self {
4648
signature: Signature::one_of(
4749
vec![
50+
Exact(vec![Utf8View, Int64]),
51+
Exact(vec![Utf8View, Int64, Utf8View]),
52+
Exact(vec![Utf8View, Int64, Utf8]),
53+
Exact(vec![Utf8View, Int64, LargeUtf8]),
4854
Exact(vec![Utf8, Int64]),
49-
Exact(vec![LargeUtf8, Int64]),
55+
Exact(vec![Utf8, Int64, Utf8View]),
5056
Exact(vec![Utf8, Int64, Utf8]),
51-
Exact(vec![LargeUtf8, Int64, Utf8]),
5257
Exact(vec![Utf8, Int64, LargeUtf8]),
58+
Exact(vec![LargeUtf8, Int64]),
59+
Exact(vec![LargeUtf8, Int64, Utf8View]),
60+
Exact(vec![LargeUtf8, Int64, Utf8]),
5361
Exact(vec![LargeUtf8, Int64, LargeUtf8]),
5462
],
5563
Volatility::Immutable,
@@ -76,97 +84,168 @@ impl ScalarUDFImpl for RPadFunc {
7684
}
7785

7886
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
79-
match args[0].data_type() {
80-
DataType::Utf8 => make_scalar_function(rpad::<i32>, vec![])(args),
81-
DataType::LargeUtf8 => make_scalar_function(rpad::<i64>, vec![])(args),
82-
other => exec_err!("Unsupported data type {other:?} for function rpad"),
87+
match args.len() {
88+
2 => match args[0].data_type() {
89+
DataType::Utf8 | DataType::Utf8View => {
90+
make_scalar_function(rpad::<i32, i32>, vec![])(args)
91+
}
92+
DataType::LargeUtf8 => {
93+
make_scalar_function(rpad::<i64, i64>, vec![])(args)
94+
}
95+
other => exec_err!("Unsupported data type {other:?} for function rpad"),
96+
},
97+
3 => match (args[0].data_type(), args[2].data_type()) {
98+
(
99+
DataType::Utf8 | DataType::Utf8View,
100+
DataType::Utf8 | DataType::Utf8View,
101+
) => make_scalar_function(rpad::<i32, i32>, vec![])(args),
102+
(DataType::LargeUtf8, DataType::LargeUtf8) => {
103+
make_scalar_function(rpad::<i64, i64>, vec![])(args)
104+
}
105+
(DataType::LargeUtf8, DataType::Utf8View | DataType::Utf8) => {
106+
make_scalar_function(rpad::<i64, i32>, vec![])(args)
107+
}
108+
(DataType::Utf8View | DataType::Utf8, DataType::LargeUtf8) => {
109+
make_scalar_function(rpad::<i32, i64>, vec![])(args)
110+
}
111+
(first_type, last_type) => {
112+
exec_err!("unsupported arguments type for rpad, first argument type is {}, last argument type is {}", first_type, last_type)
113+
}
114+
},
115+
number => {
116+
exec_err!("unsupported arguments number {} for rpad", number)
117+
}
83118
}
84119
}
85120
}
86121

87-
/// Extends the string to length 'length' by appending the characters fill (a space by default). If the string is already longer than length then it is truncated.
88-
/// rpad('hi', 5, 'xy') = 'hixyx'
89-
pub fn rpad<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
90-
match args.len() {
91-
2 => {
92-
let string_array = as_generic_string_array::<T>(&args[0])?;
93-
let length_array = as_int64_array(&args[1])?;
94-
95-
let result = string_array
96-
.iter()
97-
.zip(length_array.iter())
98-
.map(|(string, length)| match (string, length) {
99-
(Some(string), Some(length)) => {
100-
if length > i32::MAX as i64 {
101-
return exec_err!(
102-
"rpad requested length {length} too large"
103-
);
104-
}
105-
106-
let length = if length < 0 { 0 } else { length as usize };
107-
if length == 0 {
108-
Ok(Some("".to_string()))
109-
} else {
110-
let graphemes = string.graphemes(true).collect::<Vec<&str>>();
111-
if length < graphemes.len() {
112-
Ok(Some(graphemes[..length].concat()))
113-
} else {
114-
let mut s = string.to_string();
115-
s.push_str(" ".repeat(length - graphemes.len()).as_str());
116-
Ok(Some(s))
117-
}
118-
}
122+
macro_rules! process_rpad {
123+
// For the two-argument case
124+
($string_array:expr, $length_array:expr) => {{
125+
$string_array
126+
.iter()
127+
.zip($length_array.iter())
128+
.map(|(string, length)| match (string, length) {
129+
(Some(string), Some(length)) => {
130+
if length > i32::MAX as i64 {
131+
return exec_err!("rpad requested length {} too large", length);
119132
}
120-
_ => Ok(None),
121-
})
122-
.collect::<Result<GenericStringArray<T>>>()?;
123-
Ok(Arc::new(result) as ArrayRef)
124-
}
125-
3 => {
126-
let string_array = as_generic_string_array::<T>(&args[0])?;
127-
let length_array = as_int64_array(&args[1])?;
128-
let fill_array = as_generic_string_array::<T>(&args[2])?;
129-
130-
let result = string_array
131-
.iter()
132-
.zip(length_array.iter())
133-
.zip(fill_array.iter())
134-
.map(|((string, length), fill)| match (string, length, fill) {
135-
(Some(string), Some(length), Some(fill)) => {
136-
if length > i32::MAX as i64 {
137-
return exec_err!(
138-
"rpad requested length {length} too large"
139-
);
140-
}
141133

142-
let length = if length < 0 { 0 } else { length as usize };
134+
let length = if length < 0 { 0 } else { length as usize };
135+
if length == 0 {
136+
Ok(Some("".to_string()))
137+
} else {
143138
let graphemes = string.graphemes(true).collect::<Vec<&str>>();
144-
let fill_chars = fill.chars().collect::<Vec<char>>();
145-
146139
if length < graphemes.len() {
147140
Ok(Some(graphemes[..length].concat()))
148-
} else if fill_chars.is_empty() {
149-
Ok(Some(string.to_string()))
150141
} else {
151142
let mut s = string.to_string();
152-
let mut char_vector =
153-
Vec::<char>::with_capacity(length - graphemes.len());
154-
for l in 0..length - graphemes.len() {
155-
char_vector
156-
.push(*fill_chars.get(l % fill_chars.len()).unwrap());
157-
}
158-
s.push_str(char_vector.iter().collect::<String>().as_str());
143+
s.push_str(" ".repeat(length - graphemes.len()).as_str());
159144
Ok(Some(s))
160145
}
161146
}
162-
_ => Ok(None),
163-
})
164-
.collect::<Result<GenericStringArray<T>>>()?;
147+
}
148+
_ => Ok(None),
149+
})
150+
.collect::<Result<GenericStringArray<StringArrayLen>>>()
151+
}};
165152

153+
// For the three-argument case
154+
($string_array:expr, $length_array:expr, $fill_array:expr) => {{
155+
$string_array
156+
.iter()
157+
.zip($length_array.iter())
158+
.zip($fill_array.iter())
159+
.map(|((string, length), fill)| match (string, length, fill) {
160+
(Some(string), Some(length), Some(fill)) => {
161+
if length > i32::MAX as i64 {
162+
return exec_err!("rpad requested length {} too large", length);
163+
}
164+
165+
let length = if length < 0 { 0 } else { length as usize };
166+
let graphemes = string.graphemes(true).collect::<Vec<&str>>();
167+
let fill_chars = fill.chars().collect::<Vec<char>>();
168+
169+
if length < graphemes.len() {
170+
Ok(Some(graphemes[..length].concat()))
171+
} else if fill_chars.is_empty() {
172+
Ok(Some(string.to_string()))
173+
} else {
174+
let mut s = string.to_string();
175+
let char_vector: Vec<char> = (0..length - graphemes.len())
176+
.map(|l| fill_chars[l % fill_chars.len()])
177+
.collect();
178+
s.push_str(&char_vector.iter().collect::<String>());
179+
Ok(Some(s))
180+
}
181+
}
182+
_ => Ok(None),
183+
})
184+
.collect::<Result<GenericStringArray<StringArrayLen>>>()
185+
}};
186+
}
187+
188+
/// Extends the string to length 'length' by appending the characters fill (a space by default). If the string is already longer than length then it is truncated.
189+
/// rpad('hi', 5, 'xy') = 'hixyx'
190+
pub fn rpad<StringArrayLen: OffsetSizeTrait, FillArrayLen: OffsetSizeTrait>(
191+
args: &[ArrayRef],
192+
) -> Result<ArrayRef> {
193+
match (args.len(), args[0].data_type()) {
194+
(2, DataType::Utf8View) => {
195+
let string_array = as_string_view_array(&args[0])?;
196+
let length_array = as_int64_array(&args[1])?;
197+
198+
let result = process_rpad!(string_array, length_array)?;
199+
Ok(Arc::new(result) as ArrayRef)
200+
}
201+
(2, _) => {
202+
let string_array = as_generic_string_array::<StringArrayLen>(&args[0])?;
203+
let length_array = as_int64_array(&args[1])?;
204+
205+
let result = process_rpad!(string_array, length_array)?;
166206
Ok(Arc::new(result) as ArrayRef)
167207
}
168-
other => exec_err!(
169-
"rpad was called with {other} arguments. It requires at least 2 and at most 3."
208+
(3, DataType::Utf8View) => {
209+
let string_array = as_string_view_array(&args[0])?;
210+
let length_array = as_int64_array(&args[1])?;
211+
match args[2].data_type() {
212+
DataType::Utf8View => {
213+
let fill_array = as_string_view_array(&args[2])?;
214+
let result = process_rpad!(string_array, length_array, fill_array)?;
215+
Ok(Arc::new(result) as ArrayRef)
216+
}
217+
DataType::Utf8 | DataType::LargeUtf8 => {
218+
let fill_array = as_generic_string_array::<FillArrayLen>(&args[2])?;
219+
let result = process_rpad!(string_array, length_array, fill_array)?;
220+
Ok(Arc::new(result) as ArrayRef)
221+
}
222+
other_type => {
223+
exec_err!("unsupported type for rpad's third operator: {}", other_type)
224+
}
225+
}
226+
}
227+
(3, _) => {
228+
let string_array = as_generic_string_array::<StringArrayLen>(&args[0])?;
229+
let length_array = as_int64_array(&args[1])?;
230+
match args[2].data_type() {
231+
DataType::Utf8View => {
232+
let fill_array = as_string_view_array(&args[2])?;
233+
let result = process_rpad!(string_array, length_array, fill_array)?;
234+
Ok(Arc::new(result) as ArrayRef)
235+
}
236+
DataType::Utf8 | DataType::LargeUtf8 => {
237+
let fill_array = as_generic_string_array::<FillArrayLen>(&args[2])?;
238+
let result = process_rpad!(string_array, length_array, fill_array)?;
239+
Ok(Arc::new(result) as ArrayRef)
240+
}
241+
other_type => {
242+
exec_err!("unsupported type for rpad's third operator: {}", other_type)
243+
}
244+
}
245+
}
246+
(other, other_type) => exec_err!(
247+
"rpad requires 2 or 3 arguments with corresponding types, but got {}. number of arguments with {}",
248+
other, other_type
170249
),
171250
}
172251
}

datafusion/sqllogictest/test_files/functions.slt

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,7 @@ SELECT right(NULL, CAST(NULL AS INT))
294294
----
295295
NULL
296296

297+
297298
query T
298299
SELECT rpad('hi', -1, 'xy')
299300
----
@@ -354,6 +355,33 @@ SELECT rpad('xyxhi', 3)
354355
----
355356
xyx
356357

358+
# test for rpad with largeutf8 and utf8View
359+
360+
query T
361+
SELECT rpad(arrow_cast('hi', 'LargeUtf8'), 5, 'xy')
362+
----
363+
hixyx
364+
365+
query T
366+
SELECT rpad(arrow_cast('hi', 'Utf8View'), 5, 'xy')
367+
----
368+
hixyx
369+
370+
query T
371+
SELECT rpad(arrow_cast('hi', 'LargeUtf8'), 5, arrow_cast('xy', 'LargeUtf8'))
372+
----
373+
hixyx
374+
375+
query T
376+
SELECT rpad(arrow_cast('hi', 'Utf8View'), 5, arrow_cast('xy', 'Utf8View'))
377+
----
378+
hixyx
379+
380+
query T
381+
SELECT rpad(arrow_cast(NULL, 'Utf8View'), 5, 'xy')
382+
----
383+
NULL
384+
357385
query I
358386
SELECT strpos('abc', 'c')
359387
----

datafusion/sqllogictest/test_files/string_view.slt

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -926,10 +926,26 @@ EXPLAIN SELECT
926926
FROM test;
927927
----
928928
logical_plan
929-
01)Projection: rpad(__common_expr_1, Int64(1)) AS c1, rpad(__common_expr_1, Int64(2), CAST(test.column2_utf8view AS Utf8)) AS c2
930-
02)--Projection: CAST(test.column1_utf8view AS Utf8) AS __common_expr_1, test.column2_utf8view
931-
03)----TableScan: test projection=[column1_utf8view, column2_utf8view]
929+
01)Projection: rpad(test.column1_utf8view, Int64(1)) AS c1, rpad(test.column1_utf8view, Int64(2), test.column2_utf8view) AS c2
930+
02)--TableScan: test projection=[column1_utf8view, column2_utf8view]
931+
932+
query TT
933+
EXPLAIN SELECT
934+
RPAD(column1_utf8view, 12, column2_large_utf8) as c1
935+
FROM test;
936+
----
937+
logical_plan
938+
01)Projection: rpad(test.column1_utf8view, Int64(12), test.column2_large_utf8) AS c1
939+
02)--TableScan: test projection=[column2_large_utf8, column1_utf8view]
932940

941+
query TT
942+
EXPLAIN SELECT
943+
RPAD(column1_utf8view, 12, column2_utf8view) as c1
944+
FROM test;
945+
----
946+
logical_plan
947+
01)Projection: rpad(test.column1_utf8view, Int64(12), test.column2_utf8view) AS c1
948+
02)--TableScan: test projection=[column1_utf8view, column2_utf8view]
933949

934950
## Ensure no casts for SPLIT_PART
935951
## TODO file ticket

0 commit comments

Comments
 (0)