From 8746e073b7a7ffcf86fac19d5ea1984ee6970d20 Mon Sep 17 00:00:00 2001 From: Dharan Aditya Date: Sun, 1 Sep 2024 14:10:10 +0530 Subject: [PATCH] Support `map_keys` & `map_values` for MAP type (#12194) * impl map_keys * rename field name * add logic tests * one more * owned to clone * more tests * typo * impl * add logic tests * chore * add docs * trying to make prettier happy * Update scalar_functions.md Co-authored-by: Alex Huang * reface signature * format docs * Update map_values.rs Co-authored-by: Alex Huang --------- Co-authored-by: Alex Huang --- datafusion/common/src/utils/mod.rs | 17 +-- datafusion/functions-nested/src/lib.rs | 6 ++ .../functions-nested/src/map_extract.rs | 3 +- datafusion/functions-nested/src/map_keys.rs | 102 ++++++++++++++++++ datafusion/functions-nested/src/map_values.rs | 102 ++++++++++++++++++ datafusion/functions-nested/src/utils.rs | 19 +++- datafusion/sqllogictest/test_files/map.slt | 97 ++++++++++++++++- .../source/user-guide/sql/scalar_functions.md | 52 +++++++++ 8 files changed, 377 insertions(+), 21 deletions(-) create mode 100644 datafusion/functions-nested/src/map_keys.rs create mode 100644 datafusion/functions-nested/src/map_values.rs diff --git a/datafusion/common/src/utils/mod.rs b/datafusion/common/src/utils/mod.rs index 839f890bf077..418ea380bc2c 100644 --- a/datafusion/common/src/utils/mod.rs +++ b/datafusion/common/src/utils/mod.rs @@ -35,7 +35,7 @@ use arrow_array::{ Array, FixedSizeListArray, LargeListArray, ListArray, OffsetSizeTrait, RecordBatchOptions, }; -use arrow_schema::{DataType, Fields}; +use arrow_schema::DataType; use sqlparser::ast::Ident; use sqlparser::dialect::GenericDialect; use sqlparser::parser::Parser; @@ -754,21 +754,6 @@ pub fn combine_limit( (combined_skip, combined_fetch) } -pub fn get_map_entry_field(data_type: &DataType) -> Result<&Fields> { - match data_type { - DataType::Map(field, _) => { - let field_data_type = field.data_type(); - match field_data_type { - DataType::Struct(fields) => Ok(fields), - _ => { - _internal_err!("Expected a Struct type, got {:?}", field_data_type) - } - } - } - _ => _internal_err!("Expected a Map type, got {:?}", data_type), - } -} - #[cfg(test)] mod tests { use crate::ScalarValue::Null; diff --git a/datafusion/functions-nested/src/lib.rs b/datafusion/functions-nested/src/lib.rs index 863b5a876adc..b548cf6db8b1 100644 --- a/datafusion/functions-nested/src/lib.rs +++ b/datafusion/functions-nested/src/lib.rs @@ -44,6 +44,8 @@ pub mod length; pub mod make_array; pub mod map; pub mod map_extract; +pub mod map_keys; +pub mod map_values; pub mod planner; pub mod position; pub mod range; @@ -85,6 +87,8 @@ pub mod expr_fn { pub use super::length::array_length; pub use super::make_array::make_array; pub use super::map_extract::map_extract; + pub use super::map_keys::map_keys; + pub use super::map_values::map_values; pub use super::position::array_position; pub use super::position::array_positions; pub use super::range::gen_series; @@ -149,6 +153,8 @@ pub fn all_default_nested_functions() -> Vec> { replace::array_replace_udf(), map::map_udf(), map_extract::map_extract_udf(), + map_keys::map_keys_udf(), + map_values::map_values_udf(), ] } diff --git a/datafusion/functions-nested/src/map_extract.rs b/datafusion/functions-nested/src/map_extract.rs index 82f0d8d6c15e..9f0c4ad29c60 100644 --- a/datafusion/functions-nested/src/map_extract.rs +++ b/datafusion/functions-nested/src/map_extract.rs @@ -24,7 +24,6 @@ use arrow::datatypes::DataType; use arrow_array::{Array, MapArray}; use arrow_buffer::OffsetBuffer; use arrow_schema::Field; -use datafusion_common::utils::get_map_entry_field; use datafusion_common::{cast::as_map_array, exec_err, Result}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; @@ -32,7 +31,7 @@ use std::any::Any; use std::sync::Arc; use std::vec; -use crate::utils::make_scalar_function; +use crate::utils::{get_map_entry_field, make_scalar_function}; // Create static instances of ScalarUDFs for each function make_udf_expr_and_func!( diff --git a/datafusion/functions-nested/src/map_keys.rs b/datafusion/functions-nested/src/map_keys.rs new file mode 100644 index 000000000000..0b1cebb27c86 --- /dev/null +++ b/datafusion/functions-nested/src/map_keys.rs @@ -0,0 +1,102 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ScalarUDFImpl`] definitions for map_keys function. + +use crate::utils::{get_map_entry_field, make_scalar_function}; +use arrow_array::{Array, ArrayRef, ListArray}; +use arrow_schema::{DataType, Field}; +use datafusion_common::{cast::as_map_array, exec_err, Result}; +use datafusion_expr::{ + ArrayFunctionSignature, ColumnarValue, ScalarUDFImpl, Signature, TypeSignature, + Volatility, +}; +use std::any::Any; +use std::sync::Arc; + +make_udf_expr_and_func!( + MapKeysFunc, + map_keys, + map, + "Return a list of all keys in the map.", + map_keys_udf +); + +#[derive(Debug)] +pub(crate) struct MapKeysFunc { + signature: Signature, +} + +impl MapKeysFunc { + pub fn new() -> Self { + Self { + signature: Signature::new( + TypeSignature::ArraySignature(ArrayFunctionSignature::MapArray), + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for MapKeysFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "map_keys" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result { + if arg_types.len() != 1 { + return exec_err!("map_keys expects single argument"); + } + let map_type = &arg_types[0]; + let map_fields = get_map_entry_field(map_type)?; + Ok(DataType::List(Arc::new(Field::new( + "item", + map_fields.first().unwrap().data_type().clone(), + false, + )))) + } + + fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result { + make_scalar_function(map_keys_inner)(args) + } +} + +fn map_keys_inner(args: &[ArrayRef]) -> Result { + if args.len() != 1 { + return exec_err!("map_keys expects single argument"); + } + + let map_array = match args[0].data_type() { + DataType::Map(_, _) => as_map_array(&args[0])?, + _ => return exec_err!("Argument for map_keys should be a map"), + }; + + Ok(Arc::new(ListArray::new( + Arc::new(Field::new("item", map_array.key_type().clone(), false)), + map_array.offsets().clone(), + Arc::clone(map_array.keys()), + None, + ))) +} diff --git a/datafusion/functions-nested/src/map_values.rs b/datafusion/functions-nested/src/map_values.rs new file mode 100644 index 000000000000..58c0d74eed5f --- /dev/null +++ b/datafusion/functions-nested/src/map_values.rs @@ -0,0 +1,102 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ScalarUDFImpl`] definitions for map_values function. + +use crate::utils::{get_map_entry_field, make_scalar_function}; +use arrow_array::{Array, ArrayRef, ListArray}; +use arrow_schema::{DataType, Field}; +use datafusion_common::{cast::as_map_array, exec_err, Result}; +use datafusion_expr::{ + ArrayFunctionSignature, ColumnarValue, ScalarUDFImpl, Signature, TypeSignature, + Volatility, +}; +use std::any::Any; +use std::sync::Arc; + +make_udf_expr_and_func!( + MapValuesFunc, + map_values, + map, + "Return a list of all values in the map.", + map_values_udf +); + +#[derive(Debug)] +pub(crate) struct MapValuesFunc { + signature: Signature, +} + +impl MapValuesFunc { + pub fn new() -> Self { + Self { + signature: Signature::new( + TypeSignature::ArraySignature(ArrayFunctionSignature::MapArray), + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for MapValuesFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "map_values" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result { + if arg_types.len() != 1 { + return exec_err!("map_values expects single argument"); + } + let map_type = &arg_types[0]; + let map_fields = get_map_entry_field(map_type)?; + Ok(DataType::List(Arc::new(Field::new( + "item", + map_fields.last().unwrap().data_type().clone(), + true, + )))) + } + + fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result { + make_scalar_function(map_values_inner)(args) + } +} + +fn map_values_inner(args: &[ArrayRef]) -> Result { + if args.len() != 1 { + return exec_err!("map_values expects single argument"); + } + + let map_array = match args[0].data_type() { + DataType::Map(_, _) => as_map_array(&args[0])?, + _ => return exec_err!("Argument for map_values should be a map"), + }; + + Ok(Arc::new(ListArray::new( + Arc::new(Field::new("item", map_array.value_type().clone(), true)), + map_array.offsets().clone(), + Arc::clone(map_array.values()), + None, + ))) +} diff --git a/datafusion/functions-nested/src/utils.rs b/datafusion/functions-nested/src/utils.rs index 3d5b261618d5..0765f6cd237d 100644 --- a/datafusion/functions-nested/src/utils.rs +++ b/datafusion/functions-nested/src/utils.rs @@ -26,9 +26,9 @@ use arrow_array::{ UInt32Array, }; use arrow_buffer::OffsetBuffer; -use arrow_schema::Field; +use arrow_schema::{Field, Fields}; use datafusion_common::cast::{as_large_list_array, as_list_array}; -use datafusion_common::{exec_err, plan_err, Result, ScalarValue}; +use datafusion_common::{exec_err, internal_err, plan_err, Result, ScalarValue}; use core::any::type_name; use datafusion_common::DataFusionError; @@ -253,6 +253,21 @@ pub(crate) fn compute_array_dims( } } +pub(crate) fn get_map_entry_field(data_type: &DataType) -> Result<&Fields> { + match data_type { + DataType::Map(field, _) => { + let field_data_type = field.data_type(); + match field_data_type { + DataType::Struct(fields) => Ok(fields), + _ => { + internal_err!("Expected a Struct type, got {:?}", field_data_type) + } + } + } + _ => internal_err!("Expected a Map type, got {:?}", data_type), + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/sqllogictest/test_files/map.slt b/datafusion/sqllogictest/test_files/map.slt index 270e4beccc52..c66334c4de2a 100644 --- a/datafusion/sqllogictest/test_files/map.slt +++ b/datafusion/sqllogictest/test_files/map.slt @@ -568,8 +568,103 @@ select map_extract(column1, 1), map_extract(column1, 5), map_extract(column1, 7) [] [[4, , 6]] [] [] [] [[1, , 3]] +# Tests for map_keys + +query ? +SELECT map_keys(MAP { 'a': 1, 2: 3 }); +---- +[a, 2] + +query ? +SELECT map_keys(MAP {'a':1, 'b':2, 'c':3 }) FROM t; +---- +[a, b, c] +[a, b, c] +[a, b, c] + +query ? +SELECT map_keys(Map{column1: column2, column3: column4}) FROM t; +---- +[a, k1] +[b, k3] +[d, k5] + +query ? +SELECT map_keys(map(column5, column6)) FROM t; +---- +[k1, k2] +[k3] +[k5] + +query ? +SELECT map_keys(map(column8, column9)) FROM t; +---- +[[1, 2, 3]] +[[4]] +[[1, 2]] + +query ? +SELECT map_keys(Map{}); +---- +[] + +query ? +SELECT map_keys(column1) from map_array_table_1; +---- +[1, 2, 3] +[4, 5, 6] +[7, 8, 9] + + +# Tests for map_values + +query ? +SELECT map_values(MAP { 'a': 1, 2: 3 }); +---- +[1, 3] + +query ? +SELECT map_values(MAP {'a':1, 'b':2, 'c':3 }) FROM t; +---- +[1, 2, 3] +[1, 2, 3] +[1, 2, 3] + +query ? +SELECT map_values(Map{column1: column2, column3: column4}) FROM t; +---- +[1, 10] +[2, 30] +[4, 50] + +query ? +SELECT map_values(map(column5, column6)) FROM t; +---- +[1, 2] +[3] +[5] + +query ? +SELECT map_values(map(column8, column9)) FROM t; +---- +[a] +[b] +[c] + +query ? +SELECT map_values(Map{}); +---- +[] + +query ? +SELECT map_values(column1) from map_array_table_1; +---- +[[1, , 3], [4, , 6], [7, 8, 9]] +[[1, , 3], [4, , 6], [7, 8, 9]] +[[1, , 3], [9, , 6], [7, 8, 9]] + statement ok drop table map_array_table_1; statement ok -drop table map_array_table_2; \ No newline at end of file +drop table map_array_table_2; diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index 47e35d2e72e3..80b61f8242ef 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -3677,6 +3677,8 @@ Unwraps struct fields into columns. - [map](#map) - [make_map](#make_map) - [map_extract](#map_extract) +- [map_keys](#map_keys) +- [map_values](#map_values) ### `map` @@ -3765,6 +3767,56 @@ SELECT map_extract(MAP {'a': 1, 'b': NULL, 'c': 3}, 'a'); - element_at +### `map_keys` + +Return a list of all keys in the map. + +``` +map_keys(map) +``` + +#### Arguments + +- `map`: Map expression. + Can be a constant, column, or function, and any combination of map operators. + +#### Example + +``` +SELECT map_keys(MAP {'a': 1, 'b': NULL, 'c': 3}); +---- +[a, b, c] + +select map_keys(map([100, 5], [42,43])); +---- +[100, 5] +``` + +### `map_values` + +Return a list of all values in the map. + +``` +map_values(map) +``` + +#### Arguments + +- `map`: Map expression. + Can be a constant, column, or function, and any combination of map operators. + +#### Example + +``` +SELECT map_values(MAP {'a': 1, 'b': NULL, 'c': 3}); +---- +[1, , 3] + +select map_values(map([100, 5], [42,43])); +---- +[42, 43] +``` + ## Hashing Functions - [digest](#digest)