Skip to content

Commit ae4113d

Browse files
authored
fix: issue #9213 substitute ArrayAgg to NthValue to optimize query plan (#9295)
* fix: issue #9213 substitute ArrayAgg to NthValue to optimize query plan * fix format * adding type check * adding test
1 parent a851ecf commit ae4113d

File tree

2 files changed

+142
-4
lines changed

2 files changed

+142
-4
lines changed

datafusion/sql/src/expr/mod.rs

+42-4
Original file line numberDiff line numberDiff line change
@@ -203,9 +203,44 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
203203
}
204204

205205
SQLExpr::ArrayIndex { obj, indexes } => {
206+
fn is_unsupported(expr: &SQLExpr) -> bool {
207+
matches!(expr, SQLExpr::JsonAccess { .. })
208+
}
209+
fn simplify_array_index_expr(expr: Expr, index: Expr) -> (Expr, bool) {
210+
match &expr {
211+
Expr::AggregateFunction(agg_func) if agg_func.func_def == datafusion_expr::expr::AggregateFunctionDefinition::BuiltIn(AggregateFunction::ArrayAgg) => {
212+
let mut new_args = agg_func.args.clone();
213+
new_args.push(index.clone());
214+
(Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new(
215+
datafusion_expr::AggregateFunction::NthValue,
216+
new_args,
217+
agg_func.distinct,
218+
agg_func.filter.clone(),
219+
agg_func.order_by.clone(),
220+
)), true)
221+
},
222+
_ => (expr, false),
223+
}
224+
}
206225
let expr =
207226
self.sql_expr_to_logical_expr(*obj, schema, planner_context)?;
208-
self.plan_indexed(expr, indexes, schema, planner_context)
227+
if indexes.len() > 1 || is_unsupported(&indexes[0]) {
228+
return self.plan_indexed(expr, indexes, schema, planner_context);
229+
}
230+
let (new_expr, changed) = simplify_array_index_expr(
231+
expr,
232+
self.sql_expr_to_logical_expr(
233+
indexes[0].clone(),
234+
schema,
235+
planner_context,
236+
)?,
237+
);
238+
239+
if changed {
240+
Ok(new_expr)
241+
} else {
242+
self.plan_indexed(new_expr, indexes, schema, planner_context)
243+
}
209244
}
210245

211246
SQLExpr::CompoundIdentifier(ids) => {
@@ -557,7 +592,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
557592
limit,
558593
within_group,
559594
} = array_agg;
560-
561595
let order_by = if let Some(order_by) = order_by {
562596
Some(self.order_by_to_sort_expr(
563597
&order_by,
@@ -581,10 +615,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
581615
vec![self.sql_expr_to_logical_expr(*expr, input_schema, planner_context)?];
582616

583617
// next, aggregate built-ins
584-
let fun = AggregateFunction::ArrayAgg;
585618
Ok(Expr::AggregateFunction(expr::AggregateFunction::new(
586-
fun, args, distinct, None, order_by,
619+
AggregateFunction::ArrayAgg,
620+
args,
621+
distinct,
622+
None,
623+
order_by,
587624
)))
625+
// see if we can rewrite it into NTH-VALUE
588626
}
589627

590628
fn sql_in_list_to_expr(
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
#######
19+
# Setup test data table
20+
#######
21+
statement ok
22+
CREATE EXTERNAL TABLE multiple_ordered_table (
23+
a0 INTEGER,
24+
a INTEGER,
25+
b INTEGER,
26+
c INTEGER,
27+
d INTEGER
28+
)
29+
STORED AS CSV
30+
WITH HEADER ROW
31+
WITH ORDER (a ASC, b ASC)
32+
WITH ORDER (c ASC)
33+
LOCATION '../../datafusion/core/tests/data/window_2.csv';
34+
35+
36+
query TT
37+
EXPLAIN SELECT a, ARRAY_AGG(c ORDER BY c)[1] as result
38+
FROM multiple_ordered_table
39+
GROUP BY a;
40+
----
41+
logical_plan
42+
Projection: multiple_ordered_table.a, NTH_VALUE(multiple_ordered_table.c,Int64(1)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST] AS result
43+
--Aggregate: groupBy=[[multiple_ordered_table.a]], aggr=[[NTH_VALUE(multiple_ordered_table.c, Int64(1)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]]]
44+
----TableScan: multiple_ordered_table projection=[a, c]
45+
physical_plan
46+
ProjectionExec: expr=[a@0 as a, NTH_VALUE(multiple_ordered_table.c,Int64(1)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]@1 as result]
47+
--AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[NTH_VALUE(multiple_ordered_table.c,Int64(1))], ordering_mode=Sorted
48+
----SortExec: expr=[a@0 ASC NULLS LAST]
49+
------CoalesceBatchesExec: target_batch_size=8192
50+
--------RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=4
51+
----------AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[NTH_VALUE(multiple_ordered_table.c,Int64(1))], ordering_mode=Sorted
52+
------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
53+
--------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c], output_orderings=[[a@0 ASC NULLS LAST], [c@1 ASC NULLS LAST]], has_header=true
54+
55+
56+
query TT
57+
EXPLAIN SELECT a, NTH_VALUE(c, 1 ORDER BY c) as result
58+
FROM multiple_ordered_table
59+
GROUP BY a;
60+
----
61+
logical_plan
62+
Projection: multiple_ordered_table.a, NTH_VALUE(multiple_ordered_table.c,Int64(1)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST] AS result
63+
--Aggregate: groupBy=[[multiple_ordered_table.a]], aggr=[[NTH_VALUE(multiple_ordered_table.c, Int64(1)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]]]
64+
----TableScan: multiple_ordered_table projection=[a, c]
65+
physical_plan
66+
ProjectionExec: expr=[a@0 as a, NTH_VALUE(multiple_ordered_table.c,Int64(1)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]@1 as result]
67+
--AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[NTH_VALUE(multiple_ordered_table.c,Int64(1))], ordering_mode=Sorted
68+
----SortExec: expr=[a@0 ASC NULLS LAST]
69+
------CoalesceBatchesExec: target_batch_size=8192
70+
--------RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=4
71+
----------AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[NTH_VALUE(multiple_ordered_table.c,Int64(1))], ordering_mode=Sorted
72+
------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
73+
--------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c], output_orderings=[[a@0 ASC NULLS LAST], [c@1 ASC NULLS LAST]], has_header=true
74+
75+
query TT
76+
EXPLAIN SELECT a, ARRAY_AGG(c ORDER BY c)[1 + 100] as result
77+
FROM multiple_ordered_table
78+
GROUP BY a;
79+
----
80+
logical_plan
81+
Projection: multiple_ordered_table.a, NTH_VALUE(multiple_ordered_table.c,Int64(1) + Int64(100)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST] AS result
82+
--Aggregate: groupBy=[[multiple_ordered_table.a]], aggr=[[NTH_VALUE(multiple_ordered_table.c, Int64(101)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST] AS NTH_VALUE(multiple_ordered_table.c,Int64(1) + Int64(100)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]]]
83+
----TableScan: multiple_ordered_table projection=[a, c]
84+
physical_plan
85+
ProjectionExec: expr=[a@0 as a, NTH_VALUE(multiple_ordered_table.c,Int64(1) + Int64(100)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]@1 as result]
86+
--AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[NTH_VALUE(multiple_ordered_table.c,Int64(1) + Int64(100)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]], ordering_mode=Sorted
87+
----SortExec: expr=[a@0 ASC NULLS LAST]
88+
------CoalesceBatchesExec: target_batch_size=8192
89+
--------RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=4
90+
----------AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[NTH_VALUE(multiple_ordered_table.c,Int64(1) + Int64(100)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]], ordering_mode=Sorted
91+
------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
92+
--------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c], output_orderings=[[a@0 ASC NULLS LAST], [c@1 ASC NULLS LAST]], has_header=true
93+
94+
query II
95+
SELECT a, ARRAY_AGG(c ORDER BY c)[1] as result
96+
FROM multiple_ordered_table
97+
GROUP BY a;
98+
----
99+
0 0
100+
1 50

0 commit comments

Comments
 (0)