diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java index 8873c9a6deed5e2..6bb86d6aa95537e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java @@ -53,7 +53,7 @@ import org.apache.doris.nereids.rules.rewrite.ColumnPruning; import org.apache.doris.nereids.rules.rewrite.ConvertInnerOrCrossJoin; import org.apache.doris.nereids.rules.rewrite.CountDistinctRewrite; -import org.apache.doris.nereids.rules.rewrite.CountDistinctSplit; +import org.apache.doris.nereids.rules.rewrite.DistinctSplit; import org.apache.doris.nereids.rules.rewrite.CountLiteralRewrite; import org.apache.doris.nereids.rules.rewrite.CreatePartitionTopNFromWindow; import org.apache.doris.nereids.rules.rewrite.DeferMaterializeTopNResult; @@ -549,7 +549,7 @@ private static List getWholeTreeRewriteJobs( rewriteJobs.addAll(jobs(topic("or expansion", custom(RuleType.OR_EXPANSION, () -> OrExpansion.INSTANCE)))); } - rewriteJobs.addAll(jobs(topic("count distinct split", topDown(new CountDistinctSplit()) + rewriteJobs.addAll(jobs(topic("count distinct split", topDown(new DistinctSplit()) ))); if (needSubPathPushDown) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java index 9c7e84dc9e7b839..388a8e4ef4489f6 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java @@ -315,7 +315,7 @@ public enum RuleType { MERGE_TOP_N(RuleTypeClass.REWRITE), BUILD_AGG_FOR_UNION(RuleTypeClass.REWRITE), COUNT_DISTINCT_REWRITE(RuleTypeClass.REWRITE), - COUNT_DISTINCT_SPLIT(RuleTypeClass.REWRITE), + DISTINCT_SPLIT(RuleTypeClass.REWRITE), INNER_TO_CROSS_JOIN(RuleTypeClass.REWRITE), CROSS_TO_INNER_JOIN(RuleTypeClass.REWRITE), PRUNE_EMPTY_PARTITION(RuleTypeClass.REWRITE), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CountDistinctSplit.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DistinctSplit.java similarity index 81% rename from fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CountDistinctSplit.java rename to fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DistinctSplit.java index ac68943c615346e..cb9ba1da7d8cd2f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CountDistinctSplit.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DistinctSplit.java @@ -24,8 +24,13 @@ import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.OrderExpression; import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; +import org.apache.doris.nereids.trees.expressions.functions.agg.Avg; import org.apache.doris.nereids.trees.expressions.functions.agg.Count; +import org.apache.doris.nereids.trees.expressions.functions.agg.GroupConcat; +import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; @@ -36,6 +41,8 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import org.apache.doris.nereids.util.ExpressionUtils; +import com.google.common.collect.ImmutableSet; + import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; @@ -58,36 +65,53 @@ * +--LogicalAggregate(output:count(distinct b)) * +--LogicalCTEConsumer * */ -public class CountDistinctSplit extends OneRewriteRuleFactory { +public class DistinctSplit extends OneRewriteRuleFactory { + private static final ImmutableSet> supportedFunctions = + ImmutableSet.of(Count.class, Sum.class, Avg.class, GroupConcat.class); + @Override public Rule build() { return logicalAggregate() + .whenNot(agg -> agg.getSourceRepeat().isPresent()) .thenApply(ctx -> doSplit(ctx.root, ctx.cascadesContext)) - .toRule(RuleType.COUNT_DISTINCT_SPLIT); + .toRule(RuleType.DISTINCT_SPLIT); + } + + private static boolean isDistinctMultiColumns(AggregateFunction func) { + if (func.arity() <= 1) { + return false; + } + for (int i = 1; i < func.arity(); ++i) { + // think about group_concat(distinct col_1, ',') + if (!(func.child(i) instanceof OrderExpression) && !func.child(i).getInputSlots().isEmpty()) { + return true; + } + } + return false; } private static Plan doSplit(LogicalAggregate agg, CascadesContext ctx) { List aliases = new ArrayList<>(); - Set countDistinct = new HashSet<>(); + Set distinctFunc = new HashSet<>(); List otherAggFuncs = new ArrayList<>(); boolean distinctMultiColumns = false; for (NamedExpression namedExpression : agg.getOutputExpressions()) { - if (!(namedExpression instanceof Alias)) { + if (!(namedExpression instanceof Alias) || !(namedExpression.child(0) instanceof AggregateFunction)) { continue; } - if (namedExpression.child(0) instanceof Count - && ((Count) namedExpression.child(0)).isDistinct()) { + AggregateFunction aggFunc = (AggregateFunction) namedExpression.child(0); + if (supportedFunctions.contains(aggFunc.getClass()) && aggFunc.isDistinct()) { aliases.add((Alias) namedExpression); - countDistinct.add(namedExpression.child(0)); - distinctMultiColumns |= namedExpression.child(0).arity() > 1; + distinctFunc.add(aggFunc); + distinctMultiColumns |= isDistinctMultiColumns(aggFunc); } else { otherAggFuncs.add((Alias) namedExpression); } } - if (countDistinct.size() <= 1) { + if (distinctFunc.size() <= 1) { return null; } - if (!distinctMultiColumns && agg.getGroupByExpressions().isEmpty()) { + if (!distinctMultiColumns && !agg.getGroupByExpressions().isEmpty()) { return null; }