Skip to content

Commit

Permalink
Add contains function, and support in datafusion substrait consumer (#…
Browse files Browse the repository at this point in the history
…10879)

* adding new function contains

* adding substrait test

* adding doc

* adding doc

* Update docs/source/user-guide/sql/scalar_functions.md

Co-authored-by: Alex Huang <[email protected]>

* adding entry

---------

Co-authored-by: Alex Huang <[email protected]>
  • Loading branch information
Lordworms and Weijun-H authored Jun 15, 2024
1 parent 2f43476 commit 87aea14
Show file tree
Hide file tree
Showing 7 changed files with 373 additions and 2 deletions.
143 changes: 143 additions & 0 deletions datafusion/functions/src/string/contains.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use crate::utils::make_scalar_function;
use arrow::array::{ArrayRef, OffsetSizeTrait};
use arrow::datatypes::DataType;
use arrow::datatypes::DataType::Boolean;
use datafusion_common::cast::as_generic_string_array;
use datafusion_common::DataFusionError;
use datafusion_common::Result;
use datafusion_common::{arrow_datafusion_err, exec_err};
use datafusion_expr::ScalarUDFImpl;
use datafusion_expr::TypeSignature::Exact;
use datafusion_expr::{ColumnarValue, Signature, Volatility};
use std::any::Any;
use std::sync::Arc;
#[derive(Debug)]
pub struct ContainsFunc {
signature: Signature,
}

impl Default for ContainsFunc {
fn default() -> Self {
ContainsFunc::new()
}
}

impl ContainsFunc {
pub fn new() -> Self {
use DataType::*;
Self {
signature: Signature::one_of(
vec![Exact(vec![Utf8, Utf8]), Exact(vec![LargeUtf8, LargeUtf8])],
Volatility::Immutable,
),
}
}
}

impl ScalarUDFImpl for ContainsFunc {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
"contains"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _: &[DataType]) -> Result<DataType> {
Ok(Boolean)
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
match args[0].data_type() {
DataType::Utf8 => make_scalar_function(contains::<i32>, vec![])(args),
DataType::LargeUtf8 => make_scalar_function(contains::<i64>, vec![])(args),
other => {
exec_err!("unsupported data type {other:?} for function contains")
}
}
}
}

/// use regexp_is_match_utf8_scalar to do the calculation for contains
pub fn contains<T: OffsetSizeTrait>(
args: &[ArrayRef],
) -> Result<ArrayRef, DataFusionError> {
let mod_str = as_generic_string_array::<T>(&args[0])?;
let match_str = as_generic_string_array::<T>(&args[1])?;
let res = arrow::compute::kernels::comparison::regexp_is_match_utf8(
mod_str, match_str, None,
)
.map_err(|e| arrow_datafusion_err!(e))?;

Ok(Arc::new(res) as ArrayRef)
}

#[cfg(test)]
mod tests {
use crate::string::contains::ContainsFunc;
use crate::utils::test::test_function;
use arrow::array::Array;
use arrow::{array::BooleanArray, datatypes::DataType::Boolean};
use datafusion_common::Result;
use datafusion_common::ScalarValue;
use datafusion_expr::ColumnarValue;
use datafusion_expr::ScalarUDFImpl;
#[test]
fn test_functions() -> Result<()> {
test_function!(
ContainsFunc::new(),
&[
ColumnarValue::Scalar(ScalarValue::from("alphabet")),
ColumnarValue::Scalar(ScalarValue::from("alph")),
],
Ok(Some(true)),
bool,
Boolean,
BooleanArray
);
test_function!(
ContainsFunc::new(),
&[
ColumnarValue::Scalar(ScalarValue::from("alphabet")),
ColumnarValue::Scalar(ScalarValue::from("dddddd")),
],
Ok(Some(false)),
bool,
Boolean,
BooleanArray
);
test_function!(
ContainsFunc::new(),
&[
ColumnarValue::Scalar(ScalarValue::from("alphabet")),
ColumnarValue::Scalar(ScalarValue::from("pha")),
],
Ok(Some(true)),
bool,
Boolean,
BooleanArray
);
Ok(())
}
}
8 changes: 6 additions & 2 deletions datafusion/functions/src/string/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ pub mod chr;
pub mod common;
pub mod concat;
pub mod concat_ws;
pub mod contains;
pub mod ends_with;
pub mod initcap;
pub mod levenshtein;
Expand All @@ -43,7 +44,6 @@ pub mod starts_with;
pub mod to_hex;
pub mod upper;
pub mod uuid;

// create UDFs
make_udf_function!(ascii::AsciiFunc, ASCII, ascii);
make_udf_function!(bit_length::BitLengthFunc, BIT_LENGTH, bit_length);
Expand All @@ -66,7 +66,7 @@ make_udf_function!(split_part::SplitPartFunc, SPLIT_PART, split_part);
make_udf_function!(to_hex::ToHexFunc, TO_HEX, to_hex);
make_udf_function!(upper::UpperFunc, UPPER, upper);
make_udf_function!(uuid::UuidFunc, UUID, uuid);

make_udf_function!(contains::ContainsFunc, CONTAINS, contains);
pub mod expr_fn {
use datafusion_expr::Expr;

Expand Down Expand Up @@ -149,6 +149,9 @@ pub mod expr_fn {
),(
uuid,
"returns uuid v4 as a string value",
), (
contains,
"Return true if search_string is found within string. treated it like a reglike",
));

#[doc = "Removes all characters, spaces by default, from both sides of a string"]
Expand Down Expand Up @@ -188,5 +191,6 @@ pub fn functions() -> Vec<Arc<ScalarUDF>> {
to_hex(),
upper(),
uuid(),
contains(),
]
}
18 changes: 18 additions & 0 deletions datafusion/sqllogictest/test_files/functions.slt
Original file line number Diff line number Diff line change
Expand Up @@ -1158,3 +1158,21 @@ drop table uuid_table

statement ok
drop table t


# test for contains

query B
select contains('alphabet', 'pha');
----
true

query B
select contains('alphabet', 'dddd');
----
false

query B
select contains('', '');
----
true
58 changes: 58 additions & 0 deletions datafusion/substrait/tests/cases/function_test.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! Tests for Function Compatibility

#[cfg(test)]
mod tests {
use datafusion::common::Result;
use datafusion::prelude::{CsvReadOptions, SessionContext};
use datafusion_substrait::logical_plan::consumer::from_substrait_plan;
use std::fs::File;
use std::io::BufReader;
use substrait::proto::Plan;

#[tokio::test]
async fn contains_function_test() -> Result<()> {
let ctx = create_context().await?;

let path = "tests/testdata/contains_plan.substrait.json";
let proto = serde_json::from_reader::<_, Plan>(BufReader::new(
File::open(path).expect("file not found"),
))
.expect("failed to parse json");

let plan = from_substrait_plan(&ctx, &proto).await?;

let plan_str = format!("{:?}", plan);

assert_eq!(
plan_str,
"Projection: nation.b AS n_name\
\n Filter: contains(nation.b, Utf8(\"IA\"))\
\n TableScan: nation projection=[a, b, c, d, e, f]"
);
Ok(())
}

async fn create_context() -> datafusion::common::Result<SessionContext> {
let ctx = SessionContext::new();
ctx.register_csv("nation", "tests/testdata/data.csv", CsvReadOptions::new())
.await?;
Ok(ctx)
}
}
1 change: 1 addition & 0 deletions datafusion/substrait/tests/cases/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
// under the License.

mod consumer_integration;
mod function_test;
mod logical_plans;
mod roundtrip_logical_plan;
mod roundtrip_physical_plan;
Expand Down
Loading

0 comments on commit 87aea14

Please sign in to comment.