diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildrenPropertiesRegulator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildrenPropertiesRegulator.java index 366730f7dc521e..5df1bfc0ce1e56 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildrenPropertiesRegulator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildrenPropertiesRegulator.java @@ -115,16 +115,6 @@ public Boolean visitPhysicalHashAggregate(PhysicalHashAggregate // this means one stage gather agg, usually bad pattern return false; } - // forbid three or four stage distinct agg inter by distribute - if (agg.getAggMode() == AggMode.BUFFER_TO_BUFFER && children.get(0).getPlan() instanceof PhysicalDistribute) { - // if distinct without group by key, we prefer three or four stage distinct agg - // because the second phase of multi-distinct only have one instance, and it is slow generally. - if (agg.getGroupByExpressions().size() == 1 - && agg.getOutputExpressions().size() == 1) { - return true; - } - return false; - } // forbid TWO_PHASE_AGGREGATE_WITH_DISTINCT after shuffle // TODO: this is forbid good plan after cte reuse by mistake diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java index 696463523f6904..e0f8ef23013a0c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java @@ -416,6 +416,7 @@ public enum RuleType { TWO_PHASE_AGGREGATE_WITH_MULTI_DISTINCT(RuleTypeClass.IMPLEMENTATION), THREE_PHASE_AGGREGATE_WITH_DISTINCT(RuleTypeClass.IMPLEMENTATION), FOUR_PHASE_AGGREGATE_WITH_DISTINCT(RuleTypeClass.IMPLEMENTATION), + FOUR_PHASE_AGGREGATE_WITH_DISTINCT_WITH_FULL_DISTRIBUTE(RuleTypeClass.IMPLEMENTATION), LOGICAL_UNION_TO_PHYSICAL_UNION(RuleTypeClass.IMPLEMENTATION), LOGICAL_EXCEPT_TO_PHYSICAL_EXCEPT(RuleTypeClass.IMPLEMENTATION), LOGICAL_INTERSECT_TO_PHYSICAL_INTERSECT(RuleTypeClass.IMPLEMENTATION), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java index edbd28677b4a00..556ca20e47aa80 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java @@ -292,16 +292,23 @@ && couldConvertToMulti(agg)) // .when(agg -> agg.getDistinctArguments().size() == 1) // .thenApplyMulti(ctx -> twoPhaseAggregateWithDistinct(ctx.root, ctx.connectContext)) // ), + RuleType.FOUR_PHASE_AGGREGATE_WITH_DISTINCT_WITH_FULL_DISTRIBUTE.build( + basePattern + .when(agg -> agg.getDistinctArguments().size() == 1) + .thenApplyMulti(ctx -> + fourPhaseAggregateWithDistinctAndFullDistribute(ctx.root, ctx.connectContext) + ) + ), RuleType.THREE_PHASE_AGGREGATE_WITH_DISTINCT.build( - basePattern - .when(agg -> agg.getDistinctArguments().size() == 1) - .thenApplyMulti(ctx -> threePhaseAggregateWithDistinct(ctx.root, ctx.connectContext)) + basePattern + .when(agg -> agg.getDistinctArguments().size() == 1) + .thenApplyMulti(ctx -> threePhaseAggregateWithDistinct(ctx.root, ctx.connectContext)) ), RuleType.FOUR_PHASE_AGGREGATE_WITH_DISTINCT.build( - basePattern - .when(agg -> agg.getDistinctArguments().size() == 1) - .when(agg -> agg.getGroupByExpressions().isEmpty()) - .thenApplyMulti(ctx -> fourPhaseAggregateWithDistinct(ctx.root, ctx.connectContext)) + basePattern + .when(agg -> agg.getDistinctArguments().size() == 1) + // .when(agg -> agg.getGroupByExpressions().isEmpty()) + .thenApplyMulti(ctx -> fourPhaseAggregateWithDistinct(ctx.root, ctx.connectContext)) ) ); } @@ -1831,6 +1838,189 @@ private List> fourPhaseAggregateWithDistin .build(); } + /** + * sql: + * select count(distinct name), sum(age) from student; + *

+ * 4 phase plan + * DISTINCT_GLOBAL, BUFFER_TO_RESULT groupBy(), output[count(name), sum(age#5)], [GATHER] + * +--DISTINCT_LOCAL, INPUT_TO_BUFFER, groupBy()), output(count(name), partial_sum(age)), hash distribute by name + * +--GLOBAL, BUFFER_TO_BUFFER, groupBy(name), output(name, partial_sum(age)), hash_distribute by name + * +--LOCAL, INPUT_TO_BUFFER, groupBy(name), output(name, partial_sum(age)) + * +--scan(name, age) + */ + private List> fourPhaseAggregateWithDistinctAndFullDistribute( + LogicalAggregate logicalAgg, ConnectContext connectContext) { + boolean couldBanned = couldConvertToMulti(logicalAgg); + + Set aggregateFunctions = logicalAgg.getAggregateFunctions(); + + Set distinctArguments = aggregateFunctions.stream() + .filter(AggregateFunction::isDistinct) + .flatMap(aggregateExpression -> aggregateExpression.getArguments().stream()) + .filter(NamedExpression.class::isInstance) + .map(NamedExpression.class::cast) + .collect(ImmutableSet.toImmutableSet()); + + Set localAggGroupBySet = ImmutableSet.builder() + .addAll((List) (List) logicalAgg.getGroupByExpressions()) + .addAll(distinctArguments) + .build(); + + AggregateParam inputToBufferParam = new AggregateParam(AggPhase.LOCAL, AggMode.INPUT_TO_BUFFER, couldBanned); + + Map nonDistinctAggFunctionToAliasPhase1 = aggregateFunctions.stream() + .filter(aggregateFunction -> !aggregateFunction.isDistinct()) + .collect(ImmutableMap.toImmutableMap(expr -> expr, expr -> { + AggregateExpression localAggExpr = new AggregateExpression(expr, inputToBufferParam); + return new Alias(localAggExpr); + }, (oldValue, newValue) -> newValue)); + + List localAggOutput = ImmutableList.builder() + .addAll(localAggGroupBySet) + .addAll(nonDistinctAggFunctionToAliasPhase1.values()) + .build(); + + List localAggGroupBy = ImmutableList.copyOf(localAggGroupBySet); + boolean maybeUsingStreamAgg = maybeUsingStreamAgg(connectContext, localAggGroupBy); + List partitionExpressions = getHashAggregatePartitionExpressions(logicalAgg); + RequireProperties requireAny = RequireProperties.of(PhysicalProperties.ANY); + + boolean isGroupByEmptySelectEmpty = localAggGroupBy.isEmpty() && localAggOutput.isEmpty(); + + // be not recommend generate an aggregate node with empty group by and empty output, + // so add a null int slot to group by slot and output + if (isGroupByEmptySelectEmpty) { + localAggGroupBy = ImmutableList.of(new NullLiteral(TinyIntType.INSTANCE)); + localAggOutput = ImmutableList.of(new Alias(new NullLiteral(TinyIntType.INSTANCE))); + } + + PhysicalHashAggregate anyLocalAgg = new PhysicalHashAggregate<>(localAggGroupBy, + localAggOutput, Optional.of(partitionExpressions), inputToBufferParam, + maybeUsingStreamAgg, Optional.empty(), logicalAgg.getLogicalProperties(), + requireAny, logicalAgg.child()); + + AggregateParam bufferToBufferParam = new AggregateParam(AggPhase.GLOBAL, AggMode.BUFFER_TO_BUFFER, couldBanned); + Map nonDistinctAggFunctionToAliasPhase2 = + nonDistinctAggFunctionToAliasPhase1.entrySet() + .stream() + .collect(ImmutableMap.toImmutableMap(Entry::getKey, kv -> { + AggregateFunction originFunction = kv.getKey(); + Alias localOutput = kv.getValue(); + AggregateExpression globalAggExpr = new AggregateExpression( + originFunction, bufferToBufferParam, localOutput.toSlot()); + return new Alias(globalAggExpr); + })); + + List globalAggOutput = ImmutableList.builder() + .addAll(localAggGroupBySet) + .addAll(nonDistinctAggFunctionToAliasPhase2.values()) + .build(); + + // be not recommend generate an aggregate node with empty group by and empty output, + // so add a null int slot to group by slot and output + if (isGroupByEmptySelectEmpty) { + globalAggOutput = ImmutableList.of(new Alias(new NullLiteral(TinyIntType.INSTANCE))); + } + + RequireProperties requireGroupByAndDistinctHash = RequireProperties.of( + PhysicalProperties.createHash(localAggGroupBy, ShuffleType.REQUIRE)); + + //phase 2 + PhysicalHashAggregate anyLocalHashGlobalAgg = new PhysicalHashAggregate<>( + localAggGroupBy, globalAggOutput, Optional.of(ImmutableList.copyOf(logicalAgg.getDistinctArguments())), + bufferToBufferParam, false, logicalAgg.getLogicalProperties(), + requireGroupByAndDistinctHash, anyLocalAgg); + + // phase 3 + AggregateParam distinctLocalParam = new AggregateParam( + AggPhase.DISTINCT_LOCAL, AggMode.INPUT_TO_BUFFER, couldBanned); + Map nonDistinctAggFunctionToAliasPhase3 = new HashMap<>(); + List localDistinctOutput = Lists.newArrayList(); + for (int i = 0; i < logicalAgg.getOutputExpressions().size(); i++) { + NamedExpression outputExpr = logicalAgg.getOutputExpressions().get(i); + List needUpdateSlot = Lists.newArrayList(); + NamedExpression outputExprPhase3 = (NamedExpression) outputExpr + .rewriteDownShortCircuit(expr -> { + if (expr instanceof AggregateFunction) { + AggregateFunction aggregateFunction = (AggregateFunction) expr; + if (aggregateFunction.isDistinct()) { + Set aggChild = Sets.newLinkedHashSet(aggregateFunction.children()); + Preconditions.checkArgument(aggChild.size() == 1 + || aggregateFunction.getDistinctArguments().size() == 1, + "cannot process more than one child in aggregate distinct function: " + + aggregateFunction); + AggregateFunction nonDistinct = aggregateFunction + .withDistinctAndChildren(false, ImmutableList.copyOf(aggChild)); + AggregateExpression nonDistinctAggExpr = new AggregateExpression(nonDistinct, + distinctLocalParam, aggregateFunction.child(0)); + return nonDistinctAggExpr; + } else { + needUpdateSlot.add(aggregateFunction); + Alias alias = nonDistinctAggFunctionToAliasPhase2.get(expr); + return new AggregateExpression(aggregateFunction, + new AggregateParam(AggPhase.DISTINCT_LOCAL, AggMode.BUFFER_TO_BUFFER), + alias.toSlot()); + } + } + return expr; + }); + for (AggregateFunction originFunction : needUpdateSlot) { + nonDistinctAggFunctionToAliasPhase3.put(originFunction, (Alias) outputExprPhase3); + } + localDistinctOutput.add(outputExprPhase3); + + } + PhysicalHashAggregate distinctLocal = new PhysicalHashAggregate<>( + logicalAgg.getGroupByExpressions(), localDistinctOutput, Optional.empty(), + distinctLocalParam, false, logicalAgg.getLogicalProperties(), + requireGroupByAndDistinctHash, anyLocalHashGlobalAgg); + + //phase 4 + AggregateParam distinctGlobalParam = new AggregateParam( + AggPhase.DISTINCT_GLOBAL, AggMode.BUFFER_TO_RESULT, couldBanned); + List globalDistinctOutput = Lists.newArrayList(); + for (int i = 0; i < logicalAgg.getOutputExpressions().size(); i++) { + NamedExpression outputExpr = logicalAgg.getOutputExpressions().get(i); + NamedExpression outputExprPhase4 = (NamedExpression) outputExpr.rewriteDownShortCircuit(expr -> { + if (expr instanceof AggregateFunction) { + AggregateFunction aggregateFunction = (AggregateFunction) expr; + if (aggregateFunction.isDistinct()) { + Set aggChild = Sets.newLinkedHashSet(aggregateFunction.children()); + Preconditions.checkArgument(aggChild.size() == 1 + || aggregateFunction.getDistinctArguments().size() == 1, + "cannot process more than one child in aggregate distinct function: " + + aggregateFunction); + AggregateFunction nonDistinct = aggregateFunction + .withDistinctAndChildren(false, ImmutableList.copyOf(aggChild)); + int idx = logicalAgg.getOutputExpressions().indexOf(outputExpr); + Alias localDistinctAlias = (Alias) (localDistinctOutput.get(idx)); + return new AggregateExpression(nonDistinct, + distinctGlobalParam, localDistinctAlias.toSlot()); + } else { + Alias alias = nonDistinctAggFunctionToAliasPhase3.get(expr); + return new AggregateExpression(aggregateFunction, + new AggregateParam(AggPhase.DISTINCT_LOCAL, AggMode.BUFFER_TO_RESULT), + alias.toSlot()); + } + } + return expr; + }); + globalDistinctOutput.add(outputExprPhase4); + } + + RequireProperties requireGroupByHash = RequireProperties.of( + PhysicalProperties.createHash(logicalAgg.getGroupByExpressions(), ShuffleType.REQUIRE)); + PhysicalHashAggregate distinctGlobal = new PhysicalHashAggregate<>( + logicalAgg.getGroupByExpressions(), globalDistinctOutput, Optional.empty(), + distinctGlobalParam, false, logicalAgg.getLogicalProperties(), + requireGroupByHash, distinctLocal); + + return ImmutableList.>builder() + .add(distinctGlobal) + .build(); + } + private boolean couldConvertToMulti(LogicalAggregate aggregate) { Set aggregateFunctions = aggregate.getAggregateFunctions(); for (AggregateFunction func : aggregateFunctions) {