Skip to content

Commit

Permalink
save
Browse files Browse the repository at this point in the history
  • Loading branch information
924060929 committed Mar 4, 2024
1 parent fee6ac8 commit 919132b
Show file tree
Hide file tree
Showing 13 changed files with 430 additions and 565 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
import org.apache.doris.nereids.util.Utils;

import com.google.common.base.Suppliers;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.LinkedListMultimap;
import com.google.common.collect.ListMultimap;
import com.google.common.collect.Sets;

import java.util.List;
Expand Down Expand Up @@ -62,7 +63,7 @@ public class Scope {
private final List<Slot> slots;
private final Optional<SubqueryExpr> ownerSubquery;
private final Set<Slot> correlatedSlots;
private final Supplier<ArrayListMultimap<String, Slot>> nameToSlot;
private final Supplier<ListMultimap<String, Slot>> nameToSlot;

public Scope(List<Slot> slots) {
this(Optional.empty(), slots, Optional.empty());
Expand Down Expand Up @@ -93,12 +94,20 @@ public Set<Slot> getCorrelatedSlots() {
return correlatedSlots;
}

/** findSlotIgnoreCase */
public List<Slot> findSlotIgnoreCase(String slotName) {
// Builder<Slot> candidateSlots = ImmutableList.builder();
// for (Slot slot : slots) {
// if (slot.getName().equalsIgnoreCase(slotName)) {
// candidateSlots.add(slot);
// }
// }
// return candidateSlots.build();
return nameToSlot.get().get(slotName.toUpperCase(Locale.ROOT));
}

private ArrayListMultimap<String, Slot> buildNameToSlot() {
ArrayListMultimap<String, Slot> map = ArrayListMultimap.create(slots.size(), 2);
private ListMultimap<String, Slot> buildNameToSlot() {
ListMultimap<String, Slot> map = LinkedListMultimap.create(slots.size());
for (Slot slot : slots) {
map.put(slot.getName().toUpperCase(Locale.ROOT), slot);
}
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -186,13 +186,11 @@ public Expression visitUnboundAlias(UnboundAlias unboundAlias, ExpressionRewrite
@Override
public Slot visitUnboundSlot(UnboundSlot unboundSlot, ExpressionRewriteContext context) {
Optional<Scope> outerScope = getDefaultScope().getOuterScope();
Optional<List<Slot>> boundedOpt = Optional.of(
bindSlotByMultiScopes(unboundSlot, getScopes())
);
Optional<List<Slot>> boundedOpt = Optional.of(bindSlotByInnerScopes(unboundSlot));
boolean foundInThisScope = !boundedOpt.get().isEmpty();
// Currently only looking for symbols on the previous level.
if (bindSlotInOuterScope && !foundInThisScope && outerScope.isPresent()) {
boundedOpt = Optional.of(bindSlot(unboundSlot, outerScope.get()));
boundedOpt = Optional.of(bindSlotByScope(unboundSlot, outerScope.get()));
}
List<Slot> bounded = boundedOpt.get();
switch (bounded.size()) {
Expand Down Expand Up @@ -589,17 +587,22 @@ && compareDbName(qualifierStar.get(1), boundSlotQualifier.get(1))
return new BoundStar(slots);
}

protected List<Slot> bindSlotByInnerScopes(UnboundSlot unboundSlot) {
return bindSlotByMultiScopes(unboundSlot, getScopes());
}

private List<Slot> bindSlotByMultiScopes(UnboundSlot unboundSlot, List<Scope> scopes) {
for (Scope candidateScope : scopes) {
List<Slot> slots = bindSlot(unboundSlot, candidateScope);
List<Slot> slots = bindSlotByScope(unboundSlot, candidateScope);
if (!slots.isEmpty()) {
return slots;
}
}
return ImmutableList.of();
}

private List<Slot> bindSlot(UnboundSlot unboundSlot, Scope scope) {
/** bindSlotByScope */
public List<Slot> bindSlotByScope(UnboundSlot unboundSlot, Scope scope) {
// return scope.getSlots().stream().distinct().filter(boundSlot -> {
// if (boundSlot instanceof SlotReference
// && ((SlotReference) boundSlot).hasSubColPath()) {
Expand Down Expand Up @@ -723,7 +726,7 @@ private UnboundFunction bindHighOrderFunction(UnboundFunction unboundFunction, E
.build());
}

private boolean shouldSlotBindBy(int namePartSize, Slot boundSlot) {
private boolean shouldBindSlotBy(int namePartSize, Slot boundSlot) {
if (boundSlot instanceof SlotReference
&& ((SlotReference) boundSlot).hasSubColPath()) {
// already bounded
Expand All @@ -739,7 +742,7 @@ private List<Slot> bindSingleSlotByName(String name, Scope scope) {
int namePartSize = 1;
Builder<Slot> usedSlots = ImmutableList.builderWithExpectedSize(1);
for (Slot boundSlot : scope.findSlotIgnoreCase(name)) {
if (!shouldSlotBindBy(namePartSize, boundSlot)) {
if (!shouldBindSlotBy(namePartSize, boundSlot)) {
continue;
}
// set sql case as alias
Expand All @@ -752,7 +755,7 @@ private List<Slot> bindSingleSlotByTable(String table, String name, Scope scope)
int namePartSize = 2;
Builder<Slot> usedSlots = ImmutableList.builderWithExpectedSize(1);
for (Slot boundSlot : scope.findSlotIgnoreCase(name)) {
if (!shouldSlotBindBy(namePartSize, boundSlot)) {
if (!shouldBindSlotBy(namePartSize, boundSlot)) {
continue;
}
List<String> boundSlotQualifier = boundSlot.getQualifier();
Expand All @@ -770,7 +773,7 @@ private List<Slot> bindSingleSlotByDb(String db, String table, String name, Scop
int namePartSize = 3;
Builder<Slot> usedSlots = ImmutableList.builderWithExpectedSize(1);
for (Slot boundSlot : scope.findSlotIgnoreCase(name)) {
if (!shouldSlotBindBy(namePartSize, boundSlot)) {
if (!shouldBindSlotBy(namePartSize, boundSlot)) {
continue;
}
List<String> boundSlotQualifier = boundSlot.getQualifier();
Expand All @@ -789,7 +792,7 @@ private List<Slot> bindSingleSlotByCatalog(String catalog, String db, String tab
int namePartSize = 4;
Builder<Slot> usedSlots = ImmutableList.builderWithExpectedSize(1);
for (Slot boundSlot : scope.findSlotIgnoreCase(name)) {
if (!shouldSlotBindBy(namePartSize, boundSlot)) {
if (!shouldBindSlotBy(namePartSize, boundSlot)) {
continue;
}
List<String> boundSlotQualifier = boundSlot.getQualifier();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Deque;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;
import java.util.function.BiFunction;
Expand Down Expand Up @@ -179,6 +180,18 @@ default void foreach(Consumer<TreeNode<NODE_TYPE>> func) {
}
}

/** foreachBreath */
default void foreachBreath(Predicate<TreeNode<NODE_TYPE>> func) {
LinkedList<TreeNode<NODE_TYPE>> queue = new LinkedList<>();
queue.add(this);
while (!queue.isEmpty()) {
TreeNode<NODE_TYPE> current = queue.pollFirst();
if (!func.test(current)) {
queue.addAll(current.children());
}
}
}

default void foreachUp(Consumer<TreeNode<NODE_TYPE>> func) {
for (NODE_TYPE child : children()) {
child.foreach(func);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ public LogicalGenerate(List<Function> generators, List<Slot> generatorOutput, Li
Optional<GroupExpression> groupExpression,
Optional<LogicalProperties> logicalProperties, CHILD_TYPE child) {
super(PlanType.LOGICAL_GENERATE, groupExpression, logicalProperties, child);
this.generators = ImmutableList.copyOf(generators);
this.generatorOutput = ImmutableList.copyOf(generatorOutput);
this.expandColumnAlias = ImmutableList.copyOf(expandColumnAlias);
this.generators = Utils.fastToImmutableList(generators);
this.generatorOutput = Utils.fastToImmutableList(generatorOutput);
this.expandColumnAlias = Utils.fastToImmutableList(expandColumnAlias);
}

public List<Function> getGenerators() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,9 @@ private LogicalJoin(JoinType joinType, List<Expression> hashJoinConjuncts,
// Just use in withXXX method. Don't need check/copyOf()
super(PlanType.LOGICAL_JOIN, groupExpression, logicalProperties, children);
this.joinType = Objects.requireNonNull(joinType, "joinType can not be null");
this.hashJoinConjuncts = ImmutableList.copyOf(hashJoinConjuncts);
this.otherJoinConjuncts = ImmutableList.copyOf(otherJoinConjuncts);
this.markJoinConjuncts = ImmutableList.copyOf(markJoinConjuncts);
this.hashJoinConjuncts = Utils.fastToImmutableList(hashJoinConjuncts);
this.otherJoinConjuncts = Utils.fastToImmutableList(otherJoinConjuncts);
this.markJoinConjuncts = Utils.fastToImmutableList(markJoinConjuncts);
this.hint = Objects.requireNonNull(hint, "hint can not be null");
if (joinReorderContext != null) {
this.joinReorderContext.copyFrom(joinReorderContext);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ private LogicalProject(List<NamedExpression> projects, List<NamedExpression> exc
this.projects = projects.isEmpty()
? ImmutableList.of(ExpressionUtils.selectMinimumColumn(child.get(0).getOutput()))
: projects;
this.excepts = ImmutableList.copyOf(excepts);
this.excepts = Utils.fastToImmutableList(excepts);
this.isDistinct = isDistinct;
this.canEliminate = canEliminate;
}
Expand Down Expand Up @@ -173,7 +173,7 @@ public int hashCode() {
@Override
public LogicalProject<Plan> withChildren(List<Plan> children) {
Preconditions.checkArgument(children.size() == 1);
return new LogicalProject<>(projects, excepts, isDistinct, canEliminate, ImmutableList.copyOf(children));
return new LogicalProject<>(projects, excepts, isDistinct, canEliminate, Utils.fastToImmutableList(children));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,7 @@
import org.apache.doris.qe.SessionVariable;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;

import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
Expand Down Expand Up @@ -115,31 +113,38 @@ public List<List<NamedExpression>> collectChildrenProjections() {
* Generate new output for SetOperation.
*/
public List<NamedExpression> buildNewOutputs() {
ImmutableList.Builder<NamedExpression> newOutputs = new Builder<>();
for (Slot slot : resetNullableForLeftOutputs()) {
List<Slot> slots = resetNullableForLeftOutputs();
ImmutableList.Builder<NamedExpression> newOutputs = ImmutableList.builderWithExpectedSize(slots.size());
for (Slot slot : slots) {
newOutputs.add(new SlotReference(slot.toSql(), slot.getDataType(), slot.nullable()));
}
return newOutputs.build();
}

// If the right child is nullable, need to ensure that the left child is also nullable
private List<Slot> resetNullableForLeftOutputs() {
List<Slot> resetNullableForLeftOutputs = new ArrayList<>();
for (int i = 0; i < child(1).getOutput().size(); ++i) {
int rightChildOutputSize = child(1).getOutput().size();
ImmutableList.Builder<Slot> resetNullableForLeftOutputs
= ImmutableList.builderWithExpectedSize(rightChildOutputSize);
for (int i = 0; i < rightChildOutputSize; ++i) {
if (child(1).getOutput().get(i).nullable() && !child(0).getOutput().get(i).nullable()) {
resetNullableForLeftOutputs.add(child(0).getOutput().get(i).withNullable(true));
} else {
resetNullableForLeftOutputs.add(child(0).getOutput().get(i));
}
}
return ImmutableList.copyOf(resetNullableForLeftOutputs);
return resetNullableForLeftOutputs.build();
}

private List<List<NamedExpression>> castCommonDataTypeOutputs() {
List<NamedExpression> newLeftOutputs = new ArrayList<>();
List<NamedExpression> newRightOutputs = new ArrayList<>();
int childOutputSize = child(0).getOutput().size();
ImmutableList.Builder<NamedExpression> newLeftOutputs = ImmutableList.builderWithExpectedSize(
childOutputSize);
ImmutableList.Builder<NamedExpression> newRightOutputs = ImmutableList.builderWithExpectedSize(
childOutputSize
);
// Ensure that the output types of the left and right children are consistent and expand upward.
for (int i = 0; i < child(0).getOutput().size(); ++i) {
for (int i = 0; i < childOutputSize; ++i) {
Slot left = child(0).getOutput().get(i);
Slot right = child(1).getOutput().get(i);
DataType compatibleType = getAssignmentCompatibleType(left.getDataType(), right.getDataType());
Expand All @@ -155,10 +160,7 @@ private List<List<NamedExpression>> castCommonDataTypeOutputs() {
newRightOutputs.add((NamedExpression) newRight);
}

List<List<NamedExpression>> resultExpressions = new ArrayList<>();
resultExpressions.add(newLeftOutputs);
resultExpressions.add(newRightOutputs);
return ImmutableList.copyOf(resultExpressions);
return ImmutableList.of(newLeftOutputs.build(), newRightOutputs.build());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ public LogicalUnion(Qualifier qualifier, List<NamedExpression> outputs, List<Lis
List<List<NamedExpression>> constantExprsList, boolean hasPushedFilter, List<Plan> children) {
super(PlanType.LOGICAL_UNION, qualifier, outputs, childrenOutputs, children);
this.hasPushedFilter = hasPushedFilter;
this.constantExprsList = ImmutableList.copyOf(
this.constantExprsList = Utils.fastToImmutableList(
Objects.requireNonNull(constantExprsList, "constantExprsList should not be null"));
}

Expand All @@ -81,7 +81,7 @@ public LogicalUnion(Qualifier qualifier, List<NamedExpression> outputs, List<Lis
super(PlanType.LOGICAL_UNION, qualifier, outputs, childrenOutputs,
groupExpression, logicalProperties, children);
this.hasPushedFilter = hasPushedFilter;
this.constantExprsList = ImmutableList.copyOf(
this.constantExprsList = Utils.fastToImmutableList(
Objects.requireNonNull(constantExprsList, "constantExprsList should not be null"));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@
import com.google.common.collect.ImmutableSet;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

/**
* Infer output column name when it refers an expression and not has an alias manually.
Expand All @@ -38,17 +40,27 @@ public class InferPlanOutputAlias {

private final List<Slot> currentOutputs;
private final List<NamedExpression> finalOutputs;
private final Set<Integer> shouldProcessOutputIndex;

/** InferPlanOutputAlias */
public InferPlanOutputAlias(List<Slot> currentOutputs) {
this.currentOutputs = currentOutputs;
this.finalOutputs = new ArrayList<>(currentOutputs);
this.shouldProcessOutputIndex = new HashSet<>();
for (int i = 0; i < currentOutputs.size(); i++) {
shouldProcessOutputIndex.add(i);
}
}

/** infer */
public List<NamedExpression> infer(Plan plan, ImmutableMultimap<ExprId, Integer> currentExprIdAndIndexMap) {
ImmutableSet<ExprId> currentOutputExprIdSet = currentExprIdAndIndexMap.keySet();
plan.foreach(p -> {
for (Expression expression : plan.getExpressions()) {
// Breath First Search
plan.foreachBreath(childPlan -> {
if (shouldProcessOutputIndex.isEmpty()) {
return true;
}
for (Expression expression : ((Plan) childPlan).getExpressions()) {
if (!(expression instanceof Alias)) {
continue;
}
Expand All @@ -58,14 +70,22 @@ public List<NamedExpression> infer(Plan plan, ImmutableMultimap<ExprId, Integer>
if (currentOutputExprIdSet.contains(projectItem.getExprId())
&& projectItem.isNameFromChild()) {
String inferredAliasName = projectItem.child().getExpressionName();
ImmutableCollection<Integer> outPutExprIndexes = currentExprIdAndIndexMap.get(exprId);
ImmutableCollection<Integer> outputExprIndexes = currentExprIdAndIndexMap.get(exprId);
// replace output name by inferred name
for (Integer index : outPutExprIndexes) {
for (Integer index : outputExprIndexes) {
Slot slot = currentOutputs.get(index);
finalOutputs.set(index, slot.withName("__" + inferredAliasName + "_" + index));
shouldProcessOutputIndex.remove(index);

if (shouldProcessOutputIndex.isEmpty()) {
// replace finished
return true;
}
}
}
}
// continue replace
return false;
});
return finalOutputs;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
Expand Down Expand Up @@ -605,7 +606,7 @@ public static boolean containsType(List<? extends Expression> expressions, Class
return anyMatch(expressions, type::isInstance);
}

public static <E> Set<E> collect(List<? extends Expression> expressions,
public static <E> Set<E> collect(Collection<? extends Expression> expressions,
Predicate<TreeNode<Expression>> predicate) {
return expressions.stream()
.flatMap(expr -> expr.<Set<E>>collect(predicate).stream())
Expand Down Expand Up @@ -645,7 +646,7 @@ public static <E> Set<E> mutableCollect(List<? extends Expression> expressions,
.collect(Collectors.toSet());
}

public static <E> List<E> collectAll(List<? extends Expression> expressions,
public static <E> List<E> collectAll(Collection<? extends Expression> expressions,
Predicate<TreeNode<Expression>> predicate) {
return expressions.stream()
.flatMap(expr -> expr.<Set<E>>collect(predicate).stream())
Expand Down Expand Up @@ -785,4 +786,17 @@ public Boolean visit(Expression expr, Void context) {
}
}, null);
}

/** distinctSlotByName */
public static List<Slot> distinctSlotByName(List<Slot> slots) {
Set<String> existSlotNames = new HashSet<>(slots.size() * 2);
Builder<Slot> distinctSlots = ImmutableList.builderWithExpectedSize(slots.size());
for (Slot slot : slots) {
String name = slot.getName();
if (existSlotNames.add(name)) {
distinctSlots.add(slot);
}
}
return distinctSlots.build();
}
}
Loading

0 comments on commit 919132b

Please sign in to comment.