Skip to content

Commit 1f76b9c

Browse files
committed
Implement native support StringView for REPEAT
Signed-off-by: Tai Le Manh <[email protected]>
1 parent 18193e6 commit 1f76b9c

File tree

1 file changed

+70
-12
lines changed

1 file changed

+70
-12
lines changed

datafusion/functions/src/string/repeat.rs

+70-12
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@
1818
use std::any::Any;
1919
use std::sync::Arc;
2020

21-
use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait};
21+
use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait, StringArray};
2222
use arrow::datatypes::DataType;
2323

24-
use datafusion_common::cast::{as_generic_string_array, as_int64_array};
24+
use datafusion_common::cast::{as_generic_string_array, as_int64_array, as_string_view_array};
2525
use datafusion_common::{exec_err, Result};
2626
use datafusion_expr::TypeSignature::*;
2727
use datafusion_expr::{ColumnarValue, Volatility};
@@ -45,7 +45,14 @@ impl RepeatFunc {
4545
use DataType::*;
4646
Self {
4747
signature: Signature::one_of(
48-
vec![Exact(vec![Utf8, Int64]), Exact(vec![LargeUtf8, Int64])],
48+
vec![
49+
// Planner attempts coercion to the target type starting with the most preferred candidate.
50+
// For example, given input `(Utf8View, Int64)`, it first tries coercing to `(Utf8View, Int64)`.
51+
// If that fails, it proceeds to `(Utf8, Int64)`.
52+
Exact(vec![Utf8View, Int64]),
53+
Exact(vec![Utf8, Int64]),
54+
Exact(vec![LargeUtf8, Int64]),
55+
],
4956
Volatility::Immutable,
5057
),
5158
}
@@ -71,9 +78,10 @@ impl ScalarUDFImpl for RepeatFunc {
7178

7279
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
7380
match args[0].data_type() {
81+
DataType::Utf8View => make_scalar_function(repeat_utf8view, vec![])(args),
7482
DataType::Utf8 => make_scalar_function(repeat::<i32>, vec![])(args),
7583
DataType::LargeUtf8 => make_scalar_function(repeat::<i64>, vec![])(args),
76-
other => exec_err!("Unsupported data type {other:?} for function repeat"),
84+
other => exec_err!("Unsupported data type {other:?} for function repeat. Expected Utf8, Utf8View or LargeUtf8"),
7785
}
7886
}
7987
}
@@ -87,18 +95,35 @@ fn repeat<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
8795
let result = string_array
8896
.iter()
8997
.zip(number_array.iter())
90-
.map(|(string, number)| match (string, number) {
91-
(Some(string), Some(number)) if number >= 0 => {
92-
Some(string.repeat(number as usize))
93-
}
94-
(Some(_), Some(_)) => Some("".to_string()),
95-
_ => None,
96-
})
98+
.map(|(string, number)| repeat_common(string, number))
9799
.collect::<GenericStringArray<T>>();
98100

99101
Ok(Arc::new(result) as ArrayRef)
100102
}
101103

104+
fn repeat_utf8view(args: &[ArrayRef]) -> Result<ArrayRef> {
105+
let string_view_array = as_string_view_array(&args[0])?;
106+
let number_array = as_int64_array(&args[1])?;
107+
108+
let result = string_view_array
109+
.iter()
110+
.zip(number_array.iter())
111+
.map(|(string, number)| repeat_common(string, number))
112+
.collect::<StringArray>();
113+
114+
Ok(Arc::new(result) as ArrayRef)
115+
}
116+
117+
fn repeat_common(string: Option<&str>, number: Option<i64>) -> Option<String> {
118+
match (string, number) {
119+
(Some(string), Some(number)) if number >= 0 => {
120+
Some(string.repeat(number as usize))
121+
}
122+
(Some(_), Some(_)) => Some("".to_string()),
123+
_ => None,
124+
}
125+
}
126+
102127
#[cfg(test)]
103128
mod tests {
104129
use arrow::array::{Array, StringArray};
@@ -124,7 +149,6 @@ mod tests {
124149
Utf8,
125150
StringArray
126151
);
127-
128152
test_function!(
129153
RepeatFunc::new(),
130154
&[
@@ -148,6 +172,40 @@ mod tests {
148172
StringArray
149173
);
150174

175+
test_function!(
176+
RepeatFunc::new(),
177+
&[
178+
ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("Pg")))),
179+
ColumnarValue::Scalar(ScalarValue::Int64(Some(4))),
180+
],
181+
Ok(Some("PgPgPgPg")),
182+
&str,
183+
Utf8,
184+
StringArray
185+
);
186+
test_function!(
187+
RepeatFunc::new(),
188+
&[
189+
ColumnarValue::Scalar(ScalarValue::Utf8View(None)),
190+
ColumnarValue::Scalar(ScalarValue::Int64(Some(4))),
191+
],
192+
Ok(None),
193+
&str,
194+
Utf8,
195+
StringArray
196+
);
197+
test_function!(
198+
RepeatFunc::new(),
199+
&[
200+
ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("Pg")))),
201+
ColumnarValue::Scalar(ScalarValue::Int64(None)),
202+
],
203+
Ok(None),
204+
&str,
205+
Utf8,
206+
StringArray
207+
);
208+
151209
Ok(())
152210
}
153211
}

0 commit comments

Comments
 (0)