From cbc9532f910e23b1aaaafe3deab61a2886770740 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sun, 10 Nov 2024 10:02:16 -0500 Subject: [PATCH 1/5] Add string view options to concat, fix simplifier for handling concat to return the same schema as without --- datafusion/functions/src/string/concat.rs | 83 ++++++++++++++++++++--- 1 file changed, 73 insertions(+), 10 deletions(-) diff --git a/datafusion/functions/src/string/concat.rs b/datafusion/functions/src/string/concat.rs index e429a938b27d..bb1e83151131 100644 --- a/datafusion/functions/src/string/concat.rs +++ b/datafusion/functions/src/string/concat.rs @@ -110,8 +110,19 @@ impl ScalarUDFImpl for ConcatFunc { if array_len.is_none() { let mut result = String::new(); for arg in args { - if let ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) = arg { - result.push_str(v); + match arg { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) + | ColumnarValue::Scalar(ScalarValue::Utf8View(Some(v))) + | ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(v))) => { + result.push_str(v); + } + ColumnarValue::Scalar(ScalarValue::Utf8(None)) + | ColumnarValue::Scalar(ScalarValue::Utf8View(None)) + | ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)) => {} + other => plan_err!( + "Concat function does not support scalar type {:?}", + other + )?, } } @@ -282,15 +293,34 @@ pub fn simplify_concat(args: Vec) -> Result { let mut new_args = Vec::with_capacity(args.len()); let mut contiguous_scalar = "".to_string(); + let mut merged_type = DataType::Utf8; for arg in args.clone() { match arg { + Expr::Literal(ScalarValue::Utf8(None)) => {} + Expr::Literal(ScalarValue::LargeUtf8(None)) => { + if merged_type != DataType::Utf8View { + merged_type = DataType::LargeUtf8; + } + } + Expr::Literal(ScalarValue::Utf8View(None)) => { merged_type = DataType::Utf8View } + // filter out `null` args - Expr::Literal(ScalarValue::Utf8(None) | ScalarValue::LargeUtf8(None) | ScalarValue::Utf8View(None)) => {} // All literals have been converted to Utf8 or LargeUtf8 in type_coercion. // Concatenate it with the `contiguous_scalar`. - Expr::Literal( - ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v)) | ScalarValue::Utf8View(Some(v)), - ) => contiguous_scalar += &v, + Expr::Literal(ScalarValue::Utf8(Some(v))) => { + contiguous_scalar += &v; + } + Expr::Literal(ScalarValue::LargeUtf8(Some(v))) => { + if merged_type != DataType::Utf8View { + merged_type = DataType::LargeUtf8; + } + contiguous_scalar += &v; + } + Expr::Literal(ScalarValue::Utf8View(Some(v))) => { + merged_type = DataType::Utf8View; + contiguous_scalar += &v; + } + Expr::Literal(x) => { return internal_err!( "The scalar {x} should be casted to string type during the type coercion." @@ -301,7 +331,13 @@ pub fn simplify_concat(args: Vec) -> Result { // Then pushing this arg to the `new_args`. arg => { if !contiguous_scalar.is_empty() { - new_args.push(lit(contiguous_scalar)); + match merged_type { + DataType::Utf8 => new_args.push(lit(contiguous_scalar)), + DataType::LargeUtf8 => new_args.push(lit(ScalarValue::LargeUtf8(Some(contiguous_scalar)))), + DataType::Utf8View => new_args.push(lit(ScalarValue::Utf8View(Some(contiguous_scalar)))), + _ => unreachable!(), + } + merged_type = DataType::Utf8; contiguous_scalar = "".to_string(); } new_args.push(arg); @@ -310,7 +346,16 @@ pub fn simplify_concat(args: Vec) -> Result { } if !contiguous_scalar.is_empty() { - new_args.push(lit(contiguous_scalar)); + match merged_type { + DataType::Utf8 => new_args.push(lit(contiguous_scalar)), + DataType::LargeUtf8 => { + new_args.push(lit(ScalarValue::LargeUtf8(Some(contiguous_scalar)))) + } + DataType::Utf8View => { + new_args.push(lit(ScalarValue::Utf8View(Some(contiguous_scalar)))) + } + _ => unreachable!(), + } } if !args.eq(&new_args) { @@ -392,6 +437,17 @@ mod tests { LargeUtf8, LargeStringArray ); + test_function!( + ConcatFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some("aa".to_string()))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("cc".to_string()))), + ], + Ok(Some("aacc")), + &str, + Utf8View, + StringViewArray + ); Ok(()) } @@ -406,12 +462,19 @@ mod tests { None, Some("z"), ]))); - let args = &[c0, c1, c2]; + let c3 = ColumnarValue::Scalar(ScalarValue::Utf8View(Some(",".to_string()))); + let c4 = ColumnarValue::Array(Arc::new(StringViewArray::from(vec![ + Some("a"), + None, + Some("b"), + ]))); + let args = &[c0, c1, c2, c3, c4]; #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch let result = ConcatFunc::new().invoke(args)?; let expected = - Arc::new(StringArray::from(vec!["foo,x", "bar,", "baz,z"])) as ArrayRef; + Arc::new(StringViewArray::from(vec!["foo,x,a", "bar,,", "baz,z,b"])) + as ArrayRef; match &result { ColumnarValue::Array(array) => { assert_eq!(&expected, array); From 8937955217af80075e3e0aa75280419d5a979752 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sun, 10 Nov 2024 10:55:04 -0500 Subject: [PATCH 2/5] Set coersion ordering --- datafusion/functions/src/string/concat.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/functions/src/string/concat.rs b/datafusion/functions/src/string/concat.rs index bb1e83151131..61677550fdf9 100644 --- a/datafusion/functions/src/string/concat.rs +++ b/datafusion/functions/src/string/concat.rs @@ -48,7 +48,7 @@ impl ConcatFunc { use DataType::*; Self { signature: Signature::variadic( - vec![Utf8, Utf8View, LargeUtf8], + vec![Utf8View, LargeUtf8, Utf8], Volatility::Immutable, ), } From 0c6b886ba9330d2f36eebb814e5bfaa5c86b20d0 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sun, 10 Nov 2024 10:57:17 -0500 Subject: [PATCH 3/5] Add to simplification unit test to catch changes in type for concat --- .../core/tests/expr_api/simplification.rs | 24 +++++++++++++++---- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/datafusion/core/tests/expr_api/simplification.rs b/datafusion/core/tests/expr_api/simplification.rs index 68785b7a5a45..1e6ff8088d0a 100644 --- a/datafusion/core/tests/expr_api/simplification.rs +++ b/datafusion/core/tests/expr_api/simplification.rs @@ -483,10 +483,12 @@ fn expr_test_schema() -> DFSchemaRef { Field::new("c2", DataType::Boolean, true), Field::new("c3", DataType::Int64, true), Field::new("c4", DataType::UInt32, true), + Field::new("c5", DataType::Utf8View, true), Field::new("c1_non_null", DataType::Utf8, false), Field::new("c2_non_null", DataType::Boolean, false), Field::new("c3_non_null", DataType::Int64, false), Field::new("c4_non_null", DataType::UInt32, false), + Field::new("c5_non_null", DataType::Utf8View, false), ]) .to_dfschema_ref() .unwrap() @@ -665,20 +667,32 @@ fn test_simplify_concat_ws_with_null() { } #[test] -fn test_simplify_concat() { +fn test_simplify_concat() -> Result<()> { + let schema = expr_test_schema(); let null = lit(ScalarValue::Utf8(None)); let expr = concat(vec![ null.clone(), - col("c0"), + col("c1"), lit("hello "), null.clone(), lit("rust"), - col("c1"), + lit(ScalarValue::Utf8View(Some("!".to_string()))), + col("c2"), lit(""), null, + col("c5"), ]); - let expected = concat(vec![col("c0"), lit("hello rust"), col("c1")]); - test_simplify(expr, expected) + let expr_datatype = expr.get_type(schema.as_ref())?; + let expected = concat(vec![ + col("c1"), + lit(ScalarValue::Utf8View(Some("hello rust!".to_string()))), + col("c2"), + col("c5"), + ]); + let expected_datatype = expected.get_type(schema.as_ref())?; + assert_eq!(expr_datatype, expected_datatype); + test_simplify(expr, expected); + Ok(()) } #[test] fn test_simplify_cycles() { From 8bad384744cca071b074475c59c2aa09874b656a Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Mon, 11 Nov 2024 07:00:40 -0500 Subject: [PATCH 4/5] Update coersion ordering --- datafusion/functions/src/string/concat.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/functions/src/string/concat.rs b/datafusion/functions/src/string/concat.rs index 61677550fdf9..228d727dc2cc 100644 --- a/datafusion/functions/src/string/concat.rs +++ b/datafusion/functions/src/string/concat.rs @@ -48,7 +48,7 @@ impl ConcatFunc { use DataType::*; Self { signature: Signature::variadic( - vec![Utf8View, LargeUtf8, Utf8], + vec![Utf8View, Utf8, LargeUtf8], Volatility::Immutable, ), } From bac1f2bf4965b5116d2a605ece9031e40f4dc602 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Fri, 15 Nov 2024 10:31:23 -0600 Subject: [PATCH 5/5] Simplify computing merged type for concat --- datafusion/functions/src/string/concat.rs | 26 ++++++++++++----------- 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/datafusion/functions/src/string/concat.rs b/datafusion/functions/src/string/concat.rs index 228d727dc2cc..c76a08653f53 100644 --- a/datafusion/functions/src/string/concat.rs +++ b/datafusion/functions/src/string/concat.rs @@ -293,16 +293,23 @@ pub fn simplify_concat(args: Vec) -> Result { let mut new_args = Vec::with_capacity(args.len()); let mut contiguous_scalar = "".to_string(); - let mut merged_type = DataType::Utf8; + let return_type = { + let data_types: Vec<_> = args + .iter() + .filter_map(|expr| match expr { + Expr::Literal(l) => Some(l.data_type()), + _ => None, + }) + .collect(); + ConcatFunc::new().return_type(&data_types) + }?; + for arg in args.clone() { match arg { Expr::Literal(ScalarValue::Utf8(None)) => {} Expr::Literal(ScalarValue::LargeUtf8(None)) => { - if merged_type != DataType::Utf8View { - merged_type = DataType::LargeUtf8; - } } - Expr::Literal(ScalarValue::Utf8View(None)) => { merged_type = DataType::Utf8View } + Expr::Literal(ScalarValue::Utf8View(None)) => { } // filter out `null` args // All literals have been converted to Utf8 or LargeUtf8 in type_coercion. @@ -311,13 +318,9 @@ pub fn simplify_concat(args: Vec) -> Result { contiguous_scalar += &v; } Expr::Literal(ScalarValue::LargeUtf8(Some(v))) => { - if merged_type != DataType::Utf8View { - merged_type = DataType::LargeUtf8; - } contiguous_scalar += &v; } Expr::Literal(ScalarValue::Utf8View(Some(v))) => { - merged_type = DataType::Utf8View; contiguous_scalar += &v; } @@ -331,13 +334,12 @@ pub fn simplify_concat(args: Vec) -> Result { // Then pushing this arg to the `new_args`. arg => { if !contiguous_scalar.is_empty() { - match merged_type { + match return_type { DataType::Utf8 => new_args.push(lit(contiguous_scalar)), DataType::LargeUtf8 => new_args.push(lit(ScalarValue::LargeUtf8(Some(contiguous_scalar)))), DataType::Utf8View => new_args.push(lit(ScalarValue::Utf8View(Some(contiguous_scalar)))), _ => unreachable!(), } - merged_type = DataType::Utf8; contiguous_scalar = "".to_string(); } new_args.push(arg); @@ -346,7 +348,7 @@ pub fn simplify_concat(args: Vec) -> Result { } if !contiguous_scalar.is_empty() { - match merged_type { + match return_type { DataType::Utf8 => new_args.push(lit(contiguous_scalar)), DataType::LargeUtf8 => { new_args.push(lit(ScalarValue::LargeUtf8(Some(contiguous_scalar))))