diff --git a/datafusion/functions/src/string/contains.rs b/datafusion/functions/src/string/contains.rs new file mode 100644 index 000000000000..faf979f80614 --- /dev/null +++ b/datafusion/functions/src/string/contains.rs @@ -0,0 +1,143 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::utils::make_scalar_function; +use arrow::array::{ArrayRef, OffsetSizeTrait}; +use arrow::datatypes::DataType; +use arrow::datatypes::DataType::Boolean; +use datafusion_common::cast::as_generic_string_array; +use datafusion_common::DataFusionError; +use datafusion_common::Result; +use datafusion_common::{arrow_datafusion_err, exec_err}; +use datafusion_expr::ScalarUDFImpl; +use datafusion_expr::TypeSignature::Exact; +use datafusion_expr::{ColumnarValue, Signature, Volatility}; +use std::any::Any; +use std::sync::Arc; +#[derive(Debug)] +pub struct ContainsFunc { + signature: Signature, +} + +impl Default for ContainsFunc { + fn default() -> Self { + ContainsFunc::new() + } +} + +impl ContainsFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![Exact(vec![Utf8, Utf8]), Exact(vec![LargeUtf8, LargeUtf8])], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for ContainsFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "contains" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _: &[DataType]) -> Result { + Ok(Boolean) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Utf8 => make_scalar_function(contains::, vec![])(args), + DataType::LargeUtf8 => make_scalar_function(contains::, vec![])(args), + other => { + exec_err!("unsupported data type {other:?} for function contains") + } + } + } +} + +/// use regexp_is_match_utf8_scalar to do the calculation for contains +pub fn contains( + args: &[ArrayRef], +) -> Result { + let mod_str = as_generic_string_array::(&args[0])?; + let match_str = as_generic_string_array::(&args[1])?; + let res = arrow::compute::kernels::comparison::regexp_is_match_utf8( + mod_str, match_str, None, + ) + .map_err(|e| arrow_datafusion_err!(e))?; + + Ok(Arc::new(res) as ArrayRef) +} + +#[cfg(test)] +mod tests { + use crate::string::contains::ContainsFunc; + use crate::utils::test::test_function; + use arrow::array::Array; + use arrow::{array::BooleanArray, datatypes::DataType::Boolean}; + use datafusion_common::Result; + use datafusion_common::ScalarValue; + use datafusion_expr::ColumnarValue; + use datafusion_expr::ScalarUDFImpl; + #[test] + fn test_functions() -> Result<()> { + test_function!( + ContainsFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("alphabet")), + ColumnarValue::Scalar(ScalarValue::from("alph")), + ], + Ok(Some(true)), + bool, + Boolean, + BooleanArray + ); + test_function!( + ContainsFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("alphabet")), + ColumnarValue::Scalar(ScalarValue::from("dddddd")), + ], + Ok(Some(false)), + bool, + Boolean, + BooleanArray + ); + test_function!( + ContainsFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("alphabet")), + ColumnarValue::Scalar(ScalarValue::from("pha")), + ], + Ok(Some(true)), + bool, + Boolean, + BooleanArray + ); + Ok(()) + } +} diff --git a/datafusion/functions/src/string/mod.rs b/datafusion/functions/src/string/mod.rs index 219ef8b5a50f..5bf372c29f2d 100644 --- a/datafusion/functions/src/string/mod.rs +++ b/datafusion/functions/src/string/mod.rs @@ -28,6 +28,7 @@ pub mod chr; pub mod common; pub mod concat; pub mod concat_ws; +pub mod contains; pub mod ends_with; pub mod initcap; pub mod levenshtein; @@ -43,7 +44,6 @@ pub mod starts_with; pub mod to_hex; pub mod upper; pub mod uuid; - // create UDFs make_udf_function!(ascii::AsciiFunc, ASCII, ascii); make_udf_function!(bit_length::BitLengthFunc, BIT_LENGTH, bit_length); @@ -66,7 +66,7 @@ make_udf_function!(split_part::SplitPartFunc, SPLIT_PART, split_part); make_udf_function!(to_hex::ToHexFunc, TO_HEX, to_hex); make_udf_function!(upper::UpperFunc, UPPER, upper); make_udf_function!(uuid::UuidFunc, UUID, uuid); - +make_udf_function!(contains::ContainsFunc, CONTAINS, contains); pub mod expr_fn { use datafusion_expr::Expr; @@ -149,6 +149,9 @@ pub mod expr_fn { ),( uuid, "returns uuid v4 as a string value", + ), ( + contains, + "Return true if search_string is found within string. treated it like a reglike", )); #[doc = "Removes all characters, spaces by default, from both sides of a string"] @@ -188,5 +191,6 @@ pub fn functions() -> Vec> { to_hex(), upper(), uuid(), + contains(), ] } diff --git a/datafusion/sqllogictest/test_files/functions.slt b/datafusion/sqllogictest/test_files/functions.slt index df6295d63b81..c3dd791f6ca8 100644 --- a/datafusion/sqllogictest/test_files/functions.slt +++ b/datafusion/sqllogictest/test_files/functions.slt @@ -1158,3 +1158,21 @@ drop table uuid_table statement ok drop table t + + +# test for contains + +query B +select contains('alphabet', 'pha'); +---- +true + +query B +select contains('alphabet', 'dddd'); +---- +false + +query B +select contains('', ''); +---- +true diff --git a/datafusion/substrait/tests/cases/function_test.rs b/datafusion/substrait/tests/cases/function_test.rs new file mode 100644 index 000000000000..b4c5659a3a49 --- /dev/null +++ b/datafusion/substrait/tests/cases/function_test.rs @@ -0,0 +1,58 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Tests for Function Compatibility + +#[cfg(test)] +mod tests { + use datafusion::common::Result; + use datafusion::prelude::{CsvReadOptions, SessionContext}; + use datafusion_substrait::logical_plan::consumer::from_substrait_plan; + use std::fs::File; + use std::io::BufReader; + use substrait::proto::Plan; + + #[tokio::test] + async fn contains_function_test() -> Result<()> { + let ctx = create_context().await?; + + let path = "tests/testdata/contains_plan.substrait.json"; + let proto = serde_json::from_reader::<_, Plan>(BufReader::new( + File::open(path).expect("file not found"), + )) + .expect("failed to parse json"); + + let plan = from_substrait_plan(&ctx, &proto).await?; + + let plan_str = format!("{:?}", plan); + + assert_eq!( + plan_str, + "Projection: nation.b AS n_name\ + \n Filter: contains(nation.b, Utf8(\"IA\"))\ + \n TableScan: nation projection=[a, b, c, d, e, f]" + ); + Ok(()) + } + + async fn create_context() -> datafusion::common::Result { + let ctx = SessionContext::new(); + ctx.register_csv("nation", "tests/testdata/data.csv", CsvReadOptions::new()) + .await?; + Ok(ctx) + } +} diff --git a/datafusion/substrait/tests/cases/mod.rs b/datafusion/substrait/tests/cases/mod.rs index a31f93087d83..d3ea7695e4b9 100644 --- a/datafusion/substrait/tests/cases/mod.rs +++ b/datafusion/substrait/tests/cases/mod.rs @@ -16,6 +16,7 @@ // under the License. mod consumer_integration; +mod function_test; mod logical_plans; mod roundtrip_logical_plan; mod roundtrip_physical_plan; diff --git a/datafusion/substrait/tests/testdata/contains_plan.substrait.json b/datafusion/substrait/tests/testdata/contains_plan.substrait.json new file mode 100644 index 000000000000..76edde34e3b0 --- /dev/null +++ b/datafusion/substrait/tests/testdata/contains_plan.substrait.json @@ -0,0 +1,133 @@ +{ + "extensionUris": [ + { + "extensionUriAnchor": 1, + "uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_string.yaml" + } + ], + "extensions": [ + { + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 1, + "name": "contains:str_str" + } + } + ], + "relations": [ + { + "root": { + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [ + 4 + ] + } + }, + "input": { + "filter": { + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "n_nationkey", + "n_name", + "n_regionkey", + "n_comment" + ], + "struct": { + "types": [ + { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + } + ], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": [ + "nation" + ] + } + } + }, + "condition": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": {} + } + } + }, + { + "value": { + "literal": { + "string": "IA" + } + } + } + ] + } + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": {} + } + } + ] + } + }, + "names": [ + "n_name" + ] + } + } + ], + "version": { + "minorNumber": 38, + "producer": "ibis-substrait" + } +} \ No newline at end of file diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index 10c52bc5de9e..ec34dbf9ba6c 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -681,6 +681,7 @@ _Alias of [nvl](#nvl)._ - [substr_index](#substr_index) - [find_in_set](#find_in_set) - [position](#position) +- [contains](#contains) ### `ascii` @@ -1443,6 +1444,19 @@ position(substr in origstr) - **substr**: The pattern string. - **origstr**: The model string. +### `contains` + +Return true if search_string is found within string. + +``` +contains(string, search_string) +``` + +#### Arguments + +- **string**: The pattern string. +- **search_string**: The model string. + ## Time and Date Functions - [now](#now)