Skip to content

Commit

Permalink
4 phase
Browse files Browse the repository at this point in the history
  • Loading branch information
924060929 committed Apr 25, 2024
1 parent 4aee706 commit aed2979
Show file tree
Hide file tree
Showing 3 changed files with 198 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -115,16 +115,6 @@ public Boolean visitPhysicalHashAggregate(PhysicalHashAggregate<? extends Plan>
// this means one stage gather agg, usually bad pattern
return false;
}
// forbid three or four stage distinct agg inter by distribute
if (agg.getAggMode() == AggMode.BUFFER_TO_BUFFER && children.get(0).getPlan() instanceof PhysicalDistribute) {
// if distinct without group by key, we prefer three or four stage distinct agg
// because the second phase of multi-distinct only have one instance, and it is slow generally.
if (agg.getGroupByExpressions().size() == 1
&& agg.getOutputExpressions().size() == 1) {
return true;
}
return false;
}

// forbid TWO_PHASE_AGGREGATE_WITH_DISTINCT after shuffle
// TODO: this is forbid good plan after cte reuse by mistake
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,7 @@ public enum RuleType {
TWO_PHASE_AGGREGATE_WITH_MULTI_DISTINCT(RuleTypeClass.IMPLEMENTATION),
THREE_PHASE_AGGREGATE_WITH_DISTINCT(RuleTypeClass.IMPLEMENTATION),
FOUR_PHASE_AGGREGATE_WITH_DISTINCT(RuleTypeClass.IMPLEMENTATION),
FOUR_PHASE_AGGREGATE_WITH_DISTINCT_WITH_FULL_DISTRIBUTE(RuleTypeClass.IMPLEMENTATION),
LOGICAL_UNION_TO_PHYSICAL_UNION(RuleTypeClass.IMPLEMENTATION),
LOGICAL_EXCEPT_TO_PHYSICAL_EXCEPT(RuleTypeClass.IMPLEMENTATION),
LOGICAL_INTERSECT_TO_PHYSICAL_INTERSECT(RuleTypeClass.IMPLEMENTATION),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -292,16 +292,23 @@ && couldConvertToMulti(agg))
// .when(agg -> agg.getDistinctArguments().size() == 1)
// .thenApplyMulti(ctx -> twoPhaseAggregateWithDistinct(ctx.root, ctx.connectContext))
// ),
RuleType.FOUR_PHASE_AGGREGATE_WITH_DISTINCT_WITH_FULL_DISTRIBUTE.build(
basePattern
.when(agg -> agg.getDistinctArguments().size() == 1)
.thenApplyMulti(ctx ->
fourPhaseAggregateWithDistinctAndFullDistribute(ctx.root, ctx.connectContext)
)
),
RuleType.THREE_PHASE_AGGREGATE_WITH_DISTINCT.build(
basePattern
.when(agg -> agg.getDistinctArguments().size() == 1)
.thenApplyMulti(ctx -> threePhaseAggregateWithDistinct(ctx.root, ctx.connectContext))
basePattern
.when(agg -> agg.getDistinctArguments().size() == 1)
.thenApplyMulti(ctx -> threePhaseAggregateWithDistinct(ctx.root, ctx.connectContext))
),
RuleType.FOUR_PHASE_AGGREGATE_WITH_DISTINCT.build(
basePattern
.when(agg -> agg.getDistinctArguments().size() == 1)
.when(agg -> agg.getGroupByExpressions().isEmpty())
.thenApplyMulti(ctx -> fourPhaseAggregateWithDistinct(ctx.root, ctx.connectContext))
basePattern
.when(agg -> agg.getDistinctArguments().size() == 1)
// .when(agg -> agg.getGroupByExpressions().isEmpty())
.thenApplyMulti(ctx -> fourPhaseAggregateWithDistinct(ctx.root, ctx.connectContext))
)
);
}
Expand Down Expand Up @@ -1831,6 +1838,189 @@ private List<PhysicalHashAggregate<? extends Plan>> fourPhaseAggregateWithDistin
.build();
}

/**
* sql:
* select count(distinct name), sum(age) from student;
* <p>
* 4 phase plan
* DISTINCT_GLOBAL, BUFFER_TO_RESULT groupBy(), output[count(name), sum(age#5)], [GATHER]
* +--DISTINCT_LOCAL, INPUT_TO_BUFFER, groupBy()), output(count(name), partial_sum(age)), hash distribute by name
* +--GLOBAL, BUFFER_TO_BUFFER, groupBy(name), output(name, partial_sum(age)), hash_distribute by name
* +--LOCAL, INPUT_TO_BUFFER, groupBy(name), output(name, partial_sum(age))
* +--scan(name, age)
*/
private List<PhysicalHashAggregate<? extends Plan>> fourPhaseAggregateWithDistinctAndFullDistribute(
LogicalAggregate<? extends Plan> logicalAgg, ConnectContext connectContext) {
boolean couldBanned = couldConvertToMulti(logicalAgg);

Set<AggregateFunction> aggregateFunctions = logicalAgg.getAggregateFunctions();

Set<NamedExpression> distinctArguments = aggregateFunctions.stream()
.filter(AggregateFunction::isDistinct)
.flatMap(aggregateExpression -> aggregateExpression.getArguments().stream())
.filter(NamedExpression.class::isInstance)
.map(NamedExpression.class::cast)
.collect(ImmutableSet.toImmutableSet());

Set<NamedExpression> localAggGroupBySet = ImmutableSet.<NamedExpression>builder()
.addAll((List<NamedExpression>) (List) logicalAgg.getGroupByExpressions())
.addAll(distinctArguments)
.build();

AggregateParam inputToBufferParam = new AggregateParam(AggPhase.LOCAL, AggMode.INPUT_TO_BUFFER, couldBanned);

Map<AggregateFunction, Alias> nonDistinctAggFunctionToAliasPhase1 = aggregateFunctions.stream()
.filter(aggregateFunction -> !aggregateFunction.isDistinct())
.collect(ImmutableMap.toImmutableMap(expr -> expr, expr -> {
AggregateExpression localAggExpr = new AggregateExpression(expr, inputToBufferParam);
return new Alias(localAggExpr);
}, (oldValue, newValue) -> newValue));

List<NamedExpression> localAggOutput = ImmutableList.<NamedExpression>builder()
.addAll(localAggGroupBySet)
.addAll(nonDistinctAggFunctionToAliasPhase1.values())
.build();

List<Expression> localAggGroupBy = ImmutableList.copyOf(localAggGroupBySet);
boolean maybeUsingStreamAgg = maybeUsingStreamAgg(connectContext, localAggGroupBy);
List<Expression> partitionExpressions = getHashAggregatePartitionExpressions(logicalAgg);
RequireProperties requireAny = RequireProperties.of(PhysicalProperties.ANY);

boolean isGroupByEmptySelectEmpty = localAggGroupBy.isEmpty() && localAggOutput.isEmpty();

// be not recommend generate an aggregate node with empty group by and empty output,
// so add a null int slot to group by slot and output
if (isGroupByEmptySelectEmpty) {
localAggGroupBy = ImmutableList.of(new NullLiteral(TinyIntType.INSTANCE));
localAggOutput = ImmutableList.of(new Alias(new NullLiteral(TinyIntType.INSTANCE)));
}

PhysicalHashAggregate<Plan> anyLocalAgg = new PhysicalHashAggregate<>(localAggGroupBy,
localAggOutput, Optional.of(partitionExpressions), inputToBufferParam,
maybeUsingStreamAgg, Optional.empty(), logicalAgg.getLogicalProperties(),
requireAny, logicalAgg.child());

AggregateParam bufferToBufferParam = new AggregateParam(AggPhase.GLOBAL, AggMode.BUFFER_TO_BUFFER, couldBanned);
Map<AggregateFunction, Alias> nonDistinctAggFunctionToAliasPhase2 =
nonDistinctAggFunctionToAliasPhase1.entrySet()
.stream()
.collect(ImmutableMap.toImmutableMap(Entry::getKey, kv -> {
AggregateFunction originFunction = kv.getKey();
Alias localOutput = kv.getValue();
AggregateExpression globalAggExpr = new AggregateExpression(
originFunction, bufferToBufferParam, localOutput.toSlot());
return new Alias(globalAggExpr);
}));

List<NamedExpression> globalAggOutput = ImmutableList.<NamedExpression>builder()
.addAll(localAggGroupBySet)
.addAll(nonDistinctAggFunctionToAliasPhase2.values())
.build();

// be not recommend generate an aggregate node with empty group by and empty output,
// so add a null int slot to group by slot and output
if (isGroupByEmptySelectEmpty) {
globalAggOutput = ImmutableList.of(new Alias(new NullLiteral(TinyIntType.INSTANCE)));
}

RequireProperties requireGroupByAndDistinctHash = RequireProperties.of(
PhysicalProperties.createHash(localAggGroupBy, ShuffleType.REQUIRE));

//phase 2
PhysicalHashAggregate<? extends Plan> anyLocalHashGlobalAgg = new PhysicalHashAggregate<>(
localAggGroupBy, globalAggOutput, Optional.of(ImmutableList.copyOf(logicalAgg.getDistinctArguments())),
bufferToBufferParam, false, logicalAgg.getLogicalProperties(),
requireGroupByAndDistinctHash, anyLocalAgg);

// phase 3
AggregateParam distinctLocalParam = new AggregateParam(
AggPhase.DISTINCT_LOCAL, AggMode.INPUT_TO_BUFFER, couldBanned);
Map<AggregateFunction, Alias> nonDistinctAggFunctionToAliasPhase3 = new HashMap<>();
List<NamedExpression> localDistinctOutput = Lists.newArrayList();
for (int i = 0; i < logicalAgg.getOutputExpressions().size(); i++) {
NamedExpression outputExpr = logicalAgg.getOutputExpressions().get(i);
List<AggregateFunction> needUpdateSlot = Lists.newArrayList();
NamedExpression outputExprPhase3 = (NamedExpression) outputExpr
.rewriteDownShortCircuit(expr -> {
if (expr instanceof AggregateFunction) {
AggregateFunction aggregateFunction = (AggregateFunction) expr;
if (aggregateFunction.isDistinct()) {
Set<Expression> aggChild = Sets.newLinkedHashSet(aggregateFunction.children());
Preconditions.checkArgument(aggChild.size() == 1
|| aggregateFunction.getDistinctArguments().size() == 1,
"cannot process more than one child in aggregate distinct function: "
+ aggregateFunction);
AggregateFunction nonDistinct = aggregateFunction
.withDistinctAndChildren(false, ImmutableList.copyOf(aggChild));
AggregateExpression nonDistinctAggExpr = new AggregateExpression(nonDistinct,
distinctLocalParam, aggregateFunction.child(0));
return nonDistinctAggExpr;
} else {
needUpdateSlot.add(aggregateFunction);
Alias alias = nonDistinctAggFunctionToAliasPhase2.get(expr);
return new AggregateExpression(aggregateFunction,
new AggregateParam(AggPhase.DISTINCT_LOCAL, AggMode.BUFFER_TO_BUFFER),
alias.toSlot());
}
}
return expr;
});
for (AggregateFunction originFunction : needUpdateSlot) {
nonDistinctAggFunctionToAliasPhase3.put(originFunction, (Alias) outputExprPhase3);
}
localDistinctOutput.add(outputExprPhase3);

}
PhysicalHashAggregate<? extends Plan> distinctLocal = new PhysicalHashAggregate<>(
logicalAgg.getGroupByExpressions(), localDistinctOutput, Optional.empty(),
distinctLocalParam, false, logicalAgg.getLogicalProperties(),
requireGroupByAndDistinctHash, anyLocalHashGlobalAgg);

//phase 4
AggregateParam distinctGlobalParam = new AggregateParam(
AggPhase.DISTINCT_GLOBAL, AggMode.BUFFER_TO_RESULT, couldBanned);
List<NamedExpression> globalDistinctOutput = Lists.newArrayList();
for (int i = 0; i < logicalAgg.getOutputExpressions().size(); i++) {
NamedExpression outputExpr = logicalAgg.getOutputExpressions().get(i);
NamedExpression outputExprPhase4 = (NamedExpression) outputExpr.rewriteDownShortCircuit(expr -> {
if (expr instanceof AggregateFunction) {
AggregateFunction aggregateFunction = (AggregateFunction) expr;
if (aggregateFunction.isDistinct()) {
Set<Expression> aggChild = Sets.newLinkedHashSet(aggregateFunction.children());
Preconditions.checkArgument(aggChild.size() == 1
|| aggregateFunction.getDistinctArguments().size() == 1,
"cannot process more than one child in aggregate distinct function: "
+ aggregateFunction);
AggregateFunction nonDistinct = aggregateFunction
.withDistinctAndChildren(false, ImmutableList.copyOf(aggChild));
int idx = logicalAgg.getOutputExpressions().indexOf(outputExpr);
Alias localDistinctAlias = (Alias) (localDistinctOutput.get(idx));
return new AggregateExpression(nonDistinct,
distinctGlobalParam, localDistinctAlias.toSlot());
} else {
Alias alias = nonDistinctAggFunctionToAliasPhase3.get(expr);
return new AggregateExpression(aggregateFunction,
new AggregateParam(AggPhase.DISTINCT_LOCAL, AggMode.BUFFER_TO_RESULT),
alias.toSlot());
}
}
return expr;
});
globalDistinctOutput.add(outputExprPhase4);
}

RequireProperties requireGroupByHash = RequireProperties.of(
PhysicalProperties.createHash(logicalAgg.getGroupByExpressions(), ShuffleType.REQUIRE));
PhysicalHashAggregate<? extends Plan> distinctGlobal = new PhysicalHashAggregate<>(
logicalAgg.getGroupByExpressions(), globalDistinctOutput, Optional.empty(),
distinctGlobalParam, false, logicalAgg.getLogicalProperties(),
requireGroupByHash, distinctLocal);

return ImmutableList.<PhysicalHashAggregate<? extends Plan>>builder()
.add(distinctGlobal)
.build();
}

private boolean couldConvertToMulti(LogicalAggregate<? extends Plan> aggregate) {
Set<AggregateFunction> aggregateFunctions = aggregate.getAggregateFunctions();
for (AggregateFunction func : aggregateFunctions) {
Expand Down

0 comments on commit aed2979

Please sign in to comment.