Skip to content

Commit 488dc58

Browse files
committed
Implement native support StringView for substr_index
Signed-off-by: Chojan Shang <[email protected]>
1 parent 69c99a7 commit 488dc58

File tree

3 files changed

+144
-20
lines changed

3 files changed

+144
-20
lines changed

datafusion/functions/src/unicode/substrindex.rs

+63-20
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,12 @@
1818
use std::any::Any;
1919
use std::sync::Arc;
2020

21-
use arrow::array::{ArrayRef, OffsetSizeTrait, StringBuilder};
22-
use arrow::datatypes::DataType;
21+
use arrow::array::{
22+
ArrayAccessor, ArrayIter, ArrayRef, ArrowPrimitiveType, AsArray, OffsetSizeTrait,
23+
PrimitiveArray, StringBuilder,
24+
};
25+
use arrow::datatypes::{DataType, Int32Type, Int64Type};
2326

24-
use datafusion_common::cast::{as_generic_string_array, as_int64_array};
2527
use datafusion_common::{exec_err, Result};
2628
use datafusion_expr::TypeSignature::Exact;
2729
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
@@ -46,6 +48,7 @@ impl SubstrIndexFunc {
4648
Self {
4749
signature: Signature::one_of(
4850
vec![
51+
Exact(vec![Utf8View, Utf8View, Int64]),
4952
Exact(vec![Utf8, Utf8, Int64]),
5053
Exact(vec![LargeUtf8, LargeUtf8, Int64]),
5154
],
@@ -74,15 +77,7 @@ impl ScalarUDFImpl for SubstrIndexFunc {
7477
}
7578

7679
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
77-
match args[0].data_type() {
78-
DataType::Utf8 => make_scalar_function(substr_index::<i32>, vec![])(args),
79-
DataType::LargeUtf8 => {
80-
make_scalar_function(substr_index::<i64>, vec![])(args)
81-
}
82-
other => {
83-
exec_err!("Unsupported data type {other:?} for function substr_index")
84-
}
85-
}
80+
make_scalar_function(substr_index, vec![])(args)
8681
}
8782

8883
fn aliases(&self) -> &[String] {
@@ -95,23 +90,71 @@ impl ScalarUDFImpl for SubstrIndexFunc {
9590
/// SUBSTRING_INDEX('www.apache.org', '.', 2) = www.apache
9691
/// SUBSTRING_INDEX('www.apache.org', '.', -2) = apache.org
9792
/// SUBSTRING_INDEX('www.apache.org', '.', -1) = org
98-
pub fn substr_index<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
93+
fn substr_index(args: &[ArrayRef]) -> Result<ArrayRef> {
9994
if args.len() != 3 {
10095
return exec_err!(
10196
"substr_index was called with {} arguments. It requires 3.",
10297
args.len()
10398
);
10499
}
105100

106-
let string_array = as_generic_string_array::<T>(&args[0])?;
107-
let delimiter_array = as_generic_string_array::<T>(&args[1])?;
108-
let count_array = as_int64_array(&args[2])?;
101+
match args[0].data_type() {
102+
DataType::Utf8 => {
103+
let string_array = args[0].as_string::<i32>();
104+
let delimiter_array = args[1].as_string::<i32>();
105+
let count_array: &PrimitiveArray<Int64Type> = args[2].as_primitive();
106+
substr_index_general::<Int32Type, _, _>(
107+
string_array,
108+
delimiter_array,
109+
count_array,
110+
)
111+
}
112+
DataType::LargeUtf8 => {
113+
let string_array = args[0].as_string::<i64>();
114+
let delimiter_array = args[1].as_string::<i64>();
115+
let count_array: &PrimitiveArray<Int64Type> = args[2].as_primitive();
116+
substr_index_general::<Int64Type, _, _>(
117+
string_array,
118+
delimiter_array,
119+
count_array,
120+
)
121+
}
122+
DataType::Utf8View => {
123+
let string_array = args[0].as_string_view();
124+
let delimiter_array = args[1].as_string_view();
125+
let count_array: &PrimitiveArray<Int64Type> = args[2].as_primitive();
126+
substr_index_general::<Int32Type, _, _>(
127+
string_array,
128+
delimiter_array,
129+
count_array,
130+
)
131+
}
132+
other => {
133+
exec_err!("Unsupported data type {other:?} for function substr_index")
134+
}
135+
}
136+
}
109137

138+
pub fn substr_index_general<
139+
'a,
140+
T: ArrowPrimitiveType,
141+
V: ArrayAccessor<Item = &'a str>,
142+
P: ArrayAccessor<Item = i64>,
143+
>(
144+
string_array: V,
145+
delimiter_array: V,
146+
count_array: P,
147+
) -> Result<ArrayRef>
148+
where
149+
T::Native: OffsetSizeTrait,
150+
{
110151
let mut builder = StringBuilder::new();
111-
string_array
112-
.iter()
113-
.zip(delimiter_array.iter())
114-
.zip(count_array.iter())
152+
let string_iter = ArrayIter::new(string_array);
153+
let delimiter_array_iter = ArrayIter::new(delimiter_array);
154+
let count_array_iter = ArrayIter::new(count_array);
155+
string_iter
156+
.zip(delimiter_array_iter)
157+
.zip(count_array_iter)
115158
.for_each(|((string, delimiter), n)| match (string, delimiter, n) {
116159
(Some(string), Some(delimiter), Some(n)) => {
117160
// In MySQL, these cases will return an empty string.

datafusion/sqllogictest/test_files/functions.slt

+59
Original file line numberDiff line numberDiff line change
@@ -1014,6 +1014,65 @@ arrow.apache.org 100 arrow.apache.org
10141014
. 3 .
10151015
. 100 .
10161016

1017+
query I
1018+
SELECT levenshtein(NULL, NULL)
1019+
----
1020+
NULL
1021+
1022+
# Test substring_index using '.' as delimiter with utf8view
1023+
query TIT
1024+
SELECT str, n, substring_index(arrow_cast(str, 'Utf8View'), '.', n) AS c FROM
1025+
(VALUES
1026+
ROW('arrow.apache.org'),
1027+
ROW('.'),
1028+
ROW('...'),
1029+
ROW(NULL)
1030+
) AS strings(str),
1031+
(VALUES
1032+
ROW(1),
1033+
ROW(2),
1034+
ROW(3),
1035+
ROW(100),
1036+
ROW(-1),
1037+
ROW(-2),
1038+
ROW(-3),
1039+
ROW(-100)
1040+
) AS occurrences(n)
1041+
ORDER BY str DESC, n;
1042+
----
1043+
NULL -100 NULL
1044+
NULL -3 NULL
1045+
NULL -2 NULL
1046+
NULL -1 NULL
1047+
NULL 1 NULL
1048+
NULL 2 NULL
1049+
NULL 3 NULL
1050+
NULL 100 NULL
1051+
arrow.apache.org -100 arrow.apache.org
1052+
arrow.apache.org -3 arrow.apache.org
1053+
arrow.apache.org -2 apache.org
1054+
arrow.apache.org -1 org
1055+
arrow.apache.org 1 arrow
1056+
arrow.apache.org 2 arrow.apache
1057+
arrow.apache.org 3 arrow.apache.org
1058+
arrow.apache.org 100 arrow.apache.org
1059+
... -100 ...
1060+
... -3 ..
1061+
... -2 .
1062+
... -1 (empty)
1063+
... 1 (empty)
1064+
... 2 .
1065+
... 3 ..
1066+
... 100 ...
1067+
. -100 .
1068+
. -3 .
1069+
. -2 .
1070+
. -1 (empty)
1071+
. 1 (empty)
1072+
. 2 .
1073+
. 3 .
1074+
. 100 .
1075+
10171076
# Test substring_index using 'ac' as delimiter
10181077
query TIT
10191078
SELECT str, n, substring_index(str, 'ac', n) AS c FROM

datafusion/sqllogictest/test_files/string_view.slt

+22
Original file line numberDiff line numberDiff line change
@@ -969,6 +969,28 @@ logical_plan
969969
02)--Projection: CAST(test.column1_utf8view AS Utf8) AS __common_expr_1
970970
03)----TableScan: test projection=[column1_utf8view]
971971

972+
## Ensure no casts for SUBSTRINDEX
973+
query TT
974+
EXPLAIN SELECT
975+
SUBSTR_INDEX(column1_utf8view, 'a', 1) as c,
976+
SUBSTR_INDEX(column1_utf8view, 'a', 2) as c2
977+
FROM test;
978+
----
979+
logical_plan
980+
01)Projection: substr_index(test.column1_utf8view, Utf8View("a"), Int64(1)) AS c, substr_index(test.column1_utf8view, Utf8View("a"), Int64(2)) AS c2
981+
02)--TableScan: test projection=[column1_utf8view]
982+
983+
query TT
984+
SELECT
985+
SUBSTR_INDEX(column1_utf8view, 'a', 1) as c,
986+
SUBSTR_INDEX(column1_utf8view, 'a', 2) as c2
987+
FROM test;
988+
----
989+
Andrew Andrew
990+
Xi Xiangpeng
991+
R Raph
992+
NULL NULL
993+
972994
## Ensure no casts on columns for STARTS_WITH
973995
query TT
974996
EXPLAIN SELECT

0 commit comments

Comments
 (0)