Skip to content

Commit 098ba30

Browse files
authored
Relax combine partial final rule (#10913)
* Minor changes * Minor changes * Re-introduce group by expression check
1 parent 6dffc53 commit 098ba30

File tree

3 files changed

+99
-61
lines changed

3 files changed

+99
-61
lines changed

datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs

Lines changed: 17 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@ use crate::physical_plan::ExecutionPlan;
2727

2828
use datafusion_common::config::ConfigOptions;
2929
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
30-
use datafusion_physical_expr::expressions::Column;
31-
use datafusion_physical_expr::{AggregateExpr, PhysicalExpr};
30+
use datafusion_physical_expr::{physical_exprs_equal, AggregateExpr, PhysicalExpr};
3231

3332
/// CombinePartialFinalAggregate optimizer rule combines the adjacent Partial and Final AggregateExecs
3433
/// into a Single AggregateExec if their grouping exprs and aggregate exprs equal.
@@ -132,19 +131,23 @@ type GroupExprsRef<'a> = (
132131
&'a [Option<Arc<dyn PhysicalExpr>>],
133132
);
134133

135-
type GroupExprs = (
136-
PhysicalGroupBy,
137-
Vec<Arc<dyn AggregateExpr>>,
138-
Vec<Option<Arc<dyn PhysicalExpr>>>,
139-
);
140-
141134
fn can_combine(final_agg: GroupExprsRef, partial_agg: GroupExprsRef) -> bool {
142-
let (final_group_by, final_aggr_expr, final_filter_expr) =
143-
normalize_group_exprs(final_agg);
144-
let (input_group_by, input_aggr_expr, input_filter_expr) =
145-
normalize_group_exprs(partial_agg);
146-
147-
final_group_by.eq(&input_group_by)
135+
let (final_group_by, final_aggr_expr, final_filter_expr) = final_agg;
136+
let (input_group_by, input_aggr_expr, input_filter_expr) = partial_agg;
137+
138+
// Compare output expressions of the partial, and input expressions of the final operator.
139+
physical_exprs_equal(
140+
&input_group_by.output_exprs(),
141+
&final_group_by.input_exprs(),
142+
) && input_group_by.groups() == final_group_by.groups()
143+
&& input_group_by.null_expr().len() == final_group_by.null_expr().len()
144+
&& input_group_by
145+
.null_expr()
146+
.iter()
147+
.zip(final_group_by.null_expr().iter())
148+
.all(|((lhs_expr, lhs_str), (rhs_expr, rhs_str))| {
149+
lhs_expr.eq(rhs_expr) && lhs_str == rhs_str
150+
})
148151
&& final_aggr_expr.len() == input_aggr_expr.len()
149152
&& final_aggr_expr
150153
.iter()
@@ -160,41 +163,6 @@ fn can_combine(final_agg: GroupExprsRef, partial_agg: GroupExprsRef) -> bool {
160163
)
161164
}
162165

163-
// To compare the group expressions between the final and partial aggregations, need to discard all the column indexes and compare
164-
fn normalize_group_exprs(group_exprs: GroupExprsRef) -> GroupExprs {
165-
let (group, agg, filter) = group_exprs;
166-
let new_group_expr = group
167-
.expr()
168-
.iter()
169-
.map(|(expr, name)| (discard_column_index(expr.clone()), name.clone()))
170-
.collect::<Vec<_>>();
171-
let new_group = PhysicalGroupBy::new(
172-
new_group_expr,
173-
group.null_expr().to_vec(),
174-
group.groups().to_vec(),
175-
);
176-
(new_group, agg.to_vec(), filter.to_vec())
177-
}
178-
179-
fn discard_column_index(group_expr: Arc<dyn PhysicalExpr>) -> Arc<dyn PhysicalExpr> {
180-
group_expr
181-
.clone()
182-
.transform(|expr| {
183-
let normalized_form: Option<Arc<dyn PhysicalExpr>> =
184-
match expr.as_any().downcast_ref::<Column>() {
185-
Some(column) => Some(Arc::new(Column::new(column.name(), 0))),
186-
None => None,
187-
};
188-
Ok(if let Some(normalized_form) = normalized_form {
189-
Transformed::yes(normalized_form)
190-
} else {
191-
Transformed::no(expr)
192-
})
193-
})
194-
.data()
195-
.unwrap_or(group_expr)
196-
}
197-
198166
#[cfg(test)]
199167
mod tests {
200168
use super::*;

datafusion/sqllogictest/test_files/group_by.slt

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5064,3 +5064,74 @@ statement error DataFusion error: Error during planning: Cannot find column with
50645064
SELECT a, b, COUNT(1)
50655065
FROM multiple_ordered_table
50665066
GROUP BY 1, 2, 4, 5, 6;
5067+
5068+
statement ok
5069+
set datafusion.execution.target_partitions = 1;
5070+
5071+
# Create a table that contains various keywords, with their corresponding timestamps
5072+
statement ok
5073+
CREATE TABLE keywords_stream (
5074+
ts TIMESTAMP,
5075+
sn INTEGER PRIMARY KEY,
5076+
keyword VARCHAR NOT NULL
5077+
);
5078+
5079+
statement ok
5080+
INSERT INTO keywords_stream(ts, sn, keyword) VALUES
5081+
('2024-01-01T00:00:00Z', '0', 'Drug'),
5082+
('2024-01-01T00:00:05Z', '1', 'Bomb'),
5083+
('2024-01-01T00:00:10Z', '2', 'Theft'),
5084+
('2024-01-01T00:00:15Z', '3', 'Gun'),
5085+
('2024-01-01T00:00:20Z', '4', 'Calm');
5086+
5087+
# Create a table that contains alert keywords
5088+
statement ok
5089+
CREATE TABLE ALERT_KEYWORDS(keyword VARCHAR NOT NULL);
5090+
5091+
statement ok
5092+
INSERT INTO ALERT_KEYWORDS VALUES
5093+
('Drug'),
5094+
('Bomb'),
5095+
('Theft'),
5096+
('Gun'),
5097+
('Knife'),
5098+
('Fire');
5099+
5100+
query TT
5101+
explain SELECT
5102+
DATE_BIN(INTERVAL '2' MINUTE, ts, '2000-01-01') AS ts_chunk,
5103+
COUNT(keyword) AS alert_keyword_count
5104+
FROM
5105+
keywords_stream
5106+
WHERE
5107+
keywords_stream.keyword IN (SELECT keyword FROM ALERT_KEYWORDS)
5108+
GROUP BY
5109+
ts_chunk;
5110+
----
5111+
logical_plan
5112+
01)Projection: date_bin(IntervalMonthDayNano("IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 120000000000 }"),keywords_stream.ts,Utf8("2000-01-01")) AS ts_chunk, COUNT(keywords_stream.keyword) AS alert_keyword_count
5113+
02)--Aggregate: groupBy=[[date_bin(IntervalMonthDayNano("IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 120000000000 }"), keywords_stream.ts, TimestampNanosecond(946684800000000000, None)) AS date_bin(IntervalMonthDayNano("IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 120000000000 }"),keywords_stream.ts,Utf8("2000-01-01"))]], aggr=[[COUNT(keywords_stream.keyword)]]
5114+
03)----LeftSemi Join: keywords_stream.keyword = __correlated_sq_1.keyword
5115+
04)------TableScan: keywords_stream projection=[ts, keyword]
5116+
05)------SubqueryAlias: __correlated_sq_1
5117+
06)--------TableScan: alert_keywords projection=[keyword]
5118+
physical_plan
5119+
01)ProjectionExec: expr=[date_bin(IntervalMonthDayNano("IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 120000000000 }"),keywords_stream.ts,Utf8("2000-01-01"))@0 as ts_chunk, COUNT(keywords_stream.keyword)@1 as alert_keyword_count]
5120+
02)--AggregateExec: mode=Single, gby=[date_bin(IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 120000000000 }, ts@0, 946684800000000000) as date_bin(IntervalMonthDayNano("IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 120000000000 }"),keywords_stream.ts,Utf8("2000-01-01"))], aggr=[COUNT(keywords_stream.keyword)]
5121+
03)----CoalesceBatchesExec: target_batch_size=2
5122+
04)------HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(keyword@0, keyword@1)]
5123+
05)--------MemoryExec: partitions=1, partition_sizes=[1]
5124+
06)--------MemoryExec: partitions=1, partition_sizes=[1]
5125+
5126+
query PI
5127+
SELECT
5128+
DATE_BIN(INTERVAL '2' MINUTE, ts, '2000-01-01') AS ts_chunk,
5129+
COUNT(keyword) AS alert_keyword_count
5130+
FROM
5131+
keywords_stream
5132+
WHERE
5133+
keywords_stream.keyword IN (SELECT keyword FROM ALERT_KEYWORDS)
5134+
GROUP BY
5135+
ts_chunk;
5136+
----
5137+
2024-01-01T00:00:00 4

datafusion/sqllogictest/test_files/joins.slt

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1382,18 +1382,17 @@ physical_plan
13821382
02)--AggregateExec: mode=Final, gby=[], aggr=[COUNT(alias1)]
13831383
03)----CoalescePartitionsExec
13841384
04)------AggregateExec: mode=Partial, gby=[], aggr=[COUNT(alias1)]
1385-
05)--------AggregateExec: mode=FinalPartitioned, gby=[alias1@0 as alias1], aggr=[]
1386-
06)----------AggregateExec: mode=Partial, gby=[t1_id@0 as alias1], aggr=[]
1387-
07)------------CoalesceBatchesExec: target_batch_size=2
1388-
08)--------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(t1_id@0, t2_id@0)], projection=[t1_id@0]
1389-
09)----------------CoalesceBatchesExec: target_batch_size=2
1390-
10)------------------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2
1391-
11)--------------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
1392-
12)----------------------MemoryExec: partitions=1, partition_sizes=[1]
1393-
13)----------------CoalesceBatchesExec: target_batch_size=2
1394-
14)------------------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2
1395-
15)--------------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
1396-
16)----------------------MemoryExec: partitions=1, partition_sizes=[1]
1385+
05)--------AggregateExec: mode=SinglePartitioned, gby=[t1_id@0 as alias1], aggr=[]
1386+
06)----------CoalesceBatchesExec: target_batch_size=2
1387+
07)------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(t1_id@0, t2_id@0)], projection=[t1_id@0]
1388+
08)--------------CoalesceBatchesExec: target_batch_size=2
1389+
09)----------------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2
1390+
10)------------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
1391+
11)--------------------MemoryExec: partitions=1, partition_sizes=[1]
1392+
12)--------------CoalesceBatchesExec: target_batch_size=2
1393+
13)----------------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2
1394+
14)------------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
1395+
15)--------------------MemoryExec: partitions=1, partition_sizes=[1]
13971396

13981397
statement ok
13991398
set datafusion.explain.logical_plan_only = true;

0 commit comments

Comments
 (0)