Skip to content

Commit

Permalink
[opt](nereids) support pushdown agg distinct through join
Browse files Browse the repository at this point in the history
  • Loading branch information
xzj7019 committed Nov 29, 2024
1 parent 17b2c1d commit e99aa8c
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,8 @@
import com.google.common.collect.Sets;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

/**
Expand All @@ -57,7 +55,8 @@ public List<Rule> buildRules() {
.when(agg -> agg.child().child().getOtherJoinConjuncts().isEmpty())
.when(agg -> !agg.isGenerated())
.whenNot(agg -> agg.getAggregateFunctions().isEmpty())
.whenNot(agg -> agg.child().children().stream().anyMatch(p -> p instanceof LogicalAggregate))
.whenNot(agg -> agg.child()
.child(0).children().stream().anyMatch(p -> p instanceof LogicalAggregate))
.when(agg -> {
Set<AggregateFunction> funcs = agg.getAggregateFunctions();
if (funcs.size() > 1) {
Expand Down Expand Up @@ -86,14 +85,18 @@ private static LogicalAggregate<Plan> pushDownAggWithDistinct(LogicalAggregate<?

List<AggregateFunction> leftFuncs = new ArrayList<>();
List<AggregateFunction> rightFuncs = new ArrayList<>();
Set<Slot> leftFuncSlotSet = new HashSet<>();
Set<Slot> rightFuncSlotSet = new HashSet<>();
Set<Slot> newAggOverJoinGroupByKeys = new HashSet<>();
for (AggregateFunction func : agg.getAggregateFunctions()) {
Slot slot = (Slot) func.child(0);
newAggOverJoinGroupByKeys.add(slot);
if (leftJoinOutput.contains(slot)) {
leftFuncs.add(func);
leftFuncSlotSet.add(slot);
} else if (rightJoinOutput.contains(slot)) {
rightFuncs.add(func);
rightFuncSlotSet.add(slot);
} else {
throw new IllegalStateException("Slot " + slot + " not found in join output");
}
Expand Down Expand Up @@ -127,27 +130,16 @@ private static LogicalAggregate<Plan> pushDownAggWithDistinct(LogicalAggregate<?
}
}));

Map<Slot, NamedExpression> leftSlotToOutput = new HashMap<>();
Map<Slot, NamedExpression> rightSlotToOutput = new HashMap<>();
if (isLeftSideAggDistinct) {
leftPushDownGroupBy.add((Slot) leftFuncs.get(0).child(0));
Builder<NamedExpression> leftAggOutputBuilder = ImmutableList.<NamedExpression>builder()
.addAll(leftPushDownGroupBy);
leftFuncs.forEach(func -> {
Alias alias = func.alias("PDADT_" + func.getName());
leftSlotToOutput.put((Slot) func.child(0), alias);
});
leftJoin = new LogicalAggregate<>(ImmutableList.copyOf(leftPushDownGroupBy),
leftAggOutputBuilder.build(), join.left());
} else {
rightPushDownGroupBy.add((Slot) rightFuncs.get(0).child(0));
Builder<NamedExpression> rightAggOutputBuilder = ImmutableList.<NamedExpression>builder()
.addAll(rightPushDownGroupBy);
rightFuncs.forEach(func -> {
Alias alias = func.alias("PDADT_" + func.getName());
rightSlotToOutput.put((Slot) func.child(0), alias);
rightAggOutputBuilder.add(alias);
});
rightJoin = new LogicalAggregate<>(ImmutableList.copyOf(rightPushDownGroupBy),
rightAggOutputBuilder.build(), join.right());
}
Expand All @@ -162,10 +154,7 @@ private static LogicalAggregate<Plan> pushDownAggWithDistinct(LogicalAggregate<?
if (ne instanceof Alias && ((Alias) ne).child() instanceof AggregateFunction) {
AggregateFunction func = (AggregateFunction) ((Alias) ne).child();
Slot slot = (Slot) func.child(0);
if (leftSlotToOutput.containsKey(slot)) {
Expression newFunc = discardDistinct(func);
newOutputExprs.add((NamedExpression) ne.withChildren(newFunc));
} else if (rightSlotToOutput.containsKey(slot)) {
if (leftFuncSlotSet.contains(slot) || rightFuncSlotSet.contains(slot)) {
Expression newFunc = discardDistinct(func);
newOutputExprs.add((NamedExpression) ne.withChildren(newFunc));
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ suite("push_down_aggr_distinct_through_join_one_side_cust") {
contains"groupByExpr=[gz_user_id#1, dt#2]"
contains"groupByExpr=[gz_user_id#1, dt#2, group_name#5], outputExpr=[gz_user_id#1, dt#2, group_name#5]"
contains"[group_name#5, dt#2]"
contains"groupByExpr=[group_name#5, dt#2], outputExpr=[group_name#5, dt#2, count(partial_count(gz_user_id)#13) AS `a2c1a830_1`#7]"
contains"groupByExpr=[group_name#5, dt#2], outputExpr=[group_name#5, dt#2, count(partial_count(gz_user_id)#12) AS `a2c1a830_1`#7]"
}

explain {
Expand All @@ -120,7 +120,7 @@ suite("push_down_aggr_distinct_through_join_one_side_cust") {
"GROUP BY 2, 3 ORDER BY 3 asc limit 10000;");
contains"groupByExpr=[ip#0, gz_user_id#1, dt#2], outputExpr=[ip#0, gz_user_id#1, dt#2]"
contains"groupByExpr=[ip#0, dt#2, group_name#5], outputExpr=[ip#0, dt#2, group_name#5]"
contains"groupByExpr=[group_name#5, dt#2], outputExpr=[group_name#5, dt#2, partial_count(ip#0) AS `partial_count(ip)`#13]"
contains"groupByExpr=[group_name#5, dt#2], outputExpr=[group_name#5, dt#2, count(partial_count(ip)#13) AS `a2c1a830_1`#7]"
contains"groupByExpr=[group_name#5, dt#2], outputExpr=[group_name#5, dt#2, partial_count(ip#0) AS `partial_count(ip)`#12]"
contains"groupByExpr=[group_name#5, dt#2], outputExpr=[group_name#5, dt#2, count(partial_count(ip)#12) AS `a2c1a830_1`#7]"
}
}
}

0 comments on commit e99aa8c

Please sign in to comment.