Skip to content

Commit

Permalink
[opt](nereids) enhance PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE (#43856)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?
PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE has some restrictions

do not support count(*)
do not support join with other join conditions
do not support the project between agg and join that contains non-slot
expressions
this pr removes above restrictions for pattern: agg-project-join
  • Loading branch information
englefly authored and Your Name committed Dec 5, 2024
1 parent bbfb363 commit 97af1ff
Show file tree
Hide file tree
Showing 4 changed files with 212 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
import com.google.common.collect.Lists;

import java.util.ArrayList;
import java.util.HashMap;
Expand Down Expand Up @@ -74,8 +75,8 @@ public List<Rule> buildRules() {
Set<AggregateFunction> funcs = agg.getAggregateFunctions();
return !funcs.isEmpty() && funcs.stream()
.allMatch(f -> (f instanceof Min || f instanceof Max || f instanceof Sum
|| (f instanceof Count && !((Count) f).isCountStar())) && !f.isDistinct()
&& f.child(0) instanceof Slot);
|| f instanceof Count && !f.isDistinct()
&& (f.children().isEmpty() || f.child(0) instanceof Slot)));
})
.thenApply(ctx -> {
Set<Integer> enableNereidsRules = ctx.cascadesContext.getConnectContext()
Expand All @@ -88,15 +89,16 @@ public List<Rule> buildRules() {
})
.toRule(RuleType.PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE),
logicalAggregate(logicalProject(innerLogicalJoin()))
.when(agg -> agg.child().isAllSlots())
.when(agg -> agg.child().child().getOtherJoinConjuncts().isEmpty())
.whenNot(agg -> agg.child().children().stream().anyMatch(p -> p instanceof LogicalAggregate))
// .when(agg -> agg.child().isAllSlots())
// .when(agg -> agg.child().child().getOtherJoinConjuncts().isEmpty())
.whenNot(agg -> agg.child()
.child(0).children().stream().anyMatch(p -> p instanceof LogicalAggregate))
.when(agg -> {
Set<AggregateFunction> funcs = agg.getAggregateFunctions();
return !funcs.isEmpty() && funcs.stream()
.allMatch(f -> (f instanceof Min || f instanceof Max || f instanceof Sum
|| (f instanceof Count && (!((Count) f).isCountStar()))) && !f.isDistinct()
&& f.child(0) instanceof Slot);
|| f instanceof Count) && !f.isDistinct()
&& (f.children().isEmpty() || f.child(0) instanceof Slot));
})
.thenApply(ctx -> {
Set<Integer> enableNereidsRules = ctx.cascadesContext.getConnectContext()
Expand All @@ -118,23 +120,6 @@ public static LogicalAggregate<Plan> pushMinMaxSumCount(LogicalAggregate<? exten
LogicalJoin<Plan, Plan> join, List<NamedExpression> projects) {
List<Slot> leftOutput = join.left().getOutput();
List<Slot> rightOutput = join.right().getOutput();

List<AggregateFunction> leftFuncs = new ArrayList<>();
List<AggregateFunction> rightFuncs = new ArrayList<>();
for (AggregateFunction func : agg.getAggregateFunctions()) {
Slot slot = (Slot) func.child(0);
if (leftOutput.contains(slot)) {
leftFuncs.add(func);
} else if (rightOutput.contains(slot)) {
rightFuncs.add(func);
} else {
throw new IllegalStateException("Slot " + slot + " not found in join output");
}
}
if (leftFuncs.isEmpty() && rightFuncs.isEmpty()) {
return null;
}

Set<Slot> leftGroupBy = new HashSet<>();
Set<Slot> rightGroupBy = new HashSet<>();
for (Expression e : agg.getGroupByExpressions()) {
Expand All @@ -144,18 +129,71 @@ public static LogicalAggregate<Plan> pushMinMaxSumCount(LogicalAggregate<? exten
} else if (rightOutput.contains(slot)) {
rightGroupBy.add(slot);
} else {
return null;
if (projects.isEmpty()) {
// TODO: select ... from ... group by A , B, 1.2; 1.2 is constant
return null;
} else {
for (NamedExpression proj : projects) {
if (proj instanceof Alias && proj.toSlot().equals(slot)) {
Set<Slot> inputForAliasSet = proj.getInputSlots();
for (Slot aliasInputSlot : inputForAliasSet) {
if (leftOutput.contains(aliasInputSlot)) {
leftGroupBy.add(aliasInputSlot);
} else if (rightOutput.contains(aliasInputSlot)) {
rightGroupBy.add(aliasInputSlot);
} else {
return null;
}
}
break;
}
}
}
}
}
join.getHashJoinConjuncts().forEach(e -> e.getInputSlots().forEach(slot -> {
if (leftOutput.contains(slot)) {
leftGroupBy.add(slot);
} else if (rightOutput.contains(slot)) {
rightGroupBy.add(slot);

List<AggregateFunction> leftFuncs = new ArrayList<>();
List<AggregateFunction> rightFuncs = new ArrayList<>();
Count countStar = null;
Count rewrittenCountStar = null;
for (AggregateFunction func : agg.getAggregateFunctions()) {
if (func instanceof Count && ((Count) func).isCountStar()) {
countStar = (Count) func;
} else {
Slot slot = (Slot) func.child(0);
if (leftOutput.contains(slot)) {
leftFuncs.add(func);
} else if (rightOutput.contains(slot)) {
rightFuncs.add(func);
} else {
throw new IllegalStateException("Slot " + slot + " not found in join output");
}
}
}
// rewrite count(*) to count(A), where A is slot from left/right group by key
if (countStar != null) {
if (!leftGroupBy.isEmpty()) {
rewrittenCountStar = (Count) countStar.withChildren(leftGroupBy.iterator().next());
leftFuncs.add(rewrittenCountStar);
} else if (!rightGroupBy.isEmpty()) {
rewrittenCountStar = (Count) countStar.withChildren(rightGroupBy.iterator().next());
rightFuncs.add(rewrittenCountStar);
} else {
throw new IllegalStateException("Slot " + slot + " not found in join output");
return null;
}
}
for (Expression condition : join.getHashJoinConjuncts()) {
for (Slot joinConditionSlot : condition.getInputSlots()) {
if (leftOutput.contains(joinConditionSlot)) {
leftGroupBy.add(joinConditionSlot);
} else if (rightOutput.contains(joinConditionSlot)) {
rightGroupBy.add(joinConditionSlot);
} else {
// apply failed
return null;
}
}
}));
}

Plan left = join.left();
Plan right = join.right();
Expand Down Expand Up @@ -196,6 +234,10 @@ public static LogicalAggregate<Plan> pushMinMaxSumCount(LogicalAggregate<? exten
for (NamedExpression ne : agg.getOutputExpressions()) {
if (ne instanceof Alias && ((Alias) ne).child() instanceof AggregateFunction) {
AggregateFunction func = (AggregateFunction) ((Alias) ne).child();
if (func instanceof Count && ((Count) func).isCountStar()) {
// countStar is already rewritten as count(left_slot) or count(right_slot)
func = rewrittenCountStar;
}
Slot slot = (Slot) func.child(0);
if (leftSlotToOutput.containsKey(slot)) {
Expression newFunc = replaceAggFunc(func, leftSlotToOutput.get(slot).toSlot());
Expand All @@ -210,9 +252,20 @@ public static LogicalAggregate<Plan> pushMinMaxSumCount(LogicalAggregate<? exten
newOutputExprs.add(ne);
}
}

// TODO: column prune project
return agg.withAggOutputChild(newOutputExprs, newJoin);
Plan newAggChild = newJoin;
if (agg.child() instanceof LogicalProject) {
LogicalProject project = (LogicalProject) agg.child();
List<NamedExpression> newProjections = Lists.newArrayList();
newProjections.addAll(project.getProjects());
Set<NamedExpression> leftDifference = new HashSet<NamedExpression>(left.getOutput());
leftDifference.removeAll(project.getProjects());
newProjections.addAll(leftDifference);
Set<NamedExpression> rightDifference = new HashSet<NamedExpression>(right.getOutput());
rightDifference.removeAll(project.getProjects());
newProjections.addAll(rightDifference);
newAggChild = ((LogicalProject) agg.child()).withProjectsAndChild(newProjections, newJoin);
}
return agg.withAggOutputChild(newOutputExprs, newAggChild);
}

private static Expression replaceAggFunc(AggregateFunction func, Slot inputSlot) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -323,11 +323,11 @@ void testSingleCountStar() {
.applyTopDown(new PushDownAggThroughJoinOneSide())
.printlnTree()
.matches(
logicalAggregate(
logicalJoin(
logicalOlapScan(),
logicalJoin(
logicalAggregate(
logicalOlapScan()
)
),
logicalOlapScan()
)
);
}
Expand All @@ -346,11 +346,9 @@ void testBothSideCountAndCountStar() {
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
.applyTopDown(new PushDownAggThroughJoinOneSide())
.matches(
logicalAggregate(
logicalJoin(
logicalOlapScan(),
logicalOlapScan()
)
logicalJoin(
logicalAggregate(logicalOlapScan()),
logicalAggregate(logicalOlapScan())
)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1034,3 +1034,25 @@ Used:
UnUsed: use_push_down_agg_through_join_one_side
SyntaxError:

-- !shape --
PhysicalResultSink
--PhysicalTopN[MERGE_SORT]
----PhysicalTopN[LOCAL_SORT]
------hashAgg[GLOBAL]
--------hashAgg[LOCAL]
----------hashJoin[INNER_JOIN] hashCondition=((dwd_tracking_sensor_init_tmp_ymd.dt = dw_user_b2c_tracking_info_tmp_ymd.dt) and (dwd_tracking_sensor_init_tmp_ymd.guid = dw_user_b2c_tracking_info_tmp_ymd.guid)) otherCondition=((dwd_tracking_sensor_init_tmp_ymd.dt >= substring(first_visit_time, 1, 10)))
------------hashAgg[GLOBAL]
--------------hashAgg[LOCAL]
----------------filter((dwd_tracking_sensor_init_tmp_ymd.dt = '2024-08-19') and (dwd_tracking_sensor_init_tmp_ymd.tracking_type = 'click'))
------------------PhysicalOlapScan[dwd_tracking_sensor_init_tmp_ymd]
------------filter((dw_user_b2c_tracking_info_tmp_ymd.dt = '2024-08-19'))
--------------PhysicalOlapScan[dw_user_b2c_tracking_info_tmp_ymd]

Hint log:
Used: use_PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE
UnUsed:
SyntaxError:

-- !agg_pushed --
2 是 2024-08-19

Original file line number Diff line number Diff line change
Expand Up @@ -426,4 +426,99 @@ suite("push_down_count_through_join_one_side") {
qt_with_hint_groupby_pushdown_nested_queries """
explain shape plan select /*+ USE_CBO_RULE(push_down_agg_through_join_one_side) */ count(*) from (select * from count_t_one_side where score > 20) t1 join (select * from count_t_one_side where id < 100) t2 on t1.id = t2.id group by t1.name;
"""

sql """
drop table if exists dw_user_b2c_tracking_info_tmp_ymd;
create table dw_user_b2c_tracking_info_tmp_ymd (
guid int,
dt varchar,
first_visit_time varchar
)Engine=Olap
DUPLICATE KEY(guid)
distributed by hash(dt) buckets 3
properties('replication_num' = '1');
insert into dw_user_b2c_tracking_info_tmp_ymd values (1, '2024-08-19', '2024-08-19');
drop table if exists dwd_tracking_sensor_init_tmp_ymd;
create table dwd_tracking_sensor_init_tmp_ymd (
guid int,
dt varchar,
tracking_type varchar
)Engine=Olap
DUPLICATE KEY(guid)
distributed by hash(dt) buckets 3
properties('replication_num' = '1');
insert into dwd_tracking_sensor_init_tmp_ymd values(1, '2024-08-19', 'click'), (1, '2024-08-19', 'click');
"""
sql """
set ENABLE_NEREIDS_RULES = "PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE";
set disable_join_reorder=true;
"""

qt_shape """
explain shape plan
SELECT /*+use_cbo_rule(PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE)*/
Count(*) AS accee593,
CASE
WHEN dwd_tracking_sensor_init_tmp_ymd.dt =
Substring(dw_user_b2c_tracking_info_tmp_ymd.first_visit_time, 1,
10) THEN
'是'
WHEN dwd_tracking_sensor_init_tmp_ymd.dt >
Substring(dw_user_b2c_tracking_info_tmp_ymd.first_visit_time, 1,
10) THEN
'否'
ELSE '-1'
end AS a1302fb2,
dwd_tracking_sensor_init_tmp_ymd.dt AS ad466123
FROM dwd_tracking_sensor_init_tmp_ymd
LEFT JOIN dw_user_b2c_tracking_info_tmp_ymd
ON dwd_tracking_sensor_init_tmp_ymd.guid =
dw_user_b2c_tracking_info_tmp_ymd.guid
AND dwd_tracking_sensor_init_tmp_ymd.dt =
dw_user_b2c_tracking_info_tmp_ymd.dt
WHERE dwd_tracking_sensor_init_tmp_ymd.dt = '2024-08-19'
AND dw_user_b2c_tracking_info_tmp_ymd.dt = '2024-08-19'
AND dwd_tracking_sensor_init_tmp_ymd.dt >=
Substring(dw_user_b2c_tracking_info_tmp_ymd.first_visit_time, 1, 10)
AND dwd_tracking_sensor_init_tmp_ymd.tracking_type = 'click'
GROUP BY 2,
3
ORDER BY 3 ASC
LIMIT 10000;
"""

qt_agg_pushed """
SELECT /*+use_cbo_rule(PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE)*/
Count(*) AS accee593,
CASE
WHEN dwd_tracking_sensor_init_tmp_ymd.dt =
Substring(dw_user_b2c_tracking_info_tmp_ymd.first_visit_time, 1,
10) THEN
'是'
WHEN dwd_tracking_sensor_init_tmp_ymd.dt >
Substring(dw_user_b2c_tracking_info_tmp_ymd.first_visit_time, 1,
10) THEN
'否'
ELSE '-1'
end AS a1302fb2,
dwd_tracking_sensor_init_tmp_ymd.dt AS ad466123
FROM dwd_tracking_sensor_init_tmp_ymd
LEFT JOIN dw_user_b2c_tracking_info_tmp_ymd
ON dwd_tracking_sensor_init_tmp_ymd.guid =
dw_user_b2c_tracking_info_tmp_ymd.guid
AND dwd_tracking_sensor_init_tmp_ymd.dt =
dw_user_b2c_tracking_info_tmp_ymd.dt
WHERE dwd_tracking_sensor_init_tmp_ymd.dt = '2024-08-19'
AND dw_user_b2c_tracking_info_tmp_ymd.dt = '2024-08-19'
AND dwd_tracking_sensor_init_tmp_ymd.dt >=
Substring(dw_user_b2c_tracking_info_tmp_ymd.first_visit_time, 1, 10)
AND dwd_tracking_sensor_init_tmp_ymd.tracking_type = 'click'
GROUP BY 2,
3
ORDER BY 3 ASC
LIMIT 10000;
"""
}

0 comments on commit 97af1ff

Please sign in to comment.