From 8a61eb9e69f6be4ee175bfdcf1bfa88b256f9bea Mon Sep 17 00:00:00 2001 From: feiniaofeiafei Date: Tue, 10 Dec 2024 10:41:22 +0800 Subject: [PATCH] [feat](nereids) add rewrite rule :EliminateGroupByKeyByUniform (#43391) (#45075) cherry-pick #43391 to branch-3.0 --- .../doris/nereids/jobs/executor/Rewriter.java | 2 + .../doris/nereids/properties/DataTrait.java | 211 +++++++++++-- .../apache/doris/nereids/rules/RuleType.java | 6 + .../rules/expression/ExpressionRewrite.java | 2 +- .../rewrite/EliminateGroupByKeyByUniform.java | 148 +++++++++ .../nereids/rules/rewrite/ExprIdRewriter.java | 284 ++++++++++++++++++ .../plans/commands/info/CreateMTMVInfo.java | 3 +- .../trees/plans/logical/LogicalFilter.java | 7 +- .../trees/plans/logical/LogicalHaving.java | 7 +- .../trees/plans/logical/LogicalJoin.java | 34 ++- .../trees/plans/logical/LogicalProject.java | 14 +- .../doris/nereids/util/ExpressionUtils.java | 11 +- .../mv/MaterializedViewUtilsTest.java | 2 +- .../EliminateGroupByKeyByUniformTest.java | 250 +++++++++++++++ .../rewrite/EliminateGroupByKeyTest.java | 4 +- .../eliminate_group_by_key_by_uniform.out | 267 ++++++++++++++++ .../eliminate_group_by_key_by_uniform.groovy | 221 ++++++++++++++ .../aggregate_without_roll_up.groovy | 6 +- 18 files changed, 1421 insertions(+), 58 deletions(-) create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByKeyByUniform.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ExprIdRewriter.java create mode 100644 fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByKeyByUniformTest.java create mode 100644 regression-test/data/nereids_rules_p0/eliminate_gby_key/eliminate_group_by_key_by_uniform.out create mode 100644 regression-test/suites/nereids_rules_p0/eliminate_gby_key/eliminate_group_by_key_by_uniform.groovy diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java index 4feeb6439e46c8..e70746701711e5 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java @@ -64,6 +64,7 @@ import org.apache.doris.nereids.rules.rewrite.EliminateFilter; import org.apache.doris.nereids.rules.rewrite.EliminateGroupBy; import org.apache.doris.nereids.rules.rewrite.EliminateGroupByKey; +import org.apache.doris.nereids.rules.rewrite.EliminateGroupByKeyByUniform; import org.apache.doris.nereids.rules.rewrite.EliminateJoinByFK; import org.apache.doris.nereids.rules.rewrite.EliminateJoinByUnique; import org.apache.doris.nereids.rules.rewrite.EliminateJoinCondition; @@ -356,6 +357,7 @@ public class Rewriter extends AbstractBatchJobExecutor { topDown(new EliminateJoinByUnique()) ), topic("eliminate Aggregate according to fd items", + custom(RuleType.ELIMINATE_GROUP_BY_KEY_BY_UNIFORM, EliminateGroupByKeyByUniform::new), topDown(new EliminateGroupByKey()), topDown(new PushDownAggThroughJoinOnPkFk()), topDown(new PullUpJoinFromUnionAll()) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/DataTrait.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/DataTrait.java index e97fad6f479047..ff4756979e450e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/DataTrait.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/DataTrait.java @@ -17,18 +17,23 @@ package org.apache.doris.nereids.properties; +import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.functions.ExpressionTrait; +import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; import org.apache.doris.nereids.util.ImmutableEqualSet; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Sets; import java.util.ArrayList; import java.util.HashSet; import java.util.Iterator; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; @@ -46,16 +51,16 @@ public class DataTrait { public static final DataTrait EMPTY_TRAIT - = new DataTrait(new NestedSet().toImmutable(), - new NestedSet().toImmutable(), new ImmutableSet.Builder().build(), + = new DataTrait(new UniqueDescription().toImmutable(), + new UniformDescription().toImmutable(), new ImmutableSet.Builder().build(), ImmutableEqualSet.empty(), new FuncDepsDG.Builder().build()); - private final NestedSet uniqueSet; - private final NestedSet uniformSet; + private final UniqueDescription uniqueSet; + private final UniformDescription uniformSet; private final ImmutableSet fdItems; private final ImmutableEqualSet equalSet; private final FuncDepsDG fdDg; - private DataTrait(NestedSet uniqueSet, NestedSet uniformSet, ImmutableSet fdItems, + private DataTrait(UniqueDescription uniqueSet, UniformDescription uniformSet, ImmutableSet fdItems, ImmutableEqualSet equalSet, FuncDepsDG fdDg) { this.uniqueSet = uniqueSet; this.uniformSet = uniformSet; @@ -86,8 +91,7 @@ public boolean isUniform(Slot slot) { } public boolean isUniform(Set slotSet) { - return !slotSet.isEmpty() - && uniformSet.slots.containsAll(slotSet); + return uniformSet.contains(slotSet); } public boolean isUniqueAndNotNull(Slot slot) { @@ -102,11 +106,25 @@ public boolean isUniqueAndNotNull(Set slotSet) { } public boolean isUniformAndNotNull(Slot slot) { - return !slot.nullable() && isUniform(slot); + return uniformSet.isUniformAndNotNull(slot); } + /** isUniformAndNotNull for slot set */ public boolean isUniformAndNotNull(ImmutableSet slotSet) { - return slotSet.stream().noneMatch(Slot::nullable) && isUniform(slotSet); + for (Slot slot : slotSet) { + if (!uniformSet.isUniformAndNotNull(slot)) { + return false; + } + } + return true; + } + + public boolean isUniformAndHasConstValue(Slot slot) { + return uniformSet.isUniformAndHasConstValue(slot); + } + + public Optional getUniformValue(Slot slot) { + return uniformSet.slotUniformValue.get(slot); } public boolean isNullSafeEqual(Slot l, Slot r) { @@ -143,23 +161,23 @@ public String toString() { * Builder of trait */ public static class Builder { - private final NestedSet uniqueSet; - private final NestedSet uniformSet; + private final UniqueDescription uniqueSet; + private final UniformDescription uniformSet; private ImmutableSet fdItems; private final ImmutableEqualSet.Builder equalSetBuilder; private final FuncDepsDG.Builder fdDgBuilder; public Builder() { - uniqueSet = new NestedSet(); - uniformSet = new NestedSet(); + uniqueSet = new UniqueDescription(); + uniformSet = new UniformDescription(); fdItems = new ImmutableSet.Builder().build(); equalSetBuilder = new ImmutableEqualSet.Builder<>(); fdDgBuilder = new FuncDepsDG.Builder(); } public Builder(DataTrait other) { - this.uniformSet = new NestedSet(other.uniformSet); - this.uniqueSet = new NestedSet(other.uniqueSet); + this.uniformSet = new UniformDescription(other.uniformSet); + this.uniqueSet = new UniqueDescription(other.uniqueSet); this.fdItems = ImmutableSet.copyOf(other.fdItems); equalSetBuilder = new ImmutableEqualSet.Builder<>(other.equalSet); fdDgBuilder = new FuncDepsDG.Builder(other.fdDg); @@ -173,6 +191,14 @@ public void addUniformSlot(DataTrait dataTrait) { uniformSet.add(dataTrait.uniformSet); } + public void addUniformSlotForOuterJoinNullableSide(DataTrait dataTrait) { + uniformSet.addUniformSlotForOuterJoinNullableSide(dataTrait.uniformSet); + } + + public void addUniformSlotAndLiteral(Slot slot, Expression literal) { + uniformSet.add(slot, literal); + } + public void addUniqueSlot(Slot slot) { uniqueSet.add(slot); } @@ -261,8 +287,21 @@ public void addUniqueByEqualSet(Set equalSet) { * if there is a uniform slot in the equivalence set, then all slots of an equivalence set are uniform */ public void addUniformByEqualSet(Set equalSet) { - if (uniformSet.isIntersect(uniformSet.slots, equalSet)) { - uniformSet.slots.addAll(equalSet); + List intersectionList = uniformSet.slotUniformValue.keySet().stream() + .filter(equalSet::contains) + .collect(Collectors.toList()); + if (intersectionList.isEmpty()) { + return; + } + Expression expr = null; + for (Slot slot : intersectionList) { + if (uniformSet.slotUniformValue.get(slot).isPresent()) { + expr = uniformSet.slotUniformValue.get(slot).get(); + break; + } + } + for (Slot equal : equalSet) { + uniformSet.add(equal, expr); } } @@ -293,9 +332,11 @@ public List> getAllUniqueAndNotNull() { */ public List> getAllUniformAndNotNull() { List> res = new ArrayList<>(); - for (Slot s : uniformSet.slots) { - if (!s.nullable()) { - res.add(ImmutableSet.of(s)); + for (Map.Entry> entry : uniformSet.slotUniformValue.entrySet()) { + if (!entry.getKey().nullable()) { + res.add(ImmutableSet.of(entry.getKey())); + } else if (entry.getValue().isPresent() && !entry.getValue().get().nullable()) { + res.add(ImmutableSet.of(entry.getKey())); } } return res; @@ -338,21 +379,21 @@ public void replaceFuncDepsBy(Map replaceMap) { } } - static class NestedSet { + static class UniqueDescription { Set slots; Set> slotSets; - NestedSet() { + UniqueDescription() { slots = new HashSet<>(); slotSets = new HashSet<>(); } - NestedSet(NestedSet o) { + UniqueDescription(UniqueDescription o) { this.slots = new HashSet<>(o.slots); this.slotSets = new HashSet<>(o.slotSets); } - NestedSet(Set slots, Set> slotSets) { + UniqueDescription(Set slots, Set> slotSets) { this.slots = slots; this.slotSets = slotSets; } @@ -408,9 +449,9 @@ public void add(ImmutableSet slotSet) { slotSets.add(slotSet); } - public void add(NestedSet nestedSet) { - slots.addAll(nestedSet.slots); - slotSets.addAll(nestedSet.slotSets); + public void add(UniqueDescription uniqueDescription) { + slots.addAll(uniqueDescription.slots); + slotSets.addAll(uniqueDescription.slotSets); } public boolean isIntersect(Set set1, Set set2) { @@ -446,8 +487,120 @@ public void replace(Map replaceMap) { .collect(Collectors.toSet()); } - public NestedSet toImmutable() { - return new NestedSet(ImmutableSet.copyOf(slots), ImmutableSet.copyOf(slotSets)); + public UniqueDescription toImmutable() { + return new UniqueDescription(ImmutableSet.copyOf(slots), ImmutableSet.copyOf(slotSets)); + } + } + + static class UniformDescription { + // slot and its uniform expression(literal or const expression) + // some slot can get uniform values, others can not. + // e.g.select a from t where a=10 group by a, b; + // in LogicalAggregate, a UniformDescription with map {a : 10} can be obtained. + // which means a is uniform and the uniform value is 10. + Map> slotUniformValue; + + public UniformDescription() { + slotUniformValue = new LinkedHashMap<>(); + } + + public UniformDescription(UniformDescription ud) { + slotUniformValue = new LinkedHashMap<>(ud.slotUniformValue); + } + + public UniformDescription(Map> slotUniformValue) { + this.slotUniformValue = slotUniformValue; + } + + public UniformDescription toImmutable() { + return new UniformDescription(ImmutableMap.copyOf(slotUniformValue)); + } + + public boolean isEmpty() { + return slotUniformValue.isEmpty(); + } + + public boolean contains(Slot slot) { + return slotUniformValue.containsKey(slot); + } + + public boolean contains(Set slots) { + return !slots.isEmpty() && slotUniformValue.keySet().containsAll(slots); + } + + public void add(Slot slot) { + slotUniformValue.putIfAbsent(slot, Optional.empty()); + } + + public void add(Set slots) { + for (Slot s : slots) { + slotUniformValue.putIfAbsent(s, Optional.empty()); + } + } + + public void add(UniformDescription ud) { + slotUniformValue.putAll(ud.slotUniformValue); + for (Map.Entry> entry : ud.slotUniformValue.entrySet()) { + add(entry.getKey(), entry.getValue().orElse(null)); + } + } + + public void add(Slot slot, Expression literal) { + if (null == literal) { + slotUniformValue.putIfAbsent(slot, Optional.empty()); + } else { + slotUniformValue.put(slot, Optional.of(literal)); + } + } + + public void addUniformSlotForOuterJoinNullableSide(UniformDescription ud) { + for (Map.Entry> entry : ud.slotUniformValue.entrySet()) { + if ((!entry.getValue().isPresent() && entry.getKey().nullable()) + || (entry.getValue().isPresent() && entry.getValue().get() instanceof NullLiteral)) { + add(entry.getKey(), entry.getValue().orElse(null)); + } + } + } + + public void removeNotContain(Set slotSet) { + if (slotSet.isEmpty()) { + return; + } + Map> newSlotUniformValue = new LinkedHashMap<>(); + for (Map.Entry> entry : slotUniformValue.entrySet()) { + if (slotSet.contains(entry.getKey())) { + newSlotUniformValue.put(entry.getKey(), entry.getValue()); + } + } + this.slotUniformValue = newSlotUniformValue; + } + + public void replace(Map replaceMap) { + Map> newSlotUniformValue = new LinkedHashMap<>(); + for (Map.Entry> entry : slotUniformValue.entrySet()) { + Slot newKey = replaceMap.getOrDefault(entry.getKey(), entry.getKey()); + newSlotUniformValue.put(newKey, entry.getValue()); + } + slotUniformValue = newSlotUniformValue; + } + + // The current implementation logic is: if a slot key exists in map slotUniformValue, + // its value is present and is not nullable, + // or if a slot key exists in map slotUniformValue and the slot is not nullable + // it indicates that this slot is uniform and not null. + public boolean isUniformAndNotNull(Slot slot) { + return slotUniformValue.containsKey(slot) + && (!slot.nullable() || slotUniformValue.get(slot).isPresent() + && !slotUniformValue.get(slot).get().nullable()); + } + + public boolean isUniformAndHasConstValue(Slot slot) { + return slotUniformValue.containsKey(slot) && slotUniformValue.get(slot).isPresent(); + } + + @Override + public String toString() { + return "{" + slotUniformValue + "}"; } } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java index 430d2b35b35727..e587f953e82c2f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java @@ -209,6 +209,11 @@ public enum RuleType { REWRITE_SORT_EXPRESSION(RuleTypeClass.REWRITE), REWRITE_HAVING_EXPRESSION(RuleTypeClass.REWRITE), REWRITE_REPEAT_EXPRESSION(RuleTypeClass.REWRITE), + REWRITE_SINK_EXPRESSION(RuleTypeClass.REWRITE), + REWRITE_WINDOW_EXPRESSION(RuleTypeClass.REWRITE), + REWRITE_SET_OPERATION_EXPRESSION(RuleTypeClass.REWRITE), + REWRITE_PARTITION_TOPN_EXPRESSION(RuleTypeClass.REWRITE), + REWRITE_TOPN_EXPRESSION(RuleTypeClass.REWRITE), EXTRACT_FILTER_FROM_JOIN(RuleTypeClass.REWRITE), REORDER_JOIN(RuleTypeClass.REWRITE), MERGE_PERCENTILE_TO_ARRAY(RuleTypeClass.REWRITE), @@ -238,6 +243,7 @@ public enum RuleType { ELIMINATE_JOIN_BY_UK(RuleTypeClass.REWRITE), ELIMINATE_JOIN_BY_FK(RuleTypeClass.REWRITE), ELIMINATE_GROUP_BY_KEY(RuleTypeClass.REWRITE), + ELIMINATE_GROUP_BY_KEY_BY_UNIFORM(RuleTypeClass.REWRITE), ELIMINATE_FILTER_GROUP_BY_KEY(RuleTypeClass.REWRITE), ELIMINATE_DEDUP_JOIN_CONDITION(RuleTypeClass.REWRITE), ELIMINATE_NULL_AWARE_LEFT_ANTI_JOIN(RuleTypeClass.REWRITE), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRewrite.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRewrite.java index 0eb4d3a79a0c96..277f1e6f2d8adc 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRewrite.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRewrite.java @@ -53,7 +53,7 @@ * expression of plan rewrite rule. */ public class ExpressionRewrite implements RewriteRuleFactory { - private final ExpressionRuleExecutor rewriter; + protected final ExpressionRuleExecutor rewriter; public ExpressionRewrite(ExpressionRewriteRule... rules) { this.rewriter = new ExpressionRuleExecutor(ImmutableList.copyOf(rules)); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByKeyByUniform.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByKeyByUniform.java new file mode 100644 index 00000000000000..4cb39c2a9341ae --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByKeyByUniform.java @@ -0,0 +1,148 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.nereids.jobs.JobContext; +import org.apache.doris.nereids.properties.DataTrait; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.CTEId; +import org.apache.doris.nereids.trees.expressions.ExprId; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; +import org.apache.doris.nereids.trees.expressions.functions.agg.AnyValue; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; +import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter; +import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +/** + * +--aggregate(group by a, b output a#0 ,b#1, max(c) as max(c)#2) + * (a is uniform and not null: e.g. a is projection 2 as a in logicalProject) + * -> + * +--aggregate(group by b output b#1, any_value(a#0) as a#3, max(c)#2) + * if output any_value(a#0) as a#0, the uniqueness of ExprId #0 is violated, because #0 is both any_value(a#0) and a#0 + * error will occurs in other module(e.g. mv rewrite). + * As a result, new aggregate outputs #3 instead of #0, but upper plan refer slot #0, + * therefore, all references to #0 in the upper plan need to be changed to #3. + * use ExprIdRewriter to do this ExprId rewrite, and use CustomRewriter to rewrite upward。 + * */ +public class EliminateGroupByKeyByUniform extends DefaultPlanRewriter> implements CustomRewriter { + private ExprIdRewriter exprIdReplacer; + + @Override + public Plan rewriteRoot(Plan plan, JobContext jobContext) { + Optional cteId = jobContext.getCascadesContext().getCurrentTree(); + if (cteId.isPresent()) { + return plan; + } + Map replaceMap = new HashMap<>(); + ExprIdRewriter.ReplaceRule replaceRule = new ExprIdRewriter.ReplaceRule(replaceMap); + exprIdReplacer = new ExprIdRewriter(replaceRule, jobContext); + return plan.accept(this, replaceMap); + } + + @Override + public Plan visit(Plan plan, Map replaceMap) { + plan = visitChildren(this, plan, replaceMap); + plan = exprIdReplacer.rewriteExpr(plan, replaceMap); + return plan; + } + + @Override + public Plan visitLogicalAggregate(LogicalAggregate aggregate, Map replaceMap) { + aggregate = visitChildren(this, aggregate, replaceMap); + aggregate = (LogicalAggregate) exprIdReplacer.rewriteExpr(aggregate, replaceMap); + + if (aggregate.getGroupByExpressions().isEmpty() || aggregate.getSourceRepeat().isPresent()) { + return aggregate; + } + DataTrait aggChildTrait = aggregate.child().getLogicalProperties().getTrait(); + // Get the Group by column of agg. If there is a uniform one, delete the group by key. + Set removedExpression = new LinkedHashSet<>(); + List newGroupBy = new ArrayList<>(); + for (Expression groupBy : aggregate.getGroupByExpressions()) { + if (!(groupBy instanceof Slot)) { + newGroupBy.add(groupBy); + continue; + } + if (aggChildTrait.isUniformAndNotNull((Slot) groupBy)) { + removedExpression.add(groupBy); + } else { + newGroupBy.add(groupBy); + } + } + if (removedExpression.isEmpty()) { + return aggregate; + } + // when newGroupBy is empty, need retain one expr in group by, otherwise the result may be wrong in empty table + if (newGroupBy.isEmpty()) { + Expression expr = removedExpression.iterator().next(); + newGroupBy.add(expr); + removedExpression.remove(expr); + } + if (removedExpression.isEmpty()) { + return aggregate; + } + List newOutputs = new ArrayList<>(); + // If this output appears in the removedExpression column, replace it with any_value + for (NamedExpression output : aggregate.getOutputExpressions()) { + if (output instanceof Slot) { + if (removedExpression.contains(output)) { + Alias alias = new Alias(new AnyValue(false, output), output.getName()); + newOutputs.add(alias); + replaceMap.put(output.getExprId(), alias.getExprId()); + } else { + newOutputs.add(output); + } + } else if (output instanceof Alias) { + if (removedExpression.contains(output.child(0))) { + newOutputs.add(new Alias( + new AnyValue(false, output.child(0)), output.getName())); + } else { + newOutputs.add(output); + } + } else { + newOutputs.add(output); + } + } + + // Adjust the order of this new output so that aggregate functions are placed at the back + // and non-aggregated functions are placed at the front. + List aggFuncs = new ArrayList<>(); + List orderOutput = new ArrayList<>(); + for (NamedExpression output : newOutputs) { + if (output.anyMatch(e -> e instanceof AggregateFunction)) { + aggFuncs.add(output); + } else { + orderOutput.add(output); + } + } + orderOutput.addAll(aggFuncs); + return aggregate.withGroupByAndOutput(newGroupBy, orderOutput); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ExprIdRewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ExprIdRewriter.java new file mode 100644 index 00000000000000..60c9da4bc6eec5 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ExprIdRewriter.java @@ -0,0 +1,284 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.nereids.jobs.JobContext; +import org.apache.doris.nereids.pattern.MatchingContext; +import org.apache.doris.nereids.pattern.Pattern; +import org.apache.doris.nereids.properties.OrderKey; +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; +import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; +import org.apache.doris.nereids.rules.expression.ExpressionRewrite; +import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; +import org.apache.doris.nereids.rules.expression.ExpressionRuleExecutor; +import org.apache.doris.nereids.trees.expressions.ExprId; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.OrderExpression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalPartitionTopN; +import org.apache.doris.nereids.trees.plans.logical.LogicalSetOperation; +import org.apache.doris.nereids.trees.plans.logical.LogicalSink; +import org.apache.doris.nereids.trees.plans.logical.LogicalTopN; +import org.apache.doris.nereids.trees.plans.logical.LogicalWindow; + +import com.google.common.collect.ImmutableList; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +/** replace SlotReference ExprId in logical plans */ +public class ExprIdRewriter extends ExpressionRewrite { + private final List rules; + private final JobContext jobContext; + + public ExprIdRewriter(ReplaceRule replaceRule, JobContext jobContext) { + super(new ExpressionRuleExecutor(ImmutableList.of(bottomUp(replaceRule)))); + rules = buildRules(); + this.jobContext = jobContext; + } + + @Override + public List buildRules() { + ImmutableList.Builder builder = ImmutableList.builder(); + builder.addAll(super.buildRules()); + builder.addAll(ImmutableList.of( + new LogicalPartitionTopNExpressionRewrite().build(), + new LogicalTopNExpressionRewrite().build(), + new LogicalSetOperationRewrite().build(), + new LogicalWindowRewrite().build(), + new LogicalResultSinkRewrite().build(), + new LogicalFileSinkRewrite().build(), + new LogicalHiveTableSinkRewrite().build(), + new LogicalIcebergTableSinkRewrite().build(), + new LogicalJdbcTableSinkRewrite().build(), + new LogicalOlapTableSinkRewrite().build(), + new LogicalDeferMaterializeResultSinkRewrite().build() + )); + return builder.build(); + } + + /**rewriteExpr*/ + public Plan rewriteExpr(Plan plan, Map replaceMap) { + if (replaceMap.isEmpty()) { + return plan; + } + for (Rule rule : rules) { + Pattern pattern = (Pattern) rule.getPattern(); + if (pattern.matchPlanTree(plan)) { + List newPlans = rule.transform(plan, jobContext.getCascadesContext()); + Plan newPlan = newPlans.get(0); + if (!newPlan.deepEquals(plan)) { + return newPlan; + } + } + } + return plan; + } + + /** + * Iteratively rewrites IDs using the replaceMap: + * 1. For a given SlotReference with initial ID, retrieve the corresponding value ID from the replaceMap. + * 2. If the value ID exists within the replaceMap, continue the lookup process using the value ID + * until it no longer appears in the replaceMap. + * 3. return SlotReference final value ID as the result of the rewrite. + * e.g. replaceMap:{0:3, 1:6, 6:7} + * SlotReference:a#0 -> a#3, a#1 -> a#7 + * */ + public static class ReplaceRule implements ExpressionPatternRuleFactory { + private final Map replaceMap; + + public ReplaceRule(Map replaceMap) { + this.replaceMap = replaceMap; + } + + @Override + public List> buildRules() { + return ImmutableList.of( + matchesType(SlotReference.class).thenApply(ctx -> { + Slot slot = ctx.expr; + if (replaceMap.containsKey(slot.getExprId())) { + ExprId newId = replaceMap.get(slot.getExprId()); + while (replaceMap.containsKey(newId)) { + newId = replaceMap.get(newId); + } + return slot.withExprId(newId); + } + return slot; + }) + ); + } + } + + private class LogicalResultSinkRewrite extends OneRewriteRuleFactory { + @Override + public Rule build() { + return logicalResultSink().thenApply(ExprIdRewriter.this::applyRewrite) + .toRule(RuleType.REWRITE_SINK_EXPRESSION); + } + } + + private class LogicalFileSinkRewrite extends OneRewriteRuleFactory { + @Override + public Rule build() { + return logicalFileSink().thenApply(ExprIdRewriter.this::applyRewrite) + .toRule(RuleType.REWRITE_SINK_EXPRESSION); + } + } + + private class LogicalHiveTableSinkRewrite extends OneRewriteRuleFactory { + @Override + public Rule build() { + return logicalHiveTableSink().thenApply(ExprIdRewriter.this::applyRewrite) + .toRule(RuleType.REWRITE_SINK_EXPRESSION); + } + } + + private class LogicalIcebergTableSinkRewrite extends OneRewriteRuleFactory { + @Override + public Rule build() { + return logicalIcebergTableSink().thenApply(ExprIdRewriter.this::applyRewrite) + .toRule(RuleType.REWRITE_SINK_EXPRESSION); + } + } + + private class LogicalJdbcTableSinkRewrite extends OneRewriteRuleFactory { + @Override + public Rule build() { + return logicalJdbcTableSink().thenApply(ExprIdRewriter.this::applyRewrite) + .toRule(RuleType.REWRITE_SINK_EXPRESSION); + } + } + + private class LogicalOlapTableSinkRewrite extends OneRewriteRuleFactory { + @Override + public Rule build() { + return logicalOlapTableSink().thenApply(ExprIdRewriter.this::applyRewrite) + .toRule(RuleType.REWRITE_SINK_EXPRESSION); + } + } + + private class LogicalDeferMaterializeResultSinkRewrite extends OneRewriteRuleFactory { + @Override + public Rule build() { + return logicalDeferMaterializeResultSink().thenApply(ExprIdRewriter.this::applyRewrite) + .toRule(RuleType.REWRITE_SINK_EXPRESSION); + } + } + + private class LogicalSetOperationRewrite extends OneRewriteRuleFactory { + @Override + public Rule build() { + return logicalSetOperation().thenApply(ctx -> { + LogicalSetOperation setOperation = ctx.root; + List> slotsList = setOperation.getRegularChildrenOutputs(); + List> newSlotsList = new ArrayList<>(); + ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext); + for (List slots : slotsList) { + List newSlots = rewriteAll(slots, rewriter, context); + newSlotsList.add(newSlots); + } + if (newSlotsList.equals(slotsList)) { + return setOperation; + } + return setOperation.withChildrenAndTheirOutputs(setOperation.children(), newSlotsList); + }) + .toRule(RuleType.REWRITE_SET_OPERATION_EXPRESSION); + } + } + + private class LogicalWindowRewrite extends OneRewriteRuleFactory { + @Override + public Rule build() { + return logicalWindow().thenApply(ctx -> { + LogicalWindow window = ctx.root; + List windowExpressions = window.getWindowExpressions(); + ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext); + List newWindowExpressions = rewriteAll(windowExpressions, rewriter, context); + if (newWindowExpressions.equals(windowExpressions)) { + return window; + } + return window.withExpressionsAndChild(newWindowExpressions, window.child()); + }) + .toRule(RuleType.REWRITE_WINDOW_EXPRESSION); + } + } + + private class LogicalTopNExpressionRewrite extends OneRewriteRuleFactory { + @Override + public Rule build() { + return logicalTopN().thenApply(ctx -> { + LogicalTopN topN = ctx.root; + List orderKeys = topN.getOrderKeys(); + ImmutableList.Builder rewrittenOrderKeys + = ImmutableList.builderWithExpectedSize(orderKeys.size()); + ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext); + boolean changed = false; + for (OrderKey k : orderKeys) { + Expression expression = rewriter.rewrite(k.getExpr(), context); + changed |= expression != k.getExpr(); + rewrittenOrderKeys.add(new OrderKey(expression, k.isAsc(), k.isNullFirst())); + } + return changed ? topN.withOrderKeys(rewrittenOrderKeys.build()) : topN; + }).toRule(RuleType.REWRITE_TOPN_EXPRESSION); + } + } + + private class LogicalPartitionTopNExpressionRewrite extends OneRewriteRuleFactory { + @Override + public Rule build() { + return logicalPartitionTopN().thenApply(ctx -> { + LogicalPartitionTopN partitionTopN = ctx.root; + ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext); + List newOrderExpressions = new ArrayList<>(); + boolean changed = false; + for (OrderExpression orderExpression : partitionTopN.getOrderKeys()) { + OrderKey orderKey = orderExpression.getOrderKey(); + Expression expr = rewriter.rewrite(orderKey.getExpr(), context); + changed |= expr != orderKey.getExpr(); + OrderKey newOrderKey = new OrderKey(expr, orderKey.isAsc(), orderKey.isNullFirst()); + newOrderExpressions.add(new OrderExpression(newOrderKey)); + } + List newPartitionKeys = rewriteAll(partitionTopN.getPartitionKeys(), rewriter, context); + if (!newPartitionKeys.equals(partitionTopN.getPartitionKeys())) { + changed = true; + } + if (!changed) { + return partitionTopN; + } + return partitionTopN.withPartitionKeysAndOrderKeys(newPartitionKeys, newOrderExpressions); + }).toRule(RuleType.REWRITE_PARTITION_TOPN_EXPRESSION); + } + } + + private LogicalSink applyRewrite(MatchingContext> ctx) { + LogicalSink sink = ctx.root; + ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext); + List outputExprs = sink.getOutputExprs(); + List newOutputExprs = rewriteAll(outputExprs, rewriter, context); + if (outputExprs.equals(newOutputExprs)) { + return sink; + } + return sink.withOutputExprs(newOutputExprs); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/info/CreateMTMVInfo.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/info/CreateMTMVInfo.java index 3e8e23d4044e69..a3e1f53acd94b6 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/info/CreateMTMVInfo.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/info/CreateMTMVInfo.java @@ -103,7 +103,8 @@ */ public class CreateMTMVInfo { public static final Logger LOG = LogManager.getLogger(CreateMTMVInfo.class); - public static final String MTMV_PLANER_DISABLE_RULES = "OLAP_SCAN_PARTITION_PRUNE,PRUNE_EMPTY_PARTITION"; + public static final String MTMV_PLANER_DISABLE_RULES = "OLAP_SCAN_PARTITION_PRUNE,PRUNE_EMPTY_PARTITION," + + "ELIMINATE_GROUP_BY_KEY_BY_UNIFORM"; private final boolean ifNotExists; private final TableNameInfo mvName; private List keys; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalFilter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalFilter.java index d23ea3d2395f05..efd7e90c13615e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalFilter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalFilter.java @@ -37,6 +37,7 @@ import java.util.Collection; import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.Optional; import java.util.Set; @@ -154,9 +155,9 @@ public void computeUnique(Builder builder) { @Override public void computeUniform(Builder builder) { for (Expression e : getConjuncts()) { - Set uniformSlots = ExpressionUtils.extractUniformSlot(e); - for (Slot slot : uniformSlots) { - builder.addUniformSlot(slot); + Map uniformSlots = ExpressionUtils.extractUniformSlot(e); + for (Map.Entry entry : uniformSlots.entrySet()) { + builder.addUniformSlotAndLiteral(entry.getKey(), entry.getValue()); } } builder.addUniformSlot(child(0).getLogicalProperties().getTrait()); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalHaving.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalHaving.java index f4f2178840b6ab..680988b39f6bb1 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalHaving.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalHaving.java @@ -35,6 +35,7 @@ import com.google.common.collect.ImmutableSet; import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.Optional; import java.util.Set; @@ -125,9 +126,9 @@ public void computeUnique(Builder builder) { @Override public void computeUniform(Builder builder) { for (Expression e : getConjuncts()) { - Set uniformSlots = ExpressionUtils.extractUniformSlot(e); - for (Slot slot : uniformSlots) { - builder.addUniformSlot(slot); + Map uniformSlots = ExpressionUtils.extractUniformSlot(e); + for (Map.Entry entry : uniformSlots.entrySet()) { + builder.addUniformSlotAndLiteral(entry.getKey(), entry.getValue()); } } builder.addUniformSlot(child(0).getLogicalProperties().getTrait()); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java index f557b07d3b646e..c583360c3d8a76 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java @@ -535,11 +535,35 @@ public void computeUniform(Builder builder) { // TODO disable function dependence calculation for mark join, but need re-think this in future. return; } - if (!joinType.isLeftSemiOrAntiJoin()) { - builder.addUniformSlot(right().getLogicalProperties().getTrait()); - } - if (!joinType.isRightSemiOrAntiJoin()) { - builder.addUniformSlot(left().getLogicalProperties().getTrait()); + switch (joinType) { + case INNER_JOIN: + case CROSS_JOIN: + builder.addUniformSlot(left().getLogicalProperties().getTrait()); + builder.addUniformSlot(right().getLogicalProperties().getTrait()); + break; + case LEFT_SEMI_JOIN: + case LEFT_ANTI_JOIN: + case NULL_AWARE_LEFT_ANTI_JOIN: + builder.addUniformSlot(left().getLogicalProperties().getTrait()); + break; + case RIGHT_SEMI_JOIN: + case RIGHT_ANTI_JOIN: + builder.addUniformSlot(right().getLogicalProperties().getTrait()); + break; + case LEFT_OUTER_JOIN: + builder.addUniformSlot(left().getLogicalProperties().getTrait()); + builder.addUniformSlotForOuterJoinNullableSide(right().getLogicalProperties().getTrait()); + break; + case RIGHT_OUTER_JOIN: + builder.addUniformSlot(right().getLogicalProperties().getTrait()); + builder.addUniformSlotForOuterJoinNullableSide(left().getLogicalProperties().getTrait()); + break; + case FULL_OUTER_JOIN: + builder.addUniformSlotForOuterJoinNullableSide(left().getLogicalProperties().getTrait()); + builder.addUniformSlotForOuterJoinNullableSide(right().getLogicalProperties().getTrait()); + break; + default: + break; } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalProject.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalProject.java index 58d99b635e78ec..159552b4653b07 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalProject.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalProject.java @@ -248,14 +248,18 @@ public void computeUnique(DataTrait.Builder builder) { public void computeUniform(DataTrait.Builder builder) { builder.addUniformSlot(child(0).getLogicalProperties().getTrait()); for (NamedExpression proj : getProjects()) { - if (proj.children().isEmpty()) { + if (!(proj instanceof Alias)) { continue; } if (proj.child(0).isConstant()) { - builder.addUniformSlot(proj.toSlot()); - } else if (ExpressionUtils.isInjective(proj.child(0))) { - ImmutableSet inputs = ImmutableSet.copyOf(proj.getInputSlots()); - if (child(0).getLogicalProperties().getTrait().isUniform(inputs)) { + builder.addUniformSlotAndLiteral(proj.toSlot(), proj.child(0)); + } else if (proj.child(0) instanceof Slot) { + Slot slot = (Slot) proj.child(0); + DataTrait childTrait = child(0).getLogicalProperties().getTrait(); + if (childTrait.isUniformAndHasConstValue(slot)) { + builder.addUniformSlotAndLiteral(proj.toSlot(), + child(0).getLogicalProperties().getTrait().getUniformValue(slot).get()); + } else if (childTrait.isUniform(slot)) { builder.addUniformSlot(proj.toSlot()); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java index 67734f66aa17c1..b303fd2a6b59fd 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java @@ -63,6 +63,7 @@ import com.google.common.base.Predicate; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList.Builder; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Lists; import com.google.common.collect.Maps; @@ -726,15 +727,15 @@ public static List collectToList(Collection express /** * extract uniform slot for the given predicate, such as a = 1 and b = 2 */ - public static ImmutableSet extractUniformSlot(Expression expression) { - ImmutableSet.Builder builder = new ImmutableSet.Builder<>(); + public static ImmutableMap extractUniformSlot(Expression expression) { + ImmutableMap.Builder builder = new ImmutableMap.Builder<>(); if (expression instanceof And) { - builder.addAll(extractUniformSlot(expression.child(0))); - builder.addAll(extractUniformSlot(expression.child(1))); + builder.putAll(extractUniformSlot(expression.child(0))); + builder.putAll(extractUniformSlot(expression.child(1))); } if (expression instanceof EqualTo) { if (isInjective(expression.child(0)) && expression.child(1).isConstant()) { - builder.add((Slot) expression.child(0)); + builder.put((Slot) expression.child(0), expression.child(1)); } } return builder.build(); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/mv/MaterializedViewUtilsTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/mv/MaterializedViewUtilsTest.java index f824a40eda6474..45e1190412d0a4 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/mv/MaterializedViewUtilsTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/mv/MaterializedViewUtilsTest.java @@ -248,7 +248,7 @@ protected void runBeforeAll() throws Exception { + "\"replication_allocation\" = \"tag.location.default: 1\"\n" + ");\n"); // Should not make scan to empty relation when the table used by materialized view has no data - connectContext.getSessionVariable().setDisableNereidsRules("OLAP_SCAN_PARTITION_PRUNE,PRUNE_EMPTY_PARTITION"); + connectContext.getSessionVariable().setDisableNereidsRules("OLAP_SCAN_PARTITION_PRUNE,PRUNE_EMPTY_PARTITION,ELIMINATE_GROUP_BY_KEY_BY_UNIFORM"); } // Test when join both side are all partition table and partition column name is same diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByKeyByUniformTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByKeyByUniformTest.java new file mode 100644 index 00000000000000..78d8034e3fdfed --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByKeyByUniformTest.java @@ -0,0 +1,250 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.nereids.util.MemoPatternMatchSupported; +import org.apache.doris.nereids.util.PlanChecker; +import org.apache.doris.utframe.TestWithFeService; + +import org.junit.jupiter.api.Test; + +public class EliminateGroupByKeyByUniformTest extends TestWithFeService implements MemoPatternMatchSupported { + @Override + protected void runBeforeAll() throws Exception { + createDatabase("test"); + createTable("create table test.eli_gbk_by_uniform_t(a int null, b int not null," + + "c varchar(10) null, d date, dt datetime)\n" + + "distributed by hash(a) properties('replication_num' = '1');"); + connectContext.setDatabase("test"); + connectContext.getSessionVariable().setDisableNereidsRules("PRUNE_EMPTY_PARTITION"); + } + + @Test + void testEliminateByFilter() { + PlanChecker.from(connectContext) + .analyze("select a, min(a), sum(a),b from eli_gbk_by_uniform_t where a = 1 group by a,b") + .rewrite() + .printlnTree() + .matches(logicalAggregate().when(agg -> + agg.getGroupByExpressions().size() == 1 + && agg.getGroupByExpressions().get(0).toSql().equals("b"))); + + } + + @Test + void testNotEliminateWhenOnlyOneGbyKey() { + PlanChecker.from(connectContext) + .analyze("select a, min(a), sum(a) from eli_gbk_by_uniform_t where a = 1 group by a") + .rewrite() + .printlnTree() + .matches(logicalAggregate().when(agg -> + agg.getGroupByExpressions().size() == 1 + && agg.getGroupByExpressions().get(0).toSql().equals("a"))); + + } + + @Test + void testEliminateByProjectConst() { + PlanChecker.from(connectContext) + .analyze("select sum(c1), c2 from (select a c1,1 c2, d c3 from eli_gbk_by_uniform_t) t group by c2,c3 ") + .rewrite() + .printlnTree() + .matches(logicalAggregate().when(agg -> + agg.getGroupByExpressions().size() == 1 + && agg.getGroupByExpressions().get(0).toSql().equals("c3"))); + } + + @Test + void testEliminateByProjectUniformSlot() { + PlanChecker.from(connectContext) + .analyze("select max(c3), c1,c2,c3 from (select a c1,1 c2, d c3 from eli_gbk_by_uniform_t where a=1) t group by c1,c2,c3") + .rewrite() + .printlnTree() + .matches(logicalAggregate().when(agg -> + agg.getGroupByExpressions().size() == 1 + && agg.getGroupByExpressions().get(0).toSql().equals("c3"))); + } + + @Test + void testEliminateDate() { + PlanChecker.from(connectContext) + .analyze("select d, min(a), sum(a), count(a) from eli_gbk_by_uniform_t where d = '2023-01-06' group by d,a") + .rewrite() + .printlnTree() + .matches(logicalAggregate().when(agg -> + agg.getGroupByExpressions().size() == 1 + && agg.getGroupByExpressions().get(0).toSql().equals("a"))); + } + + @Test + void testSaveOneExpr() { + PlanChecker.from(connectContext) + .analyze("select a, min(a), sum(a), count(a) from eli_gbk_by_uniform_t where a = 1 and b=100 group by a, b,'abc'") + .rewrite() + .printlnTree() + .matches(logicalAggregate().when(agg -> + agg.getGroupByExpressions().size() == 1 + && agg.getGroupByExpressions().get(0).toSql().equals("a"))); + } + + @Test + void testSaveOneExprProjectConst() { + PlanChecker.from(connectContext) + .analyze("select c2 from (select a c1,1 c2, 3 c3 from eli_gbk_by_uniform_t) t group by c2,c3 order by 1;") + .rewrite() + .printlnTree() + .matches(logicalAggregate().when(agg -> + agg.getGroupByExpressions().size() == 1 + && agg.getGroupByExpressions().get(0).toSql().equals("c2"))); + } + + @Test + void testNotRewriteWhenHasRepeat() { + PlanChecker.from(connectContext) + .analyze("select c2 from (select a c1,1 c2, 3 c3 from eli_gbk_by_uniform_t) t group by grouping sets((c2),(c3)) order by 1;") + .rewrite() + .printlnTree() + .matches(logicalAggregate().when(agg -> agg.getGroupByExpressions().size() == 3)); + } + + @Test + void testInnerJoin() { + PlanChecker.from(connectContext) + .analyze("select t1.b,t2.b from eli_gbk_by_uniform_t t1 inner join eli_gbk_by_uniform_t t2 on t1.b=t2.b and t1.b=100 group by t1.b,t2.b,t2.c;") + .rewrite() + .printlnTree() + .matches(logicalAggregate().when(agg -> agg.getGroupByExpressions().size() == 1)); + } + + @Test + void testLeftJoinOnConditionNotRewrite() { + PlanChecker.from(connectContext) + .analyze("select t1.b,t2.b from eli_gbk_by_uniform_t t1 left join eli_gbk_by_uniform_t t2 on t1.b=t2.b and t1.b=100 group by t1.b,t2.b,t2.c;") + .rewrite() + .printlnTree() + .matches(logicalAggregate().when(agg -> agg.getGroupByExpressions().size() == 3)); + } + + @Test + void testLeftJoinWhereConditionRewrite() { + PlanChecker.from(connectContext) + .analyze("select t1.b,t2.b from eli_gbk_by_uniform_t t1 left join eli_gbk_by_uniform_t t2 on t1.b=t2.b where t1.b=100 group by t1.b,t2.b,t2.c;") + .rewrite() + .printlnTree() + .matches(logicalAggregate().when(agg -> agg.getGroupByExpressions().size() == 2)); + } + + @Test + void testRightJoinOnConditionNullableSideFilterNotRewrite() { + PlanChecker.from(connectContext) + .analyze("select t1.b,t2.b from eli_gbk_by_uniform_t t1 right join eli_gbk_by_uniform_t t2 on t1.b=t2.b and t1.b=100 group by t1.b,t2.b,t2.c;") + .rewrite() + .printlnTree() + .matches(logicalAggregate().when(agg -> agg.getGroupByExpressions().size() == 3)); + } + + @Test + void testRightJoinOnConditionNonNullableSideFilterNotRewrite() { + PlanChecker.from(connectContext) + .analyze("select t1.b,t2.b from eli_gbk_by_uniform_t t1 right join eli_gbk_by_uniform_t t2 on t1.b=t2.b and t2.b=100 group by t1.b,t2.b,t2.c;") + .rewrite() + .printlnTree() + .matches(logicalAggregate().when(agg -> agg.getGroupByExpressions().size() == 3)); + } + + @Test + void testRightJoinWhereConditionToInnerRewrite() { + PlanChecker.from(connectContext) + .analyze("select t1.b,t2.b from eli_gbk_by_uniform_t t1 right join eli_gbk_by_uniform_t t2 on t1.b=t2.b where t1.b=100 group by t1.b,t2.b,t2.c;") + .rewrite() + .printlnTree() + .matches(logicalAggregate().when(agg -> agg.getGroupByExpressions().size() == 1)); + } + + @Test + void testLeftSemiJoinWhereConditionRewrite() { + PlanChecker.from(connectContext) + .analyze("select t1.b from eli_gbk_by_uniform_t t1 left semi join eli_gbk_by_uniform_t t2 on t1.b=t2.b and t2.b=100 group by t1.b,t1.a") + .rewrite() + .printlnTree() + .matches(logicalAggregate().when(agg -> agg.getGroupByExpressions().size() == 1)); + } + + @Test + void testLeftSemiJoinRetainOneSlotInGroupBy() { + PlanChecker.from(connectContext) + .analyze("select t1.b from eli_gbk_by_uniform_t t1 left semi join eli_gbk_by_uniform_t t2 on t1.b=t2.b and t2.b=100 group by t1.b") + .rewrite() + .printlnTree() + .matches(logicalAggregate().when(agg -> agg.getGroupByExpressions().size() == 1)); + } + + @Test + void testRightSemiJoinWhereConditionRewrite() { + PlanChecker.from(connectContext) + .analyze("select t2.b from eli_gbk_by_uniform_t t1 right semi join eli_gbk_by_uniform_t t2 on t1.b=t2.b and t2.b=100 group by t2.b,t2.a") + .rewrite() + .printlnTree() + .matches(logicalAggregate().when(agg -> agg.getGroupByExpressions().size() == 1)); + } + + @Test + void testRightSemiJoinRetainOneSlotInGroupBy() { + PlanChecker.from(connectContext) + .analyze("select t2.b from eli_gbk_by_uniform_t t1 right semi join eli_gbk_by_uniform_t t2 on t1.b=t2.b and t2.b=100 group by t2.b") + .rewrite() + .printlnTree() + .matches(logicalAggregate().when(agg -> agg.getGroupByExpressions().size() == 1)); + } + + @Test + void testLeftAntiJoinOnConditionNotRewrite() { + PlanChecker.from(connectContext) + .analyze("select t1.b from eli_gbk_by_uniform_t t1 left anti join eli_gbk_by_uniform_t t2 on t1.b=t2.b and t1.b=100 group by t1.b,t1.a") + .rewrite() + .printlnTree() + .matches(logicalAggregate().when(agg -> agg.getGroupByExpressions().size() == 2)); + } + + @Test + void testLeftAntiJoinWhereConditionRewrite() { + PlanChecker.from(connectContext) + .analyze("select t1.b from eli_gbk_by_uniform_t t1 left anti join eli_gbk_by_uniform_t t2 on t1.b=t2.b where t1.b=100 group by t1.b,t1.c") + .rewrite() + .printlnTree() + .matches(logicalAggregate().when(agg -> agg.getGroupByExpressions().size() == 1)); + } + + @Test + void testRightAntiJoinOnConditionNotRewrite() { + PlanChecker.from(connectContext) + .analyze("select t2.b from eli_gbk_by_uniform_t t1 right anti join eli_gbk_by_uniform_t t2 on t1.b=t2.b and t1.b=100 group by t2.b,t2.a") + .rewrite() + .printlnTree() + .matches(logicalAggregate().when(agg -> agg.getGroupByExpressions().size() == 2)); + } + + @Test + void testRightAntiJoinWhereConditionRewrite() { + PlanChecker.from(connectContext) + .analyze("select t2.b from eli_gbk_by_uniform_t t1 right anti join eli_gbk_by_uniform_t t2 on t1.b=t2.b where t2.b=100 group by t2.b,t2.c") + .rewrite() + .printlnTree() + .matches(logicalAggregate().when(agg -> agg.getGroupByExpressions().size() == 1)); + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByKeyTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByKeyTest.java index 5a9e15cf4774d1..103e074c73bfd5 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByKeyTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByKeyTest.java @@ -111,7 +111,7 @@ void testProjectAlias() { .rewrite() .printlnTree() .matches(logicalAggregate().when(agg -> - agg.getGroupByExpressions().size() == 2)); + agg.getGroupByExpressions().size() == 1)); PlanChecker.from(connectContext) .analyze("select id as c, name as n from t1 group by name, id") .rewrite() @@ -123,7 +123,7 @@ void testProjectAlias() { .rewrite() .printlnTree() .matches(logicalAggregate().when(agg -> - agg.getGroupByExpressions().size() == 2)); + agg.getGroupByExpressions().size() == 1)); } @Test diff --git a/regression-test/data/nereids_rules_p0/eliminate_gby_key/eliminate_group_by_key_by_uniform.out b/regression-test/data/nereids_rules_p0/eliminate_gby_key/eliminate_group_by_key_by_uniform.out new file mode 100644 index 00000000000000..32e6d2c32fb6d5 --- /dev/null +++ b/regression-test/data/nereids_rules_p0/eliminate_gby_key/eliminate_group_by_key_by_uniform.out @@ -0,0 +1,267 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !empty_tranform_not_to_scalar_agg -- + +-- !empty_tranform_multi_column -- + +-- !empty_tranform_multi_column -- +1 1 2 2 + +-- !tranform_to_scalar_agg_not_null_column -- + +-- !project_const -- +\N 1 +\N 1 +2 1 +2 1 +4 1 +6 1 +6 1 +10 1 + +-- !project_slot_uniform -- +2023-01-01 1 1 2023-01-01 + +-- !upper_refer -- + +-- !upper_refer_varchar_alias -- +cherry 3 + +-- !upper_refer_date -- +2023-01-06 + +-- !upper_refer_datetime_not_to_scalar_agg -- +2023-01-06T15:00 + +-- !upper_refer_datetime -- +2023-01-06T15:00 + +-- !project_no_other_agg_func -- +1 +1 +1 +1 +1 +1 +1 +1 + +-- !project_const_not_to_scalar_agg_multi -- +1 + +-- !not_to_scalar_agg_multi -- +1 1 2 2 + +-- !conflict_equal_value -- + +-- !project_slot_uniform_confict_value -- + +-- !inner_join_left_has_filter -- +100 100 + +-- !inner_join_right_has_filter -- +100 100 + +-- !left_join_right_has_filter -- +100 100 +101 \N +102 \N +103 \N +104 \N +105 \N +106 \N +107 \N + +-- !left_join_left_has_filter -- +100 100 +101 \N +102 \N +103 \N +104 \N +105 \N +106 \N +107 \N + +-- !right_join_right_has_filter -- +\N 101 +\N 102 +\N 103 +\N 104 +\N 105 +\N 106 +\N 107 +100 100 + +-- !right_join_left_has_filter -- +\N 101 +\N 102 +\N 103 +\N 104 +\N 105 +\N 106 +\N 107 +100 100 + +-- !left_semi_join_right_has_filter -- +100 + +-- !left_semi_join_left_has_filter -- +100 + +-- !left_anti_join_right_has_on_filter -- +101 +102 +103 +104 +105 +106 +107 + +-- !left_anti_join_left_has_on_filter -- +101 +102 +103 +104 +105 +106 +107 + +-- !left_anti_join_left_has_where_filter -- + +-- !right_semi_join_right_has_filter -- +100 + +-- !right_semi_join_left_has_filter -- +100 + +-- !right_anti_join_right_has_on_filter -- +101 +102 +103 +104 +105 +106 +107 + +-- !right_anti_join_left_has_on_filter -- +101 +102 +103 +104 +105 +106 +107 + +-- !right_anti_join_right_has_where_filter -- + +-- !cross_join_left_has_filter -- +100 100 +100 101 +100 102 +100 103 +100 104 +100 105 +100 106 +100 107 + +-- !cross_join_right_has_filter -- +100 100 +101 100 +102 100 +103 100 +104 100 +105 100 +106 100 +107 100 + +-- !union -- +1 100 +5 105 + +-- !union_all -- +1 100 +1 100 +5 105 + +-- !intersect -- + +-- !except -- + +-- !set_op_mixed -- +1 100 + +-- !window -- + +-- !partition_topn -- + +-- !cte_producer -- +1 1 100 + +-- !cte_multi_producer -- + +-- !cte_consumer -- + +-- !filter -- +1 100 + +-- !topn -- +1 100 + +-- !sink -- +\N 103 date 2023-01-04 2023-01-04T13:00 +\N 107 grape 2023-01-08 2023-01-08T17:00 +1 100 apple 2023-01-01 2023-01-01T10:00 +1 100 apple 2023-01-01 2023-01-01T10:00 +1 100 apple 2023-01-01 2023-01-01T10:00 +2 101 banana 2023-01-02 2023-01-02T11:00 +3 102 cherry 2023-01-03 2023-01-03T12:00 +3 102 cherry 2023-01-03 2023-01-03T12:00 +4 104 elderberry 2023-01-05 2023-01-05T14:00 +5 105 \N 2023-01-06 2023-01-06T15:00 +5 105 \N 2023-01-06 2023-01-06T15:00 +6 106 fig 2023-01-07 2023-01-07T16:00 + +-- !nest_exprid_replace -- +2023-10-17 2 2023-10-17 2 6 +2023-10-17 2 2023-10-18 2 6 +2023-10-17 2 2023-10-21 2 6 +2023-10-18 2 2023-10-17 2 6 +2023-10-18 2 2023-10-18 2 6 +2023-10-18 2 2023-10-21 2 6 + +-- !full_join_uniform_should_not_eliminate_group_by_key -- +\N 1 +105 1 + +-- !full2 -- +1 \N +1 105 + +-- !left_join_right_side_should_not_eliminate_group_by_key -- +\N 1 +105 1 + +-- !left_join_left_side_should_eliminate_group_by_key -- +\N 1 +105 1 + +-- !right_join_left_side_should_not_eliminate_group_by_key -- +1 \N +1 105 + +-- !right_join_right_side_should_eliminate_group_by_key -- +1 \N +1 105 + +-- !left_semi_left_side -- +1 +1 + +-- !left_anti_left_side -- +1 + +-- !right_semi_right_side -- +105 +105 + +-- !right_anti_right_side -- + diff --git a/regression-test/suites/nereids_rules_p0/eliminate_gby_key/eliminate_group_by_key_by_uniform.groovy b/regression-test/suites/nereids_rules_p0/eliminate_gby_key/eliminate_group_by_key_by_uniform.groovy new file mode 100644 index 00000000000000..89abe58fc1f003 --- /dev/null +++ b/regression-test/suites/nereids_rules_p0/eliminate_gby_key/eliminate_group_by_key_by_uniform.groovy @@ -0,0 +1,221 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +suite("eliminate_group_by_key_by_uniform") { + sql "set enable_nereids_rules = 'ELIMINATE_GROUP_BY_KEY_BY_UNIFORM'" + sql "drop table if exists eli_gbk_by_uniform_t" + sql """create table eli_gbk_by_uniform_t(a int null, b int not null, c varchar(10) null, d date, dt datetime) + distributed by hash(a) properties("replication_num"="1"); + """ + qt_empty_tranform_not_to_scalar_agg "select a, min(a), sum(a), count(a) from eli_gbk_by_uniform_t where a = 1 group by a" + qt_empty_tranform_multi_column "select a,b, min(a), sum(a), count(a) from eli_gbk_by_uniform_t where a = 1 and b=2 group by a,b" + + sql """ + INSERT INTO eli_gbk_by_uniform_t (a, b, c, d, dt) VALUES + (1, 100, 'apple', '2023-01-01', '2023-01-01 10:00:00'), + (1, 100, 'apple', '2023-01-01', '2023-01-01 10:00:00'), + (2, 101, 'banana', '2023-01-02', '2023-01-02 11:00:00'), + (3, 102, 'cherry', '2023-01-03', '2023-01-03 12:00:00'), + (3, 102, 'cherry', '2023-01-03', '2023-01-03 12:00:00'), + (NULL, 103, 'date', '2023-01-04', '2023-01-04 13:00:00'), + (4, 104, 'elderberry', '2023-01-05', '2023-01-05 14:00:00'), + (5, 105, NULL, '2023-01-06', '2023-01-06 15:00:00'), + (5, 105, NULL, '2023-01-06', '2023-01-06 15:00:00'), + (6, 106, 'fig', '2023-01-07', '2023-01-07 16:00:00'), + (NULL, 107, 'grape', '2023-01-08', '2023-01-08 17:00:00'); + """ + qt_empty_tranform_multi_column "select a, min(a), sum(a), count(a) from eli_gbk_by_uniform_t where a = 1 group by a, b,'abc' order by 1,2,3,4" + qt_tranform_to_scalar_agg_not_null_column "select b, min(a), sum(a), count(a) from eli_gbk_by_uniform_t where b = 1 group by a, b order by 1,2,3,4" + + qt_project_const "select sum(c1), c2 from (select a c1,1 c2, d c3 from eli_gbk_by_uniform_t) t group by c2,c3 order by 1,2;" + qt_project_slot_uniform "select max(c3), c1,c2,c3 from (select a c1,1 c2, d c3 from eli_gbk_by_uniform_t where a=1) t group by c1,c2,c3 order by 1,2,3,4;" + + qt_upper_refer "select b from (select b, min(a), sum(a), count(a) from eli_gbk_by_uniform_t where b = 1 group by a, b) t order by b" + qt_upper_refer_varchar_alias "select c1,c2 from (select c as c1, min(a) c2, sum(a), count(a) from eli_gbk_by_uniform_t where c = 'cherry' group by a, b,c) t order by c1,c2" + qt_upper_refer_date "select d from (select d, min(a), sum(a), count(a) from eli_gbk_by_uniform_t where d = '2023-01-06' group by d,a) t order by 1" + qt_upper_refer_datetime_not_to_scalar_agg "select dt from (select dt, min(a), sum(a), count(a) from eli_gbk_by_uniform_t where dt = '2023-01-06 15:00:00' group by dt) t order by 1" + qt_upper_refer_datetime "select dt from (select dt, min(a), sum(a), count(a) from eli_gbk_by_uniform_t where dt = '2023-01-06 15:00:00' group by dt, a) t order by 1" + + qt_project_no_other_agg_func "select c2 from (select a c1,1 c2, d c3 from eli_gbk_by_uniform_t) t group by c2,c3 order by 1;" + qt_project_const_not_to_scalar_agg_multi "select c2 from (select a c1,1 c2, 3 c3 from eli_gbk_by_uniform_t) t group by c2,c3 order by 1;" + qt_not_to_scalar_agg_multi "select a, min(a), sum(a), count(a) from eli_gbk_by_uniform_t where a = 1 and b=100 group by a, b,'abc' order by 1,2,3,4" + qt_conflict_equal_value "select a, min(a), sum(a), count(a) from eli_gbk_by_uniform_t where a = 1 and a=2 group by a, b,'abc' order by 1,2,3,4" + qt_project_slot_uniform_confict_value "select max(c3), c1,c2,c3 from (select a c1,1 c2, d c3 from eli_gbk_by_uniform_t where a=1) t where c2=2 group by c1,c2,c3 order by 1,2,3,4;" + + // test join + qt_inner_join_left_has_filter "select t1.b,t2.b from eli_gbk_by_uniform_t t1 inner join eli_gbk_by_uniform_t t2 on t1.b=t2.b and t1.b=100 group by t1.b,t2.b,t2.c order by 1,2" + qt_inner_join_right_has_filter "select t1.b,t2.b from eli_gbk_by_uniform_t t1 inner join eli_gbk_by_uniform_t t2 on t1.b=t2.b and t2.b=100 group by t1.b,t2.b,t2.c order by 1,2" + qt_left_join_right_has_filter "select t1.b,t2.b from eli_gbk_by_uniform_t t1 left join eli_gbk_by_uniform_t t2 on t1.b=t2.b and t2.b=100 group by t1.b,t2.b,t2.c order by 1,2" + qt_left_join_left_has_filter "select t1.b,t2.b from eli_gbk_by_uniform_t t1 left join eli_gbk_by_uniform_t t2 on t1.b=t2.b and t1.b=100 group by t1.b,t2.b,t2.c order by 1,2" + qt_right_join_right_has_filter "select t1.b,t2.b from eli_gbk_by_uniform_t t1 right join eli_gbk_by_uniform_t t2 on t1.b=t2.b and t2.b=100 group by t1.b,t2.b,t2.c order by 1,2" + qt_right_join_left_has_filter "select t1.b,t2.b from eli_gbk_by_uniform_t t1 right join eli_gbk_by_uniform_t t2 on t1.b=t2.b and t1.b=100 group by t1.b,t2.b,t2.c order by 1,2" + qt_left_semi_join_right_has_filter "select t1.b from eli_gbk_by_uniform_t t1 left semi join eli_gbk_by_uniform_t t2 on t1.b=t2.b and t2.b=100 group by t1.b,t1.a order by 1" + qt_left_semi_join_left_has_filter "select t1.b from eli_gbk_by_uniform_t t1 left semi join eli_gbk_by_uniform_t t2 on t1.b=t2.b and t1.b=100 group by t1.b,t1.a order by 1" + qt_left_anti_join_right_has_on_filter "select t1.b from eli_gbk_by_uniform_t t1 left anti join eli_gbk_by_uniform_t t2 on t1.b=t2.b and t2.b=100 group by t1.b,t1.a order by 1" + qt_left_anti_join_left_has_on_filter "select t1.b from eli_gbk_by_uniform_t t1 left anti join eli_gbk_by_uniform_t t2 on t1.b=t2.b and t1.b=100 group by t1.b,t1.a order by 1" + qt_left_anti_join_left_has_where_filter "select t1.b from eli_gbk_by_uniform_t t1 left anti join eli_gbk_by_uniform_t t2 on t1.b=t2.b where t1.b=100 group by t1.b,t1.a order by 1" + qt_right_semi_join_right_has_filter "select t2.b from eli_gbk_by_uniform_t t1 right semi join eli_gbk_by_uniform_t t2 on t1.b=t2.b and t2.b=100 group by t2.b,t2.c order by 1" + qt_right_semi_join_left_has_filter "select t2.b from eli_gbk_by_uniform_t t1 right semi join eli_gbk_by_uniform_t t2 on t1.b=t2.b and t1.b=100 group by t2.b,t2.c order by 1" + qt_right_anti_join_right_has_on_filter "select t2.b from eli_gbk_by_uniform_t t1 right anti join eli_gbk_by_uniform_t t2 on t1.b=t2.b and t2.b=100 group by t2.b,t2.c order by 1" + qt_right_anti_join_left_has_on_filter "select t2.b from eli_gbk_by_uniform_t t1 right anti join eli_gbk_by_uniform_t t2 on t1.b=t2.b and t1.b=100 group by t2.b,t2.c order by 1" + qt_right_anti_join_right_has_where_filter "select t2.b from eli_gbk_by_uniform_t t1 right anti join eli_gbk_by_uniform_t t2 on t1.b=t2.b where t2.b=100 group by t2.b,t2.c order by 1" + qt_cross_join_left_has_filter "select t1.b,t2.b from eli_gbk_by_uniform_t t1 cross join eli_gbk_by_uniform_t t2 where t1.b=100 group by t1.b,t2.b,t2.c order by 1,2" + qt_cross_join_right_has_filter "select t1.b,t2.b from eli_gbk_by_uniform_t t1 cross join eli_gbk_by_uniform_t t2 where t2.b=100 group by t1.b,t2.b,t2.c order by 1,2" + + //test union + qt_union "select * from (select a,b from eli_gbk_by_uniform_t where a=1 group by a,b union select a,b from eli_gbk_by_uniform_t where b=100 group by a,b union select a,b from eli_gbk_by_uniform_t where a=5 group by a,b) t order by 1,2,3,4,5" + qt_union_all "select * from (select a,b from eli_gbk_by_uniform_t where a=1 group by a,b union all select a,b from eli_gbk_by_uniform_t where b=100 group by a,b union all select a,b from eli_gbk_by_uniform_t where a=5 group by a,b) t order by 1,2,3,4,5" + qt_intersect "select * from (select a,b from eli_gbk_by_uniform_t where a=1 group by a,b intersect select a,b from eli_gbk_by_uniform_t where b=100 group by a,b intersect select a,b from eli_gbk_by_uniform_t where a=5 group by a,b) t order by 1,2,3,4,5" + qt_except "select * from (select a,b from eli_gbk_by_uniform_t where a=1 group by a,b except select a,b from eli_gbk_by_uniform_t where b=100 group by a,b except select a,b from eli_gbk_by_uniform_t where a=5 group by a,b) t order by 1,2,3,4,5" + qt_set_op_mixed "select * from (select a,b from eli_gbk_by_uniform_t where a=1 group by a,b union select a,b from eli_gbk_by_uniform_t where b=100 group by a,b except select a,b from eli_gbk_by_uniform_t where a=5 group by a,b) t order by 1,2,3,4,5" + + //test window + qt_window "select max(a) over(partition by a order by a) from eli_gbk_by_uniform_t where a=10 group by a,b order by 1" + //test partition topn + qt_partition_topn "select r from (select rank() over(partition by a order by a) r from eli_gbk_by_uniform_t where a=10 group by a,b) t where r<2 order by 1" + + //test cte + qt_cte_producer "with t as (select a,b,count(*) from eli_gbk_by_uniform_t where a=1 group by a,b) select t1.a,t2.a,t2.b from t t1 inner join t t2 on t1.a=t2.a order by 1,2,3" + qt_cte_multi_producer "with t as (select a,b,count(*) from eli_gbk_by_uniform_t where a=1 group by a,b), tt as (select a,b,count(*) from eli_gbk_by_uniform_t where b=10 group by a,b) select t1.a,t2.a,t2.b from t t1 inner join tt t2 on t1.a=t2.a order by 1,2,3" + qt_cte_consumer "with t as (select * from eli_gbk_by_uniform_t) select t1.a,t2.b from t t1 inner join t t2 on t1.a=t2.a where t1.a=10 group by t1.a,t2.b order by 1,2 " + + //test filter + qt_filter "select * from (select a,b from eli_gbk_by_uniform_t where a=1 group by a,b) t where a>0 order by 1,2" + + //test topn + qt_topn "select a,b from eli_gbk_by_uniform_t where a=1 group by a,b order by a limit 10 offset 0" + + //olap table sink + sql "insert into eli_gbk_by_uniform_t select a,b,c,d,dt from eli_gbk_by_uniform_t where a = 1 group by a,b,c,d,dt" + qt_sink "select * from eli_gbk_by_uniform_t order by 1,2,3,4,5" + + sql """ + drop table if exists orders_inner_1 + """ + + sql """CREATE TABLE `orders_inner_1` ( + `o_orderkey` BIGINT not NULL, + `o_custkey` INT NULL, + `o_orderstatus` VARCHAR(1) NULL, + `o_totalprice` DECIMAL(15, 2) NULL, + `o_orderpriority` VARCHAR(15) NULL, + `o_clerk` VARCHAR(15) NULL, + `o_shippriority` INT NULL, + `o_comment` VARCHAR(79) NULL, + `o_orderdate` DATE NULL + ) ENGINE=OLAP + DUPLICATE KEY(`o_orderkey`, `o_custkey`) + COMMENT 'OLAP' + PARTITION BY list(o_orderkey) ( + PARTITION p1 VALUES in ('1'), + PARTITION p2 VALUES in ('2'), + PARTITION p3 VALUES in ('3'), + PARTITION p4 VALUES in ('4') + ) + DISTRIBUTED BY HASH(`o_orderkey`) BUCKETS 96 + PROPERTIES ( + "replication_allocation" = "tag.location.default: 1" + );""" + + sql """ + drop table if exists lineitem_inner_1 + """ + + sql """CREATE TABLE `lineitem_inner_1` ( + `l_orderkey` BIGINT not NULL, + `l_linenumber` INT NULL, + `l_partkey` INT NULL, + `l_suppkey` INT NULL, + `l_quantity` DECIMAL(15, 2) NULL, + `l_extendedprice` DECIMAL(15, 2) NULL, + `l_discount` DECIMAL(15, 2) NULL, + `l_tax` DECIMAL(15, 2) NULL, + `l_returnflag` VARCHAR(1) NULL, + `l_linestatus` VARCHAR(1) NULL, + `l_commitdate` DATE NULL, + `l_receiptdate` DATE NULL, + `l_shipinstruct` VARCHAR(25) NULL, + `l_shipmode` VARCHAR(10) NULL, + `l_comment` VARCHAR(44) NULL, + `l_shipdate` DATE NULL + ) ENGINE=OLAP + DUPLICATE KEY(l_orderkey, l_linenumber, l_partkey, l_suppkey ) + COMMENT 'OLAP' + PARTITION BY list(l_orderkey) ( + PARTITION p1 VALUES in ('1'), + PARTITION p2 VALUES in ('2'), + PARTITION p3 VALUES in ('3') + ) + DISTRIBUTED BY HASH(`l_orderkey`) BUCKETS 96 + PROPERTIES ( + "replication_allocation" = "tag.location.default: 1" + );""" + + sql """ + insert into orders_inner_1 values + (2, 1, 'o', 99.5, 'a', 'b', 1, 'yy', '2023-10-17'), + (1, null, 'k', 109.2, 'c','d',2, 'mm', '2023-10-17'), + (3, 3, null, 99.5, 'a', 'b', 1, 'yy', '2023-10-19'), + (1, 2, 'o', null, 'a', 'b', 1, 'yy', '2023-10-20'), + (2, 3, 'k', 109.2, null,'d',2, 'mm', '2023-10-21'), + (3, 1, 'o', 99.5, 'a', null, 1, 'yy', '2023-10-22'), + (1, 3, 'k', 99.5, 'a', 'b', null, 'yy', '2023-10-19'), + (2, 1, 'o', 109.2, 'c','d',2, null, '2023-10-18'), + (3, 2, 'k', 99.5, 'a', 'b', 1, 'yy', '2023-10-17'), + (4, 5, 'o', 99.5, 'a', 'b', 1, 'yy', '2023-10-19'); + """ + + sql """ + insert into lineitem_inner_1 values + (2, 1, 2, 3, 5.5, 6.5, 7.5, 8.5, 'o', 'k', '2023-10-17', '2023-10-17', 'a', 'b', 'yyyyyyyyy', '2023-10-17'), + (1, null, 3, 1, 5.5, 6.5, 7.5, 8.5, 'o', 'k', '2023-10-18', '2023-10-18', 'a', 'b', 'yyyyyyyyy', '2023-10-17'), + (3, 3, null, 2, 7.5, 8.5, 9.5, 10.5, 'k', 'o', '2023-10-19', '2023-10-19', 'c', 'd', 'xxxxxxxxx', '2023-10-19'), + (1, 2, 3, null, 5.5, 6.5, 7.5, 8.5, 'o', 'k', '2023-10-17', '2023-10-17', 'a', 'b', 'yyyyyyyyy', '2023-10-17'), + (2, 3, 2, 1, 5.5, 6.5, 7.5, 8.5, 'o', 'k', null, '2023-10-18', 'a', 'b', 'yyyyyyyyy', '2023-10-18'), + (3, 1, 1, 2, 7.5, 8.5, 9.5, 10.5, 'k', 'o', '2023-10-19', null, 'c', 'd', 'xxxxxxxxx', '2023-10-19'), + (1, 3, 2, 2, 5.5, 6.5, 7.5, 8.5, 'o', 'k', '2023-10-17', '2023-10-17', 'a', 'b', 'yyyyyyyyy', '2023-10-17'); + """ + + qt_nest_exprid_replace """ + select l_shipdate, l_orderkey, t.O_ORDERDATE, t.o_orderkey, + count(t.O_ORDERDATE) over (partition by lineitem_inner_1.l_orderkey order by lineitem_inner_1.l_orderkey) as window_count + from lineitem_inner_1 + inner join (select O_ORDERDATE, o_orderkey, count(O_ORDERDATE) over (partition by O_ORDERDATE order by o_orderkey ) from orders_inner_1 where o_orderkey=2 group by O_ORDERDATE, o_orderkey) as t + on lineitem_inner_1.l_orderkey = t.o_orderkey + where t.o_orderkey=2 + group by l_shipdate, l_orderkey, t.O_ORDERDATE, t.o_orderkey + order by 1,2,3,4,5 + """ + sql "drop table if exists test1" + sql "drop table if exists test2" + sql "create table test1(a int, b int) distributed by hash(a) properties('replication_num'='1');" + sql "insert into test1 values(1,1),(2,1),(3,1);" + sql "create table test2(a int, b int) distributed by hash(a) properties('replication_num'='1');" + sql "insert into test2 values(1,105),(2,105);" + qt_full_join_uniform_should_not_eliminate_group_by_key "select t2.b,t1.b from test1 t1 full join (select * from test2 where b=105) t2 on t1.a=t2.a group by t2.b,t1.b order by 1,2;" + qt_full2 "select t2.b,t1.b from (select * from test2 where b=105) t1 full join test1 t2 on t1.a=t2.a group by t2.b,t1.b order by 1,2;" + + qt_left_join_right_side_should_not_eliminate_group_by_key "select t2.b,t1.b from test1 t1 left join (select * from test2 where b=105) t2 on t1.a=t2.a group by t2.b,t1.b order by 1,2;" + qt_left_join_left_side_should_eliminate_group_by_key "select t2.b,t1.b from test1 t1 left join (select * from test2 where b=105) t2 on t1.a=t2.a where t1.b=1 group by t2.b,t1.b order by 1,2;" + + qt_right_join_left_side_should_not_eliminate_group_by_key "select t2.b,t1.b from (select * from test2 where b=105) t1 right join test1 t2 on t1.a=t2.a group by t2.b,t1.b order by 1,2;" + qt_right_join_right_side_should_eliminate_group_by_key "select t2.b,t1.b from (select * from test2 where b=105) t1 right join test1 t2 on t1.a=t2.a where t2.b=1 group by t2.b,t1.b order by 1,2;" + + qt_left_semi_left_side "select t1.b from test1 t1 left semi join (select * from test2 where b=105) t2 on t1.a=t2.a where t1.b=1 group by t1.b,t1.a order by 1;" + qt_left_anti_left_side "select t1.b from test1 t1 left anti join (select * from test2 where b=105) t2 on t1.a=t2.a where t1.b=1 group by t1.b,t1.a order by 1;" + qt_right_semi_right_side "select t2.b from test1 t1 right semi join (select * from test2 where b=105) t2 on t1.a=t2.a group by t2.b,t2.a order by 1;" + qt_right_anti_right_side "select t2.b from test1 t1 right anti join (select * from test2 where b=105) t2 on t1.a=t2.a group by t2.b,t2.a order by 1;" +} \ No newline at end of file diff --git a/regression-test/suites/nereids_rules_p0/mv/agg_without_roll_up/aggregate_without_roll_up.groovy b/regression-test/suites/nereids_rules_p0/mv/agg_without_roll_up/aggregate_without_roll_up.groovy index f082b3bdefd20c..356b96267a88f5 100644 --- a/regression-test/suites/nereids_rules_p0/mv/agg_without_roll_up/aggregate_without_roll_up.groovy +++ b/regression-test/suites/nereids_rules_p0/mv/agg_without_roll_up/aggregate_without_roll_up.groovy @@ -372,7 +372,7 @@ suite("aggregate_without_roll_up") { "max(o_totalprice) as max_total, " + "min(o_totalprice) as min_total, " + "count(*) as count_all, " + - "count(distinct case when o_shippriority > 1 and o_orderkey IN (1, 3) then o_custkey else null end) as distinct_count " + + "bitmap_union(to_bitmap(case when o_shippriority > 1 and o_orderkey IN (1, 3) then o_custkey else null end)) as distinct_count " + "from lineitem " + "left join orders on lineitem.l_orderkey = orders.o_orderkey and l_shipdate = o_orderdate " + "group by " + @@ -570,7 +570,7 @@ suite("aggregate_without_roll_up") { "max(o_totalprice) as max_total, " + "min(o_totalprice) as min_total, " + "count(*) as count_all, " + - "count(distinct case when o_shippriority > 1 and o_orderkey IN (1, 3) then o_custkey else null end) as distinct_count " + + "bitmap_union(to_bitmap(case when o_shippriority > 1 and o_orderkey IN (1, 3) then o_custkey else null end)) as distinct_count " + "from lineitem " + "left join orders on lineitem.l_orderkey = orders.o_orderkey and l_shipdate = o_orderdate " + "group by " + @@ -660,7 +660,7 @@ suite("aggregate_without_roll_up") { "max(o_totalprice) as max_total, " + "min(o_totalprice) as min_total, " + "count(*) as count_all, " + - "count(distinct case when o_shippriority > 1 and o_orderkey IN (1, 3) then o_custkey else null end) as distinct_count " + + "bitmap_union(to_bitmap(case when o_shippriority > 1 and o_orderkey IN (1, 3) then o_custkey else null end)) as distinct_count " + "from lineitem " + "left join orders on lineitem.l_orderkey = orders.o_orderkey and l_shipdate = o_orderdate " + "group by " +