From ee59dcc730ce6dda2340cde31a6168ff7a5e9c7a Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Thu, 10 Aug 2023 01:14:09 +0800 Subject: [PATCH] Support array `flatten` sql function (#7239) * Support array flatten sql function Signed-off-by: jayzhan211 * add null and float Signed-off-by: jayzhan211 * add alias, 1d test and docs Signed-off-by: jayzhan211 * pretty Signed-off-by: jayzhan211 * rename Signed-off-by: jayzhan211 --------- Signed-off-by: jayzhan211 --- .../tests/sqllogictests/test_files/array.slt | 34 ++++++++++++++ datafusion/expr/src/built_in_function.rs | 25 +++++++++- datafusion/expr/src/expr_fn.rs | 6 +++ datafusion/expr/src/expr_schema.rs | 1 + .../physical-expr/src/array_expressions.rs | 47 +++++++++++++++++++ datafusion/physical-expr/src/functions.rs | 4 ++ datafusion/proto/proto/datafusion.proto | 1 + datafusion/proto/src/generated/pbjson.rs | 3 ++ datafusion/proto/src/generated/prost.rs | 3 ++ .../proto/src/logical_plan/from_proto.rs | 1 + datafusion/proto/src/logical_plan/to_proto.rs | 1 + docs/source/user-guide/expressions.md | 1 + .../source/user-guide/sql/scalar_functions.md | 18 +++++++ 13 files changed, 143 insertions(+), 2 deletions(-) diff --git a/datafusion/core/tests/sqllogictests/test_files/array.slt b/datafusion/core/tests/sqllogictests/test_files/array.slt index 218817fc165d..569f14f99a4e 100644 --- a/datafusion/core/tests/sqllogictests/test_files/array.slt +++ b/datafusion/core/tests/sqllogictests/test_files/array.slt @@ -110,6 +110,13 @@ AS VALUES (NULL, NULL, NULL, NULL) ; +statement ok +CREATE TABLE flatten_table +AS VALUES + (make_array([1], [2], [3]), make_array([[1, 2, 3]], [[4, 5]], [[6]]), make_array([[[1]]], [[[2, 3]]]), make_array([1.0], [2.1, 2.2], [3.2, 3.3, 3.4])), + (make_array([1, 2], [3, 4], [5, 6]), make_array([[8]]), make_array([[[1,2]]], [[[3]]]), make_array([1.0, 2.0], [3.0, 4.0], [5.0, 6.0])) +; + statement ok CREATE TABLE array_has_table_1D AS VALUES @@ -2330,6 +2337,30 @@ select array_concat(column1, [7]) from arrays_values_v2; [11, 12, 7] [7] +# flatten +query ??? +select flatten(make_array(1, 2, 1, 3, 2)), + flatten(make_array([1], [2, 3], [null], make_array(4, null, 5))), + flatten(make_array([[1.1]], [[2.2]], [[3.3], [4.4]])); +---- +[1, 2, 1, 3, 2] [1, 2, 3, , 4, , 5] [1.1, 2.2, 3.3, 4.4] + +query ???? +select column1, column2, column3, column4 from flatten_table; +---- +[[1], [2], [3]] [[[1, 2, 3]], [[4, 5]], [[6]]] [[[[1]]], [[[2, 3]]]] [[1.0], [2.1, 2.2], [3.2, 3.3, 3.4]] +[[1, 2], [3, 4], [5, 6]] [[[8]]] [[[[1, 2]]], [[[3]]]] [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]] + +query ???? +select flatten(column1), + flatten(column2), + flatten(column3), + flatten(column4) +from flatten_table; +---- +[1, 2, 3] [1, 2, 3, 4, 5, 6] [1, 2, 3] [1.0, 2.1, 2.2, 3.2, 3.3, 3.4] +[1, 2, 3, 4, 5, 6] [8] [1, 2, 3] [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] + ### Delete tables statement ok @@ -2382,3 +2413,6 @@ drop table arrays_with_repeating_elements; statement ok drop table nested_arrays_with_repeating_elements; + +statement ok +drop table flatten_table; diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index f239155a9218..703d41cbee3d 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -165,6 +165,8 @@ pub enum BuiltinScalarFunction { Cardinality, /// construct an array from columns MakeArray, + /// Flatten + Flatten, // struct functions /// struct @@ -368,6 +370,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrayReplace => Volatility::Immutable, BuiltinScalarFunction::ArrayReplaceN => Volatility::Immutable, BuiltinScalarFunction::ArrayReplaceAll => Volatility::Immutable, + BuiltinScalarFunction::Flatten => Volatility::Immutable, BuiltinScalarFunction::ArraySlice => Volatility::Immutable, BuiltinScalarFunction::ArrayToString => Volatility::Immutable, BuiltinScalarFunction::Cardinality => Volatility::Immutable, @@ -501,6 +504,22 @@ impl BuiltinScalarFunction { // the return type of the built in function. // Some built-in functions' return type depends on the incoming type. match self { + BuiltinScalarFunction::Flatten => { + fn get_base_type(data_type: &DataType) -> Result { + match data_type { + DataType::List(field) => match field.data_type() { + DataType::List(_) => get_base_type(field.data_type()), + _ => Ok(data_type.to_owned()), + }, + _ => Err(DataFusionError::Internal( + "Not reachable, data_type should be List".to_string(), + )), + } + } + + let data_type = get_base_type(&input_expr_types[0])?; + Ok(data_type) + } BuiltinScalarFunction::ArrayAppend => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::ArrayConcat => { let mut expr_type = Null; @@ -819,11 +838,12 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrayConcat => { Signature::variadic_any(self.volatility()) } + BuiltinScalarFunction::ArrayDims => Signature::any(1, self.volatility()), + BuiltinScalarFunction::ArrayElement => Signature::any(2, self.volatility()), + BuiltinScalarFunction::Flatten => Signature::any(1, self.volatility()), BuiltinScalarFunction::ArrayHasAll | BuiltinScalarFunction::ArrayHasAny | BuiltinScalarFunction::ArrayHas => Signature::any(2, self.volatility()), - BuiltinScalarFunction::ArrayDims => Signature::any(1, self.volatility()), - BuiltinScalarFunction::ArrayElement => Signature::any(2, self.volatility()), BuiltinScalarFunction::ArrayLength => { Signature::variadic_any(self.volatility()) } @@ -1307,6 +1327,7 @@ fn aliases(func: &BuiltinScalarFunction) -> &'static [&'static str] { "list_element", "list_extract", ], + BuiltinScalarFunction::Flatten => &["flatten"], BuiltinScalarFunction::ArrayHasAll => &["array_has_all", "list_has_all"], BuiltinScalarFunction::ArrayHasAny => &["array_has_any", "list_has_any"], BuiltinScalarFunction::ArrayHas => { diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index ef6ce8171153..47767c23b363 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -564,6 +564,12 @@ scalar_expr!( first_array second_array, "Returns true if at least one element of the second array appears in the first array; otherwise, it returns false." ); +scalar_expr!( + Flatten, + flatten, + array, + "flattens an array of arrays into a single array." +); scalar_expr!( ArrayDims, array_dims, diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 1d26485b4e03..d7bc86158b69 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -92,6 +92,7 @@ impl ExprSchemable for Expr { .iter() .map(|e| e.get_type(schema)) .collect::>>()?; + fun.return_type(&data_types) } Expr::WindowFunction(WindowFunction { fun, args, .. }) => { diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index fcd9adf19dee..819b389cff1d 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -1789,6 +1789,53 @@ pub fn cardinality(args: &[ArrayRef]) -> Result { Ok(Arc::new(result) as ArrayRef) } +// Create new offsets that are euqiavlent to `flatten` the array. +fn get_offsets_for_flatten( + offsets: OffsetBuffer, + indexes: OffsetBuffer, +) -> OffsetBuffer { + let buffer = offsets.into_inner(); + let offsets: Vec = indexes.iter().map(|i| buffer[*i as usize]).collect(); + OffsetBuffer::new(offsets.into()) +} + +fn flatten_internal( + array: &dyn Array, + indexes: Option>, +) -> Result { + let list_arr = as_list_array(array)?; + let (field, offsets, values, nulls) = list_arr.clone().into_parts(); + let data_type = field.data_type(); + + match data_type { + // Recursively get the base offsets for flattened array + DataType::List(_) => { + if let Some(indexes) = indexes { + let offsets = get_offsets_for_flatten(offsets, indexes); + flatten_internal(&values, Some(offsets)) + } else { + flatten_internal(&values, Some(offsets)) + } + } + // Reach the base level, create a new list array + _ => { + if let Some(indexes) = indexes { + let offsets = get_offsets_for_flatten(offsets, indexes); + let list_arr = ListArray::new(field, offsets, values, nulls); + Ok(list_arr) + } else { + Ok(list_arr.clone()) + } + } + } +} + +/// Flatten SQL function +pub fn flatten(args: &[ArrayRef]) -> Result { + let flattened_array = flatten_internal(&args[0], None)?; + Ok(Arc::new(flattened_array) as ArrayRef) +} + /// Array_length SQL function pub fn array_length(args: &[ArrayRef]) -> Result { let list_array = as_list_array(&args[0])?; diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index df76d55bfcaa..d1a5119ee8a3 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -437,6 +437,10 @@ pub fn create_physical_fun( BuiltinScalarFunction::ArrayLength => { Arc::new(|args| make_scalar_function(array_expressions::array_length)(args)) } + BuiltinScalarFunction::Flatten => { + Arc::new(|args| make_scalar_function(array_expressions::flatten)(args)) + } + BuiltinScalarFunction::ArrayNdims => { Arc::new(|args| make_scalar_function(array_expressions::array_ndims)(args)) } diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 1081fca2e1fb..1a8ad093b965 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -594,6 +594,7 @@ enum ScalarFunction { ArrayRemoveAll = 109; ArrayReplaceAll = 110; Nanvl = 111; + Flatten = 112; } message ScalarFunctionNode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 8691487c7282..a33d80be9ddb 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -18944,6 +18944,7 @@ impl serde::Serialize for ScalarFunction { Self::ArrayRemoveAll => "ArrayRemoveAll", Self::ArrayReplaceAll => "ArrayReplaceAll", Self::Nanvl => "Nanvl", + Self::Flatten => "Flatten", }; serializer.serialize_str(variant) } @@ -19067,6 +19068,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ArrayRemoveAll", "ArrayReplaceAll", "Nanvl", + "Flatten", ]; struct GeneratedVisitor; @@ -19221,6 +19223,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ArrayRemoveAll" => Ok(ScalarFunction::ArrayRemoveAll), "ArrayReplaceAll" => Ok(ScalarFunction::ArrayReplaceAll), "Nanvl" => Ok(ScalarFunction::Nanvl), + "Flatten" => Ok(ScalarFunction::Flatten), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 87371ba2772c..519cd002df3f 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2374,6 +2374,7 @@ pub enum ScalarFunction { ArrayRemoveAll = 109, ArrayReplaceAll = 110, Nanvl = 111, + Flatten = 112, } impl ScalarFunction { /// String value of the enum field names used in the ProtoBuf definition. @@ -2494,6 +2495,7 @@ impl ScalarFunction { ScalarFunction::ArrayRemoveAll => "ArrayRemoveAll", ScalarFunction::ArrayReplaceAll => "ArrayReplaceAll", ScalarFunction::Nanvl => "Nanvl", + ScalarFunction::Flatten => "Flatten", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -2611,6 +2613,7 @@ impl ScalarFunction { "ArrayRemoveAll" => Some(Self::ArrayRemoveAll), "ArrayReplaceAll" => Some(Self::ArrayReplaceAll), "Nanvl" => Some(Self::Nanvl), + "Flatten" => Some(Self::Flatten), _ => None, } } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index c17d8dbd8ca9..d2e037aa4b1b 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -457,6 +457,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::ArrayHas => Self::ArrayHas, ScalarFunction::ArrayDims => Self::ArrayDims, ScalarFunction::ArrayElement => Self::ArrayElement, + ScalarFunction::Flatten => Self::Flatten, ScalarFunction::ArrayLength => Self::ArrayLength, ScalarFunction::ArrayNdims => Self::ArrayNdims, ScalarFunction::ArrayPosition => Self::ArrayPosition, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index aa1132e8b1f6..cdb9b008036c 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1456,6 +1456,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::ArrayHas => Self::ArrayHas, BuiltinScalarFunction::ArrayDims => Self::ArrayDims, BuiltinScalarFunction::ArrayElement => Self::ArrayElement, + BuiltinScalarFunction::Flatten => Self::Flatten, BuiltinScalarFunction::ArrayLength => Self::ArrayLength, BuiltinScalarFunction::ArrayNdims => Self::ArrayNdims, BuiltinScalarFunction::ArrayPosition => Self::ArrayPosition, diff --git a/docs/source/user-guide/expressions.md b/docs/source/user-guide/expressions.md index a04f43fd4b2b..88a5a73a6df3 100644 --- a/docs/source/user-guide/expressions.md +++ b/docs/source/user-guide/expressions.md @@ -188,6 +188,7 @@ Unlike to some databases the math functions in Datafusion works the same way as | array_has_any(array, sub-array) | Returns true if any elements exist in both arrays `array_has_any([1,2,3], [1,4]) -> true` | | array_dims(array) | Returns an array of the array's dimensions. `array_dims([[1, 2, 3], [4, 5, 6]]) -> [2, 3]` | | array_element(array, index) | Extracts the element with the index n from the array `array_element([1, 2, 3, 4], 3) -> 3` | +| flatten(array) | Converts an array of arrays to a flat array `flatten([[1], [2, 3], [4, 5, 6]]) -> [1, 2, 3, 4, 5, 6]` | | array_length(array, dimension) | Returns the length of the array dimension. `array_length([1, 2, 3, 4, 5]) -> 5` | | array_ndims(array) | Returns the number of dimensions of the array. `array_ndims([[1, 2, 3], [4, 5, 6]]) -> 2` | | array_position(array, element) | Searches for an element in the array, returns first occurrence. `array_position([1, 2, 2, 3, 4], 2) -> 2` | diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index dec120db18c5..9bcf2ae0b09b 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -1685,6 +1685,24 @@ array_fill(element, array) Can be a constant, column, or function, and any combination of array operators. - **element**: Element to copy to the array. +### `flatten` + +Converts an array of arrays to a flat array + +- Applies to any depth of nested arrays +- Does not change arrays that are already flat + +The flattened array contains all the elements from all source arrays. + +#### Arguments + +- **array**: Array expression + Can be a constant, column, or function, and any combination of array operators. + +``` +flatten(array) +``` + ### `array_indexof` _Alias of [array_position](#array_position)._