From 58151bd8fa574e878da9eecdbba362106c0fb48b Mon Sep 17 00:00:00 2001 From: zhuqi-lucas <821684824@qq.com> Date: Mon, 17 Mar 2025 22:54:09 +0800 Subject: [PATCH] fix --- .../expr-common/src/type_coercion/binary.rs | 148 +++++++++++++++--- .../physical-expr/src/expressions/binary.rs | 66 +++++++- .../test_files/string/string_view.slt | 8 +- 3 files changed, 194 insertions(+), 28 deletions(-) diff --git a/datafusion/expr-common/src/type_coercion/binary.rs b/datafusion/expr-common/src/type_coercion/binary.rs index 682cc885cd6b..fb559e163bb1 100644 --- a/datafusion/expr-common/src/type_coercion/binary.rs +++ b/datafusion/expr-common/src/type_coercion/binary.rs @@ -1177,26 +1177,6 @@ pub fn string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option Option { - use arrow::datatypes::DataType::*; - match (lhs_type, rhs_type) { - // If Utf8View is in any side, we coerce to Utf8. - (Utf8View, Utf8View | Utf8 | LargeUtf8) | (Utf8 | LargeUtf8, Utf8View) => { - Some(Utf8) - } - // Then, if LargeUtf8 is in any side, we coerce to LargeUtf8. - (LargeUtf8, Utf8 | LargeUtf8) | (Utf8, LargeUtf8) => Some(LargeUtf8), - // Utf8 coerces to Utf8 - (Utf8, Utf8) => Some(Utf8), - _ => None, - } -} - fn numeric_string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { use arrow::datatypes::DataType::*; match (lhs_type, rhs_type) { @@ -1327,7 +1307,7 @@ fn regex_null_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option Option { - regex_comparison_string_coercion(lhs_type, rhs_type) + string_coercion(lhs_type, rhs_type) .or_else(|| dictionary_comparison_coercion(lhs_type, rhs_type, false)) .or_else(|| regex_null_coercion(lhs_type, rhs_type)) } @@ -1802,42 +1782,168 @@ mod tests { Operator::RegexMatch, DataType::Utf8 ); + test_coercion_binary_rule!( + DataType::Utf8, + DataType::Utf8View, + Operator::RegexMatch, + DataType::Utf8View + ); + test_coercion_binary_rule!( + DataType::Utf8View, + DataType::Utf8, + Operator::RegexMatch, + DataType::Utf8View + ); + test_coercion_binary_rule!( + DataType::Utf8View, + DataType::Utf8View, + Operator::RegexMatch, + DataType::Utf8View + ); test_coercion_binary_rule!( DataType::Utf8, DataType::Utf8, Operator::RegexNotMatch, DataType::Utf8 ); + test_coercion_binary_rule!( + DataType::Utf8View, + DataType::Utf8, + Operator::RegexNotMatch, + DataType::Utf8View + ); + test_coercion_binary_rule!( + DataType::Utf8, + DataType::Utf8View, + Operator::RegexNotMatch, + DataType::Utf8View + ); + test_coercion_binary_rule!( + DataType::Utf8View, + DataType::Utf8View, + Operator::RegexNotMatch, + DataType::Utf8View + ); test_coercion_binary_rule!( DataType::Utf8, DataType::Utf8, Operator::RegexNotIMatch, DataType::Utf8 ); + test_coercion_binary_rule!( + DataType::Utf8View, + DataType::Utf8, + Operator::RegexNotIMatch, + DataType::Utf8View + ); + test_coercion_binary_rule!( + DataType::Utf8, + DataType::Utf8View, + Operator::RegexNotIMatch, + DataType::Utf8View + ); + test_coercion_binary_rule!( + DataType::Utf8View, + DataType::Utf8View, + Operator::RegexNotIMatch, + DataType::Utf8View + ); test_coercion_binary_rule!( DataType::Dictionary(DataType::Int32.into(), DataType::Utf8.into()), DataType::Utf8, Operator::RegexMatch, DataType::Utf8 ); + test_coercion_binary_rule!( + DataType::Dictionary(DataType::Int32.into(), DataType::Utf8.into()), + DataType::Utf8View, + Operator::RegexMatch, + DataType::Utf8View + ); + test_coercion_binary_rule!( + DataType::Dictionary(DataType::Int32.into(), DataType::Utf8View.into()), + DataType::Utf8, + Operator::RegexMatch, + DataType::Utf8View + ); + test_coercion_binary_rule!( + DataType::Dictionary(DataType::Int32.into(), DataType::Utf8View.into()), + DataType::Utf8View, + Operator::RegexMatch, + DataType::Utf8View + ); test_coercion_binary_rule!( DataType::Dictionary(DataType::Int32.into(), DataType::Utf8.into()), DataType::Utf8, Operator::RegexIMatch, DataType::Utf8 ); + test_coercion_binary_rule!( + DataType::Dictionary(DataType::Int32.into(), DataType::Utf8View.into()), + DataType::Utf8, + Operator::RegexIMatch, + DataType::Utf8View + ); + test_coercion_binary_rule!( + DataType::Dictionary(DataType::Int32.into(), DataType::Utf8.into()), + DataType::Utf8View, + Operator::RegexIMatch, + DataType::Utf8View + ); + test_coercion_binary_rule!( + DataType::Dictionary(DataType::Int32.into(), DataType::Utf8View.into()), + DataType::Utf8View, + Operator::RegexIMatch, + DataType::Utf8View + ); test_coercion_binary_rule!( DataType::Dictionary(DataType::Int32.into(), DataType::Utf8.into()), DataType::Utf8, Operator::RegexNotMatch, DataType::Utf8 ); + test_coercion_binary_rule!( + DataType::Dictionary(DataType::Int32.into(), DataType::Utf8.into()), + DataType::Utf8View, + Operator::RegexNotMatch, + DataType::Utf8View + ); + test_coercion_binary_rule!( + DataType::Dictionary(DataType::Int32.into(), DataType::Utf8View.into()), + DataType::Utf8, + Operator::RegexNotMatch, + DataType::Utf8View + ); + test_coercion_binary_rule!( + DataType::Dictionary(DataType::Int32.into(), DataType::Utf8.into()), + DataType::Utf8View, + Operator::RegexNotMatch, + DataType::Utf8View + ); test_coercion_binary_rule!( DataType::Dictionary(DataType::Int32.into(), DataType::Utf8.into()), DataType::Utf8, Operator::RegexNotIMatch, DataType::Utf8 ); + test_coercion_binary_rule!( + DataType::Dictionary(DataType::Int32.into(), DataType::Utf8View.into()), + DataType::Utf8, + Operator::RegexNotIMatch, + DataType::Utf8View + ); + test_coercion_binary_rule!( + DataType::Dictionary(DataType::Int32.into(), DataType::Utf8.into()), + DataType::Utf8View, + Operator::RegexNotIMatch, + DataType::Utf8View + ); + test_coercion_binary_rule!( + DataType::Dictionary(DataType::Int32.into(), DataType::Utf8View.into()), + DataType::Utf8View, + Operator::RegexNotIMatch, + DataType::Utf8View + ); test_coercion_binary_rule!( DataType::Int16, DataType::Int64, diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index 872773b06fa6..a00d135ef3c1 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -168,9 +168,12 @@ fn boolean_op( macro_rules! binary_string_array_flag_op { ($LEFT:expr, $RIGHT:expr, $OP:ident, $NOT:expr, $FLAG:expr) => {{ match $LEFT.data_type() { - DataType::Utf8View | DataType::Utf8 => { + DataType::Utf8 => { compute_utf8_flag_op!($LEFT, $RIGHT, $OP, StringArray, $NOT, $FLAG) }, + DataType::Utf8View => { + compute_utf8view_flag_op!($LEFT, $RIGHT, $OP, StringViewArray, $NOT, $FLAG) + } DataType::LargeUtf8 => { compute_utf8_flag_op!($LEFT, $RIGHT, $OP, LargeStringArray, $NOT, $FLAG) }, @@ -207,14 +210,42 @@ macro_rules! compute_utf8_flag_op { }}; } +/// Invoke a compute kernel on a pair of binary data arrays with flags +macro_rules! compute_utf8view_flag_op { + ($LEFT:expr, $RIGHT:expr, $OP:ident, $ARRAYTYPE:ident, $NOT:expr, $FLAG:expr) => {{ + let ll = $LEFT + .as_any() + .downcast_ref::<$ARRAYTYPE>() + .expect("compute_utf8view_flag_op failed to downcast array"); + let rr = $RIGHT + .as_any() + .downcast_ref::<$ARRAYTYPE>() + .expect("compute_utf8view_flag_op failed to downcast array"); + + let flag = if $FLAG { + Some($ARRAYTYPE::from(vec!["i"; ll.len()])) + } else { + None + }; + let mut array = $OP(ll, rr, flag.as_ref())?; + if $NOT { + array = not(&array).unwrap(); + } + Ok(Arc::new(array)) + }}; +} + macro_rules! binary_string_array_flag_op_scalar { ($LEFT:ident, $RIGHT:expr, $OP:ident, $NOT:expr, $FLAG:expr) => {{ // This macro is slightly different from binary_string_array_flag_op because, when comparing with a scalar value, // the query can be optimized in such a way that operands will be dicts, so we need to support it here let result: Result> = match $LEFT.data_type() { - DataType::Utf8View | DataType::Utf8 => { + DataType::Utf8 => { compute_utf8_flag_op_scalar!($LEFT, $RIGHT, $OP, StringArray, $NOT, $FLAG) }, + DataType::Utf8View => { + compute_utf8view_flag_op_scalar!($LEFT, $RIGHT, $OP, StringViewArray, $NOT, $FLAG) + } DataType::LargeUtf8 => { compute_utf8_flag_op_scalar!($LEFT, $RIGHT, $OP, LargeStringArray, $NOT, $FLAG) }, @@ -222,7 +253,8 @@ macro_rules! binary_string_array_flag_op_scalar { let values = $LEFT.as_any_dictionary().values(); match values.data_type() { - DataType::Utf8View | DataType::Utf8 => compute_utf8_flag_op_scalar!(values, $RIGHT, $OP, StringArray, $NOT, $FLAG), + DataType::Utf8 => compute_utf8_flag_op_scalar!(values, $RIGHT, $OP, StringArray, $NOT, $FLAG), + DataType::Utf8View => compute_utf8view_flag_op_scalar!(values, $RIGHT, $OP, StringViewArray, $NOT, $FLAG), DataType::LargeUtf8 => compute_utf8_flag_op_scalar!(values, $RIGHT, $OP, LargeStringArray, $NOT, $FLAG), other => internal_err!( "Data type {:?} not supported as a dictionary value type for binary_string_array_flag_op_scalar operation '{}' on string array", @@ -276,6 +308,34 @@ macro_rules! compute_utf8_flag_op_scalar { }}; } +/// Invoke a compute kernel on a data array and a scalar value with flag +macro_rules! compute_utf8view_flag_op_scalar { + ($LEFT:expr, $RIGHT:expr, $OP:ident, $ARRAYTYPE:ident, $NOT:expr, $FLAG:expr) => {{ + let ll = $LEFT + .as_any() + .downcast_ref::<$ARRAYTYPE>() + .expect("compute_utf8view_flag_op_scalar failed to downcast array"); + + let string_value = match $RIGHT.try_as_str() { + Some(Some(string_value)) => string_value, + // null literal or non string + _ => return internal_err!( + "compute_utf8view_flag_op_scalar failed to cast literal value {} for operation '{}'", + $RIGHT, stringify!($OP) + ) + }; + + let flag = $FLAG.then_some("i"); + let mut array = + paste::expr! {[<$OP _scalar>]}(ll, &string_value, flag)?; + if $NOT { + array = not(&array).unwrap(); + } + + Ok(Arc::new(array)) + }}; +} + impl PhysicalExpr for BinaryExpr { /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { diff --git a/datafusion/sqllogictest/test_files/string/string_view.slt b/datafusion/sqllogictest/test_files/string/string_view.slt index 69c4b9bfcb4b..96fb2477598c 100644 --- a/datafusion/sqllogictest/test_files/string/string_view.slt +++ b/datafusion/sqllogictest/test_files/string/string_view.slt @@ -1100,7 +1100,7 @@ EXPLAIN SELECT FROM test; ---- logical_plan -01)Projection: CAST(test.column1_utf8view AS Utf8) LIKE Utf8("%an%") AS c1 +01)Projection: test.column1_utf8view ~ Utf8View("an") AS c1 02)--TableScan: test projection=[column1_utf8view] # `~*` operator (regex match case-insensitive) @@ -1110,7 +1110,7 @@ EXPLAIN SELECT FROM test; ---- logical_plan -01)Projection: CAST(test.column1_utf8view AS Utf8) ~* Utf8("^a.{3}e") AS c1 +01)Projection: test.column1_utf8view ~* Utf8View("^a.{3}e") AS c1 02)--TableScan: test projection=[column1_utf8view] # `!~~` operator (not like match) @@ -1120,7 +1120,7 @@ EXPLAIN SELECT FROM test; ---- logical_plan -01)Projection: CAST(test.column1_utf8view AS Utf8) !~~ Utf8("xia_g%g") AS c1 +01)Projection: test.column1_utf8view !~~ Utf8View("xia_g%g") AS c1 02)--TableScan: test projection=[column1_utf8view] # `!~~*` operator (not like match case-insensitive) @@ -1130,7 +1130,7 @@ EXPLAIN SELECT FROM test; ---- logical_plan -01)Projection: CAST(test.column1_utf8view AS Utf8) !~~* Utf8("xia_g%g") AS c1 +01)Projection: test.column1_utf8view !~~* Utf8View("xia_g%g") AS c1 02)--TableScan: test projection=[column1_utf8view] # coercions between stringview and date types