Skip to content

Commit 2d2ea14

Browse files
committed
Update ASCII scalar function to support Utf8View #11834
1 parent 16a3557 commit 2d2ea14

File tree

1 file changed

+89
-27
lines changed

1 file changed

+89
-27
lines changed

datafusion/functions/src/string/ascii.rs

+89-27
Original file line numberDiff line numberDiff line change
@@ -17,32 +17,14 @@
1717

1818
use crate::utils::make_scalar_function;
1919
use arrow::array::Int32Array;
20-
use arrow::array::{ArrayRef, OffsetSizeTrait};
20+
use arrow::array::{ArrayRef, AsArray};
2121
use arrow::datatypes::DataType;
22-
use datafusion_common::{cast::as_generic_string_array, internal_err, Result};
22+
use datafusion_common::Result;
2323
use datafusion_expr::ColumnarValue;
2424
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
2525
use std::any::Any;
2626
use std::sync::Arc;
2727

28-
/// Returns the numeric code of the first character of the argument.
29-
/// ascii('x') = 120
30-
pub fn ascii<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
31-
let string_array = as_generic_string_array::<T>(&args[0])?;
32-
33-
let result = string_array
34-
.iter()
35-
.map(|string| {
36-
string.map(|string: &str| {
37-
let mut chars = string.chars();
38-
chars.next().map_or(0, |v| v as i32)
39-
})
40-
})
41-
.collect::<Int32Array>();
42-
43-
Ok(Arc::new(result) as ArrayRef)
44-
}
45-
4628
#[derive(Debug)]
4729
pub struct AsciiFunc {
4830
signature: Signature,
@@ -60,7 +42,7 @@ impl AsciiFunc {
6042
Self {
6143
signature: Signature::uniform(
6244
1,
63-
vec![Utf8, LargeUtf8],
45+
vec![Utf8, LargeUtf8, Utf8View],
6446
Volatility::Immutable,
6547
),
6648
}
@@ -87,12 +69,92 @@ impl ScalarUDFImpl for AsciiFunc {
8769
}
8870

8971
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
90-
match args[0].data_type() {
91-
DataType::Utf8 => make_scalar_function(ascii::<i32>, vec![])(args),
92-
DataType::LargeUtf8 => {
93-
return make_scalar_function(ascii::<i64>, vec![])(args);
94-
}
95-
_ => internal_err!("Unsupported data type"),
72+
make_scalar_function(ascii, vec![])(args)
73+
}
74+
}
75+
76+
fn calculate_ascii<'a, I>(string_array: I) -> Result<ArrayRef>
77+
where
78+
I: IntoIterator<Item = Option<&'a str>>,
79+
{
80+
let result = string_array
81+
.into_iter()
82+
.map(|string| {
83+
string.map(|s| {
84+
let mut chars = s.chars();
85+
chars.next().map_or(0, |v| v as i32)
86+
})
87+
})
88+
.collect::<Int32Array>();
89+
90+
Ok(Arc::new(result) as ArrayRef)
91+
}
92+
93+
/// Returns the numeric code of the first character of the argument.
94+
pub fn ascii(args: &[ArrayRef]) -> Result<ArrayRef> {
95+
match args[0].data_type() {
96+
DataType::Utf8 => {
97+
let string_array = args[0].as_string::<i32>();
98+
calculate_ascii(string_array.iter())
99+
}
100+
DataType::LargeUtf8 => {
101+
let string_array = args[0].as_string::<i64>();
102+
calculate_ascii(string_array.iter())
96103
}
104+
DataType::Utf8View => {
105+
let string_array = args[0].as_string_view();
106+
calculate_ascii(string_array.iter())
107+
}
108+
_ => unreachable!(),
109+
}
110+
}
111+
112+
#[cfg(test)]
113+
mod tests {
114+
use crate::string::ascii::AsciiFunc;
115+
use crate::utils::test::test_function;
116+
use arrow::array::{Array, Int32Array};
117+
use arrow::datatypes::DataType::Int32;
118+
use datafusion_common::{Result, ScalarValue};
119+
use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
120+
121+
macro_rules! test_ascii {
122+
($INPUT:expr, $EXPECTED:expr) => {
123+
test_function!(
124+
AsciiFunc::new(),
125+
&[ColumnarValue::Scalar(ScalarValue::Utf8($INPUT))],
126+
$EXPECTED,
127+
i32,
128+
Int32,
129+
Int32Array
130+
);
131+
132+
test_function!(
133+
AsciiFunc::new(),
134+
&[ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT))],
135+
$EXPECTED,
136+
i32,
137+
Int32,
138+
Int32Array
139+
);
140+
141+
test_function!(
142+
AsciiFunc::new(),
143+
&[ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT))],
144+
$EXPECTED,
145+
i32,
146+
Int32,
147+
Int32Array
148+
);
149+
};
150+
}
151+
152+
#[test]
153+
fn test_functions() -> Result<()> {
154+
test_ascii!(Some(String::from("x")), Ok(Some(120)));
155+
test_ascii!(Some(String::from("a")), Ok(Some(97)));
156+
test_ascii!(Some(String::from("")), Ok(Some(0)));
157+
test_ascii!(None, Ok(None));
158+
Ok(())
97159
}
98160
}

0 commit comments

Comments
 (0)