Skip to content

Commit 3d76aa2

Browse files
tshauck2010YOUY01alamb
authored
feat: support Utf8View type in starts_with function (#11787)
* feat: support `Utf8View` for `starts_with` * style: clippy * simplify string view handling * fix: allow utf8 and largeutf8 to be cast into utf8view * fix: fix test * Apply suggestions from code review Co-authored-by: Yongting You <[email protected]> Co-authored-by: Andrew Lamb <[email protected]> * style: fix format * feat: add addiontal tests * tests: improve tests * fix: fix null case * tests: one more null test * Test comments and execution tests --------- Co-authored-by: Yongting You <[email protected]> Co-authored-by: Andrew Lamb <[email protected]>
1 parent 1c98e6e commit 3d76aa2

File tree

4 files changed

+158
-21
lines changed

4 files changed

+158
-21
lines changed

datafusion/expr/src/expr_schema.rs

+1
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ impl ExprSchemable for Expr {
148148
.iter()
149149
.map(|e| e.get_type(schema))
150150
.collect::<Result<Vec<_>>>()?;
151+
151152
// verify that function is invoked with correct number and type of arguments as defined in `TypeSignature`
152153
data_types_with_scalar_udf(&arg_data_types, func).map_err(|err| {
153154
plan_datafusion_err!(

datafusion/expr/src/type_coercion/functions.rs

+16
Original file line numberDiff line numberDiff line change
@@ -583,6 +583,10 @@ fn coerced_from<'a>(
583583
(Interval(_), _) if matches!(type_from, Utf8 | LargeUtf8) => {
584584
Some(type_into.clone())
585585
}
586+
// We can go into a Utf8View from a Utf8 or LargeUtf8
587+
(Utf8View, _) if matches!(type_from, Utf8 | LargeUtf8 | Null) => {
588+
Some(type_into.clone())
589+
}
586590
// Any type can be coerced into strings
587591
(Utf8 | LargeUtf8, _) => Some(type_into.clone()),
588592
(Null, _) if can_cast_types(type_from, type_into) => Some(type_into.clone()),
@@ -646,6 +650,18 @@ mod tests {
646650
use super::*;
647651
use arrow::datatypes::Field;
648652

653+
#[test]
654+
fn test_string_conversion() {
655+
let cases = vec![
656+
(DataType::Utf8View, DataType::Utf8, true),
657+
(DataType::Utf8View, DataType::LargeUtf8, true),
658+
];
659+
660+
for case in cases {
661+
assert_eq!(can_coerce_from(&case.0, &case.1), case.2);
662+
}
663+
}
664+
649665
#[test]
650666
fn test_maybe_data_types() {
651667
// this vec contains: arg1, arg2, expected result

datafusion/functions/src/string/starts_with.rs

+72-20
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@
1818
use std::any::Any;
1919
use std::sync::Arc;
2020

21-
use arrow::array::{ArrayRef, OffsetSizeTrait};
21+
use arrow::array::ArrayRef;
2222
use arrow::datatypes::DataType;
2323

24-
use datafusion_common::{cast::as_generic_string_array, internal_err, Result};
24+
use datafusion_common::{internal_err, Result};
2525
use datafusion_expr::ColumnarValue;
2626
use datafusion_expr::TypeSignature::*;
2727
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
@@ -30,12 +30,8 @@ use crate::utils::make_scalar_function;
3030

3131
/// Returns true if string starts with prefix.
3232
/// starts_with('alphabet', 'alph') = 't'
33-
pub fn starts_with<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
34-
let left = as_generic_string_array::<T>(&args[0])?;
35-
let right = as_generic_string_array::<T>(&args[1])?;
36-
37-
let result = arrow::compute::kernels::comparison::starts_with(left, right)?;
38-
33+
pub fn starts_with(args: &[ArrayRef]) -> Result<ArrayRef> {
34+
let result = arrow::compute::kernels::comparison::starts_with(&args[0], &args[1])?;
3935
Ok(Arc::new(result) as ArrayRef)
4036
}
4137

@@ -52,14 +48,15 @@ impl Default for StartsWithFunc {
5248

5349
impl StartsWithFunc {
5450
pub fn new() -> Self {
55-
use DataType::*;
5651
Self {
5752
signature: Signature::one_of(
5853
vec![
59-
Exact(vec![Utf8, Utf8]),
60-
Exact(vec![Utf8, LargeUtf8]),
61-
Exact(vec![LargeUtf8, Utf8]),
62-
Exact(vec![LargeUtf8, LargeUtf8]),
54+
// Planner attempts coercion to the target type starting with the most preferred candidate.
55+
// For example, given input `(Utf8View, Utf8)`, it first tries coercing to `(Utf8View, Utf8View)`.
56+
// If that fails, it proceeds to `(Utf8, Utf8)`.
57+
Exact(vec![DataType::Utf8View, DataType::Utf8View]),
58+
Exact(vec![DataType::Utf8, DataType::Utf8]),
59+
Exact(vec![DataType::LargeUtf8, DataType::LargeUtf8]),
6360
],
6461
Volatility::Immutable,
6562
),
@@ -81,18 +78,73 @@ impl ScalarUDFImpl for StartsWithFunc {
8178
}
8279

8380
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
84-
use DataType::*;
85-
86-
Ok(Boolean)
81+
Ok(DataType::Boolean)
8782
}
8883

8984
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
9085
match args[0].data_type() {
91-
DataType::Utf8 => make_scalar_function(starts_with::<i32>, vec![])(args),
92-
DataType::LargeUtf8 => {
93-
return make_scalar_function(starts_with::<i64>, vec![])(args);
86+
DataType::Utf8View | DataType::Utf8 | DataType::LargeUtf8 => {
87+
make_scalar_function(starts_with, vec![])(args)
9488
}
95-
_ => internal_err!("Unsupported data type"),
89+
_ => internal_err!("Unsupported data types for starts_with. Expected Utf8, LargeUtf8 or Utf8View")?,
9690
}
9791
}
9892
}
93+
94+
#[cfg(test)]
95+
mod tests {
96+
use crate::utils::test::test_function;
97+
use arrow::array::{Array, BooleanArray};
98+
use arrow::datatypes::DataType::Boolean;
99+
use datafusion_common::{Result, ScalarValue};
100+
use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
101+
102+
use super::*;
103+
104+
#[test]
105+
fn test_functions() -> Result<()> {
106+
// Generate test cases for starts_with
107+
let test_cases = vec![
108+
(Some("alphabet"), Some("alph"), Some(true)),
109+
(Some("alphabet"), Some("bet"), Some(false)),
110+
(
111+
Some("somewhat large string"),
112+
Some("somewhat large"),
113+
Some(true),
114+
),
115+
(Some("somewhat large string"), Some("large"), Some(false)),
116+
]
117+
.into_iter()
118+
.flat_map(|(a, b, c)| {
119+
let utf_8_args = vec![
120+
ColumnarValue::Scalar(ScalarValue::Utf8(a.map(|s| s.to_string()))),
121+
ColumnarValue::Scalar(ScalarValue::Utf8(b.map(|s| s.to_string()))),
122+
];
123+
124+
let large_utf_8_args = vec![
125+
ColumnarValue::Scalar(ScalarValue::LargeUtf8(a.map(|s| s.to_string()))),
126+
ColumnarValue::Scalar(ScalarValue::LargeUtf8(b.map(|s| s.to_string()))),
127+
];
128+
129+
let utf_8_view_args = vec![
130+
ColumnarValue::Scalar(ScalarValue::Utf8View(a.map(|s| s.to_string()))),
131+
ColumnarValue::Scalar(ScalarValue::Utf8View(b.map(|s| s.to_string()))),
132+
];
133+
134+
vec![(utf_8_args, c), (large_utf_8_args, c), (utf_8_view_args, c)]
135+
});
136+
137+
for (args, expected) in test_cases {
138+
test_function!(
139+
StartsWithFunc::new(),
140+
&args,
141+
Ok(expected),
142+
bool,
143+
Boolean,
144+
BooleanArray
145+
);
146+
}
147+
148+
Ok(())
149+
}
150+
}

datafusion/sqllogictest/test_files/string_view.slt

+69-1
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,75 @@ logical_plan
355355
01)Aggregate: groupBy=[[]], aggr=[[count(DISTINCT test.column1_utf8), count(DISTINCT test.column1_utf8view), count(DISTINCT test.column1_dict)]]
356356
02)--TableScan: test projection=[column1_utf8, column1_utf8view, column1_dict]
357357

358+
### `STARTS_WITH`
359+
360+
# Test STARTS_WITH with utf8view against utf8view, utf8, and largeutf8
361+
# (should be no casts)
362+
query TT
363+
EXPLAIN SELECT
364+
STARTS_WITH(column1_utf8view, column2_utf8view) as c1,
365+
STARTS_WITH(column1_utf8view, column2_utf8) as c2,
366+
STARTS_WITH(column1_utf8view, column2_large_utf8) as c3
367+
FROM test;
368+
----
369+
logical_plan
370+
01)Projection: starts_with(test.column1_utf8view, test.column2_utf8view) AS c1, starts_with(test.column1_utf8view, CAST(test.column2_utf8 AS Utf8View)) AS c2, starts_with(test.column1_utf8view, CAST(test.column2_large_utf8 AS Utf8View)) AS c3
371+
02)--TableScan: test projection=[column2_utf8, column2_large_utf8, column1_utf8view, column2_utf8view]
372+
373+
query BBB
374+
SELECT
375+
STARTS_WITH(column1_utf8view, column2_utf8view) as c1,
376+
STARTS_WITH(column1_utf8view, column2_utf8) as c2,
377+
STARTS_WITH(column1_utf8view, column2_large_utf8) as c3
378+
FROM test;
379+
----
380+
false false false
381+
true true true
382+
true true true
383+
NULL NULL NULL
384+
385+
# Test STARTS_WITH with utf8 against utf8view, utf8, and largeutf8
386+
# Should work, but will have to cast to common types
387+
# should cast utf8 -> utf8view and largeutf8 -> utf8view
388+
query TT
389+
EXPLAIN SELECT
390+
STARTS_WITH(column1_utf8, column2_utf8view) as c1,
391+
STARTS_WITH(column1_utf8, column2_utf8) as c3,
392+
STARTS_WITH(column1_utf8, column2_large_utf8) as c4
393+
FROM test;
394+
----
395+
logical_plan
396+
01)Projection: starts_with(__common_expr_1, test.column2_utf8view) AS c1, starts_with(test.column1_utf8, test.column2_utf8) AS c3, starts_with(__common_expr_1, CAST(test.column2_large_utf8 AS Utf8View)) AS c4
397+
02)--Projection: CAST(test.column1_utf8 AS Utf8View) AS __common_expr_1, test.column1_utf8, test.column2_utf8, test.column2_large_utf8, test.column2_utf8view
398+
03)----TableScan: test projection=[column1_utf8, column2_utf8, column2_large_utf8, column2_utf8view]
399+
400+
query BBB
401+
SELECT
402+
STARTS_WITH(column1_utf8, column2_utf8view) as c1,
403+
STARTS_WITH(column1_utf8, column2_utf8) as c3,
404+
STARTS_WITH(column1_utf8, column2_large_utf8) as c4
405+
FROM test;
406+
----
407+
false false false
408+
true true true
409+
true true true
410+
NULL NULL NULL
411+
412+
413+
# Test STARTS_WITH with utf8view against literals
414+
# In this case, the literals should be cast to utf8view. The columns
415+
# should not be cast to utf8.
416+
query TT
417+
EXPLAIN SELECT
418+
STARTS_WITH(column1_utf8view, 'äöüß') as c1,
419+
STARTS_WITH(column1_utf8view, '') as c2,
420+
STARTS_WITH(column1_utf8view, NULL) as c3,
421+
STARTS_WITH(NULL, column1_utf8view) as c4
422+
FROM test;
423+
----
424+
logical_plan
425+
01)Projection: starts_with(test.column1_utf8view, Utf8View("äöüß")) AS c1, starts_with(test.column1_utf8view, Utf8View("")) AS c2, starts_with(test.column1_utf8view, Utf8View(NULL)) AS c3, starts_with(Utf8View(NULL), test.column1_utf8view) AS c4
426+
02)--TableScan: test projection=[column1_utf8view]
358427

359428
statement ok
360429
drop table test;
@@ -376,6 +445,5 @@ select t.dt from dates t where arrow_cast('2024-01-01', 'Utf8View') < t.dt;
376445
----
377446
2024-01-23
378447

379-
380448
statement ok
381449
drop table dates;

0 commit comments

Comments
 (0)