Skip to content

Commit c71a9d7

Browse files
tlm365alamb
andauthored
Implement native support StringView for CONTAINS function (#12168)
* Implement native support StringView for contains function Signed-off-by: Tai Le Manh <[email protected]> * Fix cargo fmt * Implement native support StringView for contains function Signed-off-by: Tai Le Manh <[email protected]> * Fix cargo check * Fix unresolved doc link * Implement native support StringView for contains function Signed-off-by: Tai Le Manh <[email protected]> * Update datafusion/functions/src/regexp_common.rs --------- Signed-off-by: Tai Le Manh <[email protected]> Co-authored-by: Andrew Lamb <[email protected]>
1 parent 376a0b8 commit c71a9d7

File tree

7 files changed

+329
-36
lines changed

7 files changed

+329
-36
lines changed

datafusion/functions/Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ math_expressions = []
5454
# enable regular expressions
5555
regex_expressions = ["regex"]
5656
# enable string functions
57-
string_expressions = ["uuid"]
57+
string_expressions = ["regex_expressions", "uuid"]
5858
# enable unicode functions
5959
unicode_expressions = ["hashbrown", "unicode-segmentation"]
6060

datafusion/functions/src/lib.rs

+3
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,9 @@ pub mod macros;
9292
pub mod string;
9393
make_stub_package!(string, "string_expressions");
9494

95+
#[cfg(feature = "string_expressions")]
96+
mod regexp_common;
97+
9598
/// Core datafusion expressions
9699
/// Enabled via feature flag `core_expressions`
97100
#[cfg(feature = "core_expressions")]

datafusion/functions/src/regex/mod.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,12 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
//! "regx" DataFusion functions
18+
//! "regex" DataFusion functions
1919
2020
pub mod regexplike;
2121
pub mod regexpmatch;
2222
pub mod regexpreplace;
23+
2324
// create UDFs
2425
make_udf_function!(regexpmatch::RegexpMatchFunc, REGEXP_MATCH, regexp_match);
2526
make_udf_function!(regexplike::RegexpLikeFunc, REGEXP_LIKE, regexp_like);
+123
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! Common utilities for implementing regex functions
19+
20+
use crate::string::common::StringArrayType;
21+
22+
use arrow::array::{Array, ArrayDataBuilder, BooleanArray};
23+
use arrow::datatypes::DataType;
24+
use arrow_buffer::{BooleanBufferBuilder, NullBuffer};
25+
use datafusion_common::DataFusionError;
26+
use regex::Regex;
27+
28+
use std::collections::HashMap;
29+
30+
#[cfg(doc)]
31+
use arrow::array::{LargeStringArray, StringArray, StringViewArray};
32+
/// Perform SQL `array ~ regex_array` operation on
33+
/// [`StringArray`] / [`LargeStringArray`] / [`StringViewArray`].
34+
///
35+
/// If `regex_array` element has an empty value, the corresponding result value is always true.
36+
///
37+
/// `flags_array` are optional [`StringArray`] / [`LargeStringArray`] / [`StringViewArray`] flag,
38+
/// which allow special search modes, such as case-insensitive and multi-line mode.
39+
/// See the documentation [here](https://docs.rs/regex/1.5.4/regex/#grouping-and-flags)
40+
/// for more information.
41+
///
42+
/// It is inspired / copied from `regexp_is_match_utf8` [arrow-rs].
43+
///
44+
/// Can remove when <https://github.com/apache/arrow-rs/issues/6370> is implemented upstream
45+
///
46+
/// [arrow-rs]: https://github.com/apache/arrow-rs/blob/8c956a9f9ab26c14072740cce64c2b99cb039b13/arrow-string/src/regexp.rs#L31-L37
47+
pub fn regexp_is_match_utf8<'a, S1, S2, S3>(
48+
array: &'a S1,
49+
regex_array: &'a S2,
50+
flags_array: Option<&'a S3>,
51+
) -> datafusion_common::Result<BooleanArray, DataFusionError>
52+
where
53+
&'a S1: StringArrayType<'a>,
54+
&'a S2: StringArrayType<'a>,
55+
&'a S3: StringArrayType<'a>,
56+
{
57+
if array.len() != regex_array.len() {
58+
return Err(DataFusionError::Execution(
59+
"Cannot perform comparison operation on arrays of different length"
60+
.to_string(),
61+
));
62+
}
63+
64+
let nulls = NullBuffer::union(array.nulls(), regex_array.nulls());
65+
66+
let mut patterns: HashMap<String, Regex> = HashMap::new();
67+
let mut result = BooleanBufferBuilder::new(array.len());
68+
69+
let complete_pattern = match flags_array {
70+
Some(flags) => Box::new(regex_array.iter().zip(flags.iter()).map(
71+
|(pattern, flags)| {
72+
pattern.map(|pattern| match flags {
73+
Some(flag) => format!("(?{flag}){pattern}"),
74+
None => pattern.to_string(),
75+
})
76+
},
77+
)) as Box<dyn Iterator<Item = Option<String>>>,
78+
None => Box::new(
79+
regex_array
80+
.iter()
81+
.map(|pattern| pattern.map(|pattern| pattern.to_string())),
82+
),
83+
};
84+
85+
array
86+
.iter()
87+
.zip(complete_pattern)
88+
.map(|(value, pattern)| {
89+
match (value, pattern) {
90+
(Some(_), Some(pattern)) if pattern == *"" => {
91+
result.append(true);
92+
}
93+
(Some(value), Some(pattern)) => {
94+
let existing_pattern = patterns.get(&pattern);
95+
let re = match existing_pattern {
96+
Some(re) => re,
97+
None => {
98+
let re = Regex::new(pattern.as_str()).map_err(|e| {
99+
DataFusionError::Execution(format!(
100+
"Regular expression did not compile: {e:?}"
101+
))
102+
})?;
103+
patterns.entry(pattern).or_insert(re)
104+
}
105+
};
106+
result.append(re.is_match(value));
107+
}
108+
_ => result.append(false),
109+
}
110+
Ok(())
111+
})
112+
.collect::<datafusion_common::Result<Vec<()>, DataFusionError>>()?;
113+
114+
let data = unsafe {
115+
ArrayDataBuilder::new(DataType::Boolean)
116+
.len(array.len())
117+
.buffers(vec![result.into()])
118+
.nulls(nulls)
119+
.build_unchecked()
120+
};
121+
122+
Ok(BooleanArray::from(data))
123+
}

datafusion/functions/src/string/contains.rs

+167-23
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,22 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use crate::regexp_common::regexp_is_match_utf8;
1819
use crate::utils::make_scalar_function;
19-
use arrow::array::{ArrayRef, OffsetSizeTrait};
20+
21+
use arrow::array::{Array, ArrayRef, AsArray, GenericStringArray, StringViewArray};
2022
use arrow::datatypes::DataType;
21-
use arrow::datatypes::DataType::Boolean;
22-
use datafusion_common::cast::as_generic_string_array;
23+
use arrow::datatypes::DataType::{Boolean, LargeUtf8, Utf8, Utf8View};
24+
use datafusion_common::exec_err;
2325
use datafusion_common::DataFusionError;
2426
use datafusion_common::Result;
25-
use datafusion_common::{arrow_datafusion_err, exec_err};
2627
use datafusion_expr::ScalarUDFImpl;
2728
use datafusion_expr::TypeSignature::Exact;
2829
use datafusion_expr::{ColumnarValue, Signature, Volatility};
30+
2931
use std::any::Any;
3032
use std::sync::Arc;
33+
3134
#[derive(Debug)]
3235
pub struct ContainsFunc {
3336
signature: Signature,
@@ -44,7 +47,17 @@ impl ContainsFunc {
4447
use DataType::*;
4548
Self {
4649
signature: Signature::one_of(
47-
vec![Exact(vec![Utf8, Utf8]), Exact(vec![LargeUtf8, LargeUtf8])],
50+
vec![
51+
Exact(vec![Utf8View, Utf8View]),
52+
Exact(vec![Utf8View, Utf8]),
53+
Exact(vec![Utf8View, LargeUtf8]),
54+
Exact(vec![Utf8, Utf8View]),
55+
Exact(vec![Utf8, Utf8]),
56+
Exact(vec![Utf8, LargeUtf8]),
57+
Exact(vec![LargeUtf8, Utf8View]),
58+
Exact(vec![LargeUtf8, Utf8]),
59+
Exact(vec![LargeUtf8, LargeUtf8]),
60+
],
4861
Volatility::Immutable,
4962
),
5063
}
@@ -69,28 +82,116 @@ impl ScalarUDFImpl for ContainsFunc {
6982
}
7083

7184
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
72-
match args[0].data_type() {
73-
DataType::Utf8 => make_scalar_function(contains::<i32>, vec![])(args),
74-
DataType::LargeUtf8 => make_scalar_function(contains::<i64>, vec![])(args),
75-
other => {
76-
exec_err!("unsupported data type {other:?} for function contains")
77-
}
78-
}
85+
make_scalar_function(contains, vec![])(args)
7986
}
8087
}
8188

8289
/// use regexp_is_match_utf8_scalar to do the calculation for contains
83-
pub fn contains<T: OffsetSizeTrait>(
84-
args: &[ArrayRef],
85-
) -> Result<ArrayRef, DataFusionError> {
86-
let mod_str = as_generic_string_array::<T>(&args[0])?;
87-
let match_str = as_generic_string_array::<T>(&args[1])?;
88-
let res = arrow::compute::kernels::comparison::regexp_is_match_utf8(
89-
mod_str, match_str, None,
90-
)
91-
.map_err(|e| arrow_datafusion_err!(e))?;
92-
93-
Ok(Arc::new(res) as ArrayRef)
90+
pub fn contains(args: &[ArrayRef]) -> Result<ArrayRef, DataFusionError> {
91+
match (args[0].data_type(), args[1].data_type()) {
92+
(Utf8View, Utf8View) => {
93+
let mod_str = args[0].as_string_view();
94+
let match_str = args[1].as_string_view();
95+
let res = regexp_is_match_utf8::<
96+
StringViewArray,
97+
StringViewArray,
98+
GenericStringArray<i32>,
99+
>(mod_str, match_str, None)?;
100+
101+
Ok(Arc::new(res) as ArrayRef)
102+
}
103+
(Utf8View, Utf8) => {
104+
let mod_str = args[0].as_string_view();
105+
let match_str = args[1].as_string::<i32>();
106+
let res = regexp_is_match_utf8::<
107+
StringViewArray,
108+
GenericStringArray<i32>,
109+
GenericStringArray<i32>,
110+
>(mod_str, match_str, None)?;
111+
112+
Ok(Arc::new(res) as ArrayRef)
113+
}
114+
(Utf8View, LargeUtf8) => {
115+
let mod_str = args[0].as_string_view();
116+
let match_str = args[1].as_string::<i64>();
117+
let res = regexp_is_match_utf8::<
118+
StringViewArray,
119+
GenericStringArray<i64>,
120+
GenericStringArray<i32>,
121+
>(mod_str, match_str, None)?;
122+
123+
Ok(Arc::new(res) as ArrayRef)
124+
}
125+
(Utf8, Utf8View) => {
126+
let mod_str = args[0].as_string::<i32>();
127+
let match_str = args[1].as_string_view();
128+
let res = regexp_is_match_utf8::<
129+
GenericStringArray<i32>,
130+
StringViewArray,
131+
GenericStringArray<i32>,
132+
>(mod_str, match_str, None)?;
133+
134+
Ok(Arc::new(res) as ArrayRef)
135+
}
136+
(Utf8, Utf8) => {
137+
let mod_str = args[0].as_string::<i32>();
138+
let match_str = args[1].as_string::<i32>();
139+
let res = regexp_is_match_utf8::<
140+
GenericStringArray<i32>,
141+
GenericStringArray<i32>,
142+
GenericStringArray<i32>,
143+
>(mod_str, match_str, None)?;
144+
145+
Ok(Arc::new(res) as ArrayRef)
146+
}
147+
(Utf8, LargeUtf8) => {
148+
let mod_str = args[0].as_string::<i32>();
149+
let match_str = args[1].as_string::<i64>();
150+
let res = regexp_is_match_utf8::<
151+
GenericStringArray<i32>,
152+
GenericStringArray<i64>,
153+
GenericStringArray<i32>,
154+
>(mod_str, match_str, None)?;
155+
156+
Ok(Arc::new(res) as ArrayRef)
157+
}
158+
(LargeUtf8, Utf8View) => {
159+
let mod_str = args[0].as_string::<i64>();
160+
let match_str = args[1].as_string_view();
161+
let res = regexp_is_match_utf8::<
162+
GenericStringArray<i64>,
163+
StringViewArray,
164+
GenericStringArray<i32>,
165+
>(mod_str, match_str, None)?;
166+
167+
Ok(Arc::new(res) as ArrayRef)
168+
}
169+
(LargeUtf8, Utf8) => {
170+
let mod_str = args[0].as_string::<i64>();
171+
let match_str = args[1].as_string::<i32>();
172+
let res = regexp_is_match_utf8::<
173+
GenericStringArray<i64>,
174+
GenericStringArray<i32>,
175+
GenericStringArray<i32>,
176+
>(mod_str, match_str, None)?;
177+
178+
Ok(Arc::new(res) as ArrayRef)
179+
}
180+
(LargeUtf8, LargeUtf8) => {
181+
let mod_str = args[0].as_string::<i64>();
182+
let match_str = args[1].as_string::<i64>();
183+
let res = regexp_is_match_utf8::<
184+
GenericStringArray<i64>,
185+
GenericStringArray<i64>,
186+
GenericStringArray<i32>,
187+
>(mod_str, match_str, None)?;
188+
189+
Ok(Arc::new(res) as ArrayRef)
190+
}
191+
other => {
192+
exec_err!("Unsupported data type {other:?} for function `contains`.")
193+
}
194+
}
94195
}
95196

96197
#[cfg(test)]
@@ -138,6 +239,49 @@ mod tests {
138239
Boolean,
139240
BooleanArray
140241
);
242+
243+
test_function!(
244+
ContainsFunc::new(),
245+
&[
246+
ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
247+
"Apache"
248+
)))),
249+
ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("pac")))),
250+
],
251+
Ok(Some(true)),
252+
bool,
253+
Boolean,
254+
BooleanArray
255+
);
256+
test_function!(
257+
ContainsFunc::new(),
258+
&[
259+
ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
260+
"Apache"
261+
)))),
262+
ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("ap")))),
263+
],
264+
Ok(Some(false)),
265+
bool,
266+
Boolean,
267+
BooleanArray
268+
);
269+
test_function!(
270+
ContainsFunc::new(),
271+
&[
272+
ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
273+
"Apache"
274+
)))),
275+
ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(String::from(
276+
"DataFusion"
277+
)))),
278+
],
279+
Ok(Some(false)),
280+
bool,
281+
Boolean,
282+
BooleanArray
283+
);
284+
141285
Ok(())
142286
}
143287
}

0 commit comments

Comments
 (0)