Skip to content

Commit

Permalink
add regression test
Browse files Browse the repository at this point in the history
  • Loading branch information
feiniaofeiafei committed Dec 10, 2024
1 parent fb5ffba commit 99b8ad5
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -165,10 +165,10 @@ private void checkAggregate(LogicalAggregate<? extends Plan> aggregate) {
distinctFunctionNum += aggregateFunction.isDistinct() ? 1 : 0;
}

if (distinctMultiColumns && distinctFunctionNum > 1) {
throw new AnalysisException(
"The query contains multi count distinct or sum distinct, each can't have multi columns");
}
// if (distinctMultiColumns && distinctFunctionNum > 1) {
// throw new AnalysisException(
// "The query contains multi count distinct or sum distinct, each can't have multi columns");
// }
for (Expression expr : aggregate.getGroupByExpressions()) {
if (expr.anyMatch(AggregateFunction.class::isInstance)) {
throw new AnalysisException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,24 +94,7 @@ private static Plan doSplit(LogicalAggregate<Plan> agg, CascadesContext ctx) {
List<Alias> aliases = new ArrayList<>();
Set<Expression> distinctFunc = new HashSet<>();
List<Alias> otherAggFuncs = new ArrayList<>();
boolean distinctMultiColumns = false;
for (NamedExpression namedExpression : agg.getOutputExpressions()) {
if (!(namedExpression instanceof Alias) || !(namedExpression.child(0) instanceof AggregateFunction)) {
continue;
}
AggregateFunction aggFunc = (AggregateFunction) namedExpression.child(0);
if (supportedFunctions.contains(aggFunc.getClass()) && aggFunc.isDistinct()) {
aliases.add((Alias) namedExpression);
distinctFunc.add(aggFunc);
distinctMultiColumns |= isDistinctMultiColumns(aggFunc);
} else {
otherAggFuncs.add((Alias) namedExpression);
}
}
if (distinctFunc.size() <= 1) {
return null;
}
if (!distinctMultiColumns && !agg.getGroupByExpressions().isEmpty()) {
if (!needTransform(agg, aliases, distinctFunc, otherAggFuncs)) {
return null;
}

Expand Down Expand Up @@ -154,7 +137,51 @@ private static Plan doSplit(LogicalAggregate<Plan> agg, CascadesContext ctx) {
newAggs.add(newAgg);
joinOutput.put(alias, aliases.get(i));
}
// construct join
LogicalJoin<Plan, Plan> join = constructJoin(newAggs, groupBy);
LogicalProject<Plan> project = constructProject(groupBy, joinOutput, outputJoinGroupBys, join);
return new LogicalCTEAnchor<Plan, Plan>(producer.getCteId(), producer, project);
}

private static boolean needTransform(LogicalAggregate<Plan> agg, List<Alias> aliases,
Set<Expression> distinctFunc, List<Alias> otherAggFuncs) {
boolean distinctMultiColumns = false;
for (NamedExpression namedExpression : agg.getOutputExpressions()) {
if (!(namedExpression instanceof Alias) || !(namedExpression.child(0) instanceof AggregateFunction)) {
continue;
}
AggregateFunction aggFunc = (AggregateFunction) namedExpression.child(0);
if (supportedFunctions.contains(aggFunc.getClass()) && aggFunc.isDistinct()) {
aliases.add((Alias) namedExpression);
distinctFunc.add(aggFunc);
distinctMultiColumns |= isDistinctMultiColumns(aggFunc);
} else {
otherAggFuncs.add((Alias) namedExpression);
}
}
if (distinctFunc.size() <= 1) {
return false;
}
if (!distinctMultiColumns && !agg.getGroupByExpressions().isEmpty()) {
return false;
}
return true;
}

private static LogicalProject<Plan> constructProject(List<Expression> groupBy, Map<Alias, Alias> joinOutput,
List<Expression> outputJoinGroupBys, LogicalJoin<Plan, Plan> join) {
List<NamedExpression> projects = new ArrayList<>();
for (Map.Entry<Alias, Alias> entry : joinOutput.entrySet()) {
projects.add(new Alias(entry.getValue().getExprId(), entry.getKey().toSlot(), entry.getValue().getName()));
}
// outputJoinGroupBys.size() == agg.getGroupByExpressions().size()
for (int i = 0; i < groupBy.size(); ++i) {
Slot slot = (Slot) groupBy.get(i);
projects.add(new Alias(slot.getExprId(), outputJoinGroupBys.get(i), slot.getName()));
}
return new LogicalProject<>(projects, join);
}

private static LogicalJoin<Plan, Plan> constructJoin(List<LogicalAggregate<Plan>> newAggs, List<Expression> groupBy) {
LogicalJoin<Plan, Plan> join;
if (groupBy.isEmpty()) {
join = new LogicalJoin<>(JoinType.CROSS_JOIN, newAggs.get(0), newAggs.get(1), null);
Expand All @@ -177,21 +204,9 @@ private static Plan doSplit(LogicalAggregate<Plan> agg, CascadesContext ctx) {
for (int i = 0; i < len; ++i) {
aboveHashConditions.add(new EqualTo(belowJoinSlots.get(i), belowRightSlots.get(i)));
}
join = new LogicalJoin<>(JoinType.CROSS_JOIN, aboveHashConditions, join, newAggs.get(j), null);
join = new LogicalJoin<>(JoinType.INNER_JOIN, aboveHashConditions, join, newAggs.get(j), null);
}
}
// construct top projects
List<NamedExpression> projects = new ArrayList<>();
for (Map.Entry<Alias, Alias> entry : joinOutput.entrySet()) {
projects.add(new Alias(entry.getValue().getExprId(), entry.getKey().toSlot(), entry.getValue().getName()));
}
// outputJoinGroupBys.size() == agg.getGroupByExpressions().size()
for (int i = 0; i < groupBy.size(); ++i) {
Slot slot = (Slot) groupBy.get(i);
projects.add(new Alias(slot.getExprId(), outputJoinGroupBys.get(i), slot.getName()));
}

LogicalProject<Plan> project = new LogicalProject<>(projects, join);
return new LogicalCTEAnchor<Plan, Plan>(producer.getCteId(), producer, project);
return join;
}
}

0 comments on commit 99b8ad5

Please sign in to comment.