Skip to content

Commit

Permalink
refactor expression rewriter
Browse files Browse the repository at this point in the history
  • Loading branch information
924060929 committed Mar 26, 2024
1 parent afa3c29 commit 9d2700d
Show file tree
Hide file tree
Showing 219 changed files with 4,885 additions and 1,809 deletions.
2 changes: 1 addition & 1 deletion fe/fe-core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -1046,7 +1046,7 @@ under the License.
<configuration>
<proc>only</proc>
<compilerArgs>
<arg>-AplanPath=${basedir}/src/main/java/org/apache/doris/nereids</arg>
<arg>-Apath=${basedir}/src/main/java/org/apache/doris/nereids</arg>
</compilerArgs>
<includes>
<include>org/apache/doris/nereids/pattern/generator/PatternDescribableProcessPoint.java</include>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -569,11 +569,14 @@ public boolean isMinValue() {
switch (type.getPrimitiveType()) {
case DATE:
case DATEV2:
return this.getStringValue().compareTo(MIN_DATE.getStringValue()) == 0;
return year == 0 && month == 1 && day == 1
&& this.getStringValue().compareTo(MIN_DATE.getStringValue()) == 0;
case DATETIME:
return this.getStringValue().compareTo(MIN_DATETIME.getStringValue()) == 0;
return year == 0 && month == 1 && day == 1
&& this.getStringValue().compareTo(MIN_DATETIME.getStringValue()) == 0;
case DATETIMEV2:
return this.getStringValue().compareTo(MIN_DATETIMEV2.getStringValue()) == 0;
return year == 0 && month == 1 && day == 1
&& this.getStringValue().compareTo(MIN_DATETIMEV2.getStringValue()) == 0;
default:
return false;
}
Expand Down
19 changes: 10 additions & 9 deletions fe/fe-core/src/main/java/org/apache/doris/catalog/OlapTable.java
Original file line number Diff line number Diff line change
Expand Up @@ -1072,12 +1072,14 @@ public List<Long> selectNonEmptyPartitionIds(Collection<Long> partitionIds) {
return CloudPartition.selectNonEmptyPartitionIds(partitions);
}

return partitionIds.stream()
.map(this::getPartition)
.filter(p -> p != null)
.filter(Partition::hasData)
.map(Partition::getId)
.collect(Collectors.toList());
List<Long> nonEmptyIds = Lists.newArrayListWithCapacity(partitionIds.size());
for (Long partitionId : partitionIds) {
Partition partition = getPartition(partitionId);
if (partition != null && partition.hasData()) {
nonEmptyIds.add(partitionId);
}
}
return nonEmptyIds;
}

public int getPartitionNum() {
Expand Down Expand Up @@ -2538,9 +2540,8 @@ public Set<Long> getPartitionKeys() {
}

public boolean isDupKeysOrMergeOnWrite() {
return getKeysType() == KeysType.DUP_KEYS
|| (getKeysType() == KeysType.UNIQUE_KEYS
&& getEnableUniqueKeyMergeOnWrite());
return keysType == KeysType.DUP_KEYS
|| (keysType == KeysType.UNIQUE_KEYS && getEnableUniqueKeyMergeOnWrite());
}

public void initAutoIncrementGenerator(long dbId) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ private Set<BaseTableInfo> getMtmvsByBaseTable(BaseTableInfo table) {
* @return
*/
public Set<MTMV> getAvailableMTMVs(List<BaseTableInfo> tableInfos, ConnectContext ctx) {
Set<MTMV> res = Sets.newHashSet();
Set<MTMV> res = Sets.newLinkedHashSet();
Set<BaseTableInfo> mvInfos = getMTMVInfos(tableInfos);
for (BaseTableInfo tableInfo : mvInfos) {
try {
Expand All @@ -90,7 +90,7 @@ public boolean isMVPartitionValid(MTMV mtmv, ConnectContext ctx) {
}

private Set<BaseTableInfo> getMTMVInfos(List<BaseTableInfo> tableInfos) {
Set<BaseTableInfo> mvInfos = Sets.newHashSet();
Set<BaseTableInfo> mvInfos = Sets.newLinkedHashSet();
for (BaseTableInfo tableInfo : tableInfos) {
mvInfos.addAll(getMtmvsByBaseTable(tableInfo));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,9 @@ public boolean checkCloudPriv(String cloudName,

public boolean checkColPriv(String ctl, String db, String tbl, String col, PrivPredicate wanted) {
Optional<Privilege> colPrivilege = wanted.getColPrivilege();
Preconditions.checkState(colPrivilege.isPresent(), "this privPredicate should not use checkColPriv:" + wanted);
if (!colPrivilege.isPresent()) {
throw new IllegalStateException("this privPredicate should not use checkColPriv:" + wanted);
}
return checkTblPriv(ctl, db, tbl, wanted) || onlyCheckColPriv(ctl, db, tbl, col, colPrivilege.get());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
import org.apache.logging.log4j.Logger;

import java.util.ArrayList;
import java.util.BitSet;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
Expand Down Expand Up @@ -134,6 +135,11 @@ public class CascadesContext implements ScheduleContext {
// trigger by rule and show by `explain plan process` statement
private final List<PlanProcess> planProcesses = new ArrayList<>();

// this field is modified by FoldConstantRuleOnFE, it matters current traverse
// into AggregateFunction with distinct, we can not fold constant in this case
private int distinctAggLevel;
private final boolean isEnableExprTrace;

/**
* Constructor of OptimizerContext.
*
Expand All @@ -156,6 +162,13 @@ private CascadesContext(Optional<CascadesContext> parent, Optional<CTEId> curren
this.subqueryExprIsAnalyzed = new HashMap<>();
this.runtimeFilterContext = new RuntimeFilterContext(getConnectContext().getSessionVariable());
this.materializationContexts = new ArrayList<>();
if (statementContext.getConnectContext() != null) {
ConnectContext connectContext = statementContext.getConnectContext();
SessionVariable sessionVariable = connectContext.getSessionVariable();
this.isEnableExprTrace = sessionVariable != null && sessionVariable.isEnableExprTrace();
} else {
this.isEnableExprTrace = false;
}
}

/**
Expand Down Expand Up @@ -256,7 +269,7 @@ public void setTables(List<TableIf> tables) {
this.tables = tables.stream().collect(Collectors.toMap(TableIf::getId, t -> t, (t1, t2) -> t1));
}

public ConnectContext getConnectContext() {
public final ConnectContext getConnectContext() {
return statementContext.getConnectContext();
}

Expand Down Expand Up @@ -366,12 +379,18 @@ public <T> T getAndCacheSessionVariable(String cacheName,
return defaultValue;
}

return getStatementContext().getOrRegisterCache(cacheName,
() -> variableSupplier.apply(connectContext.getSessionVariable()));
}

/** getAndCacheDisableRules */
public final BitSet getAndCacheDisableRules() {
ConnectContext connectContext = getConnectContext();
StatementContext statementContext = getStatementContext();
if (statementContext == null) {
return defaultValue;
if (connectContext == null || statementContext == null) {
return new BitSet();
}
return statementContext.getOrRegisterCache(cacheName,
() -> variableSupplier.apply(connectContext.getSessionVariable()));
return statementContext.getOrCacheDisableRules(connectContext.getSessionVariable());
}

private CascadesContext execute(Job job) {
Expand Down Expand Up @@ -722,4 +741,20 @@ public void printPlanProcess() {
LOG.info("RULE: " + row.ruleName + "\nBEFORE:\n" + row.beforeShape + "\nafter:\n" + row.afterShape);
}
}

public void incrementDistinctAggLevel() {
this.distinctAggLevel++;
}

public void decrementDistinctAggLevel() {
this.distinctAggLevel--;
}

public int getDistinctAggLevel() {
return distinctAggLevel;
}

public boolean isEnableExprTrace() {
return isEnableExprTrace;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.qe.OriginStatement;
import org.apache.doris.qe.SessionVariable;

import com.google.common.base.Stopwatch;
import com.google.common.base.Supplier;
Expand All @@ -45,6 +46,7 @@
import com.google.common.collect.Sets;

import java.util.ArrayList;
import java.util.BitSet;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
Expand Down Expand Up @@ -117,6 +119,8 @@ public class StatementContext {
// Relation for example LogicalOlapScan
private final Map<Slot, Relation> slotToRelation = Maps.newHashMap();

private BitSet disableRules;

public StatementContext() {
this.connectContext = ConnectContext.get();
}
Expand Down Expand Up @@ -259,11 +263,22 @@ public synchronized <T> T getOrRegisterCache(String key, Supplier<T> cacheSuppli
return supplier.get();
}

public synchronized BitSet getOrCacheDisableRules(SessionVariable sessionVariable) {
if (this.disableRules != null) {
return this.disableRules;
}
this.disableRules = sessionVariable.getDisableNereidsRules();
return this.disableRules;
}

/**
* Some value of the cacheKey may change, invalid cache when value change
*/
public synchronized void invalidCache(String cacheKey) {
contextCacheMap.remove(cacheKey);
if (cacheKey.equalsIgnoreCase(SessionVariable.DISABLE_NEREIDS_RULES)) {
this.disableRules = null;
}
}

public ColumnAliasGenerator getColumnAliasGenerator() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import com.google.common.collect.ListMultimap;
import com.google.common.collect.Sets;

import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Objects;
Expand Down Expand Up @@ -63,6 +64,7 @@ public class Scope {
private final List<Slot> slots;
private final Optional<SubqueryExpr> ownerSubquery;
private final Set<Slot> correlatedSlots;
private final boolean buildNameToSlot;
private final Supplier<ListMultimap<String, Slot>> nameToSlot;

public Scope(List<? extends Slot> slots) {
Expand All @@ -75,7 +77,8 @@ public Scope(Optional<Scope> outerScope, List<? extends Slot> slots, Optional<Su
this.slots = Utils.fastToImmutableList(Objects.requireNonNull(slots, "slots can not be null"));
this.ownerSubquery = Objects.requireNonNull(subqueryExpr, "subqueryExpr can not be null");
this.correlatedSlots = Sets.newLinkedHashSet();
this.nameToSlot = Suppliers.memoize(this::buildNameToSlot);
this.buildNameToSlot = slots.size() > 500;
this.nameToSlot = buildNameToSlot ? Suppliers.memoize(this::buildNameToSlot) : null;
}

public List<Slot> getSlots() {
Expand All @@ -96,7 +99,19 @@ public Set<Slot> getCorrelatedSlots() {

/** findSlotIgnoreCase */
public List<Slot> findSlotIgnoreCase(String slotName) {
return nameToSlot.get().get(slotName.toUpperCase(Locale.ROOT));
if (!buildNameToSlot) {
Object[] array = new Object[slots.size()];
int filterIndex = 0;
for (int i = 0; i < slots.size(); i++) {
Slot slot = slots.get(i);
if (slot.getName().equalsIgnoreCase(slotName)) {
array[filterIndex++] = slot;
}
}
return (List) Arrays.asList(array).subList(0, filterIndex);
} else {
return nameToSlot.get().get(slotName.toUpperCase(Locale.ROOT));
}
}

private ListMultimap<String, Slot> buildNameToSlot() {
Expand Down
11 changes: 4 additions & 7 deletions fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/Job.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,14 @@
import org.apache.doris.nereids.trees.expressions.CTEId;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.qe.SessionVariable;
import org.apache.doris.statistics.Statistics;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableSet;

import java.util.BitSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

/**
* Abstract class for all job using for analyze and optimize query plan in Nereids.
Expand All @@ -57,7 +55,7 @@ public abstract class Job implements TracerSupplier {
protected JobType type;
protected JobContext context;
protected boolean once;
protected final Set<Integer> disableRules;
protected final BitSet disableRules;

protected Map<CTEId, Statistics> cteIdToStats;

Expand Down Expand Up @@ -129,8 +127,7 @@ protected void countJobExecutionTimesOfGroupExpressions(GroupExpression groupExp
groupExpression.getOwnerGroup(), groupExpression, groupExpression.getPlan()));
}

public static Set<Integer> getDisableRules(JobContext context) {
return context.getCascadesContext().getAndCacheSessionVariable(
SessionVariable.DISABLE_NEREIDS_RULES, ImmutableSet.of(), SessionVariable::getDisableNereidsRules);
public static BitSet getDisableRules(JobContext context) {
return context.getCascadesContext().getAndCacheDisableRules();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.apache.doris.nereids.rules.analysis.NormalizeAggregate;
import org.apache.doris.nereids.rules.expression.CheckLegalityAfterRewrite;
import org.apache.doris.nereids.rules.expression.ExpressionNormalization;
import org.apache.doris.nereids.rules.expression.ExpressionNormalizationAndOptimization;
import org.apache.doris.nereids.rules.expression.ExpressionOptimization;
import org.apache.doris.nereids.rules.expression.ExpressionRewrite;
import org.apache.doris.nereids.rules.rewrite.AddDefaultLimit;
Expand Down Expand Up @@ -152,8 +153,7 @@ public class Rewriter extends AbstractBatchJobExecutor {
// such as group by key matching and replaced
// but we need to do some normalization before subquery unnesting,
// such as extract common expression.
new ExpressionNormalization(),
new ExpressionOptimization(),
new ExpressionNormalizationAndOptimization(),
new AvgDistinctToSumDivCount(),
new CountDistinctRewrite(),
new ExtractFilterFromCrossJoin()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
import java.util.ArrayList;
import java.util.BitSet;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand All @@ -64,7 +65,7 @@
*/
public class HyperGraph {
// record all edges that can be placed on the subgraph
private final Map<Long, BitSet> treeEdgesCache = new HashMap<>();
private final Map<Long, BitSet> treeEdgesCache = new LinkedHashMap<>();
private final List<JoinEdge> joinEdges;
private final List<FilterEdge> filterEdges;
private final List<AbstractNode> nodes;
Expand Down Expand Up @@ -330,9 +331,9 @@ public static class Builder {
private final List<AbstractNode> nodes = new ArrayList<>();

// These hyperGraphs should be replaced nodes when building all
private final Map<Long, List<HyperGraph>> replacedHyperGraphs = new HashMap<>();
private final HashMap<Slot, Long> slotToNodeMap = new HashMap<>();
private final Map<Long, List<NamedExpression>> complexProject = new HashMap<>();
private final Map<Long, List<HyperGraph>> replacedHyperGraphs = new LinkedHashMap<>();
private final HashMap<Slot, Long> slotToNodeMap = new LinkedHashMap<>();
private final Map<Long, List<NamedExpression>> complexProject = new LinkedHashMap<>();
private Set<Slot> finalOutputs;

public List<AbstractNode> getNodes() {
Expand Down Expand Up @@ -522,7 +523,7 @@ private long calNodeMap(Set<Slot> slots) {
*/
private BitSet addJoin(LogicalJoin<?, ?> join,
Pair<BitSet, Long> leftEdgeNodes, Pair<BitSet, Long> rightEdgeNodes) {
HashMap<Pair<Long, Long>, Pair<List<Expression>, List<Expression>>> conjuncts = new HashMap<>();
Map<Pair<Long, Long>, Pair<List<Expression>, List<Expression>>> conjuncts = new LinkedHashMap<>();
for (Expression expression : join.getHashJoinConjuncts()) {
// TODO: avoid calling calculateEnds if calNodeMap's results are same
Pair<Long, Long> ends = calculateEnds(calNodeMap(expression.getInputSlots()), leftEdgeNodes,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;

import java.util.BitSet;
import java.util.Objects;
import java.util.Set;
import java.util.function.Supplier;

/**
Expand All @@ -50,8 +50,8 @@ public CustomRewriteJob(Supplier<CustomRewriter> rewriter, RuleType ruleType) {

@Override
public void execute(JobContext context) {
Set<Integer> disableRules = Job.getDisableRules(context);
if (disableRules.contains(ruleType.type())) {
BitSet disableRules = Job.getDisableRules(context);
if (disableRules.get(ruleType.type())) {
return;
}
CascadesContext cascadesContext = context.getCascadesContext();
Expand Down
Loading

0 comments on commit 9d2700d

Please sign in to comment.