diff --git a/datafusion/functions/src/string/ascii.rs b/datafusion/functions/src/string/ascii.rs index 9e1e6b81b61df..2798a13497c48 100644 --- a/datafusion/functions/src/string/ascii.rs +++ b/datafusion/functions/src/string/ascii.rs @@ -16,33 +16,15 @@ // under the License. use crate::utils::make_scalar_function; -use arrow::array::Int32Array; -use arrow::array::{ArrayRef, OffsetSizeTrait}; +use arrow::array::{Int32Array, ArrayRef, AsArray, ArrayAccessor, ArrayIter}; +use arrow::error::ArrowError; use arrow::datatypes::DataType; -use datafusion_common::{cast::as_generic_string_array, internal_err, Result}; +use datafusion_common::{Result, internal_err}; use datafusion_expr::ColumnarValue; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use std::any::Any; use std::sync::Arc; -/// Returns the numeric code of the first character of the argument. -/// ascii('x') = 120 -pub fn ascii(args: &[ArrayRef]) -> Result { - let string_array = as_generic_string_array::(&args[0])?; - - let result = string_array - .iter() - .map(|string| { - string.map(|string: &str| { - let mut chars = string.chars(); - chars.next().map_or(0, |v| v as i32) - }) - }) - .collect::(); - - Ok(Arc::new(result) as ArrayRef) -} - #[derive(Debug)] pub struct AsciiFunc { signature: Signature, @@ -60,7 +42,7 @@ impl AsciiFunc { Self { signature: Signature::uniform( 1, - vec![Utf8, LargeUtf8], + vec![Utf8, LargeUtf8, Utf8View], Volatility::Immutable, ), } @@ -87,12 +69,109 @@ impl ScalarUDFImpl for AsciiFunc { } fn invoke(&self, args: &[ColumnarValue]) -> Result { - match args[0].data_type() { - DataType::Utf8 => make_scalar_function(ascii::, vec![])(args), - DataType::LargeUtf8 => { - return make_scalar_function(ascii::, vec![])(args); - } - _ => internal_err!("Unsupported data type"), + make_scalar_function(ascii, vec![])(args) + } +} + +// fn calculate_ascii<'a, I>(string_array: I) -> Result +// where +// I: IntoIterator>, +// { +// let result = string_array +// .into_iter() +// .map(|string| { +// string.map(|s| { +// let mut chars = s.chars(); +// chars.next().map_or(0, |v| v as i32) +// }) +// }) +// .collect::(); + +// Ok(Arc::new(result) as ArrayRef) +// } + +fn calculate_ascii<'a, V>(array: V) -> Result +where + V: ArrayAccessor, +{ + let iter = ArrayIter::new(array); + let result = iter + .map(|string| { + string.map(|s| { + let mut chars = s.chars(); + chars.next().map_or(0, |v| v as i32) + }) + }) + .collect::(); + + Ok(Arc::new(result) as ArrayRef) +} + +/// Returns the numeric code of the first character of the argument. +pub fn ascii(args: &[ArrayRef]) -> Result { + match args[0].data_type() { + DataType::Utf8 => { + let string_array = args[0].as_string::(); + Ok(calculate_ascii(string_array)?) + } + DataType::LargeUtf8 => { + let string_array = args[0].as_string::(); + Ok(calculate_ascii(string_array)?) } + DataType::Utf8View => { + let string_array = args[0].as_string_view(); + Ok(calculate_ascii(string_array)?) + } + _ => internal_err!("Unsupported data type"), + } +} + +#[cfg(test)] +mod tests { + use crate::string::ascii::AsciiFunc; + use crate::utils::test::test_function; + use arrow::array::{Array, Int32Array}; + use arrow::datatypes::DataType::Int32; + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + macro_rules! test_ascii { + ($INPUT:expr, $EXPECTED:expr) => { + test_function!( + AsciiFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8($INPUT))], + $EXPECTED, + i32, + Int32, + Int32Array + ); + + test_function!( + AsciiFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT))], + $EXPECTED, + i32, + Int32, + Int32Array + ); + + test_function!( + AsciiFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT))], + $EXPECTED, + i32, + Int32, + Int32Array + ); + }; + } + + #[test] + fn test_functions() -> Result<()> { + test_ascii!(Some(String::from("x")), Ok(Some(120))); + test_ascii!(Some(String::from("a")), Ok(Some(97))); + test_ascii!(Some(String::from("")), Ok(Some(0))); + test_ascii!(None, Ok(None)); + Ok(()) } } diff --git a/datafusion/sqllogictest/test_files/string_view.slt b/datafusion/sqllogictest/test_files/string_view.slt index 4d3f72b1e8d4e..fc10a34256c52 100644 --- a/datafusion/sqllogictest/test_files/string_view.slt +++ b/datafusion/sqllogictest/test_files/string_view.slt @@ -500,3 +500,102 @@ select column2|| ' ' ||column3 from temp; ---- rust fast datafusion cool + +### ASCII +# Setup the initial test data +statement ok +create table test_source as values + ('Andrew', 'X'), + ('Xiangpeng', 'Xiangpeng'), + ('Raphael', 'R'), + (NULL, 'R'); + +# Table with the different combination of column types +statement ok +create table test as +SELECT + arrow_cast(column1, 'Utf8') as column1_utf8, + arrow_cast(column2, 'Utf8') as column2_utf8, + arrow_cast(column1, 'LargeUtf8') as column1_large_utf8, + arrow_cast(column2, 'LargeUtf8') as column2_large_utf8, + arrow_cast(column1, 'Utf8View') as column1_utf8view, + arrow_cast(column2, 'Utf8View') as column2_utf8view +FROM test_source; + +# Test ASCII with utf8view against utf8view, utf8, and largeutf8 +# (should be no casts) +query TT +EXPLAIN SELECT + ASCII(column1_utf8view) as c1, + ASCII(column2_utf8) as c2, + ASCII(column2_large_utf8) as c3 +FROM test; +---- +logical_plan +01)Projection: ascii(test.column1_utf8view) AS c1, ascii(test.column2_utf8) AS c2, ascii(test.column2_large_utf8) AS c3 +02)--TableScan: test projection=[column2_utf8, column2_large_utf8, column1_utf8view] + +query III +SELECT + ASCII(column1_utf8view) as c1, + ASCII(column2_utf8) as c2, + ASCII(column2_large_utf8) as c3 +FROM test; +---- +65 88 88 +88 88 88 +82 82 82 +NULL 82 82 + +query TT +EXPLAIN SELECT + ASCII(column1_utf8) as c1, + ASCII(column1_large_utf8) as c2, + ASCII(column2_utf8view) as c3, + ASCII('hello') as c4, + ASCII(arrow_cast('world', 'Utf8View')) as c5 +FROM test; +---- +logical_plan +01)Projection: ascii(test.column1_utf8) AS c1, ascii(test.column1_large_utf8) AS c2, ascii(test.column2_utf8view) AS c3, Int32(104) AS c4, Int32(119) AS c5 +02)--TableScan: test projection=[column1_utf8, column1_large_utf8, column2_utf8view] + +query IIIII +SELECT + ASCII(column1_utf8) as c1, + ASCII(column1_large_utf8) as c2, + ASCII(column2_utf8view) as c3, + ASCII('hello') as c4, + ASCII(arrow_cast('world', 'Utf8View')) as c5 +FROM test; +---- +65 65 88 104 119 +88 88 88 104 119 +82 82 82 104 119 +NULL NULL 82 104 119 + +# Test ASCII with literals cast to Utf8View +query TT +EXPLAIN SELECT + ASCII(arrow_cast('äöüß', 'Utf8View')) as c1, + ASCII(arrow_cast('', 'Utf8View')) as c2, + ASCII(arrow_cast(NULL, 'Utf8View')) as c3 +FROM test; +---- +logical_plan +01)Projection: Int32(228) AS c1, Int32(0) AS c2, Int32(NULL) AS c3 +02)--TableScan: test projection=[] + +query III +SELECT + ASCII(arrow_cast('äöüß', 'Utf8View')) as c1, + ASCII(arrow_cast('', 'Utf8View')) as c2, + ASCII(arrow_cast(NULL, 'Utf8View')) as c3 +---- +228 0 NULL + +statement ok +drop table test; + +statement ok +drop table test_source;