Skip to content

Commit 7fbc134

Browse files
Patch for PR 12586 (#12976)
1 parent 81b93e5 commit 7fbc134

File tree

2 files changed

+175
-13
lines changed
  • datafusion/physical-plan/src/aggregates

2 files changed

+175
-13
lines changed

datafusion/physical-plan/src/aggregates/group_values/row.rs

Lines changed: 49 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,14 @@ use ahash::RandomState;
2020
use arrow::compute::cast;
2121
use arrow::record_batch::RecordBatch;
2222
use arrow::row::{RowConverter, Rows, SortField};
23-
use arrow_array::{Array, ArrayRef};
23+
use arrow_array::{Array, ArrayRef, ListArray, StructArray};
2424
use arrow_schema::{DataType, SchemaRef};
2525
use datafusion_common::hash_utils::create_hashes;
26-
use datafusion_common::{DataFusionError, Result};
26+
use datafusion_common::Result;
2727
use datafusion_execution::memory_pool::proxy::{RawTableAllocExt, VecAllocExt};
2828
use datafusion_expr::EmitTo;
2929
use hashbrown::raw::RawTable;
30+
use std::sync::Arc;
3031

3132
/// A [`GroupValues`] making use of [`Rows`]
3233
pub struct GroupValuesRows {
@@ -221,15 +222,10 @@ impl GroupValues for GroupValuesRows {
221222
// TODO: Materialize dictionaries in group keys (#7647)
222223
for (field, array) in self.schema.fields.iter().zip(&mut output) {
223224
let expected = field.data_type();
224-
if let DataType::Dictionary(_, v) = expected {
225-
let actual = array.data_type();
226-
if v.as_ref() != actual {
227-
return Err(DataFusionError::Internal(format!(
228-
"Converted group rows expected dictionary of {v} got {actual}"
229-
)));
230-
}
231-
*array = cast(array.as_ref(), expected)?;
232-
}
225+
*array = dictionary_encode_if_necessary(
226+
Arc::<dyn arrow_array::Array>::clone(array),
227+
expected,
228+
)?;
233229
}
234230

235231
self.group_values = Some(group_values);
@@ -249,3 +245,45 @@ impl GroupValues for GroupValuesRows {
249245
self.hashes_buffer.shrink_to(count);
250246
}
251247
}
248+
249+
fn dictionary_encode_if_necessary(
250+
array: ArrayRef,
251+
expected: &DataType,
252+
) -> Result<ArrayRef> {
253+
match (expected, array.data_type()) {
254+
(DataType::Struct(expected_fields), _) => {
255+
let struct_array = array.as_any().downcast_ref::<StructArray>().unwrap();
256+
let arrays = expected_fields
257+
.iter()
258+
.zip(struct_array.columns())
259+
.map(|(expected_field, column)| {
260+
dictionary_encode_if_necessary(
261+
Arc::<dyn arrow_array::Array>::clone(column),
262+
expected_field.data_type(),
263+
)
264+
})
265+
.collect::<Result<Vec<_>>>()?;
266+
267+
Ok(Arc::new(StructArray::try_new(
268+
expected_fields.clone(),
269+
arrays,
270+
struct_array.nulls().cloned(),
271+
)?))
272+
}
273+
(DataType::List(expected_field), &DataType::List(_)) => {
274+
let list = array.as_any().downcast_ref::<ListArray>().unwrap();
275+
276+
Ok(Arc::new(ListArray::try_new(
277+
Arc::<arrow_schema::Field>::clone(expected_field),
278+
list.offsets().clone(),
279+
dictionary_encode_if_necessary(
280+
Arc::<dyn arrow_array::Array>::clone(list.values()),
281+
expected_field.data_type(),
282+
)?,
283+
list.nulls().cloned(),
284+
)?))
285+
}
286+
(DataType::Dictionary(_, _), _) => Ok(cast(array.as_ref(), expected)?),
287+
(_, _) => Ok(Arc::<dyn arrow_array::Array>::clone(&array)),
288+
}
289+
}

datafusion/physical-plan/src/aggregates/mod.rs

Lines changed: 126 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1200,8 +1200,10 @@ mod tests {
12001200

12011201
use arrow::array::{Float64Array, UInt32Array};
12021202
use arrow::compute::{concat_batches, SortOptions};
1203-
use arrow::datatypes::DataType;
1204-
use arrow_array::{Float32Array, Int32Array};
1203+
use arrow::datatypes::{DataType, Int32Type};
1204+
use arrow_array::{
1205+
DictionaryArray, Float32Array, Int32Array, StructArray, UInt64Array,
1206+
};
12051207
use datafusion_common::{
12061208
assert_batches_eq, assert_batches_sorted_eq, internal_err, DataFusionError,
12071209
ScalarValue,
@@ -1214,6 +1216,7 @@ mod tests {
12141216
use datafusion_functions_aggregate::count::count_udaf;
12151217
use datafusion_functions_aggregate::first_last::{first_value_udaf, last_value_udaf};
12161218
use datafusion_functions_aggregate::median::median_udaf;
1219+
use datafusion_functions_aggregate::sum::sum_udaf;
12171220
use datafusion_physical_expr::expressions::lit;
12181221
use datafusion_physical_expr::PhysicalSortExpr;
12191222

@@ -2316,6 +2319,127 @@ mod tests {
23162319
Ok(())
23172320
}
23182321

2322+
#[tokio::test]
2323+
async fn test_agg_exec_struct_of_dicts() -> Result<()> {
2324+
let batch = RecordBatch::try_new(
2325+
Arc::new(Schema::new(vec![
2326+
Field::new(
2327+
"labels".to_string(),
2328+
DataType::Struct(
2329+
vec![
2330+
Field::new_dict(
2331+
"a".to_string(),
2332+
DataType::Dictionary(
2333+
Box::new(DataType::Int32),
2334+
Box::new(DataType::Utf8),
2335+
),
2336+
true,
2337+
0,
2338+
false,
2339+
),
2340+
Field::new_dict(
2341+
"b".to_string(),
2342+
DataType::Dictionary(
2343+
Box::new(DataType::Int32),
2344+
Box::new(DataType::Utf8),
2345+
),
2346+
true,
2347+
0,
2348+
false,
2349+
),
2350+
]
2351+
.into(),
2352+
),
2353+
false,
2354+
),
2355+
Field::new("value", DataType::UInt64, false),
2356+
])),
2357+
vec![
2358+
Arc::new(StructArray::from(vec![
2359+
(
2360+
Arc::new(Field::new_dict(
2361+
"a".to_string(),
2362+
DataType::Dictionary(
2363+
Box::new(DataType::Int32),
2364+
Box::new(DataType::Utf8),
2365+
),
2366+
true,
2367+
0,
2368+
false,
2369+
)),
2370+
Arc::new(
2371+
vec![Some("a"), None, Some("a")]
2372+
.into_iter()
2373+
.collect::<DictionaryArray<Int32Type>>(),
2374+
) as ArrayRef,
2375+
),
2376+
(
2377+
Arc::new(Field::new_dict(
2378+
"b".to_string(),
2379+
DataType::Dictionary(
2380+
Box::new(DataType::Int32),
2381+
Box::new(DataType::Utf8),
2382+
),
2383+
true,
2384+
0,
2385+
false,
2386+
)),
2387+
Arc::new(
2388+
vec![Some("b"), Some("c"), Some("b")]
2389+
.into_iter()
2390+
.collect::<DictionaryArray<Int32Type>>(),
2391+
) as ArrayRef,
2392+
),
2393+
])),
2394+
Arc::new(UInt64Array::from(vec![1, 1, 1])),
2395+
],
2396+
)
2397+
.expect("Failed to create RecordBatch");
2398+
2399+
let group_by = PhysicalGroupBy::new_single(vec![(
2400+
col("labels", &batch.schema())?,
2401+
"labels".to_string(),
2402+
)]);
2403+
2404+
let aggr_expr = vec![AggregateExprBuilder::new(
2405+
sum_udaf(),
2406+
vec![col("value", &batch.schema())?],
2407+
)
2408+
.schema(Arc::clone(&batch.schema()))
2409+
.alias(String::from("SUM(value)"))
2410+
.build()?];
2411+
2412+
let input = Arc::new(MemoryExec::try_new(
2413+
&[vec![batch.clone()]],
2414+
Arc::<arrow_schema::Schema>::clone(&batch.schema()),
2415+
None,
2416+
)?);
2417+
let aggregate_exec = Arc::new(AggregateExec::try_new(
2418+
AggregateMode::FinalPartitioned,
2419+
group_by,
2420+
aggr_expr,
2421+
vec![None],
2422+
Arc::clone(&input) as Arc<dyn ExecutionPlan>,
2423+
batch.schema(),
2424+
)?);
2425+
2426+
let session_config = SessionConfig::default();
2427+
let ctx = TaskContext::default().with_session_config(session_config);
2428+
let output = collect(aggregate_exec.execute(0, Arc::new(ctx))?).await?;
2429+
2430+
let expected = [
2431+
"+--------------+------------+",
2432+
"| labels | SUM(value) |",
2433+
"+--------------+------------+",
2434+
"| {a: a, b: b} | 2 |",
2435+
"| {a: , b: c} | 1 |",
2436+
"+--------------+------------+",
2437+
];
2438+
assert_batches_eq!(expected, &output);
2439+
2440+
Ok(())
2441+
}
2442+
23192443
#[tokio::test]
23202444
async fn test_skip_aggregation_after_first_batch() -> Result<()> {
23212445
let schema = Arc::new(Schema::new(vec![

0 commit comments

Comments
 (0)