From c71a9d7508e37e5d082e22d2953a12b61d290df5 Mon Sep 17 00:00:00 2001 From: Tai Le Manh <49281946+tlm365@users.noreply.github.com> Date: Tue, 10 Sep 2024 17:56:57 +0700 Subject: [PATCH] Implement native support StringView for `CONTAINS` function (#12168) * Implement native support StringView for contains function Signed-off-by: Tai Le Manh * Fix cargo fmt * Implement native support StringView for contains function Signed-off-by: Tai Le Manh * Fix cargo check * Fix unresolved doc link * Implement native support StringView for contains function Signed-off-by: Tai Le Manh * Update datafusion/functions/src/regexp_common.rs --------- Signed-off-by: Tai Le Manh Co-authored-by: Andrew Lamb --- datafusion/functions/Cargo.toml | 2 +- datafusion/functions/src/lib.rs | 3 + datafusion/functions/src/regex/mod.rs | 3 +- datafusion/functions/src/regexp_common.rs | 123 ++++++++++++ datafusion/functions/src/string/contains.rs | 190 +++++++++++++++--- .../sqllogictest/test_files/string_view.slt | 42 +++- .../source/user-guide/sql/scalar_functions.md | 2 +- 7 files changed, 329 insertions(+), 36 deletions(-) create mode 100644 datafusion/functions/src/regexp_common.rs diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index 3c95c03896e2..5b6dceaa420d 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -54,7 +54,7 @@ math_expressions = [] # enable regular expressions regex_expressions = ["regex"] # enable string functions -string_expressions = ["uuid"] +string_expressions = ["regex_expressions", "uuid"] # enable unicode functions unicode_expressions = ["hashbrown", "unicode-segmentation"] diff --git a/datafusion/functions/src/lib.rs b/datafusion/functions/src/lib.rs index 81be5552666d..bb680f3c67de 100644 --- a/datafusion/functions/src/lib.rs +++ b/datafusion/functions/src/lib.rs @@ -92,6 +92,9 @@ pub mod macros; pub mod string; make_stub_package!(string, "string_expressions"); +#[cfg(feature = "string_expressions")] +mod regexp_common; + /// Core datafusion expressions /// Enabled via feature flag `core_expressions` #[cfg(feature = "core_expressions")] diff --git a/datafusion/functions/src/regex/mod.rs b/datafusion/functions/src/regex/mod.rs index 4ac162290ddb..4afbe6cbbb89 100644 --- a/datafusion/functions/src/regex/mod.rs +++ b/datafusion/functions/src/regex/mod.rs @@ -15,11 +15,12 @@ // specific language governing permissions and limitations // under the License. -//! "regx" DataFusion functions +//! "regex" DataFusion functions pub mod regexplike; pub mod regexpmatch; pub mod regexpreplace; + // create UDFs make_udf_function!(regexpmatch::RegexpMatchFunc, REGEXP_MATCH, regexp_match); make_udf_function!(regexplike::RegexpLikeFunc, REGEXP_LIKE, regexp_like); diff --git a/datafusion/functions/src/regexp_common.rs b/datafusion/functions/src/regexp_common.rs new file mode 100644 index 000000000000..748c1a294f97 --- /dev/null +++ b/datafusion/functions/src/regexp_common.rs @@ -0,0 +1,123 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Common utilities for implementing regex functions + +use crate::string::common::StringArrayType; + +use arrow::array::{Array, ArrayDataBuilder, BooleanArray}; +use arrow::datatypes::DataType; +use arrow_buffer::{BooleanBufferBuilder, NullBuffer}; +use datafusion_common::DataFusionError; +use regex::Regex; + +use std::collections::HashMap; + +#[cfg(doc)] +use arrow::array::{LargeStringArray, StringArray, StringViewArray}; +/// Perform SQL `array ~ regex_array` operation on +/// [`StringArray`] / [`LargeStringArray`] / [`StringViewArray`]. +/// +/// If `regex_array` element has an empty value, the corresponding result value is always true. +/// +/// `flags_array` are optional [`StringArray`] / [`LargeStringArray`] / [`StringViewArray`] flag, +/// which allow special search modes, such as case-insensitive and multi-line mode. +/// See the documentation [here](https://docs.rs/regex/1.5.4/regex/#grouping-and-flags) +/// for more information. +/// +/// It is inspired / copied from `regexp_is_match_utf8` [arrow-rs]. +/// +/// Can remove when is implemented upstream +/// +/// [arrow-rs]: https://github.com/apache/arrow-rs/blob/8c956a9f9ab26c14072740cce64c2b99cb039b13/arrow-string/src/regexp.rs#L31-L37 +pub fn regexp_is_match_utf8<'a, S1, S2, S3>( + array: &'a S1, + regex_array: &'a S2, + flags_array: Option<&'a S3>, +) -> datafusion_common::Result +where + &'a S1: StringArrayType<'a>, + &'a S2: StringArrayType<'a>, + &'a S3: StringArrayType<'a>, +{ + if array.len() != regex_array.len() { + return Err(DataFusionError::Execution( + "Cannot perform comparison operation on arrays of different length" + .to_string(), + )); + } + + let nulls = NullBuffer::union(array.nulls(), regex_array.nulls()); + + let mut patterns: HashMap = HashMap::new(); + let mut result = BooleanBufferBuilder::new(array.len()); + + let complete_pattern = match flags_array { + Some(flags) => Box::new(regex_array.iter().zip(flags.iter()).map( + |(pattern, flags)| { + pattern.map(|pattern| match flags { + Some(flag) => format!("(?{flag}){pattern}"), + None => pattern.to_string(), + }) + }, + )) as Box>>, + None => Box::new( + regex_array + .iter() + .map(|pattern| pattern.map(|pattern| pattern.to_string())), + ), + }; + + array + .iter() + .zip(complete_pattern) + .map(|(value, pattern)| { + match (value, pattern) { + (Some(_), Some(pattern)) if pattern == *"" => { + result.append(true); + } + (Some(value), Some(pattern)) => { + let existing_pattern = patterns.get(&pattern); + let re = match existing_pattern { + Some(re) => re, + None => { + let re = Regex::new(pattern.as_str()).map_err(|e| { + DataFusionError::Execution(format!( + "Regular expression did not compile: {e:?}" + )) + })?; + patterns.entry(pattern).or_insert(re) + } + }; + result.append(re.is_match(value)); + } + _ => result.append(false), + } + Ok(()) + }) + .collect::, DataFusionError>>()?; + + let data = unsafe { + ArrayDataBuilder::new(DataType::Boolean) + .len(array.len()) + .buffers(vec![result.into()]) + .nulls(nulls) + .build_unchecked() + }; + + Ok(BooleanArray::from(data)) +} diff --git a/datafusion/functions/src/string/contains.rs b/datafusion/functions/src/string/contains.rs index faf979f80614..c319f80661c3 100644 --- a/datafusion/functions/src/string/contains.rs +++ b/datafusion/functions/src/string/contains.rs @@ -15,19 +15,22 @@ // specific language governing permissions and limitations // under the License. +use crate::regexp_common::regexp_is_match_utf8; use crate::utils::make_scalar_function; -use arrow::array::{ArrayRef, OffsetSizeTrait}; + +use arrow::array::{Array, ArrayRef, AsArray, GenericStringArray, StringViewArray}; use arrow::datatypes::DataType; -use arrow::datatypes::DataType::Boolean; -use datafusion_common::cast::as_generic_string_array; +use arrow::datatypes::DataType::{Boolean, LargeUtf8, Utf8, Utf8View}; +use datafusion_common::exec_err; use datafusion_common::DataFusionError; use datafusion_common::Result; -use datafusion_common::{arrow_datafusion_err, exec_err}; use datafusion_expr::ScalarUDFImpl; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ColumnarValue, Signature, Volatility}; + use std::any::Any; use std::sync::Arc; + #[derive(Debug)] pub struct ContainsFunc { signature: Signature, @@ -44,7 +47,17 @@ impl ContainsFunc { use DataType::*; Self { signature: Signature::one_of( - vec![Exact(vec![Utf8, Utf8]), Exact(vec![LargeUtf8, LargeUtf8])], + vec![ + Exact(vec![Utf8View, Utf8View]), + Exact(vec![Utf8View, Utf8]), + Exact(vec![Utf8View, LargeUtf8]), + Exact(vec![Utf8, Utf8View]), + Exact(vec![Utf8, Utf8]), + Exact(vec![Utf8, LargeUtf8]), + Exact(vec![LargeUtf8, Utf8View]), + Exact(vec![LargeUtf8, Utf8]), + Exact(vec![LargeUtf8, LargeUtf8]), + ], Volatility::Immutable, ), } @@ -69,28 +82,116 @@ impl ScalarUDFImpl for ContainsFunc { } fn invoke(&self, args: &[ColumnarValue]) -> Result { - match args[0].data_type() { - DataType::Utf8 => make_scalar_function(contains::, vec![])(args), - DataType::LargeUtf8 => make_scalar_function(contains::, vec![])(args), - other => { - exec_err!("unsupported data type {other:?} for function contains") - } - } + make_scalar_function(contains, vec![])(args) } } /// use regexp_is_match_utf8_scalar to do the calculation for contains -pub fn contains( - args: &[ArrayRef], -) -> Result { - let mod_str = as_generic_string_array::(&args[0])?; - let match_str = as_generic_string_array::(&args[1])?; - let res = arrow::compute::kernels::comparison::regexp_is_match_utf8( - mod_str, match_str, None, - ) - .map_err(|e| arrow_datafusion_err!(e))?; - - Ok(Arc::new(res) as ArrayRef) +pub fn contains(args: &[ArrayRef]) -> Result { + match (args[0].data_type(), args[1].data_type()) { + (Utf8View, Utf8View) => { + let mod_str = args[0].as_string_view(); + let match_str = args[1].as_string_view(); + let res = regexp_is_match_utf8::< + StringViewArray, + StringViewArray, + GenericStringArray, + >(mod_str, match_str, None)?; + + Ok(Arc::new(res) as ArrayRef) + } + (Utf8View, Utf8) => { + let mod_str = args[0].as_string_view(); + let match_str = args[1].as_string::(); + let res = regexp_is_match_utf8::< + StringViewArray, + GenericStringArray, + GenericStringArray, + >(mod_str, match_str, None)?; + + Ok(Arc::new(res) as ArrayRef) + } + (Utf8View, LargeUtf8) => { + let mod_str = args[0].as_string_view(); + let match_str = args[1].as_string::(); + let res = regexp_is_match_utf8::< + StringViewArray, + GenericStringArray, + GenericStringArray, + >(mod_str, match_str, None)?; + + Ok(Arc::new(res) as ArrayRef) + } + (Utf8, Utf8View) => { + let mod_str = args[0].as_string::(); + let match_str = args[1].as_string_view(); + let res = regexp_is_match_utf8::< + GenericStringArray, + StringViewArray, + GenericStringArray, + >(mod_str, match_str, None)?; + + Ok(Arc::new(res) as ArrayRef) + } + (Utf8, Utf8) => { + let mod_str = args[0].as_string::(); + let match_str = args[1].as_string::(); + let res = regexp_is_match_utf8::< + GenericStringArray, + GenericStringArray, + GenericStringArray, + >(mod_str, match_str, None)?; + + Ok(Arc::new(res) as ArrayRef) + } + (Utf8, LargeUtf8) => { + let mod_str = args[0].as_string::(); + let match_str = args[1].as_string::(); + let res = regexp_is_match_utf8::< + GenericStringArray, + GenericStringArray, + GenericStringArray, + >(mod_str, match_str, None)?; + + Ok(Arc::new(res) as ArrayRef) + } + (LargeUtf8, Utf8View) => { + let mod_str = args[0].as_string::(); + let match_str = args[1].as_string_view(); + let res = regexp_is_match_utf8::< + GenericStringArray, + StringViewArray, + GenericStringArray, + >(mod_str, match_str, None)?; + + Ok(Arc::new(res) as ArrayRef) + } + (LargeUtf8, Utf8) => { + let mod_str = args[0].as_string::(); + let match_str = args[1].as_string::(); + let res = regexp_is_match_utf8::< + GenericStringArray, + GenericStringArray, + GenericStringArray, + >(mod_str, match_str, None)?; + + Ok(Arc::new(res) as ArrayRef) + } + (LargeUtf8, LargeUtf8) => { + let mod_str = args[0].as_string::(); + let match_str = args[1].as_string::(); + let res = regexp_is_match_utf8::< + GenericStringArray, + GenericStringArray, + GenericStringArray, + >(mod_str, match_str, None)?; + + Ok(Arc::new(res) as ArrayRef) + } + other => { + exec_err!("Unsupported data type {other:?} for function `contains`.") + } + } } #[cfg(test)] @@ -138,6 +239,49 @@ mod tests { Boolean, BooleanArray ); + + test_function!( + ContainsFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( + "Apache" + )))), + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("pac")))), + ], + Ok(Some(true)), + bool, + Boolean, + BooleanArray + ); + test_function!( + ContainsFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( + "Apache" + )))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("ap")))), + ], + Ok(Some(false)), + bool, + Boolean, + BooleanArray + ); + test_function!( + ContainsFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( + "Apache" + )))), + ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(String::from( + "DataFusion" + )))), + ], + Ok(Some(false)), + bool, + Boolean, + BooleanArray + ); + Ok(()) } } diff --git a/datafusion/sqllogictest/test_files/string_view.slt b/datafusion/sqllogictest/test_files/string_view.slt index f478ba47aa4c..171b8ec6c1d1 100644 --- a/datafusion/sqllogictest/test_files/string_view.slt +++ b/datafusion/sqllogictest/test_files/string_view.slt @@ -21,11 +21,10 @@ statement ok create table test_source as values - ('Andrew', 'X'), - ('Xiangpeng', 'Xiangpeng'), - ('Raphael', 'R'), - (NULL, 'R') -; + ('Andrew', 'X'), + ('Xiangpeng', 'Xiangpeng'), + ('Raphael', 'R'), + (NULL, 'R'); # Table with the different combination of column types statement ok @@ -800,17 +799,40 @@ logical_plan 02)--TableScan: test projection=[column1_utf8view, column2_utf8view] ## Ensure no casts for CONTAINS -## TODO https://github.com/apache/datafusion/issues/11838 query TT EXPLAIN SELECT CONTAINS(column1_utf8view, 'foo') as c1, - CONTAINS(column2_utf8view, column2_utf8view) as c2 + CONTAINS(column1_utf8view, column2_utf8view) as c2, + CONTAINS(column1_utf8view, column2_large_utf8) as c3, + CONTAINS(column1_utf8, column2_utf8view) as c4, + CONTAINS(column1_utf8, column2_utf8) as c5, + CONTAINS(column1_utf8, column2_large_utf8) as c6, + CONTAINS(column1_large_utf8, column1_utf8view) as c7, + CONTAINS(column1_large_utf8, column2_utf8) as c8, + CONTAINS(column1_large_utf8, column2_large_utf8) as c9 FROM test; ---- logical_plan -01)Projection: contains(CAST(test.column1_utf8view AS Utf8), Utf8("foo")) AS c1, contains(__common_expr_1, __common_expr_1) AS c2 -02)--Projection: CAST(test.column2_utf8view AS Utf8) AS __common_expr_1, test.column1_utf8view -03)----TableScan: test projection=[column1_utf8view, column2_utf8view] +01)Projection: contains(test.column1_utf8view, Utf8("foo")) AS c1, contains(test.column1_utf8view, test.column2_utf8view) AS c2, contains(test.column1_utf8view, test.column2_large_utf8) AS c3, contains(test.column1_utf8, test.column2_utf8view) AS c4, contains(test.column1_utf8, test.column2_utf8) AS c5, contains(test.column1_utf8, test.column2_large_utf8) AS c6, contains(test.column1_large_utf8, test.column1_utf8view) AS c7, contains(test.column1_large_utf8, test.column2_utf8) AS c8, contains(test.column1_large_utf8, test.column2_large_utf8) AS c9 +02)--TableScan: test projection=[column1_utf8, column2_utf8, column1_large_utf8, column2_large_utf8, column1_utf8view, column2_utf8view] + +query BBBBBBBBB +SELECT + CONTAINS(column1_utf8view, 'foo') as c1, + CONTAINS(column1_utf8view, column2_utf8view) as c2, + CONTAINS(column1_utf8view, column2_large_utf8) as c3, + CONTAINS(column1_utf8, column2_utf8view) as c4, + CONTAINS(column1_utf8, column2_utf8) as c5, + CONTAINS(column1_utf8, column2_large_utf8) as c6, + CONTAINS(column1_large_utf8, column1_utf8view) as c7, + CONTAINS(column1_large_utf8, column2_utf8) as c8, + CONTAINS(column1_large_utf8, column2_large_utf8) as c9 +FROM test; +---- +false false false false false false true false false +false true true true true true true true true +false true true true true true true true true +NULL NULL NULL NULL NULL NULL NULL NULL NULL ## Ensure no casts for ENDS_WITH query TT diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index c3d3ab7a64a7..e08524dcd3a7 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -1454,7 +1454,7 @@ position(substr in origstr) ### `contains` -Return true if search_string is found within string. +Return true if search_string is found within string (case-sensitive). ``` contains(string, search_string)