From 81f33b0e27f5694348cd953a937203d835b57178 Mon Sep 17 00:00:00 2001 From: casperhart <39182232+casperhart@users.noreply.github.com> Date: Mon, 18 Sep 2023 17:44:59 -0400 Subject: [PATCH] implement string_to_array (#7577) * implement string_to_array * string_to_array doc and test updates * move string_to_array from string functions to array functions --- datafusion/expr/src/built_in_function.rs | 17 +++- datafusion/expr/src/expr_fn.rs | 2 + .../physical-expr/src/array_expressions.rs | 89 +++++++++++++++++++ datafusion/physical-expr/src/functions.rs | 15 ++++ datafusion/proto/proto/datafusion.proto | 1 + datafusion/proto/src/generated/pbjson.rs | 3 + datafusion/proto/src/generated/prost.rs | 3 + .../proto/src/logical_plan/from_proto.rs | 1 + datafusion/proto/src/logical_plan/to_proto.rs | 1 + datafusion/sqllogictest/test_files/array.slt | 46 ++++++++++ .../sqllogictest/test_files/functions.slt | 2 +- .../source/user-guide/sql/scalar_functions.md | 24 +++++ 12 files changed, 202 insertions(+), 2 deletions(-) diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index f38c7f12a859..3f1eb581aa98 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -242,6 +242,8 @@ pub enum BuiltinScalarFunction { SHA512, /// split_part SplitPart, + /// string_to_array + StringToArray, /// starts_with StartsWith, /// strpos @@ -426,6 +428,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::SHA512 => Volatility::Immutable, BuiltinScalarFunction::Digest => Volatility::Immutable, BuiltinScalarFunction::SplitPart => Volatility::Immutable, + BuiltinScalarFunction::StringToArray => Volatility::Immutable, BuiltinScalarFunction::StartsWith => Volatility::Immutable, BuiltinScalarFunction::Strpos => Volatility::Immutable, BuiltinScalarFunction::Substr => Volatility::Immutable, @@ -711,6 +714,11 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::SplitPart => { utf8_to_str_type(&input_expr_types[0], "split_part") } + BuiltinScalarFunction::StringToArray => Ok(List(Arc::new(Field::new( + "item", + input_expr_types[0].clone(), + true, + )))), BuiltinScalarFunction::StartsWith => Ok(Boolean), BuiltinScalarFunction::Strpos => { utf8_to_int_type(&input_expr_types[0], "strpos") @@ -1068,7 +1076,13 @@ impl BuiltinScalarFunction { ], self.volatility(), ), - + BuiltinScalarFunction::StringToArray => Signature::one_of( + vec![ + TypeSignature::Uniform(2, vec![Utf8, LargeUtf8]), + TypeSignature::Uniform(3, vec![Utf8, LargeUtf8]), + ], + self.volatility(), + ), BuiltinScalarFunction::Strpos | BuiltinScalarFunction::StartsWith => { Signature::one_of( vec![ @@ -1279,6 +1293,7 @@ fn aliases(func: &BuiltinScalarFunction) -> &'static [&'static str] { BuiltinScalarFunction::Rpad => &["rpad"], BuiltinScalarFunction::Rtrim => &["rtrim"], BuiltinScalarFunction::SplitPart => &["split_part"], + BuiltinScalarFunction::StringToArray => &["string_to_array", "string_to_list"], BuiltinScalarFunction::StartsWith => &["starts_with"], BuiltinScalarFunction::Strpos => &["strpos"], BuiltinScalarFunction::Substr => &["substr"], diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 325d2f16fb0b..711dc123a4a4 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -746,6 +746,7 @@ scalar_expr!(SHA256, sha256, string, "SHA-256 hash"); scalar_expr!(SHA384, sha384, string, "SHA-384 hash"); scalar_expr!(SHA512, sha512, string, "SHA-512 hash"); scalar_expr!(SplitPart, split_part, string delimiter index, "splits a string based on a delimiter and picks out the desired field based on the index."); +scalar_expr!(StringToArray, string_to_array, string delimiter null_string, "splits a `string` based on a `delimiter` and returns an array of parts. Any parts matching the optional `null_string` will be replaced with `NULL`"); scalar_expr!(StartsWith, starts_with, string prefix, "whether the `string` starts with the `prefix`"); scalar_expr!(Strpos, strpos, string substring, "finds the position from where the `substring` matches the `string`"); scalar_expr!(Substr, substr, string position, "substring from the `position` to the end"); @@ -1080,6 +1081,7 @@ mod test { test_scalar_expr!(SHA384, sha384, string); test_scalar_expr!(SHA512, sha512, string); test_scalar_expr!(SplitPart, split_part, expr, delimiter, index); + test_scalar_expr!(StringToArray, string_to_array, expr, delimiter, null_value); test_scalar_expr!(StartsWith, starts_with, string, characters); test_scalar_expr!(Strpos, strpos, string, substring); test_scalar_expr!(Substr, substr, string, position); diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 2d84d8b3bdcf..34fbfc3c0269 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -1864,6 +1864,95 @@ pub fn array_has_all(args: &[ArrayRef]) -> Result { Ok(Arc::new(boolean_builder.finish())) } +/// Splits string at occurrences of delimiter and returns an array of parts +/// string_to_array('abc~@~def~@~ghi', '~@~') = '["abc", "def", "ghi"]' +pub fn string_to_array(args: &[ArrayRef]) -> Result { + let string_array = as_generic_string_array::(&args[0])?; + let delimiter_array = as_generic_string_array::(&args[1])?; + + let mut list_builder = ListBuilder::new(StringBuilder::with_capacity( + string_array.len(), + string_array.get_buffer_memory_size(), + )); + + match args.len() { + 2 => { + string_array.iter().zip(delimiter_array.iter()).for_each( + |(string, delimiter)| { + match (string, delimiter) { + (Some(string), Some("")) => { + list_builder.values().append_value(string); + list_builder.append(true); + } + (Some(string), Some(delimiter)) => { + string.split(delimiter).for_each(|s| { + list_builder.values().append_value(s); + }); + list_builder.append(true); + } + (Some(string), None) => { + string.chars().map(|c| c.to_string()).for_each(|c| { + list_builder.values().append_value(c); + }); + list_builder.append(true); + } + _ => list_builder.append(false), // null value + } + }, + ); + } + + 3 => { + let null_value_array = as_generic_string_array::(&args[2])?; + string_array + .iter() + .zip(delimiter_array.iter()) + .zip(null_value_array.iter()) + .for_each(|((string, delimiter), null_value)| { + match (string, delimiter) { + (Some(string), Some("")) => { + if Some(string) == null_value { + list_builder.values().append_null(); + } else { + list_builder.values().append_value(string); + } + list_builder.append(true); + } + (Some(string), Some(delimiter)) => { + string.split(delimiter).for_each(|s| { + if Some(s) == null_value { + list_builder.values().append_null(); + } else { + list_builder.values().append_value(s); + } + }); + list_builder.append(true); + } + (Some(string), None) => { + string.chars().map(|c| c.to_string()).for_each(|c| { + if Some(c.as_str()) == null_value { + list_builder.values().append_null(); + } else { + list_builder.values().append_value(c); + } + }); + list_builder.append(true); + } + _ => list_builder.append(false), // null value + } + }); + } + _ => { + return internal_err!( + "Expect string_to_array function to take two or three parameters" + ) + } + } + + let list_array = list_builder.finish(); + Ok(Arc::new(list_array) as ArrayRef) +} + #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 503cfe412bc8..5de0dc366b85 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -800,6 +800,21 @@ pub fn create_physical_fun( internal_err!("Unsupported data type {other:?} for function split_part") } }), + BuiltinScalarFunction::StringToArray => { + Arc::new(|args| match args[0].data_type() { + DataType::Utf8 => { + make_scalar_function(array_expressions::string_to_array::)(args) + } + DataType::LargeUtf8 => { + make_scalar_function(array_expressions::string_to_array::)(args) + } + other => { + internal_err!( + "Unsupported data type {other:?} for function string_to_array" + ) + } + }) + } BuiltinScalarFunction::StartsWith => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { make_scalar_function(string_expressions::starts_with::)(args) diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index f036d6f447a9..89e307a2299f 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -599,6 +599,7 @@ enum ScalarFunction { Iszero = 114; ArrayEmpty = 115; ArrayPopBack = 116; + StringToArray = 117; } message ScalarFunctionNode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index a70de83342fa..5ae817d1783f 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -19105,6 +19105,7 @@ impl serde::Serialize for ScalarFunction { Self::Iszero => "Iszero", Self::ArrayEmpty => "ArrayEmpty", Self::ArrayPopBack => "ArrayPopBack", + Self::StringToArray => "StringToArray", }; serializer.serialize_str(variant) } @@ -19233,6 +19234,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Iszero", "ArrayEmpty", "ArrayPopBack", + "StringToArray", ]; struct GeneratedVisitor; @@ -19392,6 +19394,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Iszero" => Ok(ScalarFunction::Iszero), "ArrayEmpty" => Ok(ScalarFunction::ArrayEmpty), "ArrayPopBack" => Ok(ScalarFunction::ArrayPopBack), + "StringToArray" => Ok(ScalarFunction::StringToArray), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index f90f46202487..2fbf4d282aec 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2393,6 +2393,7 @@ pub enum ScalarFunction { Iszero = 114, ArrayEmpty = 115, ArrayPopBack = 116, + StringToArray = 117, } impl ScalarFunction { /// String value of the enum field names used in the ProtoBuf definition. @@ -2518,6 +2519,7 @@ impl ScalarFunction { ScalarFunction::Iszero => "Iszero", ScalarFunction::ArrayEmpty => "ArrayEmpty", ScalarFunction::ArrayPopBack => "ArrayPopBack", + ScalarFunction::StringToArray => "StringToArray", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -2640,6 +2642,7 @@ impl ScalarFunction { "Iszero" => Some(Self::Iszero), "ArrayEmpty" => Some(Self::ArrayEmpty), "ArrayPopBack" => Some(Self::ArrayPopBack), + "StringToArray" => Some(Self::StringToArray), _ => None, } } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 0352f703c735..e13cbd87932a 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -511,6 +511,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::Right => Self::Right, ScalarFunction::Rpad => Self::Rpad, ScalarFunction::SplitPart => Self::SplitPart, + ScalarFunction::StringToArray => Self::StringToArray, ScalarFunction::StartsWith => Self::StartsWith, ScalarFunction::Strpos => Self::Strpos, ScalarFunction::Substr => Self::Substr, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 1a9ee64fcf91..8a8550d05d13 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1518,6 +1518,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::Right => Self::Right, BuiltinScalarFunction::Rpad => Self::Rpad, BuiltinScalarFunction::SplitPart => Self::SplitPart, + BuiltinScalarFunction::StringToArray => Self::StringToArray, BuiltinScalarFunction::StartsWith => Self::StartsWith, BuiltinScalarFunction::Strpos => Self::Strpos, BuiltinScalarFunction::Substr => Self::Substr, diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index f54c2f71718c..f11bc5206eb4 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -2485,6 +2485,52 @@ NULL false false +query ? +SELECT string_to_array('abcxxxdef', 'xxx') +---- +[abc, def] + +query ? +SELECT string_to_array('abc', '') +---- +[abc] + +query ? +SELECT string_to_array('abc', NULL) +---- +[a, b, c] + +query ? +SELECT string_to_array('abc def', ' ', 'def') +---- +[abc, ] + +query ? +select string_to_array(e, ',') from values; +---- +[Lorem] +[ipsum] +[dolor] +[sit] +[amet] +[, ] +[consectetur] +[adipiscing] +NULL + +query ? +select string_to_list(e, 'm') from values; +---- +[Lore, ] +[ipsu, ] +[dolor] +[sit] +[a, et] +[,] +[consectetur] +[adipiscing] +NULL + ### Delete tables statement ok diff --git a/datafusion/sqllogictest/test_files/functions.slt b/datafusion/sqllogictest/test_files/functions.slt index d6831cbd8ba3..e3e39ef6cc4c 100644 --- a/datafusion/sqllogictest/test_files/functions.slt +++ b/datafusion/sqllogictest/test_files/functions.slt @@ -810,4 +810,4 @@ SELECT products.* REPLACE (price*2 AS price, product_id+1000 AS product_id) FROM 1001 OldBrand Product 1 39.98 1002 OldBrand Product 2 59.98 1003 OldBrand Product 3 79.98 -1004 OldBrand Product 4 99.98 \ No newline at end of file +1004 OldBrand Product 4 99.98 diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index cd7245b34707..b68cac5cb774 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -1523,6 +1523,8 @@ from_unixtime(expression) - [list_to_string](#list_to_string) - [make_array](#make_array) - [make_list](#make_list) +- [string_to_array](#string_to_array) +- [string_to_list](#string_to_list) - [trim_array](#trim_array) ### `array_append` @@ -2369,6 +2371,28 @@ make_array(expression1[, ..., expression_n]) _Alias of [make_array](#make_array)._ +### `string_to_array` + +Splits a string in to an array of substrings based on a delimiter. Any substrings matching the optional `null_str` argument are replaced with NULL. + +``` +starts_with(str, delimiter[, null_str]) +``` + +#### Arguments + +- **str**: String expression to split. +- **delimiter**: Delimiter string to split on. +- **null_str**: Substring values to be replaced with `NULL` + +#### Aliases + +- string_to_list + +### `string_to_list` + +_Alias of [string_to_array](#string_to_array)._ + ### `trim_array` Removes the last n elements from the array.