Skip to content

Commit c412c74

Browse files
Fix bug in TopK aggregates (#12766)
Fix bug in TopK aggregates (#12766)
1 parent 7d36059 commit c412c74

File tree

11 files changed

+100
-20
lines changed

11 files changed

+100
-20
lines changed

datafusion/physical-optimizer/src/topk_aggregation.rs

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,6 @@
2020
use std::sync::Arc;
2121

2222
use datafusion_physical_plan::aggregates::AggregateExec;
23-
use datafusion_physical_plan::coalesce_batches::CoalesceBatchesExec;
24-
use datafusion_physical_plan::filter::FilterExec;
25-
use datafusion_physical_plan::repartition::RepartitionExec;
2623
use datafusion_physical_plan::sorts::sort::SortExec;
2724
use datafusion_physical_plan::ExecutionPlan;
2825

@@ -31,9 +28,10 @@ use datafusion_common::config::ConfigOptions;
3128
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
3229
use datafusion_common::Result;
3330
use datafusion_physical_expr::expressions::Column;
34-
use datafusion_physical_expr::PhysicalSortExpr;
3531

3632
use crate::PhysicalOptimizerRule;
33+
use datafusion_physical_plan::execution_plan::CardinalityEffect;
34+
use datafusion_physical_plan::projection::ProjectionExec;
3735
use itertools::Itertools;
3836

3937
/// An optimizer rule that passes a `limit` hint to aggregations if the whole result is not needed
@@ -48,12 +46,13 @@ impl TopKAggregation {
4846

4947
fn transform_agg(
5048
aggr: &AggregateExec,
51-
order: &PhysicalSortExpr,
49+
order_by: &str,
50+
order_desc: bool,
5251
limit: usize,
5352
) -> Option<Arc<dyn ExecutionPlan>> {
5453
// ensure the sort direction matches aggregate function
5554
let (field, desc) = aggr.get_minmax_desc()?;
56-
if desc != order.options.descending {
55+
if desc != order_desc {
5756
return None;
5857
}
5958
let group_key = aggr.group_expr().expr().iter().exactly_one().ok()?;
@@ -66,8 +65,7 @@ impl TopKAggregation {
6665
}
6766

6867
// ensure the sort is on the same field as the aggregate output
69-
let col = order.expr.as_any().downcast_ref::<Column>()?;
70-
if col.name() != field.name() {
68+
if order_by != field.name() {
7169
return None;
7270
}
7371

@@ -92,31 +90,39 @@ impl TopKAggregation {
9290
let child = children.into_iter().exactly_one().ok()?;
9391
let order = sort.properties().output_ordering()?;
9492
let order = order.iter().exactly_one().ok()?;
93+
let order_desc = order.options.descending;
94+
let order = order.expr.as_any().downcast_ref::<Column>()?;
95+
let mut cur_col_name = order.name().to_string();
9596
let limit = sort.fetch()?;
9697

97-
let is_cardinality_preserving = |plan: Arc<dyn ExecutionPlan>| {
98-
plan.as_any()
99-
.downcast_ref::<CoalesceBatchesExec>()
100-
.is_some()
101-
|| plan.as_any().downcast_ref::<RepartitionExec>().is_some()
102-
|| plan.as_any().downcast_ref::<FilterExec>().is_some()
103-
};
104-
10598
let mut cardinality_preserved = true;
10699
let closure = |plan: Arc<dyn ExecutionPlan>| {
107100
if !cardinality_preserved {
108101
return Ok(Transformed::no(plan));
109102
}
110103
if let Some(aggr) = plan.as_any().downcast_ref::<AggregateExec>() {
111104
// either we run into an Aggregate and transform it
112-
match Self::transform_agg(aggr, order, limit) {
105+
match Self::transform_agg(aggr, &cur_col_name, order_desc, limit) {
113106
None => cardinality_preserved = false,
114107
Some(plan) => return Ok(Transformed::yes(plan)),
115108
}
109+
} else if let Some(proj) = plan.as_any().downcast_ref::<ProjectionExec>() {
110+
// track renames due to successive projections
111+
for (src_expr, proj_name) in proj.expr() {
112+
let Some(src_col) = src_expr.as_any().downcast_ref::<Column>() else {
113+
continue;
114+
};
115+
if *proj_name == cur_col_name {
116+
cur_col_name = src_col.name().to_string();
117+
}
118+
}
116119
} else {
117-
// or we continue down whitelisted nodes of other types
118-
if !is_cardinality_preserving(Arc::clone(&plan)) {
119-
cardinality_preserved = false;
120+
// or we continue down through types that don't reduce cardinality
121+
match plan.cardinality_effect() {
122+
CardinalityEffect::Equal | CardinalityEffect::GreaterEqual => {}
123+
CardinalityEffect::Unknown | CardinalityEffect::LowerEqual => {
124+
cardinality_preserved = false;
125+
}
120126
}
121127
}
122128
Ok(Transformed::no(plan))

datafusion/physical-plan/src/aggregates/mod.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ use datafusion_physical_expr::{
4848
PhysicalExpr, PhysicalSortRequirement,
4949
};
5050

51+
use crate::execution_plan::CardinalityEffect;
5152
use datafusion_physical_expr::aggregate::AggregateFunctionExpr;
5253
use itertools::Itertools;
5354

@@ -866,6 +867,10 @@ impl ExecutionPlan for AggregateExec {
866867
}
867868
}
868869
}
870+
871+
fn cardinality_effect(&self) -> CardinalityEffect {
872+
CardinalityEffect::LowerEqual
873+
}
869874
}
870875

871876
fn create_schema(

datafusion/physical-plan/src/coalesce_batches.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ use datafusion_common::Result;
3434
use datafusion_execution::TaskContext;
3535

3636
use crate::coalesce::{BatchCoalescer, CoalescerState};
37+
use crate::execution_plan::CardinalityEffect;
3738
use futures::ready;
3839
use futures::stream::{Stream, StreamExt};
3940

@@ -199,6 +200,10 @@ impl ExecutionPlan for CoalesceBatchesExec {
199200
fn fetch(&self) -> Option<usize> {
200201
self.fetch
201202
}
203+
204+
fn cardinality_effect(&self) -> CardinalityEffect {
205+
CardinalityEffect::Equal
206+
}
202207
}
203208

204209
/// Stream for [`CoalesceBatchesExec`]. See [`CoalesceBatchesExec`] for more details.

datafusion/physical-plan/src/coalesce_partitions.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ use super::{
3030

3131
use crate::{DisplayFormatType, ExecutionPlan, Partitioning};
3232

33+
use crate::execution_plan::CardinalityEffect;
3334
use datafusion_common::{internal_err, Result};
3435
use datafusion_execution::TaskContext;
3536

@@ -178,6 +179,10 @@ impl ExecutionPlan for CoalescePartitionsExec {
178179
fn supports_limit_pushdown(&self) -> bool {
179180
true
180181
}
182+
183+
fn cardinality_effect(&self) -> CardinalityEffect {
184+
CardinalityEffect::Equal
185+
}
181186
}
182187

183188
#[cfg(test)]

datafusion/physical-plan/src/execution_plan.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,11 @@ pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync {
416416
fn fetch(&self) -> Option<usize> {
417417
None
418418
}
419+
420+
/// Gets the effect on cardinality, if known
421+
fn cardinality_effect(&self) -> CardinalityEffect {
422+
CardinalityEffect::Unknown
423+
}
419424
}
420425

421426
/// Extension trait provides an easy API to fetch various properties of
@@ -898,6 +903,20 @@ pub fn get_plan_string(plan: &Arc<dyn ExecutionPlan>) -> Vec<String> {
898903
actual.iter().map(|elem| elem.to_string()).collect()
899904
}
900905

906+
/// Indicates the effect an execution plan operator will have on the cardinality
907+
/// of its input stream
908+
pub enum CardinalityEffect {
909+
/// Unknown effect. This is the default
910+
Unknown,
911+
/// The operator is guaranteed to produce exactly one row for
912+
/// each input row
913+
Equal,
914+
/// The operator may produce fewer output rows than it receives input rows
915+
LowerEqual,
916+
/// The operator may produce more output rows than it receives input rows
917+
GreaterEqual,
918+
}
919+
901920
#[cfg(test)]
902921
mod tests {
903922
use super::*;

datafusion/physical-plan/src/filter.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ use datafusion_physical_expr::{
4848
analyze, split_conjunction, AnalysisContext, ConstExpr, ExprBoundaries, PhysicalExpr,
4949
};
5050

51+
use crate::execution_plan::CardinalityEffect;
5152
use futures::stream::{Stream, StreamExt};
5253
use log::trace;
5354

@@ -372,6 +373,10 @@ impl ExecutionPlan for FilterExec {
372373
fn statistics(&self) -> Result<Statistics> {
373374
Self::statistics_helper(&self.input, self.predicate(), self.default_selectivity)
374375
}
376+
377+
fn cardinality_effect(&self) -> CardinalityEffect {
378+
CardinalityEffect::LowerEqual
379+
}
375380
}
376381

377382
/// This function ensures that all bounds in the `ExprBoundaries` vector are

datafusion/physical-plan/src/limit.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ use arrow::record_batch::RecordBatch;
3434
use datafusion_common::{internal_err, Result};
3535
use datafusion_execution::TaskContext;
3636

37+
use crate::execution_plan::CardinalityEffect;
3738
use futures::stream::{Stream, StreamExt};
3839
use log::trace;
3940

@@ -336,6 +337,10 @@ impl ExecutionPlan for LocalLimitExec {
336337
fn supports_limit_pushdown(&self) -> bool {
337338
true
338339
}
340+
341+
fn cardinality_effect(&self) -> CardinalityEffect {
342+
CardinalityEffect::LowerEqual
343+
}
339344
}
340345

341346
/// A Limit stream skips `skip` rows, and then fetch up to `fetch` rows.

datafusion/physical-plan/src/projection.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ use datafusion_execution::TaskContext;
4242
use datafusion_physical_expr::equivalence::ProjectionMapping;
4343
use datafusion_physical_expr::expressions::Literal;
4444

45+
use crate::execution_plan::CardinalityEffect;
4546
use futures::stream::{Stream, StreamExt};
4647
use log::trace;
4748

@@ -233,6 +234,10 @@ impl ExecutionPlan for ProjectionExec {
233234
fn supports_limit_pushdown(&self) -> bool {
234235
true
235236
}
237+
238+
fn cardinality_effect(&self) -> CardinalityEffect {
239+
CardinalityEffect::Equal
240+
}
236241
}
237242

238243
/// If e is a direct column reference, returns the field level

datafusion/physical-plan/src/repartition/mod.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ use datafusion_execution::memory_pool::MemoryConsumer;
4848
use datafusion_execution::TaskContext;
4949
use datafusion_physical_expr::{EquivalenceProperties, PhysicalExpr, PhysicalSortExpr};
5050

51+
use crate::execution_plan::CardinalityEffect;
5152
use futures::stream::Stream;
5253
use futures::{FutureExt, StreamExt, TryStreamExt};
5354
use hashbrown::HashMap;
@@ -669,6 +670,10 @@ impl ExecutionPlan for RepartitionExec {
669670
fn statistics(&self) -> Result<Statistics> {
670671
self.input.statistics()
671672
}
673+
674+
fn cardinality_effect(&self) -> CardinalityEffect {
675+
CardinalityEffect::Equal
676+
}
672677
}
673678

674679
impl RepartitionExec {

datafusion/physical-plan/src/sorts/sort.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ use datafusion_execution::TaskContext;
5555
use datafusion_physical_expr::LexOrdering;
5656
use datafusion_physical_expr_common::sort_expr::PhysicalSortRequirement;
5757

58+
use crate::execution_plan::CardinalityEffect;
5859
use futures::{StreamExt, TryStreamExt};
5960
use log::{debug, trace};
6061

@@ -972,6 +973,14 @@ impl ExecutionPlan for SortExec {
972973
fn fetch(&self) -> Option<usize> {
973974
self.fetch
974975
}
976+
977+
fn cardinality_effect(&self) -> CardinalityEffect {
978+
if self.fetch.is_none() {
979+
CardinalityEffect::Equal
980+
} else {
981+
CardinalityEffect::LowerEqual
982+
}
983+
}
975984
}
976985

977986
#[cfg(test)]

datafusion/sqllogictest/test_files/aggregates_topk.slt

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,11 @@ physical_plan
5353
07)------------AggregateExec: mode=Partial, gby=[trace_id@0 as trace_id], aggr=[max(traces.timestamp)]
5454
08)--------------MemoryExec: partitions=1, partition_sizes=[1]
5555

56+
query TI
57+
select * from (select trace_id, MAX(timestamp) max_ts from traces t group by trace_id) where trace_id != 'b' order by max_ts desc limit 3;
58+
----
59+
c 4
60+
a 1
5661

5762
query TI
5863
select trace_id, MAX(timestamp) from traces group by trace_id order by MAX(timestamp) desc limit 4;
@@ -89,6 +94,12 @@ c 1 2
8994
statement ok
9095
set datafusion.optimizer.enable_topk_aggregation = true;
9196

97+
query TI
98+
select * from (select trace_id, MAX(timestamp) max_ts from traces t group by trace_id) where max_ts != 3 order by max_ts desc limit 2;
99+
----
100+
c 4
101+
a 1
102+
92103
query TT
93104
explain select trace_id, MAX(timestamp) from traces group by trace_id order by MAX(timestamp) desc limit 4;
94105
----

0 commit comments

Comments
 (0)