From 99b8ad57030e2ec3225a6428eff6aa5674687b73 Mon Sep 17 00:00:00 2001 From: feiniaofeiafei Date: Tue, 10 Dec 2024 20:41:05 +0800 Subject: [PATCH] add regression test --- .../nereids/rules/analysis/CheckAnalysis.java | 8 +- .../nereids/rules/rewrite/DistinctSplit.java | 81 +++++++++++-------- 2 files changed, 52 insertions(+), 37 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAnalysis.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAnalysis.java index 7ca8637446b0d66..597cbb68eac9c87 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAnalysis.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAnalysis.java @@ -165,10 +165,10 @@ private void checkAggregate(LogicalAggregate aggregate) { distinctFunctionNum += aggregateFunction.isDistinct() ? 1 : 0; } - if (distinctMultiColumns && distinctFunctionNum > 1) { - throw new AnalysisException( - "The query contains multi count distinct or sum distinct, each can't have multi columns"); - } + // if (distinctMultiColumns && distinctFunctionNum > 1) { + // throw new AnalysisException( + // "The query contains multi count distinct or sum distinct, each can't have multi columns"); + // } for (Expression expr : aggregate.getGroupByExpressions()) { if (expr.anyMatch(AggregateFunction.class::isInstance)) { throw new AnalysisException( diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DistinctSplit.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DistinctSplit.java index cb9ba1da7d8cd2f..d9de391fd4cf812 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DistinctSplit.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DistinctSplit.java @@ -94,24 +94,7 @@ private static Plan doSplit(LogicalAggregate agg, CascadesContext ctx) { List aliases = new ArrayList<>(); Set distinctFunc = new HashSet<>(); List otherAggFuncs = new ArrayList<>(); - boolean distinctMultiColumns = false; - for (NamedExpression namedExpression : agg.getOutputExpressions()) { - if (!(namedExpression instanceof Alias) || !(namedExpression.child(0) instanceof AggregateFunction)) { - continue; - } - AggregateFunction aggFunc = (AggregateFunction) namedExpression.child(0); - if (supportedFunctions.contains(aggFunc.getClass()) && aggFunc.isDistinct()) { - aliases.add((Alias) namedExpression); - distinctFunc.add(aggFunc); - distinctMultiColumns |= isDistinctMultiColumns(aggFunc); - } else { - otherAggFuncs.add((Alias) namedExpression); - } - } - if (distinctFunc.size() <= 1) { - return null; - } - if (!distinctMultiColumns && !agg.getGroupByExpressions().isEmpty()) { + if (!needTransform(agg, aliases, distinctFunc, otherAggFuncs)) { return null; } @@ -154,7 +137,51 @@ private static Plan doSplit(LogicalAggregate agg, CascadesContext ctx) { newAggs.add(newAgg); joinOutput.put(alias, aliases.get(i)); } - // construct join + LogicalJoin join = constructJoin(newAggs, groupBy); + LogicalProject project = constructProject(groupBy, joinOutput, outputJoinGroupBys, join); + return new LogicalCTEAnchor(producer.getCteId(), producer, project); + } + + private static boolean needTransform(LogicalAggregate agg, List aliases, + Set distinctFunc, List otherAggFuncs) { + boolean distinctMultiColumns = false; + for (NamedExpression namedExpression : agg.getOutputExpressions()) { + if (!(namedExpression instanceof Alias) || !(namedExpression.child(0) instanceof AggregateFunction)) { + continue; + } + AggregateFunction aggFunc = (AggregateFunction) namedExpression.child(0); + if (supportedFunctions.contains(aggFunc.getClass()) && aggFunc.isDistinct()) { + aliases.add((Alias) namedExpression); + distinctFunc.add(aggFunc); + distinctMultiColumns |= isDistinctMultiColumns(aggFunc); + } else { + otherAggFuncs.add((Alias) namedExpression); + } + } + if (distinctFunc.size() <= 1) { + return false; + } + if (!distinctMultiColumns && !agg.getGroupByExpressions().isEmpty()) { + return false; + } + return true; + } + + private static LogicalProject constructProject(List groupBy, Map joinOutput, + List outputJoinGroupBys, LogicalJoin join) { + List projects = new ArrayList<>(); + for (Map.Entry entry : joinOutput.entrySet()) { + projects.add(new Alias(entry.getValue().getExprId(), entry.getKey().toSlot(), entry.getValue().getName())); + } + // outputJoinGroupBys.size() == agg.getGroupByExpressions().size() + for (int i = 0; i < groupBy.size(); ++i) { + Slot slot = (Slot) groupBy.get(i); + projects.add(new Alias(slot.getExprId(), outputJoinGroupBys.get(i), slot.getName())); + } + return new LogicalProject<>(projects, join); + } + + private static LogicalJoin constructJoin(List> newAggs, List groupBy) { LogicalJoin join; if (groupBy.isEmpty()) { join = new LogicalJoin<>(JoinType.CROSS_JOIN, newAggs.get(0), newAggs.get(1), null); @@ -177,21 +204,9 @@ private static Plan doSplit(LogicalAggregate agg, CascadesContext ctx) { for (int i = 0; i < len; ++i) { aboveHashConditions.add(new EqualTo(belowJoinSlots.get(i), belowRightSlots.get(i))); } - join = new LogicalJoin<>(JoinType.CROSS_JOIN, aboveHashConditions, join, newAggs.get(j), null); + join = new LogicalJoin<>(JoinType.INNER_JOIN, aboveHashConditions, join, newAggs.get(j), null); } } - // construct top projects - List projects = new ArrayList<>(); - for (Map.Entry entry : joinOutput.entrySet()) { - projects.add(new Alias(entry.getValue().getExprId(), entry.getKey().toSlot(), entry.getValue().getName())); - } - // outputJoinGroupBys.size() == agg.getGroupByExpressions().size() - for (int i = 0; i < groupBy.size(); ++i) { - Slot slot = (Slot) groupBy.get(i); - projects.add(new Alias(slot.getExprId(), outputJoinGroupBys.get(i), slot.getName())); - } - - LogicalProject project = new LogicalProject<>(projects, join); - return new LogicalCTEAnchor(producer.getCteId(), producer, project); + return join; } }