Skip to content

Commit b60cdc7

Browse files
authored
Implement native support StringView for Levenshtein (apache#11925)
* Implement native support StringView for Levenshtein Signed-off-by: Chojan Shang <[email protected]> * Remove useless code Signed-off-by: Chojan Shang <[email protected]> * Minor fix Signed-off-by: Chojan Shang <[email protected]> --------- Signed-off-by: Chojan Shang <[email protected]>
1 parent 8deba02 commit b60cdc7

File tree

2 files changed

+32
-11
lines changed

2 files changed

+32
-11
lines changed

datafusion/functions/src/string/levenshtein.rs

+30-7
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ use arrow::array::{ArrayRef, Int32Array, Int64Array, OffsetSizeTrait};
2222
use arrow::datatypes::DataType;
2323

2424
use crate::utils::{make_scalar_function, utf8_to_int_type};
25-
use datafusion_common::cast::as_generic_string_array;
25+
use datafusion_common::cast::{as_generic_string_array, as_string_view_array};
2626
use datafusion_common::utils::datafusion_strsim;
2727
use datafusion_common::{exec_err, Result};
2828
use datafusion_expr::ColumnarValue;
@@ -42,10 +42,13 @@ impl Default for LevenshteinFunc {
4242

4343
impl LevenshteinFunc {
4444
pub fn new() -> Self {
45-
use DataType::*;
4645
Self {
4746
signature: Signature::one_of(
48-
vec![Exact(vec![Utf8, Utf8]), Exact(vec![LargeUtf8, LargeUtf8])],
47+
vec![
48+
Exact(vec![DataType::Utf8View, DataType::Utf8View]),
49+
Exact(vec![DataType::Utf8, DataType::Utf8]),
50+
Exact(vec![DataType::LargeUtf8, DataType::LargeUtf8]),
51+
],
4952
Volatility::Immutable,
5053
),
5154
}
@@ -71,7 +74,9 @@ impl ScalarUDFImpl for LevenshteinFunc {
7174

7275
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
7376
match args[0].data_type() {
74-
DataType::Utf8 => make_scalar_function(levenshtein::<i32>, vec![])(args),
77+
DataType::Utf8View | DataType::Utf8 => {
78+
make_scalar_function(levenshtein::<i32>, vec![])(args)
79+
}
7580
DataType::LargeUtf8 => make_scalar_function(levenshtein::<i64>, vec![])(args),
7681
other => {
7782
exec_err!("Unsupported data type {other:?} for function levenshtein")
@@ -89,10 +94,26 @@ pub fn levenshtein<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
8994
args.len()
9095
);
9196
}
92-
let str1_array = as_generic_string_array::<T>(&args[0])?;
93-
let str2_array = as_generic_string_array::<T>(&args[1])?;
97+
9498
match args[0].data_type() {
99+
DataType::Utf8View => {
100+
let str1_array = as_string_view_array(&args[0])?;
101+
let str2_array = as_string_view_array(&args[1])?;
102+
let result = str1_array
103+
.iter()
104+
.zip(str2_array.iter())
105+
.map(|(string1, string2)| match (string1, string2) {
106+
(Some(string1), Some(string2)) => {
107+
Some(datafusion_strsim::levenshtein(string1, string2) as i32)
108+
}
109+
_ => None,
110+
})
111+
.collect::<Int32Array>();
112+
Ok(Arc::new(result) as ArrayRef)
113+
}
95114
DataType::Utf8 => {
115+
let str1_array = as_generic_string_array::<T>(&args[0])?;
116+
let str2_array = as_generic_string_array::<T>(&args[1])?;
96117
let result = str1_array
97118
.iter()
98119
.zip(str2_array.iter())
@@ -106,6 +127,8 @@ pub fn levenshtein<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
106127
Ok(Arc::new(result) as ArrayRef)
107128
}
108129
DataType::LargeUtf8 => {
130+
let str1_array = as_generic_string_array::<T>(&args[0])?;
131+
let str2_array = as_generic_string_array::<T>(&args[1])?;
109132
let result = str1_array
110133
.iter()
111134
.zip(str2_array.iter())
@@ -120,7 +143,7 @@ pub fn levenshtein<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
120143
}
121144
other => {
122145
exec_err!(
123-
"levenshtein was called with {other} datatype arguments. It requires Utf8 or LargeUtf8."
146+
"levenshtein was called with {other} datatype arguments. It requires Utf8View, Utf8 or LargeUtf8."
124147
)
125148
}
126149
}

datafusion/sqllogictest/test_files/string_view.slt

+2-4
Original file line numberDiff line numberDiff line change
@@ -629,17 +629,15 @@ logical_plan
629629
02)--TableScan: test projection=[column1_utf8view, column2_utf8view]
630630

631631
## Ensure no casts for LEVENSHTEIN
632-
## TODO https://github.com/apache/datafusion/issues/11854
633632
query TT
634633
EXPLAIN SELECT
635634
levenshtein(column1_utf8view, 'foo') as c1,
636635
levenshtein(column1_utf8view, column2_utf8view) as c2
637636
FROM test;
638637
----
639638
logical_plan
640-
01)Projection: levenshtein(__common_expr_1, Utf8("foo")) AS c1, levenshtein(__common_expr_1, CAST(test.column2_utf8view AS Utf8)) AS c2
641-
02)--Projection: CAST(test.column1_utf8view AS Utf8) AS __common_expr_1, test.column2_utf8view
642-
03)----TableScan: test projection=[column1_utf8view, column2_utf8view]
639+
01)Projection: levenshtein(test.column1_utf8view, Utf8View("foo")) AS c1, levenshtein(test.column1_utf8view, test.column2_utf8view) AS c2
640+
02)--TableScan: test projection=[column1_utf8view, column2_utf8view]
643641

644642
## Ensure no casts for LOWER
645643
## TODO https://github.com/apache/datafusion/issues/11855

0 commit comments

Comments
 (0)