Skip to content

Commit 87c1c17

Browse files
dharanadalamb
andauthored
Support string concat || for StringViewArray (#12063)
* naive impl * calc capacity * cleanup * Update test * simplify coercion logic * write some more tests * Update tests * Improve implementation and do the right thing for null * add ticket reference --------- Co-authored-by: Andrew Lamb <[email protected]>
1 parent 41e7378 commit 87c1c17

File tree

4 files changed

+141
-50
lines changed

4 files changed

+141
-50
lines changed

datafusion/expr-common/src/type_coercion/binary.rs

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -922,26 +922,22 @@ fn dictionary_comparison_coercion(
922922

923923
/// Coercion rules for string concat.
924924
/// This is a union of string coercion rules and specified rules:
925-
/// 1. At lease one side of lhs and rhs should be string type (Utf8 / LargeUtf8)
925+
/// 1. At least one side of lhs and rhs should be string type (Utf8 / LargeUtf8)
926926
/// 2. Data type of the other side should be able to cast to string type
927927
fn string_concat_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
928928
use arrow::datatypes::DataType::*;
929-
match (lhs_type, rhs_type) {
930-
// If Utf8View is in any side, we coerce to Utf8.
931-
// Ref: https://github.com/apache/datafusion/pull/11796
932-
(Utf8View, Utf8View | Utf8 | LargeUtf8) | (Utf8 | LargeUtf8, Utf8View) => {
933-
Some(Utf8)
929+
string_coercion(lhs_type, rhs_type).or(match (lhs_type, rhs_type) {
930+
(Utf8View, from_type) | (from_type, Utf8View) => {
931+
string_concat_internal_coercion(from_type, &Utf8View)
934932
}
935-
_ => string_coercion(lhs_type, rhs_type).or(match (lhs_type, rhs_type) {
936-
(Utf8, from_type) | (from_type, Utf8) => {
937-
string_concat_internal_coercion(from_type, &Utf8)
938-
}
939-
(LargeUtf8, from_type) | (from_type, LargeUtf8) => {
940-
string_concat_internal_coercion(from_type, &LargeUtf8)
941-
}
942-
_ => None,
943-
}),
944-
}
933+
(Utf8, from_type) | (from_type, Utf8) => {
934+
string_concat_internal_coercion(from_type, &Utf8)
935+
}
936+
(LargeUtf8, from_type) | (from_type, LargeUtf8) => {
937+
string_concat_internal_coercion(from_type, &LargeUtf8)
938+
}
939+
_ => None,
940+
})
945941
}
946942

947943
fn array_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
@@ -952,6 +948,8 @@ fn array_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType>
952948
}
953949
}
954950

951+
/// If `from_type` can be casted to `to_type`, return `to_type`, otherwise
952+
/// return `None`.
955953
fn string_concat_internal_coercion(
956954
from_type: &DataType,
957955
to_type: &DataType,
@@ -977,6 +975,7 @@ fn string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType>
977975
}
978976
// Then, if LargeUtf8 is in any side, we coerce to LargeUtf8.
979977
(LargeUtf8, Utf8 | LargeUtf8) | (Utf8, LargeUtf8) => Some(LargeUtf8),
978+
// Utf8 coerces to Utf8
980979
(Utf8, Utf8) => Some(Utf8),
981980
_ => None,
982981
}

datafusion/physical-expr/src/expressions/binary.rs

Lines changed: 24 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ use datafusion_expr::type_coercion::binary::get_result_type;
4141
use datafusion_expr::{ColumnarValue, Operator};
4242
use datafusion_physical_expr_common::datum::{apply, apply_cmp, apply_cmp_for_nested};
4343

44+
use crate::expressions::binary::kernels::concat_elements_utf8view;
4445
use kernels::{
4546
bitwise_and_dyn, bitwise_and_dyn_scalar, bitwise_or_dyn, bitwise_or_dyn_scalar,
4647
bitwise_shift_left_dyn, bitwise_shift_left_dyn_scalar, bitwise_shift_right_dyn,
@@ -131,34 +132,6 @@ impl std::fmt::Display for BinaryExpr {
131132
}
132133
}
133134

134-
/// Invoke a compute kernel on a pair of binary data arrays
135-
macro_rules! compute_utf8_op {
136-
($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{
137-
let ll = $LEFT
138-
.as_any()
139-
.downcast_ref::<$DT>()
140-
.expect("compute_op failed to downcast left side array");
141-
let rr = $RIGHT
142-
.as_any()
143-
.downcast_ref::<$DT>()
144-
.expect("compute_op failed to downcast right side array");
145-
Ok(Arc::new(paste::expr! {[<$OP _utf8>]}(&ll, &rr)?))
146-
}};
147-
}
148-
149-
macro_rules! binary_string_array_op {
150-
($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
151-
match $LEFT.data_type() {
152-
DataType::Utf8 => compute_utf8_op!($LEFT, $RIGHT, $OP, StringArray),
153-
DataType::LargeUtf8 => compute_utf8_op!($LEFT, $RIGHT, $OP, LargeStringArray),
154-
other => internal_err!(
155-
"Data type {:?} not supported for binary operation '{}' on string arrays",
156-
other, stringify!($OP)
157-
),
158-
}
159-
}};
160-
}
161-
162135
/// Invoke a boolean kernel on a pair of arrays
163136
macro_rules! boolean_op {
164137
($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
@@ -662,14 +635,36 @@ impl BinaryExpr {
662635
BitwiseXor => bitwise_xor_dyn(left, right),
663636
BitwiseShiftRight => bitwise_shift_right_dyn(left, right),
664637
BitwiseShiftLeft => bitwise_shift_left_dyn(left, right),
665-
StringConcat => binary_string_array_op!(left, right, concat_elements),
638+
StringConcat => concat_elements(left, right),
666639
AtArrow | ArrowAt => {
667640
unreachable!("ArrowAt and AtArrow should be rewritten to function")
668641
}
669642
}
670643
}
671644
}
672645

646+
fn concat_elements(left: Arc<dyn Array>, right: Arc<dyn Array>) -> Result<ArrayRef> {
647+
Ok(match left.data_type() {
648+
DataType::Utf8 => Arc::new(concat_elements_utf8(
649+
left.as_string::<i32>(),
650+
right.as_string::<i32>(),
651+
)?),
652+
DataType::LargeUtf8 => Arc::new(concat_elements_utf8(
653+
left.as_string::<i64>(),
654+
right.as_string::<i64>(),
655+
)?),
656+
DataType::Utf8View => Arc::new(concat_elements_utf8view(
657+
left.as_string_view(),
658+
right.as_string_view(),
659+
)?),
660+
other => {
661+
return internal_err!(
662+
"Data type {other:?} not supported for binary operation 'concat_elements' on string arrays"
663+
);
664+
}
665+
})
666+
}
667+
673668
/// Create a binary expression whose arguments are correctly coerced.
674669
/// This function errors if it is not possible to coerce the arguments
675670
/// to computational types supported by the operator.

datafusion/physical-expr/src/expressions/binary/kernels.rs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ use arrow::datatypes::DataType;
2727
use datafusion_common::internal_err;
2828
use datafusion_common::{Result, ScalarValue};
2929

30+
use arrow_schema::ArrowError;
3031
use std::sync::Arc;
3132

3233
/// Downcasts $LEFT and $RIGHT to $ARRAY_TYPE and then calls $KERNEL($LEFT, $RIGHT)
@@ -131,3 +132,35 @@ create_dyn_scalar_kernel!(bitwise_or_dyn_scalar, bitwise_or_scalar);
131132
create_dyn_scalar_kernel!(bitwise_xor_dyn_scalar, bitwise_xor_scalar);
132133
create_dyn_scalar_kernel!(bitwise_shift_right_dyn_scalar, bitwise_shift_right_scalar);
133134
create_dyn_scalar_kernel!(bitwise_shift_left_dyn_scalar, bitwise_shift_left_scalar);
135+
136+
pub fn concat_elements_utf8view(
137+
left: &StringViewArray,
138+
right: &StringViewArray,
139+
) -> std::result::Result<StringViewArray, ArrowError> {
140+
let capacity = left
141+
.data_buffers()
142+
.iter()
143+
.zip(right.data_buffers().iter())
144+
.map(|(b1, b2)| b1.len() + b2.len())
145+
.sum();
146+
let mut result = StringViewBuilder::with_capacity(capacity);
147+
148+
// Avoid reallocations by writing to a reused buffer (note we
149+
// could be even more efficient r by creating the view directly
150+
// here and avoid the buffer but that would be more complex)
151+
let mut buffer = String::new();
152+
153+
for (left, right) in left.iter().zip(right.iter()) {
154+
if let (Some(left), Some(right)) = (left, right) {
155+
use std::fmt::Write;
156+
buffer.clear();
157+
write!(&mut buffer, "{left}{right}")
158+
.expect("writing into string buffer failed");
159+
result.append_value(&buffer);
160+
} else {
161+
// at least one of the values is null, so the output is also null
162+
result.append_null()
163+
}
164+
}
165+
Ok(result.finish())
166+
}

datafusion/sqllogictest/test_files/string_view.slt

Lines changed: 69 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1144,6 +1144,63 @@ FROM test;
11441144
0
11451145
NULL
11461146

1147+
# || mixed types
1148+
# expect all results to be the same for each row as they all have the same values
1149+
query TTTTTTTT
1150+
SELECT
1151+
column1_utf8view || column2_utf8view,
1152+
column1_utf8 || column2_utf8view,
1153+
column1_large_utf8 || column2_utf8view,
1154+
column1_dict || column2_utf8view,
1155+
-- reverse argument order
1156+
column2_utf8view || column1_utf8view,
1157+
column2_utf8view || column1_utf8,
1158+
column2_utf8view || column1_large_utf8,
1159+
column2_utf8view || column1_dict
1160+
FROM test;
1161+
----
1162+
AndrewX AndrewX AndrewX AndrewX XAndrew XAndrew XAndrew XAndrew
1163+
XiangpengXiangpeng XiangpengXiangpeng XiangpengXiangpeng XiangpengXiangpeng XiangpengXiangpeng XiangpengXiangpeng XiangpengXiangpeng XiangpengXiangpeng
1164+
RaphaelR RaphaelR RaphaelR RaphaelR RRaphael RRaphael RRaphael RRaphael
1165+
NULL NULL NULL NULL NULL NULL NULL NULL
1166+
1167+
# || constants
1168+
# expect all results to be the same for each row as they all have the same values
1169+
query TTTTTTTT
1170+
SELECT
1171+
column1_utf8view || 'foo',
1172+
column1_utf8 || 'foo',
1173+
column1_large_utf8 || 'foo',
1174+
column1_dict || 'foo',
1175+
-- reverse argument order
1176+
'foo' || column1_utf8view,
1177+
'foo' || column1_utf8,
1178+
'foo' || column1_large_utf8,
1179+
'foo' || column1_dict
1180+
FROM test;
1181+
----
1182+
Andrewfoo Andrewfoo Andrewfoo Andrewfoo fooAndrew fooAndrew fooAndrew fooAndrew
1183+
Xiangpengfoo Xiangpengfoo Xiangpengfoo Xiangpengfoo fooXiangpeng fooXiangpeng fooXiangpeng fooXiangpeng
1184+
Raphaelfoo Raphaelfoo Raphaelfoo Raphaelfoo fooRaphael fooRaphael fooRaphael fooRaphael
1185+
NULL NULL NULL NULL NULL NULL NULL NULL
1186+
1187+
# || same type (column1 has null, so also tests NULL || NULL)
1188+
# expect all results to be the same for each row as they all have the same values
1189+
query TTT
1190+
SELECT
1191+
column1_utf8view || column1_utf8view,
1192+
column1_utf8 || column1_utf8,
1193+
column1_large_utf8 || column1_large_utf8
1194+
-- Dictionary/Dictionary coercion doesn't work
1195+
-- https://github.com/apache/datafusion/issues/12101
1196+
--column1_dict || column1_dict
1197+
FROM test;
1198+
----
1199+
AndrewAndrew AndrewAndrew AndrewAndrew
1200+
XiangpengXiangpeng XiangpengXiangpeng XiangpengXiangpeng
1201+
RaphaelRaphael RaphaelRaphael RaphaelRaphael
1202+
NULL NULL NULL
1203+
11471204
statement ok
11481205
drop table test;
11491206

@@ -1167,18 +1224,25 @@ select t.dt from dates t where arrow_cast('2024-01-01', 'Utf8View') < t.dt;
11671224
statement ok
11681225
drop table dates;
11691226

1227+
### Tests for `||` with Utf8View specifically
1228+
11701229
statement ok
11711230
create table temp as values
11721231
('value1', arrow_cast('rust', 'Utf8View'), arrow_cast('fast', 'Utf8View')),
11731232
('value2', arrow_cast('datafusion', 'Utf8View'), arrow_cast('cool', 'Utf8View'));
11741233

1234+
query TTT
1235+
select arrow_typeof(column1), arrow_typeof(column2), arrow_typeof(column3) from temp;
1236+
----
1237+
Utf8 Utf8View Utf8View
1238+
Utf8 Utf8View Utf8View
1239+
11751240
query T
11761241
select column2||' is fast' from temp;
11771242
----
11781243
rust is fast
11791244
datafusion is fast
11801245

1181-
11821246
query T
11831247
select column2 || ' is ' || column3 from temp;
11841248
----
@@ -1189,15 +1253,15 @@ query TT
11891253
explain select column2 || 'is' || column3 from temp;
11901254
----
11911255
logical_plan
1192-
01)Projection: CAST(temp.column2 AS Utf8) || Utf8("is") || CAST(temp.column3 AS Utf8)
1256+
01)Projection: temp.column2 || Utf8View("is") || temp.column3 AS temp.column2 || Utf8("is") || temp.column3
11931257
02)--TableScan: temp projection=[column2, column3]
11941258

1195-
1259+
# should not cast the column2 to utf8
11961260
query TT
11971261
explain select column2||' is fast' from temp;
11981262
----
11991263
logical_plan
1200-
01)Projection: CAST(temp.column2 AS Utf8) || Utf8(" is fast")
1264+
01)Projection: temp.column2 || Utf8View(" is fast") AS temp.column2 || Utf8(" is fast")
12011265
02)--TableScan: temp projection=[column2]
12021266

12031267

@@ -1211,7 +1275,7 @@ query TT
12111275
explain select column2||column3 from temp;
12121276
----
12131277
logical_plan
1214-
01)Projection: CAST(temp.column2 AS Utf8) || CAST(temp.column3 AS Utf8)
1278+
01)Projection: temp.column2 || temp.column3
12151279
02)--TableScan: temp projection=[column2, column3]
12161280

12171281
query T

0 commit comments

Comments
 (0)