Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backport "physical-plan: Cast nested group values back to dictionary if necessary" (#12586) #12976

Merged
merged 1 commit into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 49 additions & 11 deletions datafusion/physical-plan/src/aggregates/group_values/row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@ use ahash::RandomState;
use arrow::compute::cast;
use arrow::record_batch::RecordBatch;
use arrow::row::{RowConverter, Rows, SortField};
use arrow_array::{Array, ArrayRef};
use arrow_array::{Array, ArrayRef, ListArray, StructArray};
use arrow_schema::{DataType, SchemaRef};
use datafusion_common::hash_utils::create_hashes;
use datafusion_common::{DataFusionError, Result};
use datafusion_common::Result;
use datafusion_execution::memory_pool::proxy::{RawTableAllocExt, VecAllocExt};
use datafusion_expr::EmitTo;
use hashbrown::raw::RawTable;
use std::sync::Arc;

/// A [`GroupValues`] making use of [`Rows`]
pub struct GroupValuesRows {
Expand Down Expand Up @@ -221,15 +222,10 @@ impl GroupValues for GroupValuesRows {
// TODO: Materialize dictionaries in group keys (#7647)
for (field, array) in self.schema.fields.iter().zip(&mut output) {
let expected = field.data_type();
if let DataType::Dictionary(_, v) = expected {
let actual = array.data_type();
if v.as_ref() != actual {
return Err(DataFusionError::Internal(format!(
"Converted group rows expected dictionary of {v} got {actual}"
)));
}
*array = cast(array.as_ref(), expected)?;
}
*array = dictionary_encode_if_necessary(
Arc::<dyn arrow_array::Array>::clone(array),
expected,
)?;
}

self.group_values = Some(group_values);
Expand All @@ -249,3 +245,45 @@ impl GroupValues for GroupValuesRows {
self.hashes_buffer.shrink_to(count);
}
}

fn dictionary_encode_if_necessary(
array: ArrayRef,
expected: &DataType,
) -> Result<ArrayRef> {
match (expected, array.data_type()) {
(DataType::Struct(expected_fields), _) => {
let struct_array = array.as_any().downcast_ref::<StructArray>().unwrap();
let arrays = expected_fields
.iter()
.zip(struct_array.columns())
.map(|(expected_field, column)| {
dictionary_encode_if_necessary(
Arc::<dyn arrow_array::Array>::clone(column),
expected_field.data_type(),
)
})
.collect::<Result<Vec<_>>>()?;

Ok(Arc::new(StructArray::try_new(
expected_fields.clone(),
arrays,
struct_array.nulls().cloned(),
)?))
}
(DataType::List(expected_field), &DataType::List(_)) => {
let list = array.as_any().downcast_ref::<ListArray>().unwrap();

Ok(Arc::new(ListArray::try_new(
Arc::<arrow_schema::Field>::clone(expected_field),
list.offsets().clone(),
dictionary_encode_if_necessary(
Arc::<dyn arrow_array::Array>::clone(list.values()),
expected_field.data_type(),
)?,
list.nulls().cloned(),
)?))
}
(DataType::Dictionary(_, _), _) => Ok(cast(array.as_ref(), expected)?),
(_, _) => Ok(Arc::<dyn arrow_array::Array>::clone(&array)),
}
}
128 changes: 126 additions & 2 deletions datafusion/physical-plan/src/aggregates/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1200,8 +1200,10 @@ mod tests {

use arrow::array::{Float64Array, UInt32Array};
use arrow::compute::{concat_batches, SortOptions};
use arrow::datatypes::DataType;
use arrow_array::{Float32Array, Int32Array};
use arrow::datatypes::{DataType, Int32Type};
use arrow_array::{
DictionaryArray, Float32Array, Int32Array, StructArray, UInt64Array,
};
use datafusion_common::{
assert_batches_eq, assert_batches_sorted_eq, internal_err, DataFusionError,
ScalarValue,
Expand All @@ -1214,6 +1216,7 @@ mod tests {
use datafusion_functions_aggregate::count::count_udaf;
use datafusion_functions_aggregate::first_last::{first_value_udaf, last_value_udaf};
use datafusion_functions_aggregate::median::median_udaf;
use datafusion_functions_aggregate::sum::sum_udaf;
use datafusion_physical_expr::expressions::lit;
use datafusion_physical_expr::PhysicalSortExpr;

Expand Down Expand Up @@ -2316,6 +2319,127 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn test_agg_exec_struct_of_dicts() -> Result<()> {
let batch = RecordBatch::try_new(
Arc::new(Schema::new(vec![
Field::new(
"labels".to_string(),
DataType::Struct(
vec![
Field::new_dict(
"a".to_string(),
DataType::Dictionary(
Box::new(DataType::Int32),
Box::new(DataType::Utf8),
),
true,
0,
false,
),
Field::new_dict(
"b".to_string(),
DataType::Dictionary(
Box::new(DataType::Int32),
Box::new(DataType::Utf8),
),
true,
0,
false,
),
]
.into(),
),
false,
),
Field::new("value", DataType::UInt64, false),
])),
vec![
Arc::new(StructArray::from(vec![
(
Arc::new(Field::new_dict(
"a".to_string(),
DataType::Dictionary(
Box::new(DataType::Int32),
Box::new(DataType::Utf8),
),
true,
0,
false,
)),
Arc::new(
vec![Some("a"), None, Some("a")]
.into_iter()
.collect::<DictionaryArray<Int32Type>>(),
) as ArrayRef,
),
(
Arc::new(Field::new_dict(
"b".to_string(),
DataType::Dictionary(
Box::new(DataType::Int32),
Box::new(DataType::Utf8),
),
true,
0,
false,
)),
Arc::new(
vec![Some("b"), Some("c"), Some("b")]
.into_iter()
.collect::<DictionaryArray<Int32Type>>(),
) as ArrayRef,
),
])),
Arc::new(UInt64Array::from(vec![1, 1, 1])),
],
)
.expect("Failed to create RecordBatch");

let group_by = PhysicalGroupBy::new_single(vec![(
col("labels", &batch.schema())?,
"labels".to_string(),
)]);

let aggr_expr = vec![AggregateExprBuilder::new(
sum_udaf(),
vec![col("value", &batch.schema())?],
)
.schema(Arc::clone(&batch.schema()))
.alias(String::from("SUM(value)"))
.build()?];

let input = Arc::new(MemoryExec::try_new(
&[vec![batch.clone()]],
Arc::<arrow_schema::Schema>::clone(&batch.schema()),
None,
)?);
let aggregate_exec = Arc::new(AggregateExec::try_new(
AggregateMode::FinalPartitioned,
group_by,
aggr_expr,
vec![None],
Arc::clone(&input) as Arc<dyn ExecutionPlan>,
batch.schema(),
)?);

let session_config = SessionConfig::default();
let ctx = TaskContext::default().with_session_config(session_config);
let output = collect(aggregate_exec.execute(0, Arc::new(ctx))?).await?;

let expected = [
"+--------------+------------+",
"| labels | SUM(value) |",
"+--------------+------------+",
"| {a: a, b: b} | 2 |",
"| {a: , b: c} | 1 |",
"+--------------+------------+",
];
assert_batches_eq!(expected, &output);

Ok(())
}

#[tokio::test]
async fn test_skip_aggregation_after_first_batch() -> Result<()> {
let schema = Arc::new(Schema::new(vec![
Expand Down
Loading