From d0539272da6a517f3cc1c09eaba6af998805c384 Mon Sep 17 00:00:00 2001 From: Matthew Turner Date: Wed, 16 Oct 2024 12:30:59 -0400 Subject: [PATCH] Patch for PR 12586 --- .../src/aggregates/group_values/row.rs | 60 ++++++-- .../physical-plan/src/aggregates/mod.rs | 128 +++++++++++++++++- 2 files changed, 175 insertions(+), 13 deletions(-) diff --git a/datafusion/physical-plan/src/aggregates/group_values/row.rs b/datafusion/physical-plan/src/aggregates/group_values/row.rs index dc948e28bb2d..93a3e04a90ee 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/row.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/row.rs @@ -20,13 +20,14 @@ use ahash::RandomState; use arrow::compute::cast; use arrow::record_batch::RecordBatch; use arrow::row::{RowConverter, Rows, SortField}; -use arrow_array::{Array, ArrayRef}; +use arrow_array::{Array, ArrayRef, ListArray, StructArray}; use arrow_schema::{DataType, SchemaRef}; use datafusion_common::hash_utils::create_hashes; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::Result; use datafusion_execution::memory_pool::proxy::{RawTableAllocExt, VecAllocExt}; use datafusion_expr::EmitTo; use hashbrown::raw::RawTable; +use std::sync::Arc; /// A [`GroupValues`] making use of [`Rows`] pub struct GroupValuesRows { @@ -221,15 +222,10 @@ impl GroupValues for GroupValuesRows { // TODO: Materialize dictionaries in group keys (#7647) for (field, array) in self.schema.fields.iter().zip(&mut output) { let expected = field.data_type(); - if let DataType::Dictionary(_, v) = expected { - let actual = array.data_type(); - if v.as_ref() != actual { - return Err(DataFusionError::Internal(format!( - "Converted group rows expected dictionary of {v} got {actual}" - ))); - } - *array = cast(array.as_ref(), expected)?; - } + *array = dictionary_encode_if_necessary( + Arc::::clone(array), + expected, + )?; } self.group_values = Some(group_values); @@ -249,3 +245,45 @@ impl GroupValues for GroupValuesRows { self.hashes_buffer.shrink_to(count); } } + +fn dictionary_encode_if_necessary( + array: ArrayRef, + expected: &DataType, +) -> Result { + match (expected, array.data_type()) { + (DataType::Struct(expected_fields), _) => { + let struct_array = array.as_any().downcast_ref::().unwrap(); + let arrays = expected_fields + .iter() + .zip(struct_array.columns()) + .map(|(expected_field, column)| { + dictionary_encode_if_necessary( + Arc::::clone(column), + expected_field.data_type(), + ) + }) + .collect::>>()?; + + Ok(Arc::new(StructArray::try_new( + expected_fields.clone(), + arrays, + struct_array.nulls().cloned(), + )?)) + } + (DataType::List(expected_field), &DataType::List(_)) => { + let list = array.as_any().downcast_ref::().unwrap(); + + Ok(Arc::new(ListArray::try_new( + Arc::::clone(expected_field), + list.offsets().clone(), + dictionary_encode_if_necessary( + Arc::::clone(list.values()), + expected_field.data_type(), + )?, + list.nulls().cloned(), + )?)) + } + (DataType::Dictionary(_, _), _) => Ok(cast(array.as_ref(), expected)?), + (_, _) => Ok(Arc::::clone(&array)), + } +} diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index c3bc7b042e65..617f1da3abdd 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -1200,8 +1200,10 @@ mod tests { use arrow::array::{Float64Array, UInt32Array}; use arrow::compute::{concat_batches, SortOptions}; - use arrow::datatypes::DataType; - use arrow_array::{Float32Array, Int32Array}; + use arrow::datatypes::{DataType, Int32Type}; + use arrow_array::{ + DictionaryArray, Float32Array, Int32Array, StructArray, UInt64Array, + }; use datafusion_common::{ assert_batches_eq, assert_batches_sorted_eq, internal_err, DataFusionError, ScalarValue, @@ -1214,6 +1216,7 @@ mod tests { use datafusion_functions_aggregate::count::count_udaf; use datafusion_functions_aggregate::first_last::{first_value_udaf, last_value_udaf}; use datafusion_functions_aggregate::median::median_udaf; + use datafusion_functions_aggregate::sum::sum_udaf; use datafusion_physical_expr::expressions::lit; use datafusion_physical_expr::PhysicalSortExpr; @@ -2316,6 +2319,127 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_agg_exec_struct_of_dicts() -> Result<()> { + let batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![ + Field::new( + "labels".to_string(), + DataType::Struct( + vec![ + Field::new_dict( + "a".to_string(), + DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::Utf8), + ), + true, + 0, + false, + ), + Field::new_dict( + "b".to_string(), + DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::Utf8), + ), + true, + 0, + false, + ), + ] + .into(), + ), + false, + ), + Field::new("value", DataType::UInt64, false), + ])), + vec![ + Arc::new(StructArray::from(vec![ + ( + Arc::new(Field::new_dict( + "a".to_string(), + DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::Utf8), + ), + true, + 0, + false, + )), + Arc::new( + vec![Some("a"), None, Some("a")] + .into_iter() + .collect::>(), + ) as ArrayRef, + ), + ( + Arc::new(Field::new_dict( + "b".to_string(), + DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::Utf8), + ), + true, + 0, + false, + )), + Arc::new( + vec![Some("b"), Some("c"), Some("b")] + .into_iter() + .collect::>(), + ) as ArrayRef, + ), + ])), + Arc::new(UInt64Array::from(vec![1, 1, 1])), + ], + ) + .expect("Failed to create RecordBatch"); + + let group_by = PhysicalGroupBy::new_single(vec![( + col("labels", &batch.schema())?, + "labels".to_string(), + )]); + + let aggr_expr = vec![AggregateExprBuilder::new( + sum_udaf(), + vec![col("value", &batch.schema())?], + ) + .schema(Arc::clone(&batch.schema())) + .alias(String::from("SUM(value)")) + .build()?]; + + let input = Arc::new(MemoryExec::try_new( + &[vec![batch.clone()]], + Arc::::clone(&batch.schema()), + None, + )?); + let aggregate_exec = Arc::new(AggregateExec::try_new( + AggregateMode::FinalPartitioned, + group_by, + aggr_expr, + vec![None], + Arc::clone(&input) as Arc, + batch.schema(), + )?); + + let session_config = SessionConfig::default(); + let ctx = TaskContext::default().with_session_config(session_config); + let output = collect(aggregate_exec.execute(0, Arc::new(ctx))?).await?; + + let expected = [ + "+--------------+------------+", + "| labels | SUM(value) |", + "+--------------+------------+", + "| {a: a, b: b} | 2 |", + "| {a: , b: c} | 1 |", + "+--------------+------------+", + ]; + assert_batches_eq!(expected, &output); + + Ok(()) + } + #[tokio::test] async fn test_skip_aggregation_after_first_batch() -> Result<()> { let schema = Arc::new(Schema::new(vec![