diff --git a/datafusion/expr-common/src/type_coercion/binary.rs b/datafusion/expr-common/src/type_coercion/binary.rs index 6d2fb660f669..b60b0635cf15 100644 --- a/datafusion/expr-common/src/type_coercion/binary.rs +++ b/datafusion/expr-common/src/type_coercion/binary.rs @@ -912,26 +912,22 @@ fn dictionary_coercion( /// Coercion rules for string concat. /// This is a union of string coercion rules and specified rules: -/// 1. At lease one side of lhs and rhs should be string type (Utf8 / LargeUtf8) +/// 1. At least one side of lhs and rhs should be string type (Utf8 / LargeUtf8) /// 2. Data type of the other side should be able to cast to string type fn string_concat_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { use arrow::datatypes::DataType::*; - match (lhs_type, rhs_type) { - // If Utf8View is in any side, we coerce to Utf8. - // Ref: https://github.com/apache/datafusion/pull/11796 - (Utf8View, Utf8View | Utf8 | LargeUtf8) | (Utf8 | LargeUtf8, Utf8View) => { - Some(Utf8) + string_coercion(lhs_type, rhs_type).or(match (lhs_type, rhs_type) { + (Utf8View, from_type) | (from_type, Utf8View) => { + string_concat_internal_coercion(from_type, &Utf8View) } - _ => string_coercion(lhs_type, rhs_type).or(match (lhs_type, rhs_type) { - (Utf8, from_type) | (from_type, Utf8) => { - string_concat_internal_coercion(from_type, &Utf8) - } - (LargeUtf8, from_type) | (from_type, LargeUtf8) => { - string_concat_internal_coercion(from_type, &LargeUtf8) - } - _ => None, - }), - } + (Utf8, from_type) | (from_type, Utf8) => { + string_concat_internal_coercion(from_type, &Utf8) + } + (LargeUtf8, from_type) | (from_type, LargeUtf8) => { + string_concat_internal_coercion(from_type, &LargeUtf8) + } + _ => None, + }) } fn array_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { @@ -942,6 +938,8 @@ fn array_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option } } +/// If `from_type` can be casted to `to_type`, return `to_type`, otherwise +/// return `None`. fn string_concat_internal_coercion( from_type: &DataType, to_type: &DataType, @@ -967,6 +965,7 @@ fn string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option } // Then, if LargeUtf8 is in any side, we coerce to LargeUtf8. (LargeUtf8, Utf8 | LargeUtf8) | (Utf8, LargeUtf8) => Some(LargeUtf8), + // Utf8 coerces to Utf8 (Utf8, Utf8) => Some(Utf8), _ => None, } diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index 26885ae1350c..b663d8614275 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -41,6 +41,7 @@ use datafusion_expr::type_coercion::binary::get_result_type; use datafusion_expr::{ColumnarValue, Operator}; use datafusion_physical_expr_common::datum::{apply, apply_cmp, apply_cmp_for_nested}; +use crate::expressions::binary::kernels::concat_elements_utf8view; use kernels::{ bitwise_and_dyn, bitwise_and_dyn_scalar, bitwise_or_dyn, bitwise_or_dyn_scalar, bitwise_shift_left_dyn, bitwise_shift_left_dyn_scalar, bitwise_shift_right_dyn, @@ -131,34 +132,6 @@ impl std::fmt::Display for BinaryExpr { } } -/// Invoke a compute kernel on a pair of binary data arrays -macro_rules! compute_utf8_op { - ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{ - let ll = $LEFT - .as_any() - .downcast_ref::<$DT>() - .expect("compute_op failed to downcast left side array"); - let rr = $RIGHT - .as_any() - .downcast_ref::<$DT>() - .expect("compute_op failed to downcast right side array"); - Ok(Arc::new(paste::expr! {[<$OP _utf8>]}(&ll, &rr)?)) - }}; -} - -macro_rules! binary_string_array_op { - ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{ - match $LEFT.data_type() { - DataType::Utf8 => compute_utf8_op!($LEFT, $RIGHT, $OP, StringArray), - DataType::LargeUtf8 => compute_utf8_op!($LEFT, $RIGHT, $OP, LargeStringArray), - other => internal_err!( - "Data type {:?} not supported for binary operation '{}' on string arrays", - other, stringify!($OP) - ), - } - }}; -} - /// Invoke a boolean kernel on a pair of arrays macro_rules! boolean_op { ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{ @@ -662,7 +635,7 @@ impl BinaryExpr { BitwiseXor => bitwise_xor_dyn(left, right), BitwiseShiftRight => bitwise_shift_right_dyn(left, right), BitwiseShiftLeft => bitwise_shift_left_dyn(left, right), - StringConcat => binary_string_array_op!(left, right, concat_elements), + StringConcat => concat_elements(left, right), AtArrow | ArrowAt => { unreachable!("ArrowAt and AtArrow should be rewritten to function") } @@ -670,6 +643,28 @@ impl BinaryExpr { } } +fn concat_elements(left: Arc, right: Arc) -> Result { + Ok(match left.data_type() { + DataType::Utf8 => Arc::new(concat_elements_utf8( + left.as_string::(), + right.as_string::(), + )?), + DataType::LargeUtf8 => Arc::new(concat_elements_utf8( + left.as_string::(), + right.as_string::(), + )?), + DataType::Utf8View => Arc::new(concat_elements_utf8view( + left.as_string_view(), + right.as_string_view(), + )?), + other => { + return internal_err!( + "Data type {other:?} not supported for binary operation 'concat_elements' on string arrays" + ); + } + }) +} + /// Create a binary expression whose arguments are correctly coerced. /// This function errors if it is not possible to coerce the arguments /// to computational types supported by the operator. diff --git a/datafusion/physical-expr/src/expressions/binary/kernels.rs b/datafusion/physical-expr/src/expressions/binary/kernels.rs index b0736e140fec..1f9cfed1a44f 100644 --- a/datafusion/physical-expr/src/expressions/binary/kernels.rs +++ b/datafusion/physical-expr/src/expressions/binary/kernels.rs @@ -27,6 +27,7 @@ use arrow::datatypes::DataType; use datafusion_common::internal_err; use datafusion_common::{Result, ScalarValue}; +use arrow_schema::ArrowError; use std::sync::Arc; /// Downcasts $LEFT and $RIGHT to $ARRAY_TYPE and then calls $KERNEL($LEFT, $RIGHT) @@ -131,3 +132,35 @@ create_dyn_scalar_kernel!(bitwise_or_dyn_scalar, bitwise_or_scalar); create_dyn_scalar_kernel!(bitwise_xor_dyn_scalar, bitwise_xor_scalar); create_dyn_scalar_kernel!(bitwise_shift_right_dyn_scalar, bitwise_shift_right_scalar); create_dyn_scalar_kernel!(bitwise_shift_left_dyn_scalar, bitwise_shift_left_scalar); + +pub fn concat_elements_utf8view( + left: &StringViewArray, + right: &StringViewArray, +) -> std::result::Result { + let capacity = left + .data_buffers() + .iter() + .zip(right.data_buffers().iter()) + .map(|(b1, b2)| b1.len() + b2.len()) + .sum(); + let mut result = StringViewBuilder::with_capacity(capacity); + + // Avoid reallocations by writing to a reused buffer (note we + // could be even more efficient r by creating the view directly + // here and avoid the buffer but that would be more complex) + let mut buffer = String::new(); + + for (left, right) in left.iter().zip(right.iter()) { + if let (Some(left), Some(right)) = (left, right) { + use std::fmt::Write; + buffer.clear(); + write!(&mut buffer, "{left}{right}") + .expect("writing into string buffer failed"); + result.append_value(&buffer); + } else { + // at least one of the values is null, so the output is also null + result.append_null() + } + } + Ok(result.finish()) +} diff --git a/datafusion/sqllogictest/test_files/string_view.slt b/datafusion/sqllogictest/test_files/string_view.slt index 0b441bcbeb8f..96ca212a513e 100644 --- a/datafusion/sqllogictest/test_files/string_view.slt +++ b/datafusion/sqllogictest/test_files/string_view.slt @@ -1145,6 +1145,63 @@ FROM test; 0 NULL +# || mixed types +# expect all results to be the same for each row as they all have the same values +query TTTTTTTT +SELECT + column1_utf8view || column2_utf8view, + column1_utf8 || column2_utf8view, + column1_large_utf8 || column2_utf8view, + column1_dict || column2_utf8view, + -- reverse argument order + column2_utf8view || column1_utf8view, + column2_utf8view || column1_utf8, + column2_utf8view || column1_large_utf8, + column2_utf8view || column1_dict +FROM test; +---- +AndrewX AndrewX AndrewX AndrewX XAndrew XAndrew XAndrew XAndrew +XiangpengXiangpeng XiangpengXiangpeng XiangpengXiangpeng XiangpengXiangpeng XiangpengXiangpeng XiangpengXiangpeng XiangpengXiangpeng XiangpengXiangpeng +RaphaelR RaphaelR RaphaelR RaphaelR RRaphael RRaphael RRaphael RRaphael +NULL NULL NULL NULL NULL NULL NULL NULL + +# || constants +# expect all results to be the same for each row as they all have the same values +query TTTTTTTT +SELECT + column1_utf8view || 'foo', + column1_utf8 || 'foo', + column1_large_utf8 || 'foo', + column1_dict || 'foo', + -- reverse argument order + 'foo' || column1_utf8view, + 'foo' || column1_utf8, + 'foo' || column1_large_utf8, + 'foo' || column1_dict +FROM test; +---- +Andrewfoo Andrewfoo Andrewfoo Andrewfoo fooAndrew fooAndrew fooAndrew fooAndrew +Xiangpengfoo Xiangpengfoo Xiangpengfoo Xiangpengfoo fooXiangpeng fooXiangpeng fooXiangpeng fooXiangpeng +Raphaelfoo Raphaelfoo Raphaelfoo Raphaelfoo fooRaphael fooRaphael fooRaphael fooRaphael +NULL NULL NULL NULL NULL NULL NULL NULL + +# || same type (column1 has null, so also tests NULL || NULL) +# expect all results to be the same for each row as they all have the same values +query TTT +SELECT + column1_utf8view || column1_utf8view, + column1_utf8 || column1_utf8, + column1_large_utf8 || column1_large_utf8 + -- Dictionary/Dictionary coercion doesn't work + -- https://github.com/apache/datafusion/issues/12101 + --column1_dict || column1_dict +FROM test; +---- +AndrewAndrew AndrewAndrew AndrewAndrew +XiangpengXiangpeng XiangpengXiangpeng XiangpengXiangpeng +RaphaelRaphael RaphaelRaphael RaphaelRaphael +NULL NULL NULL + statement ok drop table test; @@ -1168,18 +1225,25 @@ select t.dt from dates t where arrow_cast('2024-01-01', 'Utf8View') < t.dt; statement ok drop table dates; +### Tests for `||` with Utf8View specifically + statement ok create table temp as values ('value1', arrow_cast('rust', 'Utf8View'), arrow_cast('fast', 'Utf8View')), ('value2', arrow_cast('datafusion', 'Utf8View'), arrow_cast('cool', 'Utf8View')); +query TTT +select arrow_typeof(column1), arrow_typeof(column2), arrow_typeof(column3) from temp; +---- +Utf8 Utf8View Utf8View +Utf8 Utf8View Utf8View + query T select column2||' is fast' from temp; ---- rust is fast datafusion is fast - query T select column2 || ' is ' || column3 from temp; ---- @@ -1190,15 +1254,15 @@ query TT explain select column2 || 'is' || column3 from temp; ---- logical_plan -01)Projection: CAST(temp.column2 AS Utf8) || Utf8("is") || CAST(temp.column3 AS Utf8) +01)Projection: temp.column2 || Utf8View("is") || temp.column3 AS temp.column2 || Utf8("is") || temp.column3 02)--TableScan: temp projection=[column2, column3] - +# should not cast the column2 to utf8 query TT explain select column2||' is fast' from temp; ---- logical_plan -01)Projection: CAST(temp.column2 AS Utf8) || Utf8(" is fast") +01)Projection: temp.column2 || Utf8View(" is fast") AS temp.column2 || Utf8(" is fast") 02)--TableScan: temp projection=[column2] @@ -1212,7 +1276,7 @@ query TT explain select column2||column3 from temp; ---- logical_plan -01)Projection: CAST(temp.column2 AS Utf8) || CAST(temp.column3 AS Utf8) +01)Projection: temp.column2 || temp.column3 02)--TableScan: temp projection=[column2, column3] query T