diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggWithDistinctThroughJoinOneSide.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggWithDistinctThroughJoinOneSide.java index 02f6c2b825fbd64..d8db87ddcb68501 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggWithDistinctThroughJoinOneSide.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggWithDistinctThroughJoinOneSide.java @@ -39,10 +39,8 @@ import com.google.common.collect.Sets; import java.util.ArrayList; -import java.util.HashMap; import java.util.HashSet; import java.util.List; -import java.util.Map; import java.util.Set; /** @@ -57,7 +55,8 @@ public List buildRules() { .when(agg -> agg.child().child().getOtherJoinConjuncts().isEmpty()) .when(agg -> !agg.isGenerated()) .whenNot(agg -> agg.getAggregateFunctions().isEmpty()) - .whenNot(agg -> agg.child().children().stream().anyMatch(p -> p instanceof LogicalAggregate)) + .whenNot(agg -> agg.child() + .child(0).children().stream().anyMatch(p -> p instanceof LogicalAggregate)) .when(agg -> { Set funcs = agg.getAggregateFunctions(); if (funcs.size() > 1) { @@ -86,14 +85,18 @@ private static LogicalAggregate pushDownAggWithDistinct(LogicalAggregate leftFuncs = new ArrayList<>(); List rightFuncs = new ArrayList<>(); + Set leftFuncSlotSet = new HashSet<>(); + Set rightFuncSlotSet = new HashSet<>(); Set newAggOverJoinGroupByKeys = new HashSet<>(); for (AggregateFunction func : agg.getAggregateFunctions()) { Slot slot = (Slot) func.child(0); newAggOverJoinGroupByKeys.add(slot); if (leftJoinOutput.contains(slot)) { leftFuncs.add(func); + leftFuncSlotSet.add(slot); } else if (rightJoinOutput.contains(slot)) { rightFuncs.add(func); + rightFuncSlotSet.add(slot); } else { throw new IllegalStateException("Slot " + slot + " not found in join output"); } @@ -127,27 +130,16 @@ private static LogicalAggregate pushDownAggWithDistinct(LogicalAggregate leftSlotToOutput = new HashMap<>(); - Map rightSlotToOutput = new HashMap<>(); if (isLeftSideAggDistinct) { leftPushDownGroupBy.add((Slot) leftFuncs.get(0).child(0)); Builder leftAggOutputBuilder = ImmutableList.builder() .addAll(leftPushDownGroupBy); - leftFuncs.forEach(func -> { - Alias alias = func.alias("PDADT_" + func.getName()); - leftSlotToOutput.put((Slot) func.child(0), alias); - }); leftJoin = new LogicalAggregate<>(ImmutableList.copyOf(leftPushDownGroupBy), leftAggOutputBuilder.build(), join.left()); } else { rightPushDownGroupBy.add((Slot) rightFuncs.get(0).child(0)); Builder rightAggOutputBuilder = ImmutableList.builder() .addAll(rightPushDownGroupBy); - rightFuncs.forEach(func -> { - Alias alias = func.alias("PDADT_" + func.getName()); - rightSlotToOutput.put((Slot) func.child(0), alias); - rightAggOutputBuilder.add(alias); - }); rightJoin = new LogicalAggregate<>(ImmutableList.copyOf(rightPushDownGroupBy), rightAggOutputBuilder.build(), join.right()); } @@ -162,10 +154,7 @@ private static LogicalAggregate pushDownAggWithDistinct(LogicalAggregate