Skip to content

Commit 0b061be

Browse files
rluvatoncomphead
andauthored
fix: update group by columns for merge phase after spill (#15531)
* fix: update group by columns for merge phase after spill fixes #15530 * Update datafusion/physical-plan/src/aggregates/row_hash.rs Co-authored-by: Oleks V <[email protected]> * move test to aggregate_fuzz --------- Co-authored-by: Oleks V <[email protected]>
1 parent 387541c commit 0b061be

File tree

2 files changed

+153
-4
lines changed

2 files changed

+153
-4
lines changed

datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs

Lines changed: 142 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,17 @@ use crate::fuzz_cases::aggregation_fuzzer::{
2121
AggregationFuzzerBuilder, ColumnDescr, DatasetGeneratorConfig, QueryBuilder,
2222
};
2323

24-
use arrow::array::{types::Int64Type, Array, ArrayRef, AsArray, Int64Array, RecordBatch};
24+
use arrow::array::{
25+
types::Int64Type, Array, ArrayRef, AsArray, Int32Array, Int64Array, RecordBatch,
26+
StringArray,
27+
};
2528
use arrow::compute::{concat_batches, SortOptions};
2629
use arrow::datatypes::{
2730
DataType, IntervalUnit, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE,
2831
DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE,
2932
};
3033
use arrow::util::pretty::pretty_format_batches;
34+
use arrow_schema::{Field, Schema, SchemaRef};
3135
use datafusion::common::Result;
3236
use datafusion::datasource::memory::MemorySourceConfig;
3337
use datafusion::datasource::source::DataSourceExec;
@@ -42,14 +46,18 @@ use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor}
4246
use datafusion_common::HashMap;
4347
use datafusion_common_runtime::JoinSet;
4448
use datafusion_functions_aggregate::sum::sum_udaf;
45-
use datafusion_physical_expr::expressions::col;
49+
use datafusion_physical_expr::expressions::{col, lit, Column};
4650
use datafusion_physical_expr::PhysicalSortExpr;
4751
use datafusion_physical_expr_common::sort_expr::LexOrdering;
4852
use datafusion_physical_plan::InputOrderMode;
4953
use test_utils::{add_empty_batches, StringBatchGenerator};
5054

55+
use datafusion_execution::memory_pool::FairSpillPool;
56+
use datafusion_execution::runtime_env::RuntimeEnvBuilder;
57+
use datafusion_execution::TaskContext;
58+
use datafusion_physical_plan::metrics::MetricValue;
5159
use rand::rngs::StdRng;
52-
use rand::{thread_rng, Rng, SeedableRng};
60+
use rand::{random, thread_rng, Rng, SeedableRng};
5361

5462
// ========================================================================
5563
// The new aggregation fuzz tests based on [`AggregationFuzzer`]
@@ -663,3 +671,134 @@ fn extract_result_counts(results: Vec<RecordBatch>) -> HashMap<Option<String>, i
663671
}
664672
output
665673
}
674+
675+
fn assert_spill_count_metric(expect_spill: bool, single_aggregate: Arc<AggregateExec>) {
676+
if let Some(metrics_set) = single_aggregate.metrics() {
677+
let mut spill_count = 0;
678+
679+
// Inspect metrics for SpillCount
680+
for metric in metrics_set.iter() {
681+
if let MetricValue::SpillCount(count) = metric.value() {
682+
spill_count = count.value();
683+
break;
684+
}
685+
}
686+
687+
if expect_spill && spill_count == 0 {
688+
panic!("Expected spill but SpillCount metric not found or SpillCount was 0.");
689+
} else if !expect_spill && spill_count > 0 {
690+
panic!("Expected no spill but found SpillCount metric with value greater than 0.");
691+
}
692+
} else {
693+
panic!("No metrics returned from the operator; cannot verify spilling.");
694+
}
695+
}
696+
697+
// Fix for https://github.com/apache/datafusion/issues/15530
698+
#[tokio::test]
699+
async fn test_single_mode_aggregate_with_spill() -> Result<()> {
700+
let scan_schema = Arc::new(Schema::new(vec![
701+
Field::new("col_0", DataType::Int64, true),
702+
Field::new("col_1", DataType::Utf8, true),
703+
Field::new("col_2", DataType::Utf8, true),
704+
Field::new("col_3", DataType::Utf8, true),
705+
Field::new("col_4", DataType::Utf8, true),
706+
Field::new("col_5", DataType::Int32, true),
707+
Field::new("col_6", DataType::Utf8, true),
708+
Field::new("col_7", DataType::Utf8, true),
709+
Field::new("col_8", DataType::Utf8, true),
710+
]));
711+
712+
let group_by = PhysicalGroupBy::new_single(vec![
713+
(Arc::new(Column::new("col_1", 1)), "col_1".to_string()),
714+
(Arc::new(Column::new("col_7", 7)), "col_7".to_string()),
715+
(Arc::new(Column::new("col_0", 0)), "col_0".to_string()),
716+
(Arc::new(Column::new("col_8", 8)), "col_8".to_string()),
717+
]);
718+
719+
fn generate_int64_array() -> ArrayRef {
720+
Arc::new(Int64Array::from_iter_values(
721+
(0..1024).map(|_| random::<i64>()),
722+
))
723+
}
724+
fn generate_int32_array() -> ArrayRef {
725+
Arc::new(Int32Array::from_iter_values(
726+
(0..1024).map(|_| random::<i32>()),
727+
))
728+
}
729+
730+
fn generate_string_array() -> ArrayRef {
731+
Arc::new(StringArray::from(
732+
(0..1024)
733+
.map(|_| -> String {
734+
thread_rng()
735+
.sample_iter::<char, _>(rand::distributions::Standard)
736+
.take(5)
737+
.collect()
738+
})
739+
.collect::<Vec<_>>(),
740+
))
741+
}
742+
743+
fn generate_record_batch(schema: &SchemaRef) -> Result<RecordBatch> {
744+
RecordBatch::try_new(
745+
Arc::clone(schema),
746+
vec![
747+
generate_int64_array(),
748+
generate_string_array(),
749+
generate_string_array(),
750+
generate_string_array(),
751+
generate_string_array(),
752+
generate_int32_array(),
753+
generate_string_array(),
754+
generate_string_array(),
755+
generate_string_array(),
756+
],
757+
)
758+
.map_err(|err| err.into())
759+
}
760+
761+
let aggregate_expressions = vec![Arc::new(
762+
AggregateExprBuilder::new(sum_udaf(), vec![lit(1i64)])
763+
.schema(Arc::clone(&scan_schema))
764+
.alias("SUM(1i64)")
765+
.build()?,
766+
)];
767+
768+
let batches = (0..5)
769+
.map(|_| generate_record_batch(&scan_schema))
770+
.collect::<Result<Vec<_>>>()?;
771+
772+
let plan: Arc<dyn ExecutionPlan> =
773+
MemorySourceConfig::try_new_exec(&[batches], Arc::clone(&scan_schema), None)
774+
.unwrap();
775+
776+
let single_aggregate = Arc::new(AggregateExec::try_new(
777+
AggregateMode::Single,
778+
group_by,
779+
aggregate_expressions.clone(),
780+
vec![None; aggregate_expressions.len()],
781+
plan,
782+
Arc::clone(&scan_schema),
783+
)?);
784+
785+
let memory_pool = Arc::new(FairSpillPool::new(250000));
786+
let task_ctx = Arc::new(
787+
TaskContext::default()
788+
.with_session_config(SessionConfig::new().with_batch_size(248))
789+
.with_runtime(Arc::new(
790+
RuntimeEnvBuilder::new()
791+
.with_memory_pool(memory_pool)
792+
.build()?,
793+
)),
794+
);
795+
796+
datafusion_physical_plan::common::collect(
797+
single_aggregate.execute(0, Arc::clone(&task_ctx))?,
798+
)
799+
.await?;
800+
801+
assert_spill_count_metric(true, single_aggregate);
802+
803+
Ok(())
804+
}

datafusion/physical-plan/src/aggregates/row_hash.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -507,6 +507,16 @@ impl GroupedHashAggregateStream {
507507
AggregateMode::Partial,
508508
)?;
509509

510+
// Need to update the GROUP BY expressions to point to the correct column after schema change
511+
let merging_group_by_expr = agg_group_by
512+
.expr
513+
.iter()
514+
.enumerate()
515+
.map(|(idx, (_, name))| {
516+
(Arc::new(Column::new(name.as_str(), idx)) as _, name.clone())
517+
})
518+
.collect();
519+
510520
let partial_agg_schema = Arc::new(partial_agg_schema);
511521

512522
let spill_expr = group_schema
@@ -550,7 +560,7 @@ impl GroupedHashAggregateStream {
550560
spill_schema: partial_agg_schema,
551561
is_stream_merging: false,
552562
merging_aggregate_arguments,
553-
merging_group_by: PhysicalGroupBy::new_single(agg_group_by.expr.clone()),
563+
merging_group_by: PhysicalGroupBy::new_single(merging_group_by_expr),
554564
peak_mem_used: MetricBuilder::new(&agg.metrics)
555565
.gauge("peak_mem_used", partition),
556566
spill_manager,

0 commit comments

Comments
 (0)