Skip to content

Commit

Permalink
[feat](nereids) add rewrite rule :EliminateGroupByKeyByUniform (#43391)…
Browse files Browse the repository at this point in the history
… (#45075)

cherry-pick #43391 to branch-3.0
  • Loading branch information
feiniaofeiafei authored Dec 10, 2024
1 parent 9b2ae0c commit 8a61eb9
Show file tree
Hide file tree
Showing 18 changed files with 1,421 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
import org.apache.doris.nereids.rules.rewrite.EliminateFilter;
import org.apache.doris.nereids.rules.rewrite.EliminateGroupBy;
import org.apache.doris.nereids.rules.rewrite.EliminateGroupByKey;
import org.apache.doris.nereids.rules.rewrite.EliminateGroupByKeyByUniform;
import org.apache.doris.nereids.rules.rewrite.EliminateJoinByFK;
import org.apache.doris.nereids.rules.rewrite.EliminateJoinByUnique;
import org.apache.doris.nereids.rules.rewrite.EliminateJoinCondition;
Expand Down Expand Up @@ -356,6 +357,7 @@ public class Rewriter extends AbstractBatchJobExecutor {
topDown(new EliminateJoinByUnique())
),
topic("eliminate Aggregate according to fd items",
custom(RuleType.ELIMINATE_GROUP_BY_KEY_BY_UNIFORM, EliminateGroupByKeyByUniform::new),
topDown(new EliminateGroupByKey()),
topDown(new PushDownAggThroughJoinOnPkFk()),
topDown(new PullUpJoinFromUnionAll())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,23 @@

package org.apache.doris.nereids.properties;

import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.functions.ExpressionTrait;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.util.ImmutableEqualSet;

import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

Expand All @@ -46,16 +51,16 @@
public class DataTrait {

public static final DataTrait EMPTY_TRAIT
= new DataTrait(new NestedSet().toImmutable(),
new NestedSet().toImmutable(), new ImmutableSet.Builder<FdItem>().build(),
= new DataTrait(new UniqueDescription().toImmutable(),
new UniformDescription().toImmutable(), new ImmutableSet.Builder<FdItem>().build(),
ImmutableEqualSet.empty(), new FuncDepsDG.Builder().build());
private final NestedSet uniqueSet;
private final NestedSet uniformSet;
private final UniqueDescription uniqueSet;
private final UniformDescription uniformSet;
private final ImmutableSet<FdItem> fdItems;
private final ImmutableEqualSet<Slot> equalSet;
private final FuncDepsDG fdDg;

private DataTrait(NestedSet uniqueSet, NestedSet uniformSet, ImmutableSet<FdItem> fdItems,
private DataTrait(UniqueDescription uniqueSet, UniformDescription uniformSet, ImmutableSet<FdItem> fdItems,
ImmutableEqualSet<Slot> equalSet, FuncDepsDG fdDg) {
this.uniqueSet = uniqueSet;
this.uniformSet = uniformSet;
Expand Down Expand Up @@ -86,8 +91,7 @@ public boolean isUniform(Slot slot) {
}

public boolean isUniform(Set<Slot> slotSet) {
return !slotSet.isEmpty()
&& uniformSet.slots.containsAll(slotSet);
return uniformSet.contains(slotSet);
}

public boolean isUniqueAndNotNull(Slot slot) {
Expand All @@ -102,11 +106,25 @@ public boolean isUniqueAndNotNull(Set<Slot> slotSet) {
}

public boolean isUniformAndNotNull(Slot slot) {
return !slot.nullable() && isUniform(slot);
return uniformSet.isUniformAndNotNull(slot);
}

/** isUniformAndNotNull for slot set */
public boolean isUniformAndNotNull(ImmutableSet<Slot> slotSet) {
return slotSet.stream().noneMatch(Slot::nullable) && isUniform(slotSet);
for (Slot slot : slotSet) {
if (!uniformSet.isUniformAndNotNull(slot)) {
return false;
}
}
return true;
}

public boolean isUniformAndHasConstValue(Slot slot) {
return uniformSet.isUniformAndHasConstValue(slot);
}

public Optional<Expression> getUniformValue(Slot slot) {
return uniformSet.slotUniformValue.get(slot);
}

public boolean isNullSafeEqual(Slot l, Slot r) {
Expand Down Expand Up @@ -143,23 +161,23 @@ public String toString() {
* Builder of trait
*/
public static class Builder {
private final NestedSet uniqueSet;
private final NestedSet uniformSet;
private final UniqueDescription uniqueSet;
private final UniformDescription uniformSet;
private ImmutableSet<FdItem> fdItems;
private final ImmutableEqualSet.Builder<Slot> equalSetBuilder;
private final FuncDepsDG.Builder fdDgBuilder;

public Builder() {
uniqueSet = new NestedSet();
uniformSet = new NestedSet();
uniqueSet = new UniqueDescription();
uniformSet = new UniformDescription();
fdItems = new ImmutableSet.Builder<FdItem>().build();
equalSetBuilder = new ImmutableEqualSet.Builder<>();
fdDgBuilder = new FuncDepsDG.Builder();
}

public Builder(DataTrait other) {
this.uniformSet = new NestedSet(other.uniformSet);
this.uniqueSet = new NestedSet(other.uniqueSet);
this.uniformSet = new UniformDescription(other.uniformSet);
this.uniqueSet = new UniqueDescription(other.uniqueSet);
this.fdItems = ImmutableSet.copyOf(other.fdItems);
equalSetBuilder = new ImmutableEqualSet.Builder<>(other.equalSet);
fdDgBuilder = new FuncDepsDG.Builder(other.fdDg);
Expand All @@ -173,6 +191,14 @@ public void addUniformSlot(DataTrait dataTrait) {
uniformSet.add(dataTrait.uniformSet);
}

public void addUniformSlotForOuterJoinNullableSide(DataTrait dataTrait) {
uniformSet.addUniformSlotForOuterJoinNullableSide(dataTrait.uniformSet);
}

public void addUniformSlotAndLiteral(Slot slot, Expression literal) {
uniformSet.add(slot, literal);
}

public void addUniqueSlot(Slot slot) {
uniqueSet.add(slot);
}
Expand Down Expand Up @@ -261,8 +287,21 @@ public void addUniqueByEqualSet(Set<Slot> equalSet) {
* if there is a uniform slot in the equivalence set, then all slots of an equivalence set are uniform
*/
public void addUniformByEqualSet(Set<Slot> equalSet) {
if (uniformSet.isIntersect(uniformSet.slots, equalSet)) {
uniformSet.slots.addAll(equalSet);
List<Slot> intersectionList = uniformSet.slotUniformValue.keySet().stream()
.filter(equalSet::contains)
.collect(Collectors.toList());
if (intersectionList.isEmpty()) {
return;
}
Expression expr = null;
for (Slot slot : intersectionList) {
if (uniformSet.slotUniformValue.get(slot).isPresent()) {
expr = uniformSet.slotUniformValue.get(slot).get();
break;
}
}
for (Slot equal : equalSet) {
uniformSet.add(equal, expr);
}
}

Expand Down Expand Up @@ -293,9 +332,11 @@ public List<Set<Slot>> getAllUniqueAndNotNull() {
*/
public List<Set<Slot>> getAllUniformAndNotNull() {
List<Set<Slot>> res = new ArrayList<>();
for (Slot s : uniformSet.slots) {
if (!s.nullable()) {
res.add(ImmutableSet.of(s));
for (Map.Entry<Slot, Optional<Expression>> entry : uniformSet.slotUniformValue.entrySet()) {
if (!entry.getKey().nullable()) {
res.add(ImmutableSet.of(entry.getKey()));
} else if (entry.getValue().isPresent() && !entry.getValue().get().nullable()) {
res.add(ImmutableSet.of(entry.getKey()));
}
}
return res;
Expand Down Expand Up @@ -338,21 +379,21 @@ public void replaceFuncDepsBy(Map<Slot, Slot> replaceMap) {
}
}

static class NestedSet {
static class UniqueDescription {
Set<Slot> slots;
Set<ImmutableSet<Slot>> slotSets;

NestedSet() {
UniqueDescription() {
slots = new HashSet<>();
slotSets = new HashSet<>();
}

NestedSet(NestedSet o) {
UniqueDescription(UniqueDescription o) {
this.slots = new HashSet<>(o.slots);
this.slotSets = new HashSet<>(o.slotSets);
}

NestedSet(Set<Slot> slots, Set<ImmutableSet<Slot>> slotSets) {
UniqueDescription(Set<Slot> slots, Set<ImmutableSet<Slot>> slotSets) {
this.slots = slots;
this.slotSets = slotSets;
}
Expand Down Expand Up @@ -408,9 +449,9 @@ public void add(ImmutableSet<Slot> slotSet) {
slotSets.add(slotSet);
}

public void add(NestedSet nestedSet) {
slots.addAll(nestedSet.slots);
slotSets.addAll(nestedSet.slotSets);
public void add(UniqueDescription uniqueDescription) {
slots.addAll(uniqueDescription.slots);
slotSets.addAll(uniqueDescription.slotSets);
}

public boolean isIntersect(Set<Slot> set1, Set<Slot> set2) {
Expand Down Expand Up @@ -446,8 +487,120 @@ public void replace(Map<Slot, Slot> replaceMap) {
.collect(Collectors.toSet());
}

public NestedSet toImmutable() {
return new NestedSet(ImmutableSet.copyOf(slots), ImmutableSet.copyOf(slotSets));
public UniqueDescription toImmutable() {
return new UniqueDescription(ImmutableSet.copyOf(slots), ImmutableSet.copyOf(slotSets));
}
}

static class UniformDescription {
// slot and its uniform expression(literal or const expression)
// some slot can get uniform values, others can not.
// e.g.select a from t where a=10 group by a, b;
// in LogicalAggregate, a UniformDescription with map {a : 10} can be obtained.
// which means a is uniform and the uniform value is 10.
Map<Slot, Optional<Expression>> slotUniformValue;

public UniformDescription() {
slotUniformValue = new LinkedHashMap<>();
}

public UniformDescription(UniformDescription ud) {
slotUniformValue = new LinkedHashMap<>(ud.slotUniformValue);
}

public UniformDescription(Map<Slot, Optional<Expression>> slotUniformValue) {
this.slotUniformValue = slotUniformValue;
}

public UniformDescription toImmutable() {
return new UniformDescription(ImmutableMap.copyOf(slotUniformValue));
}

public boolean isEmpty() {
return slotUniformValue.isEmpty();
}

public boolean contains(Slot slot) {
return slotUniformValue.containsKey(slot);
}

public boolean contains(Set<Slot> slots) {
return !slots.isEmpty() && slotUniformValue.keySet().containsAll(slots);
}

public void add(Slot slot) {
slotUniformValue.putIfAbsent(slot, Optional.empty());
}

public void add(Set<Slot> slots) {
for (Slot s : slots) {
slotUniformValue.putIfAbsent(s, Optional.empty());
}
}

public void add(UniformDescription ud) {
slotUniformValue.putAll(ud.slotUniformValue);
for (Map.Entry<Slot, Optional<Expression>> entry : ud.slotUniformValue.entrySet()) {
add(entry.getKey(), entry.getValue().orElse(null));
}
}

public void add(Slot slot, Expression literal) {
if (null == literal) {
slotUniformValue.putIfAbsent(slot, Optional.empty());
} else {
slotUniformValue.put(slot, Optional.of(literal));
}
}

public void addUniformSlotForOuterJoinNullableSide(UniformDescription ud) {
for (Map.Entry<Slot, Optional<Expression>> entry : ud.slotUniformValue.entrySet()) {
if ((!entry.getValue().isPresent() && entry.getKey().nullable())
|| (entry.getValue().isPresent() && entry.getValue().get() instanceof NullLiteral)) {
add(entry.getKey(), entry.getValue().orElse(null));
}
}
}

public void removeNotContain(Set<Slot> slotSet) {
if (slotSet.isEmpty()) {
return;
}
Map<Slot, Optional<Expression>> newSlotUniformValue = new LinkedHashMap<>();
for (Map.Entry<Slot, Optional<Expression>> entry : slotUniformValue.entrySet()) {
if (slotSet.contains(entry.getKey())) {
newSlotUniformValue.put(entry.getKey(), entry.getValue());
}
}
this.slotUniformValue = newSlotUniformValue;
}

public void replace(Map<Slot, Slot> replaceMap) {
Map<Slot, Optional<Expression>> newSlotUniformValue = new LinkedHashMap<>();
for (Map.Entry<Slot, Optional<Expression>> entry : slotUniformValue.entrySet()) {
Slot newKey = replaceMap.getOrDefault(entry.getKey(), entry.getKey());
newSlotUniformValue.put(newKey, entry.getValue());
}
slotUniformValue = newSlotUniformValue;
}

// The current implementation logic is: if a slot key exists in map slotUniformValue,
// its value is present and is not nullable,
// or if a slot key exists in map slotUniformValue and the slot is not nullable
// it indicates that this slot is uniform and not null.
public boolean isUniformAndNotNull(Slot slot) {
return slotUniformValue.containsKey(slot)
&& (!slot.nullable() || slotUniformValue.get(slot).isPresent()
&& !slotUniformValue.get(slot).get().nullable());
}

public boolean isUniformAndHasConstValue(Slot slot) {
return slotUniformValue.containsKey(slot) && slotUniformValue.get(slot).isPresent();
}

@Override
public String toString() {
return "{" + slotUniformValue + "}";
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,11 @@ public enum RuleType {
REWRITE_SORT_EXPRESSION(RuleTypeClass.REWRITE),
REWRITE_HAVING_EXPRESSION(RuleTypeClass.REWRITE),
REWRITE_REPEAT_EXPRESSION(RuleTypeClass.REWRITE),
REWRITE_SINK_EXPRESSION(RuleTypeClass.REWRITE),
REWRITE_WINDOW_EXPRESSION(RuleTypeClass.REWRITE),
REWRITE_SET_OPERATION_EXPRESSION(RuleTypeClass.REWRITE),
REWRITE_PARTITION_TOPN_EXPRESSION(RuleTypeClass.REWRITE),
REWRITE_TOPN_EXPRESSION(RuleTypeClass.REWRITE),
EXTRACT_FILTER_FROM_JOIN(RuleTypeClass.REWRITE),
REORDER_JOIN(RuleTypeClass.REWRITE),
MERGE_PERCENTILE_TO_ARRAY(RuleTypeClass.REWRITE),
Expand Down Expand Up @@ -238,6 +243,7 @@ public enum RuleType {
ELIMINATE_JOIN_BY_UK(RuleTypeClass.REWRITE),
ELIMINATE_JOIN_BY_FK(RuleTypeClass.REWRITE),
ELIMINATE_GROUP_BY_KEY(RuleTypeClass.REWRITE),
ELIMINATE_GROUP_BY_KEY_BY_UNIFORM(RuleTypeClass.REWRITE),
ELIMINATE_FILTER_GROUP_BY_KEY(RuleTypeClass.REWRITE),
ELIMINATE_DEDUP_JOIN_CONDITION(RuleTypeClass.REWRITE),
ELIMINATE_NULL_AWARE_LEFT_ANTI_JOIN(RuleTypeClass.REWRITE),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
* expression of plan rewrite rule.
*/
public class ExpressionRewrite implements RewriteRuleFactory {
private final ExpressionRuleExecutor rewriter;
protected final ExpressionRuleExecutor rewriter;

public ExpressionRewrite(ExpressionRewriteRule... rules) {
this.rewriter = new ExpressionRuleExecutor(ImmutableList.copyOf(rules));
Expand Down
Loading

0 comments on commit 8a61eb9

Please sign in to comment.