Skip to content

Commit f482212

Browse files
committed
Implement native support StringView for overlay
Signed-off-by: Chojan Shang <[email protected]>
1 parent b60cdc7 commit f482212

File tree

2 files changed

+103
-4
lines changed

2 files changed

+103
-4
lines changed

datafusion/functions/src/string/overlay.rs

+102-3
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@ use std::sync::Arc;
2121
use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait};
2222
use arrow::datatypes::DataType;
2323

24-
use datafusion_common::cast::{as_generic_string_array, as_int64_array};
24+
use datafusion_common::cast::{
25+
as_generic_string_array, as_int64_array, as_string_view_array,
26+
};
2527
use datafusion_common::{exec_err, Result};
2628
use datafusion_expr::TypeSignature::*;
2729
use datafusion_expr::{ColumnarValue, Volatility};
@@ -46,8 +48,10 @@ impl OverlayFunc {
4648
Self {
4749
signature: Signature::one_of(
4850
vec![
51+
Exact(vec![Utf8View, Utf8View, Int64, Int64]),
4952
Exact(vec![Utf8, Utf8, Int64, Int64]),
5053
Exact(vec![LargeUtf8, LargeUtf8, Int64, Int64]),
54+
Exact(vec![Utf8View, Utf8View, Int64]),
5155
Exact(vec![Utf8, Utf8, Int64]),
5256
Exact(vec![LargeUtf8, LargeUtf8, Int64]),
5357
],
@@ -76,7 +80,9 @@ impl ScalarUDFImpl for OverlayFunc {
7680

7781
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
7882
match args[0].data_type() {
79-
DataType::Utf8 => make_scalar_function(overlay::<i32>, vec![])(args),
83+
DataType::Utf8View | DataType::Utf8 => {
84+
make_scalar_function(overlay::<i32>, vec![])(args)
85+
}
8086
DataType::LargeUtf8 => make_scalar_function(overlay::<i64>, vec![])(args),
8187
other => exec_err!("Unsupported data type {other:?} for function overlay"),
8288
}
@@ -87,7 +93,16 @@ impl ScalarUDFImpl for OverlayFunc {
8793
/// Replaces a substring of string1 with string2 starting at the integer bit
8894
/// pgsql overlay('Txxxxas' placing 'hom' from 2 for 4) → Thomas
8995
/// overlay('Txxxxas' placing 'hom' from 2) -> Thomxas, without for option, str2's len is instead
90-
pub fn overlay<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
96+
fn overlay<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
97+
let use_string_view = args[0].data_type() == &DataType::Utf8View;
98+
if use_string_view {
99+
string_view_overlay::<T>(args)
100+
} else {
101+
string_overlay::<T>(args)
102+
}
103+
}
104+
105+
pub fn string_overlay<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
91106
match args.len() {
92107
3 => {
93108
let string_array = as_generic_string_array::<T>(&args[0])?;
@@ -171,6 +186,90 @@ pub fn overlay<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
171186
}
172187
}
173188

189+
pub fn string_view_overlay<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
190+
match args.len() {
191+
3 => {
192+
let string_array = as_string_view_array(&args[0])?;
193+
let characters_array = as_string_view_array(&args[1])?;
194+
let pos_num = as_int64_array(&args[2])?;
195+
196+
let result = string_array
197+
.iter()
198+
.zip(characters_array.iter())
199+
.zip(pos_num.iter())
200+
.map(|((string, characters), start_pos)| {
201+
match (string, characters, start_pos) {
202+
(Some(string), Some(characters), Some(start_pos)) => {
203+
let string_len = string.chars().count();
204+
let characters_len = characters.chars().count();
205+
let replace_len = characters_len as i64;
206+
let mut res =
207+
String::with_capacity(string_len.max(characters_len));
208+
209+
//as sql replace index start from 1 while string index start from 0
210+
if start_pos > 1 && start_pos - 1 < string_len as i64 {
211+
let start = (start_pos - 1) as usize;
212+
res.push_str(&string[..start]);
213+
}
214+
res.push_str(characters);
215+
// if start + replace_len - 1 >= string_length, just to string end
216+
if start_pos + replace_len - 1 < string_len as i64 {
217+
let end = (start_pos + replace_len - 1) as usize;
218+
res.push_str(&string[end..]);
219+
}
220+
Ok(Some(res))
221+
}
222+
_ => Ok(None),
223+
}
224+
})
225+
.collect::<Result<GenericStringArray<T>>>()?;
226+
Ok(Arc::new(result) as ArrayRef)
227+
}
228+
4 => {
229+
let string_array = as_string_view_array(&args[0])?;
230+
let characters_array = as_string_view_array(&args[1])?;
231+
let pos_num = as_int64_array(&args[2])?;
232+
let len_num = as_int64_array(&args[3])?;
233+
234+
let result = string_array
235+
.iter()
236+
.zip(characters_array.iter())
237+
.zip(pos_num.iter())
238+
.zip(len_num.iter())
239+
.map(|(((string, characters), start_pos), len)| {
240+
match (string, characters, start_pos, len) {
241+
(Some(string), Some(characters), Some(start_pos), Some(len)) => {
242+
let string_len = string.chars().count();
243+
let characters_len = characters.chars().count();
244+
let replace_len = len.min(string_len as i64);
245+
let mut res =
246+
String::with_capacity(string_len.max(characters_len));
247+
248+
//as sql replace index start from 1 while string index start from 0
249+
if start_pos > 1 && start_pos - 1 < string_len as i64 {
250+
let start = (start_pos - 1) as usize;
251+
res.push_str(&string[..start]);
252+
}
253+
res.push_str(characters);
254+
// if start + replace_len - 1 >= string_length, just to string end
255+
if start_pos + replace_len - 1 < string_len as i64 {
256+
let end = (start_pos + replace_len - 1) as usize;
257+
res.push_str(&string[end..]);
258+
}
259+
Ok(Some(res))
260+
}
261+
_ => Ok(None),
262+
}
263+
})
264+
.collect::<Result<GenericStringArray<T>>>()?;
265+
Ok(Arc::new(result) as ArrayRef)
266+
}
267+
other => {
268+
exec_err!("overlay was called with {other} arguments. It requires 3 or 4.")
269+
}
270+
}
271+
}
272+
174273
#[cfg(test)]
175274
mod tests {
176275
use arrow::array::{Int64Array, StringArray};

datafusion/sqllogictest/test_files/string_view.slt

+1-1
Original file line numberDiff line numberDiff line change
@@ -691,7 +691,7 @@ EXPLAIN SELECT
691691
FROM test;
692692
----
693693
logical_plan
694-
01)Projection: overlay(CAST(test.column1_utf8view AS Utf8), Utf8("foo"), Int64(2)) AS c1
694+
01)Projection: overlay(test.column1_utf8view, Utf8View("foo"), Int64(2)) AS c1
695695
02)--TableScan: test projection=[column1_utf8view]
696696

697697
## Ensure no casts for REGEXP_LIKE

0 commit comments

Comments
 (0)