Skip to content

Commit 1ab2662

Browse files
committed
support Utf8View
1 parent 4baa901 commit 1ab2662

File tree

2 files changed

+129
-23
lines changed

2 files changed

+129
-23
lines changed

datafusion/functions/src/unicode/substr.rs

+104-20
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,12 @@ use std::any::Any;
1919
use std::cmp::max;
2020
use std::sync::Arc;
2121

22-
use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait};
22+
use arrow::array::{
23+
ArrayAccessor, ArrayIter, ArrayRef, AsArray, GenericStringArray, OffsetSizeTrait,
24+
};
2325
use arrow::datatypes::DataType;
2426

25-
use datafusion_common::cast::{as_generic_string_array, as_int64_array};
27+
use datafusion_common::cast::as_int64_array;
2628
use datafusion_common::{exec_err, Result};
2729
use datafusion_expr::TypeSignature::Exact;
2830
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
@@ -51,6 +53,8 @@ impl SubstrFunc {
5153
Exact(vec![LargeUtf8, Int64]),
5254
Exact(vec![Utf8, Int64, Int64]),
5355
Exact(vec![LargeUtf8, Int64, Int64]),
56+
Exact(vec![Utf8View, Int64]),
57+
Exact(vec![Utf8View, Int64, Int64]),
5458
],
5559
Volatility::Immutable,
5660
),
@@ -77,30 +81,47 @@ impl ScalarUDFImpl for SubstrFunc {
7781
}
7882

7983
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
80-
match args[0].data_type() {
81-
DataType::Utf8 => make_scalar_function(substr::<i32>, vec![])(args),
82-
DataType::LargeUtf8 => make_scalar_function(substr::<i64>, vec![])(args),
83-
other => exec_err!("Unsupported data type {other:?} for function substr"),
84-
}
84+
make_scalar_function(substr, vec![])(args)
8585
}
8686

8787
fn aliases(&self) -> &[String] {
8888
&self.aliases
8989
}
9090
}
9191

92+
pub fn substr(args: &[ArrayRef]) -> Result<ArrayRef> {
93+
match args[0].data_type() {
94+
DataType::Utf8 => {
95+
let string_array = args[0].as_string::<i32>();
96+
calculate_substr::<_, i32>(string_array, &args[1..])
97+
}
98+
DataType::LargeUtf8 => {
99+
let string_array = args[0].as_string::<i64>();
100+
calculate_substr::<_, i64>(string_array, &args[1..])
101+
}
102+
DataType::Utf8View => {
103+
let string_array = args[0].as_string_view();
104+
calculate_substr::<_, i32>(string_array, &args[1..])
105+
}
106+
other => exec_err!("Unsupported data type {other:?} for function substr"),
107+
}
108+
}
109+
92110
/// Extracts the substring of string starting at the start'th character, and extending for count characters if that is specified. (Same as substring(string from start for count).)
93111
/// substr('alphabet', 3) = 'phabet'
94112
/// substr('alphabet', 3, 2) = 'ph'
95113
/// The implementation uses UTF-8 code points as characters
96-
pub fn substr<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
114+
fn calculate_substr<'a, V, T>(string_array: V, args: &[ArrayRef]) -> Result<ArrayRef>
115+
where
116+
V: ArrayAccessor<Item = &'a str>,
117+
T: OffsetSizeTrait,
118+
{
97119
match args.len() {
98-
2 => {
99-
let string_array = as_generic_string_array::<T>(&args[0])?;
100-
let start_array = as_int64_array(&args[1])?;
120+
1 => {
121+
let iter = ArrayIter::new(string_array);
122+
let start_array = as_int64_array(&args[0])?;
101123

102-
let result = string_array
103-
.iter()
124+
let result = iter
104125
.zip(start_array.iter())
105126
.map(|(string, start)| match (string, start) {
106127
(Some(string), Some(start)) => {
@@ -113,16 +134,14 @@ pub fn substr<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
113134
_ => None,
114135
})
115136
.collect::<GenericStringArray<T>>();
116-
117137
Ok(Arc::new(result) as ArrayRef)
118138
}
119-
3 => {
120-
let string_array = as_generic_string_array::<T>(&args[0])?;
121-
let start_array = as_int64_array(&args[1])?;
122-
let count_array = as_int64_array(&args[2])?;
139+
2 => {
140+
let iter = ArrayIter::new(string_array);
141+
let start_array = as_int64_array(&args[0])?;
142+
let count_array = as_int64_array(&args[1])?;
123143

124-
let result = string_array
125-
.iter()
144+
let result = iter
126145
.zip(start_array.iter())
127146
.zip(count_array.iter())
128147
.map(|((string, start), count)| match (string, start, count) {
@@ -162,6 +181,71 @@ mod tests {
162181

163182
#[test]
164183
fn test_functions() -> Result<()> {
184+
test_function!(
185+
SubstrFunc::new(),
186+
&[
187+
ColumnarValue::Scalar(ScalarValue::Utf8View(None)),
188+
ColumnarValue::Scalar(ScalarValue::from(1i64)),
189+
],
190+
Ok(None),
191+
&str,
192+
Utf8,
193+
StringArray
194+
);
195+
test_function!(
196+
SubstrFunc::new(),
197+
&[
198+
ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
199+
"alphabet"
200+
)))),
201+
ColumnarValue::Scalar(ScalarValue::from(0i64)),
202+
],
203+
Ok(Some("alphabet")),
204+
&str,
205+
Utf8,
206+
StringArray
207+
);
208+
test_function!(
209+
SubstrFunc::new(),
210+
&[
211+
ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
212+
"joséésoj"
213+
)))),
214+
ColumnarValue::Scalar(ScalarValue::from(5i64)),
215+
],
216+
Ok(Some("ésoj")),
217+
&str,
218+
Utf8,
219+
StringArray
220+
);
221+
test_function!(
222+
SubstrFunc::new(),
223+
&[
224+
ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
225+
"alphabet"
226+
)))),
227+
ColumnarValue::Scalar(ScalarValue::from(3i64)),
228+
ColumnarValue::Scalar(ScalarValue::from(2i64)),
229+
],
230+
Ok(Some("ph")),
231+
&str,
232+
Utf8,
233+
StringArray
234+
);
235+
test_function!(
236+
SubstrFunc::new(),
237+
&[
238+
ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
239+
"alphabet"
240+
)))),
241+
ColumnarValue::Scalar(ScalarValue::from(3i64)),
242+
ColumnarValue::Scalar(ScalarValue::from(20i64)),
243+
],
244+
Ok(Some("phabet")),
245+
&str,
246+
Utf8,
247+
StringArray
248+
);
165249
test_function!(
166250
SubstrFunc::new(),
167251
&[

datafusion/sqllogictest/test_files/string_view.slt

+25-3
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,30 @@ logical_plan
484484
01)Projection: test.column1_utf8view LIKE Utf8View("foo") AS like, test.column1_utf8view ILIKE Utf8View("foo") AS ilike
485485
02)--TableScan: test projection=[column1_utf8view]
486486

487+
## Ensure no casts for SUBSTR
487488

489+
query TT
490+
EXPLAIN SELECT
491+
SUBSTR(column1_utf8view, 1, 3) as c1,
492+
SUBSTR(column2_utf8, 1, 3) as c2,
493+
SUBSTR(column2_large_utf8, 1, 3) as c3
494+
FROM test;
495+
----
496+
logical_plan
497+
01)Projection: substr(test.column1_utf8view, Int64(1), Int64(3)) AS c1, substr(test.column2_utf8, Int64(1), Int64(3)) AS c2, substr(test.column2_large_utf8, Int64(1), Int64(3)) AS c3
498+
02)--TableScan: test projection=[column2_utf8, column2_large_utf8, column1_utf8view]
499+
500+
query TTT
501+
SELECT
502+
SUBSTR(column1_utf8view, 1, 3) as c1,
503+
SUBSTR(column2_utf8, 1, 3) as c2,
504+
SUBSTR(column2_large_utf8, 1, 3) as c3
505+
FROM test;
506+
----
507+
And X X
508+
Xia Xia Xia
509+
Rap R R
510+
NULL R R
488511

489512
## Ensure no casts for ASCII
490513

@@ -1010,9 +1033,8 @@ EXPLAIN SELECT
10101033
FROM test;
10111034
----
10121035
logical_plan
1013-
01)Projection: substr(__common_expr_1, Int64(1)) AS c, substr(__common_expr_1, Int64(1), Int64(2)) AS c2
1014-
02)--Projection: CAST(test.column1_utf8view AS Utf8) AS __common_expr_1
1015-
03)----TableScan: test projection=[column1_utf8view]
1036+
01)Projection: substr(test.column1_utf8view, Int64(1)) AS c, substr(test.column1_utf8view, Int64(1), Int64(2)) AS c2
1037+
02)--TableScan: test projection=[column1_utf8view]
10161038

10171039
## Ensure no casts for SUBSTRINDEX
10181040
query TT

0 commit comments

Comments
 (0)