diff --git a/datafusion/functions/src/unicode/translate.rs b/datafusion/functions/src/unicode/translate.rs index 5f64d8875bf5..a42b9c6cb857 100644 --- a/datafusion/functions/src/unicode/translate.rs +++ b/datafusion/functions/src/unicode/translate.rs @@ -18,18 +18,18 @@ use std::any::Any; use std::sync::Arc; -use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow::array::{ + ArrayAccessor, ArrayIter, ArrayRef, AsArray, GenericStringArray, OffsetSizeTrait, +}; use arrow::datatypes::DataType; use hashbrown::HashMap; use unicode_segmentation::UnicodeSegmentation; -use datafusion_common::cast::as_generic_string_array; +use crate::utils::{make_scalar_function, utf8_to_str_type}; use datafusion_common::{exec_err, Result}; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; -use crate::utils::{make_scalar_function, utf8_to_str_type}; - #[derive(Debug)] pub struct TranslateFunc { signature: Signature, @@ -46,7 +46,10 @@ impl TranslateFunc { use DataType::*; Self { signature: Signature::one_of( - vec![Exact(vec![Utf8, Utf8, Utf8])], + vec![ + Exact(vec![Utf8View, Utf8, Utf8]), + Exact(vec![Utf8, Utf8, Utf8]), + ], Volatility::Immutable, ), } @@ -71,27 +74,54 @@ impl ScalarUDFImpl for TranslateFunc { } fn invoke(&self, args: &[ColumnarValue]) -> Result { - match args[0].data_type() { - DataType::Utf8 => make_scalar_function(translate::, vec![])(args), - DataType::LargeUtf8 => make_scalar_function(translate::, vec![])(args), - other => { - exec_err!("Unsupported data type {other:?} for function translate") - } + make_scalar_function(invoke_translate, vec![])(args) + } +} + +fn invoke_translate(args: &[ArrayRef]) -> Result { + match args[0].data_type() { + DataType::Utf8View => { + let string_array = args[0].as_string_view(); + let from_array = args[1].as_string::(); + let to_array = args[2].as_string::(); + translate::(string_array, from_array, to_array) + } + DataType::Utf8 => { + let string_array = args[0].as_string::(); + let from_array = args[1].as_string::(); + let to_array = args[2].as_string::(); + translate::(string_array, from_array, to_array) + } + DataType::LargeUtf8 => { + let string_array = args[0].as_string::(); + let from_array = args[1].as_string::(); + let to_array = args[2].as_string::(); + translate::(string_array, from_array, to_array) + } + other => { + exec_err!("Unsupported data type {other:?} for function translate") } } } /// Replaces each character in string that matches a character in the from set with the corresponding character in the to set. If from is longer than to, occurrences of the extra characters in from are deleted. /// translate('12345', '143', 'ax') = 'a2x5' -fn translate(args: &[ArrayRef]) -> Result { - let string_array = as_generic_string_array::(&args[0])?; - let from_array = as_generic_string_array::(&args[1])?; - let to_array = as_generic_string_array::(&args[2])?; - - let result = string_array - .iter() - .zip(from_array.iter()) - .zip(to_array.iter()) +fn translate<'a, T: OffsetSizeTrait, V, B>( + string_array: V, + from_array: B, + to_array: B, +) -> Result +where + V: ArrayAccessor, + B: ArrayAccessor, +{ + let string_array_iter = ArrayIter::new(string_array); + let from_array_iter = ArrayIter::new(from_array); + let to_array_iter = ArrayIter::new(to_array); + + let result = string_array_iter + .zip(from_array_iter) + .zip(to_array_iter) .map(|((string, from), to)| match (string, from, to) { (Some(string), Some(from), Some(to)) => { // create a hashmap of [char, index] to change from O(n) to O(1) for from list diff --git a/datafusion/sqllogictest/test_files/string_view.slt b/datafusion/sqllogictest/test_files/string_view.slt index fcd71b7f7e94..2bfc0978ba5d 100644 --- a/datafusion/sqllogictest/test_files/string_view.slt +++ b/datafusion/sqllogictest/test_files/string_view.slt @@ -425,6 +425,43 @@ logical_plan 01)Projection: starts_with(test.column1_utf8view, Utf8View("äöüß")) AS c1, starts_with(test.column1_utf8view, Utf8View("")) AS c2, starts_with(test.column1_utf8view, Utf8View(NULL)) AS c3, starts_with(Utf8View(NULL), test.column1_utf8view) AS c4 02)--TableScan: test projection=[column1_utf8view] +### Test TRANSLATE + +# Should run TRANSLATE using utf8view column successfully +query T +SELECT + TRANSLATE(column1_utf8view, 'foo', 'bar') as c +FROM test; +---- +Andrew +Xiangpeng +Raphael +NULL + +# Should run TRANSLATE using utf8 column successfully +query T +SELECT + TRANSLATE(column1_utf8, 'foo', 'bar') as c +FROM test; +---- +Andrew +Xiangpeng +Raphael +NULL + +# Should run TRANSLATE using large_utf8 column successfully +query T +SELECT + TRANSLATE(column1_large_utf8, 'foo', 'bar') as c +FROM test; +---- +Andrew +Xiangpeng +Raphael +NULL + + + ### Initcap query TT @@ -895,14 +932,13 @@ logical_plan 02)--TableScan: test projection=[column1_utf8view, column2_utf8view] ## Ensure no casts for TRANSLATE -## TODO file ticket query TT EXPLAIN SELECT TRANSLATE(column1_utf8view, 'foo', 'bar') as c FROM test; ---- logical_plan -01)Projection: translate(CAST(test.column1_utf8view AS Utf8), Utf8("foo"), Utf8("bar")) AS c +01)Projection: translate(test.column1_utf8view, Utf8("foo"), Utf8("bar")) AS c 02)--TableScan: test projection=[column1_utf8view] ## Ensure no casts for FIND_IN_SET