diff --git a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs index b8671c39a943..244a44acdcb5 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs @@ -153,12 +153,11 @@ impl Accumulator for DistinctArrayAggAccumulator { return Ok(()); } - let array = &states[0]; - - assert_eq!(array.len(), 1, "state array should only include 1 row!"); - // Unwrap outer ListArray then do update batch - let inner_array = array.as_list::().value(0); - self.update_batch(&[inner_array]) + states[0] + .as_list::() + .iter() + .flatten() + .try_for_each(|val| self.update_batch(&[val])) } fn evaluate(&mut self) -> Result { diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 40d66f9b52ce..78421d0b6431 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -198,6 +198,73 @@ statement error This feature is not implemented: LIMIT not supported in ARRAY_AG SELECT array_agg(c13 LIMIT 1) FROM aggregate_test_100 +# Test distinct aggregate function with merge batch +query II +with A as ( + select 1 as id, 2 as foo + UNION ALL + select 1, null + UNION ALL + select 1, null + UNION ALL + select 1, 3 + UNION ALL + select 1, 2 + ---- The order is non-deterministic, verify with length +) select array_length(array_agg(distinct a.foo)), sum(distinct 1) from A a group by a.id; +---- +3 1 + +# It has only AggregateExec with FinalPartitioned mode, so `merge_batch` is used +# If the plan is changed, whether the `merge_batch` is used should be verified to ensure the test coverage +query TT +explain with A as ( + select 1 as id, 2 as foo + UNION ALL + select 1, null + UNION ALL + select 1, null + UNION ALL + select 1, 3 + UNION ALL + select 1, 2 +) select array_length(array_agg(distinct a.foo)), sum(distinct 1) from A a group by a.id; +---- +logical_plan +01)Projection: array_length(ARRAY_AGG(DISTINCT a.foo)), SUM(DISTINCT Int64(1)) +02)--Aggregate: groupBy=[[a.id]], aggr=[[ARRAY_AGG(DISTINCT a.foo), SUM(DISTINCT Int64(1))]] +03)----SubqueryAlias: a +04)------SubqueryAlias: a +05)--------Union +06)----------Projection: Int64(1) AS id, Int64(2) AS foo +07)------------EmptyRelation +08)----------Projection: Int64(1) AS id, Int64(NULL) AS foo +09)------------EmptyRelation +10)----------Projection: Int64(1) AS id, Int64(NULL) AS foo +11)------------EmptyRelation +12)----------Projection: Int64(1) AS id, Int64(3) AS foo +13)------------EmptyRelation +14)----------Projection: Int64(1) AS id, Int64(2) AS foo +15)------------EmptyRelation +physical_plan +01)ProjectionExec: expr=[array_length(ARRAY_AGG(DISTINCT a.foo)@1) as array_length(ARRAY_AGG(DISTINCT a.foo)), SUM(DISTINCT Int64(1))@2 as SUM(DISTINCT Int64(1))] +02)--AggregateExec: mode=FinalPartitioned, gby=[id@0 as id], aggr=[ARRAY_AGG(DISTINCT a.foo), SUM(DISTINCT Int64(1))] +03)----CoalesceBatchesExec: target_batch_size=8192 +04)------RepartitionExec: partitioning=Hash([id@0], 4), input_partitions=5 +05)--------AggregateExec: mode=Partial, gby=[id@0 as id], aggr=[ARRAY_AGG(DISTINCT a.foo), SUM(DISTINCT Int64(1))] +06)----------UnionExec +07)------------ProjectionExec: expr=[1 as id, 2 as foo] +08)--------------PlaceholderRowExec +09)------------ProjectionExec: expr=[1 as id, NULL as foo] +10)--------------PlaceholderRowExec +11)------------ProjectionExec: expr=[1 as id, NULL as foo] +12)--------------PlaceholderRowExec +13)------------ProjectionExec: expr=[1 as id, 3 as foo] +14)--------------PlaceholderRowExec +15)------------ProjectionExec: expr=[1 as id, 2 as foo] +16)--------------PlaceholderRowExec + + # FIX: custom absolute values # csv_query_avg_multi_batch