Skip to content

Commit

Permalink
add rule count distinct split
Browse files Browse the repository at this point in the history
  • Loading branch information
feiniaofeiafei committed Dec 10, 2024
1 parent c1d0c19 commit fb5ffba
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
import org.apache.doris.nereids.rules.rewrite.ColumnPruning;
import org.apache.doris.nereids.rules.rewrite.ConvertInnerOrCrossJoin;
import org.apache.doris.nereids.rules.rewrite.CountDistinctRewrite;
import org.apache.doris.nereids.rules.rewrite.CountDistinctSplit;
import org.apache.doris.nereids.rules.rewrite.DistinctSplit;
import org.apache.doris.nereids.rules.rewrite.CountLiteralRewrite;
import org.apache.doris.nereids.rules.rewrite.CreatePartitionTopNFromWindow;
import org.apache.doris.nereids.rules.rewrite.DeferMaterializeTopNResult;
Expand Down Expand Up @@ -549,7 +549,7 @@ private static List<RewriteJob> getWholeTreeRewriteJobs(
rewriteJobs.addAll(jobs(topic("or expansion",
custom(RuleType.OR_EXPANSION, () -> OrExpansion.INSTANCE))));
}
rewriteJobs.addAll(jobs(topic("count distinct split", topDown(new CountDistinctSplit())
rewriteJobs.addAll(jobs(topic("count distinct split", topDown(new DistinctSplit())
)));

if (needSubPathPushDown) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ public enum RuleType {
MERGE_TOP_N(RuleTypeClass.REWRITE),
BUILD_AGG_FOR_UNION(RuleTypeClass.REWRITE),
COUNT_DISTINCT_REWRITE(RuleTypeClass.REWRITE),
COUNT_DISTINCT_SPLIT(RuleTypeClass.REWRITE),
DISTINCT_SPLIT(RuleTypeClass.REWRITE),
INNER_TO_CROSS_JOIN(RuleTypeClass.REWRITE),
CROSS_TO_INNER_JOIN(RuleTypeClass.REWRITE),
PRUNE_EMPTY_PARTITION(RuleTypeClass.REWRITE),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,13 @@
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.OrderExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.Avg;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.agg.GroupConcat;
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
Expand All @@ -36,6 +41,8 @@
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.util.ExpressionUtils;

import com.google.common.collect.ImmutableSet;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
Expand All @@ -58,36 +65,53 @@
* +--LogicalAggregate(output:count(distinct b))
* +--LogicalCTEConsumer
* */
public class CountDistinctSplit extends OneRewriteRuleFactory {
public class DistinctSplit extends OneRewriteRuleFactory {
private static final ImmutableSet<Class<? extends AggregateFunction>> supportedFunctions =
ImmutableSet.of(Count.class, Sum.class, Avg.class, GroupConcat.class);

@Override
public Rule build() {
return logicalAggregate()
.whenNot(agg -> agg.getSourceRepeat().isPresent())
.thenApply(ctx -> doSplit(ctx.root, ctx.cascadesContext))
.toRule(RuleType.COUNT_DISTINCT_SPLIT);
.toRule(RuleType.DISTINCT_SPLIT);
}

private static boolean isDistinctMultiColumns(AggregateFunction func) {
if (func.arity() <= 1) {
return false;
}
for (int i = 1; i < func.arity(); ++i) {
// think about group_concat(distinct col_1, ',')
if (!(func.child(i) instanceof OrderExpression) && !func.child(i).getInputSlots().isEmpty()) {
return true;
}
}
return false;
}

private static Plan doSplit(LogicalAggregate<Plan> agg, CascadesContext ctx) {
List<Alias> aliases = new ArrayList<>();
Set<Expression> countDistinct = new HashSet<>();
Set<Expression> distinctFunc = new HashSet<>();
List<Alias> otherAggFuncs = new ArrayList<>();
boolean distinctMultiColumns = false;
for (NamedExpression namedExpression : agg.getOutputExpressions()) {
if (!(namedExpression instanceof Alias)) {
if (!(namedExpression instanceof Alias) || !(namedExpression.child(0) instanceof AggregateFunction)) {
continue;
}
if (namedExpression.child(0) instanceof Count
&& ((Count) namedExpression.child(0)).isDistinct()) {
AggregateFunction aggFunc = (AggregateFunction) namedExpression.child(0);
if (supportedFunctions.contains(aggFunc.getClass()) && aggFunc.isDistinct()) {
aliases.add((Alias) namedExpression);
countDistinct.add(namedExpression.child(0));
distinctMultiColumns |= namedExpression.child(0).arity() > 1;
distinctFunc.add(aggFunc);
distinctMultiColumns |= isDistinctMultiColumns(aggFunc);
} else {
otherAggFuncs.add((Alias) namedExpression);
}
}
if (countDistinct.size() <= 1) {
if (distinctFunc.size() <= 1) {
return null;
}
if (!distinctMultiColumns && agg.getGroupByExpressions().isEmpty()) {
if (!distinctMultiColumns && !agg.getGroupByExpressions().isEmpty()) {
return null;
}

Expand Down

0 comments on commit fb5ffba

Please sign in to comment.