From f9c4c6ad738a366c26572775282659b387a95177 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 9 Oct 2024 11:41:42 -0400 Subject: [PATCH 1/2] Fix convert_to_state bug --- .../core/tests/fuzz_cases/aggregate_fuzz.rs | 18 ++++++------------ .../src/aggregate/groups_accumulator.rs | 13 +++++++++++++ 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index 5cc5157c3af9..b0852501415e 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -254,10 +254,8 @@ async fn test_basic_string_aggr_group_by_single_int64() { let fuzzer = builder .data_gen_config(data_gen_config) .data_gen_rounds(8) - // FIXME: Encounter error in min/max - // ArrowError(InvalidArgumentError("number of columns(1) must match number of fields(2) in schema")) - // .add_sql("SELECT b, max(a) FROM fuzz_table GROUP BY b") - // .add_sql("SELECT b, min(a) FROM fuzz_table GROUP BY b") + .add_sql("SELECT b, max(a) FROM fuzz_table GROUP BY b") + .add_sql("SELECT b, min(a) FROM fuzz_table GROUP BY b") .add_sql("SELECT b, count(a) FROM fuzz_table GROUP BY b") .add_sql("SELECT b, count(distinct a) FROM fuzz_table GROUP BY b") .table_name("fuzz_table") @@ -291,10 +289,8 @@ async fn test_basic_string_aggr_group_by_single_string() { let fuzzer = builder .data_gen_config(data_gen_config) .data_gen_rounds(16) - // FIXME: Encounter error in min/max - // ArrowError(InvalidArgumentError("number of columns(1) must match number of fields(2) in schema")) - // .add_sql("SELECT b, max(a) FROM fuzz_table GROUP BY b") - // .add_sql("SELECT b, min(a) FROM fuzz_table GROUP BY b") + .add_sql("SELECT b, max(a) FROM fuzz_table GROUP BY b") + .add_sql("SELECT b, min(a) FROM fuzz_table GROUP BY b") .add_sql("SELECT b, count(a) FROM fuzz_table GROUP BY b") .add_sql("SELECT b, count(distinct a) FROM fuzz_table GROUP BY b") .table_name("fuzz_table") @@ -329,10 +325,8 @@ async fn test_basic_string_aggr_group_by_mixed_string_int64() { let fuzzer = builder .data_gen_config(data_gen_config) .data_gen_rounds(16) - // FIXME: Encounter error in min/max - // ArrowError(InvalidArgumentError("number of columns(1) must match number of fields(2) in schema")) - // .add_sql("SELECT b, c, max(a) FROM fuzz_table GROUP BY b, c") - // .add_sql("SELECT b, c, min(a) FROM fuzz_table GROUP BY b, c") + .add_sql("SELECT b, c, max(a) FROM fuzz_table GROUP BY b, c") + .add_sql("SELECT b, c, min(a) FROM fuzz_table GROUP BY b, c") .add_sql("SELECT b, c, count(a) FROM fuzz_table GROUP BY b, c") .add_sql("SELECT b, c, count(distinct a) FROM fuzz_table GROUP BY b, c") .table_name("fuzz_table") diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs index fbbf4d303515..d6b535375181 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs @@ -28,6 +28,7 @@ use arrow::{ compute, datatypes::UInt32Type, }; +use arrow::array::new_empty_array; use datafusion_common::{ arrow_datafusion_err, utils::take_arrays, DataFusionError, Result, ScalarValue, }; @@ -405,6 +406,18 @@ impl GroupsAccumulator for GroupsAccumulatorAdapter { ) -> Result> { let num_rows = values[0].len(); + // If there are no rows, return empty arrays + if num_rows == 0 { + // create empty accumulator to get the state types + let empty_state = (self.factory)()?.state()?; + let empty_arrays = empty_state + .into_iter() + .map(|state_val| new_empty_array(&state_val.data_type())) + .collect::>(); + + return Ok(empty_arrays) + } + // Each row has its respective group let mut results = vec![]; for row_idx in 0..num_rows { From 6d32043e9c32966cfb12a414073d57d5538ee1f2 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 9 Oct 2024 11:49:47 -0400 Subject: [PATCH 2/2] fmt --- .../src/aggregate/groups_accumulator.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs index d6b535375181..b03df0224089 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs @@ -23,12 +23,12 @@ pub mod bool_op; pub mod nulls; pub mod prim_op; +use arrow::array::new_empty_array; use arrow::{ array::{ArrayRef, AsArray, BooleanArray, PrimitiveArray}, compute, datatypes::UInt32Type, }; -use arrow::array::new_empty_array; use datafusion_common::{ arrow_datafusion_err, utils::take_arrays, DataFusionError, Result, ScalarValue, }; @@ -415,7 +415,7 @@ impl GroupsAccumulator for GroupsAccumulatorAdapter { .map(|state_val| new_empty_array(&state_val.data_type())) .collect::>(); - return Ok(empty_arrays) + return Ok(empty_arrays); } // Each row has its respective group