diff --git a/datafusion/functions/src/unicode/find_in_set.rs b/datafusion/functions/src/unicode/find_in_set.rs index 7c864bc191d7..41a2b9d9e72d 100644 --- a/datafusion/functions/src/unicode/find_in_set.rs +++ b/datafusion/functions/src/unicode/find_in_set.rs @@ -19,11 +19,11 @@ use std::any::Any; use std::sync::Arc; use arrow::array::{ - ArrayRef, ArrowPrimitiveType, GenericStringArray, OffsetSizeTrait, PrimitiveArray, + ArrayAccessor, ArrayIter, ArrayRef, ArrowPrimitiveType, AsArray, OffsetSizeTrait, + PrimitiveArray, }; use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type}; -use datafusion_common::cast::as_generic_string_array; use datafusion_common::{exec_err, Result}; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; @@ -46,7 +46,11 @@ impl FindInSetFunc { use DataType::*; Self { signature: Signature::one_of( - vec![Exact(vec![Utf8, Utf8]), Exact(vec![LargeUtf8, LargeUtf8])], + vec![ + Exact(vec![Utf8View, Utf8View]), + Exact(vec![Utf8, Utf8]), + Exact(vec![LargeUtf8, LargeUtf8]), + ], Volatility::Immutable, ), } @@ -71,41 +75,52 @@ impl ScalarUDFImpl for FindInSetFunc { } fn invoke(&self, args: &[ColumnarValue]) -> Result { - match args[0].data_type() { - DataType::Utf8 => { - make_scalar_function(find_in_set::, vec![])(args) - } - DataType::LargeUtf8 => { - make_scalar_function(find_in_set::, vec![])(args) - } - other => { - exec_err!("Unsupported data type {other:?} for function find_in_set") - } - } + make_scalar_function(find_in_set, vec![])(args) } } ///Returns a value in the range of 1 to N if the string str is in the string list strlist consisting of N substrings ///A string list is a string composed of substrings separated by , characters. -pub fn find_in_set(args: &[ArrayRef]) -> Result -where - T::Native: OffsetSizeTrait, -{ +fn find_in_set(args: &[ArrayRef]) -> Result { if args.len() != 2 { return exec_err!( "find_in_set was called with {} arguments. It requires 2.", args.len() ); } + match args[0].data_type() { + DataType::Utf8 => { + let string_array = args[0].as_string::(); + let str_list_array = args[1].as_string::(); + find_in_set_general::(string_array, str_list_array) + } + DataType::LargeUtf8 => { + let string_array = args[0].as_string::(); + let str_list_array = args[1].as_string::(); + find_in_set_general::(string_array, str_list_array) + } + DataType::Utf8View => { + let string_array = args[0].as_string_view(); + let str_list_array = args[1].as_string_view(); + find_in_set_general::(string_array, str_list_array) + } + other => { + exec_err!("Unsupported data type {other:?} for function find_in_set") + } + } +} - let str_array: &GenericStringArray = - as_generic_string_array::(&args[0])?; - let str_list_array: &GenericStringArray = - as_generic_string_array::(&args[1])?; - - let result = str_array - .iter() - .zip(str_list_array.iter()) +pub fn find_in_set_general<'a, T: ArrowPrimitiveType, V: ArrayAccessor>( + string_array: V, + str_list_array: V, +) -> Result +where + T::Native: OffsetSizeTrait, +{ + let string_iter = ArrayIter::new(string_array); + let str_list_iter = ArrayIter::new(str_list_array); + let result = string_iter + .zip(str_list_iter) .map(|(string, str_list)| match (string, str_list) { (Some(string), Some(str_list)) => { let mut res = 0; diff --git a/datafusion/sqllogictest/test_files/functions.slt b/datafusion/sqllogictest/test_files/functions.slt index 3255ddccdb81..a34d6e89d063 100644 --- a/datafusion/sqllogictest/test_files/functions.slt +++ b/datafusion/sqllogictest/test_files/functions.slt @@ -1066,7 +1066,7 @@ docs.apache.com docs com community.influxdata.com community com arrow.apache.org arrow org - +# find_in_set tests query I SELECT find_in_set('b', 'a,b,c,d') ---- @@ -1110,6 +1110,23 @@ SELECT find_in_set(NULL, NULL) ---- NULL +# find_in_set tests with utf8view +query I +SELECT find_in_set(arrow_cast('b', 'Utf8View'), 'a,b,c,d') +---- +2 + + +query I +SELECT find_in_set('a', arrow_cast('a,b,c,d,a', 'Utf8View')) +---- +1 + +query I +SELECT find_in_set(arrow_cast('', 'Utf8View'), arrow_cast('a,b,c,d,a', 'Utf8View')) +---- +0 + # Verify that multiple calls to volatile functions like `random()` are not combined / optimized away query B SELECT r FROM (SELECT r1 == r2 r, r1, r2 FROM (SELECT random()+1 r1, random()+1 r2) WHERE r1 > 0 AND r2 > 0) diff --git a/datafusion/sqllogictest/test_files/string_view.slt b/datafusion/sqllogictest/test_files/string_view.slt index fcd71b7f7e94..8ae9ff2d6240 100644 --- a/datafusion/sqllogictest/test_files/string_view.slt +++ b/datafusion/sqllogictest/test_files/string_view.slt @@ -906,18 +906,24 @@ logical_plan 02)--TableScan: test projection=[column1_utf8view] ## Ensure no casts for FIND_IN_SET -## TODO file ticket query TT EXPLAIN SELECT FIND_IN_SET(column1_utf8view, 'a,b,c,d') as c FROM test; ---- logical_plan -01)Projection: find_in_set(CAST(test.column1_utf8view AS Utf8), Utf8("a,b,c,d")) AS c +01)Projection: find_in_set(test.column1_utf8view, Utf8View("a,b,c,d")) AS c 02)--TableScan: test projection=[column1_utf8view] - - +query I +SELECT + FIND_IN_SET(column1_utf8view, 'a,b,c,d') as c +FROM test; +---- +0 +0 +0 +NULL statement ok drop table test;