Skip to content

Commit

Permalink
refactor expression rewriter
Browse files Browse the repository at this point in the history
  • Loading branch information
924060929 committed Mar 25, 2024
1 parent 69821c0 commit e0de4d3
Show file tree
Hide file tree
Showing 169 changed files with 4,779 additions and 1,707 deletions.
2 changes: 1 addition & 1 deletion fe/fe-core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -1046,7 +1046,7 @@ under the License.
<configuration>
<proc>only</proc>
<compilerArgs>
<arg>-AplanPath=${basedir}/src/main/java/org/apache/doris/nereids</arg>
<arg>-Apath=${basedir}/src/main/java/org/apache/doris/nereids</arg>
</compilerArgs>
<includes>
<include>org/apache/doris/nereids/pattern/generator/PatternDescribableProcessPoint.java</include>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -569,11 +569,14 @@ public boolean isMinValue() {
switch (type.getPrimitiveType()) {
case DATE:
case DATEV2:
return this.getStringValue().compareTo(MIN_DATE.getStringValue()) == 0;
return year == 0 && month == 1 && day == 1
&& this.getStringValue().compareTo(MIN_DATE.getStringValue()) == 0;
case DATETIME:
return this.getStringValue().compareTo(MIN_DATETIME.getStringValue()) == 0;
return year == 0 && month == 1 && day == 1
&& this.getStringValue().compareTo(MIN_DATETIME.getStringValue()) == 0;
case DATETIMEV2:
return this.getStringValue().compareTo(MIN_DATETIMEV2.getStringValue()) == 0;
return year == 0 && month == 1 && day == 1
&& this.getStringValue().compareTo(MIN_DATETIMEV2.getStringValue()) == 0;
default:
return false;
}
Expand Down
19 changes: 10 additions & 9 deletions fe/fe-core/src/main/java/org/apache/doris/catalog/OlapTable.java
Original file line number Diff line number Diff line change
Expand Up @@ -1072,12 +1072,14 @@ public List<Long> selectNonEmptyPartitionIds(Collection<Long> partitionIds) {
return CloudPartition.selectNonEmptyPartitionIds(partitions);
}

return partitionIds.stream()
.map(this::getPartition)
.filter(p -> p != null)
.filter(Partition::hasData)
.map(Partition::getId)
.collect(Collectors.toList());
List<Long> nonEmptyIds = Lists.newArrayListWithCapacity(partitionIds.size());
for (Long partitionId : partitionIds) {
Partition partition = getPartition(partitionId);
if (partition != null && partition.hasData()) {
nonEmptyIds.add(partitionId);
}
}
return nonEmptyIds;
}

public int getPartitionNum() {
Expand Down Expand Up @@ -2538,9 +2540,8 @@ public Set<Long> getPartitionKeys() {
}

public boolean isDupKeysOrMergeOnWrite() {
return getKeysType() == KeysType.DUP_KEYS
|| (getKeysType() == KeysType.UNIQUE_KEYS
&& getEnableUniqueKeyMergeOnWrite());
return keysType == KeysType.DUP_KEYS
|| (keysType == KeysType.UNIQUE_KEYS && getEnableUniqueKeyMergeOnWrite());
}

public void initAutoIncrementGenerator(long dbId) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,9 @@ public boolean checkCloudPriv(String cloudName,

public boolean checkColPriv(String ctl, String db, String tbl, String col, PrivPredicate wanted) {
Optional<Privilege> colPrivilege = wanted.getColPrivilege();
Preconditions.checkState(colPrivilege.isPresent(), "this privPredicate should not use checkColPriv:" + wanted);
if (!colPrivilege.isPresent()) {
throw new IllegalStateException("this privPredicate should not use checkColPriv:" + wanted);
}
return checkTblPriv(ctl, db, tbl, wanted) || onlyCheckColPriv(ctl, db, tbl, col, colPrivilege.get());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
import org.apache.logging.log4j.Logger;

import java.util.ArrayList;
import java.util.BitSet;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
Expand Down Expand Up @@ -134,6 +135,11 @@ public class CascadesContext implements ScheduleContext {
// trigger by rule and show by `explain plan process` statement
private final List<PlanProcess> planProcesses = new ArrayList<>();

// this field is modified by FoldConstantRuleOnFE, it matters current traverse
// into AggregateFunction with distinct, we can not fold constant in this case
private int distinctAggLevel;
private final boolean isEnableExprTrace;

/**
* Constructor of OptimizerContext.
*
Expand All @@ -156,6 +162,13 @@ private CascadesContext(Optional<CascadesContext> parent, Optional<CTEId> curren
this.subqueryExprIsAnalyzed = new HashMap<>();
this.runtimeFilterContext = new RuntimeFilterContext(getConnectContext().getSessionVariable());
this.materializationContexts = new ArrayList<>();
if (statementContext.getConnectContext() != null) {
ConnectContext connectContext = statementContext.getConnectContext();
SessionVariable sessionVariable = connectContext.getSessionVariable();
this.isEnableExprTrace = sessionVariable != null && sessionVariable.isEnableExprTrace();
} else {
this.isEnableExprTrace = false;
}
}

/**
Expand Down Expand Up @@ -256,7 +269,7 @@ public void setTables(List<TableIf> tables) {
this.tables = tables.stream().collect(Collectors.toMap(TableIf::getId, t -> t, (t1, t2) -> t1));
}

public ConnectContext getConnectContext() {
public final ConnectContext getConnectContext() {
return statementContext.getConnectContext();
}

Expand Down Expand Up @@ -366,12 +379,18 @@ public <T> T getAndCacheSessionVariable(String cacheName,
return defaultValue;
}

return getStatementContext().getOrRegisterCache(cacheName,
() -> variableSupplier.apply(connectContext.getSessionVariable()));
}

/** getAndCacheDisableRules */
public final BitSet getAndCacheDisableRules() {
ConnectContext connectContext = getConnectContext();
StatementContext statementContext = getStatementContext();
if (statementContext == null) {
return defaultValue;
if (connectContext == null || statementContext == null) {
return new BitSet();
}
return statementContext.getOrRegisterCache(cacheName,
() -> variableSupplier.apply(connectContext.getSessionVariable()));
return statementContext.getOrCacheDisableRules(connectContext.getSessionVariable());
}

private CascadesContext execute(Job job) {
Expand Down Expand Up @@ -722,4 +741,20 @@ public void printPlanProcess() {
LOG.info("RULE: " + row.ruleName + "\nBEFORE:\n" + row.beforeShape + "\nafter:\n" + row.afterShape);
}
}

public void incrementDistinctAggLevel() {
this.distinctAggLevel++;
}

public void decrementDistinctAggLevel() {
this.distinctAggLevel--;
}

public int getDistinctAggLevel() {
return distinctAggLevel;
}

public boolean isEnableExprTrace() {
return isEnableExprTrace;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.qe.OriginStatement;
import org.apache.doris.qe.SessionVariable;

import com.google.common.base.Stopwatch;
import com.google.common.base.Supplier;
Expand All @@ -45,6 +46,7 @@
import com.google.common.collect.Sets;

import java.util.ArrayList;
import java.util.BitSet;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
Expand Down Expand Up @@ -117,6 +119,8 @@ public class StatementContext {
// Relation for example LogicalOlapScan
private final Map<Slot, Relation> slotToRelation = Maps.newHashMap();

private BitSet disableRules;

public StatementContext() {
this.connectContext = ConnectContext.get();
}
Expand Down Expand Up @@ -259,11 +263,22 @@ public synchronized <T> T getOrRegisterCache(String key, Supplier<T> cacheSuppli
return supplier.get();
}

public synchronized BitSet getOrCacheDisableRules(SessionVariable sessionVariable) {
if (this.disableRules != null) {
return this.disableRules;
}
this.disableRules = sessionVariable.getDisableNereidsRules();
return this.disableRules;
}

/**
* Some value of the cacheKey may change, invalid cache when value change
*/
public synchronized void invalidCache(String cacheKey) {
contextCacheMap.remove(cacheKey);
if (cacheKey.equalsIgnoreCase(SessionVariable.DISABLE_NEREIDS_RULES)) {
this.disableRules = null;
}
}

public ColumnAliasGenerator getColumnAliasGenerator() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import com.google.common.collect.ListMultimap;
import com.google.common.collect.Sets;

import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Objects;
Expand Down Expand Up @@ -63,6 +64,7 @@ public class Scope {
private final List<Slot> slots;
private final Optional<SubqueryExpr> ownerSubquery;
private final Set<Slot> correlatedSlots;
private final boolean buildNameToSlot;
private final Supplier<ListMultimap<String, Slot>> nameToSlot;

public Scope(List<? extends Slot> slots) {
Expand All @@ -75,7 +77,8 @@ public Scope(Optional<Scope> outerScope, List<? extends Slot> slots, Optional<Su
this.slots = Utils.fastToImmutableList(Objects.requireNonNull(slots, "slots can not be null"));
this.ownerSubquery = Objects.requireNonNull(subqueryExpr, "subqueryExpr can not be null");
this.correlatedSlots = Sets.newLinkedHashSet();
this.nameToSlot = Suppliers.memoize(this::buildNameToSlot);
this.buildNameToSlot = slots.size() > 500;
this.nameToSlot = buildNameToSlot ? Suppliers.memoize(this::buildNameToSlot) : null;
}

public List<Slot> getSlots() {
Expand All @@ -96,7 +99,19 @@ public Set<Slot> getCorrelatedSlots() {

/** findSlotIgnoreCase */
public List<Slot> findSlotIgnoreCase(String slotName) {
return nameToSlot.get().get(slotName.toUpperCase(Locale.ROOT));
if (!buildNameToSlot) {
Object[] array = new Object[slots.size()];
int filterIndex = 0;
for (int i = 0; i < slots.size(); i++) {
Slot slot = slots.get(i);
if (slot.getName().equalsIgnoreCase(slotName)) {
array[filterIndex++] = slot;
}
}
return (List) Arrays.asList(array).subList(0, filterIndex);
} else {
return nameToSlot.get().get(slotName.toUpperCase(Locale.ROOT));
}
}

private ListMultimap<String, Slot> buildNameToSlot() {
Expand Down
11 changes: 4 additions & 7 deletions fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/Job.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,14 @@
import org.apache.doris.nereids.trees.expressions.CTEId;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.qe.SessionVariable;
import org.apache.doris.statistics.Statistics;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableSet;

import java.util.BitSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

/**
* Abstract class for all job using for analyze and optimize query plan in Nereids.
Expand All @@ -57,7 +55,7 @@ public abstract class Job implements TracerSupplier {
protected JobType type;
protected JobContext context;
protected boolean once;
protected final Set<Integer> disableRules;
protected final BitSet disableRules;

protected Map<CTEId, Statistics> cteIdToStats;

Expand Down Expand Up @@ -129,8 +127,7 @@ protected void countJobExecutionTimesOfGroupExpressions(GroupExpression groupExp
groupExpression.getOwnerGroup(), groupExpression, groupExpression.getPlan()));
}

public static Set<Integer> getDisableRules(JobContext context) {
return context.getCascadesContext().getAndCacheSessionVariable(
SessionVariable.DISABLE_NEREIDS_RULES, ImmutableSet.of(), SessionVariable::getDisableNereidsRules);
public static BitSet getDisableRules(JobContext context) {
return context.getCascadesContext().getAndCacheDisableRules();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.apache.doris.nereids.rules.analysis.NormalizeAggregate;
import org.apache.doris.nereids.rules.expression.CheckLegalityAfterRewrite;
import org.apache.doris.nereids.rules.expression.ExpressionNormalization;
import org.apache.doris.nereids.rules.expression.ExpressionNormalizationAndOptimization;
import org.apache.doris.nereids.rules.expression.ExpressionOptimization;
import org.apache.doris.nereids.rules.expression.ExpressionRewrite;
import org.apache.doris.nereids.rules.rewrite.AddDefaultLimit;
Expand Down Expand Up @@ -152,8 +153,7 @@ public class Rewriter extends AbstractBatchJobExecutor {
// such as group by key matching and replaced
// but we need to do some normalization before subquery unnesting,
// such as extract common expression.
new ExpressionNormalization(),
new ExpressionOptimization(),
new ExpressionNormalizationAndOptimization(),
new AvgDistinctToSumDivCount(),
new CountDistinctRewrite(),
new ExtractFilterFromCrossJoin()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;

import java.util.BitSet;
import java.util.Objects;
import java.util.Set;
import java.util.function.Supplier;

/**
Expand All @@ -50,8 +50,8 @@ public CustomRewriteJob(Supplier<CustomRewriter> rewriter, RuleType ruleType) {

@Override
public void execute(JobContext context) {
Set<Integer> disableRules = Job.getDisableRules(context);
if (disableRules.contains(ruleType.type())) {
BitSet disableRules = Job.getDisableRules(context);
if (disableRules.get(ruleType.type())) {
return;
}
CascadesContext cascadesContext = context.getCascadesContext();
Expand Down
Loading

0 comments on commit e0de4d3

Please sign in to comment.