From e0de4d328c6a55c951014011ace817e323b14811 Mon Sep 17 00:00:00 2001 From: 924060929 <924060929@qq.com> Date: Mon, 11 Mar 2024 14:34:04 +0800 Subject: [PATCH] refactor expression rewriter --- fe/fe-core/pom.xml | 2 +- .../apache/doris/analysis/DateLiteral.java | 9 +- .../org/apache/doris/catalog/OlapTable.java | 19 +- .../apache/doris/mysql/privilege/Role.java | 4 +- .../apache/doris/nereids/CascadesContext.java | 45 +++- .../doris/nereids/StatementContext.java | 15 ++ .../apache/doris/nereids/analyzer/Scope.java | 19 +- .../org/apache/doris/nereids/jobs/Job.java | 11 +- .../doris/nereids/jobs/executor/Rewriter.java | 4 +- .../jobs/rewrite/CustomRewriteJob.java | 6 +- .../rewrite/PlanTreeRewriteBottomUpJob.java | 113 ++++---- .../jobs/rewrite/PlanTreeRewriteJob.java | 64 +++-- .../rewrite/PlanTreeRewriteTopDownJob.java | 41 ++- .../jobs/rewrite/RewriteJobContext.java | 10 +- .../jobs/rewrite/RootPlanTreeRewriteJob.java | 16 +- .../pattern/ExpressionPatternRules.java | 112 ++++++++ .../ExpressionPatternTraverseListeners.java | 112 ++++++++ .../nereids/pattern/ParentTypeIdMapping.java | 59 +++++ .../apache/doris/nereids/pattern/Pattern.java | 4 + .../doris/nereids/pattern/TypeMappings.java | 133 ++++++++++ .../ExpressionTypeMappingGenerator.java | 159 ++++++++++++ ...atorAnalyzer.java => JavaAstAnalyzer.java} | 93 +++---- .../LogicalBinaryPatternGenerator.java | 4 +- .../LogicalLeafPatternGenerator.java | 4 +- .../LogicalUnaryPatternGenerator.java | 4 +- .../PatternDescribableProcessor.java | 34 ++- .../PhysicalBinaryPatternGenerator.java | 4 +- .../PhysicalLeafPatternGenerator.java | 4 +- .../PhysicalUnaryPatternGenerator.java | 4 +- ...nerator.java => PlanPatternGenerator.java} | 18 +- .../PlanPatternGeneratorAnalyzer.java | 73 ++++++ .../generator/PlanTypeMappingGenerator.java | 159 ++++++++++++ .../processor/post/RuntimeFilterPruner.java | 17 +- .../nereids/processor/post/Validator.java | 10 +- .../properties/FunctionalDependencies.java | 24 +- .../nereids/properties/LogicalProperties.java | 50 ++-- .../org/apache/doris/nereids/rules/Rule.java | 6 +- .../AdjustAggregateNullableForEmptySet.java | 29 ++- .../rules/analysis/BindExpression.java | 28 +- .../rules/analysis/BindSlotWithPaths.java | 29 +-- .../rules/analysis/CheckAfterRewrite.java | 85 +++--- .../nereids/rules/analysis/CheckAnalysis.java | 36 +-- .../analysis/EliminateGroupByConstant.java | 2 +- .../rules/analysis/ExpressionAnalyzer.java | 2 +- .../rules/analysis/FillUpMissingSlots.java | 21 +- .../rules/analysis/NormalizeAggregate.java | 19 +- .../ReplaceExpressionByChildOutput.java | 48 ++-- .../rules/analysis/SubqueryToApply.java | 77 ++++-- .../ExpressionBottomUpRewriter.java | 124 +++++++++ .../expression/ExpressionListenerMatcher.java | 41 +++ .../expression/ExpressionMatchingAction.java | 25 ++ .../expression/ExpressionMatchingContext.java | 46 ++++ .../expression/ExpressionNormalization.java | 29 ++- ...xpressionNormalizationAndOptimization.java | 33 +++ .../expression/ExpressionOptimization.java | 26 +- .../ExpressionPatternMatchRule.java | 64 +++++ .../expression/ExpressionPatternMatcher.java | 41 +++ .../ExpressionPatternRuleFactory.java | 84 ++++++ .../rules/expression/ExpressionRewrite.java | 51 +++- .../expression/ExpressionRewriteContext.java | 4 +- .../expression/ExpressionRuleExecutor.java | 16 +- .../ExpressionTraverseListener.java | 31 +++ .../ExpressionTraverseListenerFactory.java | 79 ++++++ .../ExpressionTraverseListenerMapping.java | 59 +++++ .../rules/expression/check/CheckCast.java | 24 +- .../rules/ArrayContainToArrayOverlap.java | 94 ++++--- .../rules/expression/rules/CaseWhenToIf.java | 18 +- .../expression/rules/ConvertAggStateCast.java | 33 +-- .../expression/rules/DateFunctionRewrite.java | 34 ++- .../rules/DigitalMaskingConvert.java | 23 +- .../rules/DistinctPredicatesRule.java | 18 +- .../rules/ExtractCommonFactorRule.java | 222 +++++++++++++--- .../expression/rules/FoldConstantRule.java | 32 ++- .../rules/FoldConstantRuleOnBE.java | 46 +++- .../rules/FoldConstantRuleOnFE.java | 170 ++++++++++-- .../expression/rules/InPredicateDedup.java | 40 +-- .../rules/InPredicateToEqualToRule.java | 25 +- .../rules/NormalizeBinaryPredicatesRule.java | 21 +- .../rules/NullSafeEqualToEqual.java | 21 +- .../rules/OneListPartitionEvaluator.java | 2 +- .../rules/OneRangePartitionEvaluator.java | 120 ++++++--- .../rules/expression/rules/OrToIn.java | 36 ++- .../expression/rules/PartitionPruner.java | 23 +- .../rules/PartitionRangeExpander.java | 115 +++++---- .../PredicateRewriteForPartitionPrune.java | 4 +- .../rules/RangePartitionValueIterator.java | 64 +++++ .../rules/ReplaceVariableByLiteral.java | 17 +- .../SimplifyArithmeticComparisonRule.java | 105 ++++---- .../rules/SimplifyArithmeticRule.java | 70 ++--- .../expression/rules/SimplifyCastRule.java | 21 +- .../rules/SimplifyComparisonPredicate.java | 37 ++- .../rules/SimplifyDecimalV3Comparison.java | 24 +- .../expression/rules/SimplifyInPredicate.java | 20 +- .../expression/rules/SimplifyNotExprRule.java | 34 ++- .../rules/expression/rules/SimplifyRange.java | 71 +++--- .../rules/SupportJavaDateFormatter.java | 44 ++-- .../rules/expression/rules/TopnToMax.java | 29 +-- .../TryEliminateUninterestedPredicates.java | 14 +- .../implementation/AggregateStrategies.java | 2 +- .../nereids/rules/rewrite/AdjustNullable.java | 12 +- .../rules/rewrite/CheckMatchExpression.java | 7 +- .../rules/rewrite/CheckPrivileges.java | 29 ++- .../nereids/rules/rewrite/ColumnPruning.java | 98 +++---- .../rules/rewrite/CountDistinctRewrite.java | 60 +++-- .../rules/rewrite/CountLiteralRewrite.java | 37 ++- .../rules/rewrite/EliminateFilter.java | 7 +- .../rules/rewrite/EliminateGroupBy.java | 56 ++-- .../rules/rewrite/EliminateMarkJoin.java | 17 +- .../rules/rewrite/EliminateNotNull.java | 39 +-- .../rewrite/EliminateOrderByConstant.java | 16 +- .../ExtractAndNormalizeWindowExpression.java | 161 ++++++------ ...tSingleTableExpressionFromDisjunction.java | 9 +- .../nereids/rules/rewrite/MergeAggregate.java | 2 +- .../nereids/rules/rewrite/MergeProjects.java | 10 +- .../nereids/rules/rewrite/NormalizeSort.java | 59 +++-- .../rules/rewrite/NormalizeToSlot.java | 43 ++-- .../rules/rewrite/PruneOlapScanPartition.java | 51 ++-- .../rules/rewrite/PullUpPredicates.java | 75 +++--- .../PushDownFilterThroughAggregation.java | 12 +- .../rewrite/PushDownFilterThroughProject.java | 13 +- .../rules/rewrite/SimplifyAggGroupBy.java | 23 +- .../AbstractSelectMaterializedIndexRule.java | 15 +- .../SelectMaterializedIndexWithAggregate.java | 9 +- ...lectMaterializedIndexWithoutAggregate.java | 45 ++-- .../doris/nereids/stats/StatsCalculator.java | 11 +- .../doris/nereids/trees/AbstractTreeNode.java | 22 +- .../apache/doris/nereids/trees/TreeNode.java | 17 ++ .../trees/expressions/BinaryOperator.java | 6 - .../expressions/ComparisonPredicate.java | 4 +- .../nereids/trees/expressions/Expression.java | 99 +++++-- .../trees/expressions/InPredicate.java | 5 +- .../trees/expressions/SlotReference.java | 7 +- .../functions/ComputeSignatureHelper.java | 11 +- .../functions/agg/AggregateFunction.java | 17 +- .../scalar/PushDownToProjectionFunction.java | 7 +- .../expressions/literal/DateLiteral.java | 39 ++- .../visitor/DefaultExpressionRewriter.java | 10 +- .../nereids/trees/plans/AbstractPlan.java | 28 +- .../doris/nereids/trees/plans/Plan.java | 61 +++-- .../trees/plans/algebra/Aggregate.java | 17 +- .../nereids/trees/plans/algebra/Project.java | 27 +- .../trees/plans/logical/LogicalAggregate.java | 8 +- .../plans/logical/LogicalCatalogRelation.java | 132 +++++----- .../trees/plans/logical/LogicalOlapScan.java | 65 +++-- .../trees/plans/logical/LogicalProject.java | 8 +- .../trees/plans/logical/LogicalSort.java | 19 +- .../trees/plans/logical/LogicalTopN.java | 13 +- .../doris/nereids/util/ExpressionUtils.java | 241 ++++++++++++++---- .../apache/doris/nereids/util/PlanUtils.java | 24 ++ .../doris/nereids/util/TypeCoercionUtils.java | 19 +- .../org/apache/doris/nereids/util/Utils.java | 64 ++++- .../org/apache/doris/qe/SessionVariable.java | 32 ++- .../expression/ExpressionRewriteTest.java | 64 +++-- .../ExpressionRewriteTestHelper.java | 2 +- .../rules/expression/FoldConstantTest.java | 36 ++- .../expression/PredicatesSplitterTest.java | 2 +- .../SimplifyArithmeticRuleTest.java | 56 ++-- .../expression/SimplifyInPredicateTest.java | 8 +- .../rules/expression/SimplifyRangeTest.java | 18 +- .../rules/NullSafeEqualToEqualTest.java | 20 +- .../SimplifyArithmeticComparisonRuleTest.java | 7 +- .../rules/SimplifyCastRuleTest.java | 7 +- .../SimplifyComparisonPredicateTest.java | 35 ++- .../SimplifyDecimalV3ComparisonTest.java | 6 +- .../rules/expression/rules/TopnToMaxTest.java | 4 +- .../nereids/rules/rewrite/OrToInTest.java | 19 +- .../PushDownFilterThroughAggregationTest.java | 4 +- .../functions/ComputeSignatureHelperTest.java | 11 + .../test_alter_table_replace.groovy | 2 +- 169 files changed, 4779 insertions(+), 1707 deletions(-) create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/ExpressionPatternRules.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/ExpressionPatternTraverseListeners.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/ParentTypeIdMapping.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/TypeMappings.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/ExpressionTypeMappingGenerator.java rename fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/{PatternGeneratorAnalyzer.java => JavaAstAnalyzer.java} (75%) rename fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/{PatternGenerator.java => PlanPatternGenerator.java} (96%) create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/PlanPatternGeneratorAnalyzer.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/PlanTypeMappingGenerator.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionBottomUpRewriter.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionListenerMatcher.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionMatchingAction.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionMatchingContext.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionNormalizationAndOptimization.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionPatternMatchRule.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionPatternMatcher.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionPatternRuleFactory.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionTraverseListener.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionTraverseListenerFactory.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionTraverseListenerMapping.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/RangePartitionValueIterator.java diff --git a/fe/fe-core/pom.xml b/fe/fe-core/pom.xml index 21dd17d60b4bf0..7dfef816f6e5d3 100644 --- a/fe/fe-core/pom.xml +++ b/fe/fe-core/pom.xml @@ -1046,7 +1046,7 @@ under the License. only - -AplanPath=${basedir}/src/main/java/org/apache/doris/nereids + -Apath=${basedir}/src/main/java/org/apache/doris/nereids org/apache/doris/nereids/pattern/generator/PatternDescribableProcessPoint.java diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/DateLiteral.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/DateLiteral.java index 97922cee126871..a8148237fb7a52 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/DateLiteral.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/DateLiteral.java @@ -569,11 +569,14 @@ public boolean isMinValue() { switch (type.getPrimitiveType()) { case DATE: case DATEV2: - return this.getStringValue().compareTo(MIN_DATE.getStringValue()) == 0; + return year == 0 && month == 1 && day == 1 + && this.getStringValue().compareTo(MIN_DATE.getStringValue()) == 0; case DATETIME: - return this.getStringValue().compareTo(MIN_DATETIME.getStringValue()) == 0; + return year == 0 && month == 1 && day == 1 + && this.getStringValue().compareTo(MIN_DATETIME.getStringValue()) == 0; case DATETIMEV2: - return this.getStringValue().compareTo(MIN_DATETIMEV2.getStringValue()) == 0; + return year == 0 && month == 1 && day == 1 + && this.getStringValue().compareTo(MIN_DATETIMEV2.getStringValue()) == 0; default: return false; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/OlapTable.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/OlapTable.java index b6239f486cdeec..072886a0a3c06e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/OlapTable.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/OlapTable.java @@ -1072,12 +1072,14 @@ public List selectNonEmptyPartitionIds(Collection partitionIds) { return CloudPartition.selectNonEmptyPartitionIds(partitions); } - return partitionIds.stream() - .map(this::getPartition) - .filter(p -> p != null) - .filter(Partition::hasData) - .map(Partition::getId) - .collect(Collectors.toList()); + List nonEmptyIds = Lists.newArrayListWithCapacity(partitionIds.size()); + for (Long partitionId : partitionIds) { + Partition partition = getPartition(partitionId); + if (partition != null && partition.hasData()) { + nonEmptyIds.add(partitionId); + } + } + return nonEmptyIds; } public int getPartitionNum() { @@ -2538,9 +2540,8 @@ public Set getPartitionKeys() { } public boolean isDupKeysOrMergeOnWrite() { - return getKeysType() == KeysType.DUP_KEYS - || (getKeysType() == KeysType.UNIQUE_KEYS - && getEnableUniqueKeyMergeOnWrite()); + return keysType == KeysType.DUP_KEYS + || (keysType == KeysType.UNIQUE_KEYS && getEnableUniqueKeyMergeOnWrite()); } public void initAutoIncrementGenerator(long dbId) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/mysql/privilege/Role.java b/fe/fe-core/src/main/java/org/apache/doris/mysql/privilege/Role.java index b22849ea75ef39..ed7081abb7c37a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/mysql/privilege/Role.java +++ b/fe/fe-core/src/main/java/org/apache/doris/mysql/privilege/Role.java @@ -422,7 +422,9 @@ public boolean checkCloudPriv(String cloudName, public boolean checkColPriv(String ctl, String db, String tbl, String col, PrivPredicate wanted) { Optional colPrivilege = wanted.getColPrivilege(); - Preconditions.checkState(colPrivilege.isPresent(), "this privPredicate should not use checkColPriv:" + wanted); + if (!colPrivilege.isPresent()) { + throw new IllegalStateException("this privPredicate should not use checkColPriv:" + wanted); + } return checkTblPriv(ctl, db, tbl, wanted) || onlyCheckColPriv(ctl, db, tbl, col, colPrivilege.get()); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/CascadesContext.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/CascadesContext.java index 8e4a47938e49c4..7e199f72d6ec9c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/CascadesContext.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/CascadesContext.java @@ -76,6 +76,7 @@ import org.apache.logging.log4j.Logger; import java.util.ArrayList; +import java.util.BitSet; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -134,6 +135,11 @@ public class CascadesContext implements ScheduleContext { // trigger by rule and show by `explain plan process` statement private final List planProcesses = new ArrayList<>(); + // this field is modified by FoldConstantRuleOnFE, it matters current traverse + // into AggregateFunction with distinct, we can not fold constant in this case + private int distinctAggLevel; + private final boolean isEnableExprTrace; + /** * Constructor of OptimizerContext. * @@ -156,6 +162,13 @@ private CascadesContext(Optional parent, Optional curren this.subqueryExprIsAnalyzed = new HashMap<>(); this.runtimeFilterContext = new RuntimeFilterContext(getConnectContext().getSessionVariable()); this.materializationContexts = new ArrayList<>(); + if (statementContext.getConnectContext() != null) { + ConnectContext connectContext = statementContext.getConnectContext(); + SessionVariable sessionVariable = connectContext.getSessionVariable(); + this.isEnableExprTrace = sessionVariable != null && sessionVariable.isEnableExprTrace(); + } else { + this.isEnableExprTrace = false; + } } /** @@ -256,7 +269,7 @@ public void setTables(List tables) { this.tables = tables.stream().collect(Collectors.toMap(TableIf::getId, t -> t, (t1, t2) -> t1)); } - public ConnectContext getConnectContext() { + public final ConnectContext getConnectContext() { return statementContext.getConnectContext(); } @@ -366,12 +379,18 @@ public T getAndCacheSessionVariable(String cacheName, return defaultValue; } + return getStatementContext().getOrRegisterCache(cacheName, + () -> variableSupplier.apply(connectContext.getSessionVariable())); + } + + /** getAndCacheDisableRules */ + public final BitSet getAndCacheDisableRules() { + ConnectContext connectContext = getConnectContext(); StatementContext statementContext = getStatementContext(); - if (statementContext == null) { - return defaultValue; + if (connectContext == null || statementContext == null) { + return new BitSet(); } - return statementContext.getOrRegisterCache(cacheName, - () -> variableSupplier.apply(connectContext.getSessionVariable())); + return statementContext.getOrCacheDisableRules(connectContext.getSessionVariable()); } private CascadesContext execute(Job job) { @@ -722,4 +741,20 @@ public void printPlanProcess() { LOG.info("RULE: " + row.ruleName + "\nBEFORE:\n" + row.beforeShape + "\nafter:\n" + row.afterShape); } } + + public void incrementDistinctAggLevel() { + this.distinctAggLevel++; + } + + public void decrementDistinctAggLevel() { + this.distinctAggLevel--; + } + + public int getDistinctAggLevel() { + return distinctAggLevel; + } + + public boolean isEnableExprTrace() { + return isEnableExprTrace; + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/StatementContext.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/StatementContext.java index 5c894fd46ef2b9..7b444995120cab 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/StatementContext.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/StatementContext.java @@ -36,6 +36,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; import org.apache.doris.qe.ConnectContext; import org.apache.doris.qe.OriginStatement; +import org.apache.doris.qe.SessionVariable; import com.google.common.base.Stopwatch; import com.google.common.base.Supplier; @@ -45,6 +46,7 @@ import com.google.common.collect.Sets; import java.util.ArrayList; +import java.util.BitSet; import java.util.Collection; import java.util.Comparator; import java.util.HashMap; @@ -117,6 +119,8 @@ public class StatementContext { // Relation for example LogicalOlapScan private final Map slotToRelation = Maps.newHashMap(); + private BitSet disableRules; + public StatementContext() { this.connectContext = ConnectContext.get(); } @@ -259,11 +263,22 @@ public synchronized T getOrRegisterCache(String key, Supplier cacheSuppli return supplier.get(); } + public synchronized BitSet getOrCacheDisableRules(SessionVariable sessionVariable) { + if (this.disableRules != null) { + return this.disableRules; + } + this.disableRules = sessionVariable.getDisableNereidsRules(); + return this.disableRules; + } + /** * Some value of the cacheKey may change, invalid cache when value change */ public synchronized void invalidCache(String cacheKey) { contextCacheMap.remove(cacheKey); + if (cacheKey.equalsIgnoreCase(SessionVariable.DISABLE_NEREIDS_RULES)) { + this.disableRules = null; + } } public ColumnAliasGenerator getColumnAliasGenerator() { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/analyzer/Scope.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/analyzer/Scope.java index a95e562f7e029c..dbcbea7c104b5a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/analyzer/Scope.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/analyzer/Scope.java @@ -26,6 +26,7 @@ import com.google.common.collect.ListMultimap; import com.google.common.collect.Sets; +import java.util.Arrays; import java.util.List; import java.util.Locale; import java.util.Objects; @@ -63,6 +64,7 @@ public class Scope { private final List slots; private final Optional ownerSubquery; private final Set correlatedSlots; + private final boolean buildNameToSlot; private final Supplier> nameToSlot; public Scope(List slots) { @@ -75,7 +77,8 @@ public Scope(Optional outerScope, List slots, Optional 500; + this.nameToSlot = buildNameToSlot ? Suppliers.memoize(this::buildNameToSlot) : null; } public List getSlots() { @@ -96,7 +99,19 @@ public Set getCorrelatedSlots() { /** findSlotIgnoreCase */ public List findSlotIgnoreCase(String slotName) { - return nameToSlot.get().get(slotName.toUpperCase(Locale.ROOT)); + if (!buildNameToSlot) { + Object[] array = new Object[slots.size()]; + int filterIndex = 0; + for (int i = 0; i < slots.size(); i++) { + Slot slot = slots.get(i); + if (slot.getName().equalsIgnoreCase(slotName)) { + array[filterIndex++] = slot; + } + } + return (List) Arrays.asList(array).subList(0, filterIndex); + } else { + return nameToSlot.get().get(slotName.toUpperCase(Locale.ROOT)); + } } private ListMultimap buildNameToSlot() { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/Job.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/Job.java index a9739cbb9e22ff..41e5e1b8d7e75e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/Job.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/Job.java @@ -34,16 +34,14 @@ import org.apache.doris.nereids.trees.expressions.CTEId; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.qe.ConnectContext; -import org.apache.doris.qe.SessionVariable; import org.apache.doris.statistics.Statistics; import com.google.common.base.Preconditions; -import com.google.common.collect.ImmutableSet; +import java.util.BitSet; import java.util.List; import java.util.Map; import java.util.Optional; -import java.util.Set; /** * Abstract class for all job using for analyze and optimize query plan in Nereids. @@ -57,7 +55,7 @@ public abstract class Job implements TracerSupplier { protected JobType type; protected JobContext context; protected boolean once; - protected final Set disableRules; + protected final BitSet disableRules; protected Map cteIdToStats; @@ -129,8 +127,7 @@ protected void countJobExecutionTimesOfGroupExpressions(GroupExpression groupExp groupExpression.getOwnerGroup(), groupExpression, groupExpression.getPlan())); } - public static Set getDisableRules(JobContext context) { - return context.getCascadesContext().getAndCacheSessionVariable( - SessionVariable.DISABLE_NEREIDS_RULES, ImmutableSet.of(), SessionVariable::getDisableNereidsRules); + public static BitSet getDisableRules(JobContext context) { + return context.getCascadesContext().getAndCacheDisableRules(); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java index 65998416fb0973..89db86ad2ab225 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java @@ -30,6 +30,7 @@ import org.apache.doris.nereids.rules.analysis.NormalizeAggregate; import org.apache.doris.nereids.rules.expression.CheckLegalityAfterRewrite; import org.apache.doris.nereids.rules.expression.ExpressionNormalization; +import org.apache.doris.nereids.rules.expression.ExpressionNormalizationAndOptimization; import org.apache.doris.nereids.rules.expression.ExpressionOptimization; import org.apache.doris.nereids.rules.expression.ExpressionRewrite; import org.apache.doris.nereids.rules.rewrite.AddDefaultLimit; @@ -152,8 +153,7 @@ public class Rewriter extends AbstractBatchJobExecutor { // such as group by key matching and replaced // but we need to do some normalization before subquery unnesting, // such as extract common expression. - new ExpressionNormalization(), - new ExpressionOptimization(), + new ExpressionNormalizationAndOptimization(), new AvgDistinctToSumDivCount(), new CountDistinctRewrite(), new ExtractFilterFromCrossJoin() diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/CustomRewriteJob.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/CustomRewriteJob.java index 35e04b9f33fd85..0e58f1bc976ba5 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/CustomRewriteJob.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/CustomRewriteJob.java @@ -25,8 +25,8 @@ import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter; +import java.util.BitSet; import java.util.Objects; -import java.util.Set; import java.util.function.Supplier; /** @@ -50,8 +50,8 @@ public CustomRewriteJob(Supplier rewriter, RuleType ruleType) { @Override public void execute(JobContext context) { - Set disableRules = Job.getDisableRules(context); - if (disableRules.contains(ruleType.type())) { + BitSet disableRules = Job.getDisableRules(context); + if (disableRules.get(ruleType.type())) { return; } CascadesContext cascadesContext = context.getCascadesContext(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/PlanTreeRewriteBottomUpJob.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/PlanTreeRewriteBottomUpJob.java index 4f623e5450060f..60555a9cc04ad6 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/PlanTreeRewriteBottomUpJob.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/PlanTreeRewriteBottomUpJob.java @@ -39,9 +39,9 @@ public class PlanTreeRewriteBottomUpJob extends PlanTreeRewriteJob { // Different 'RewriteState' has different actions, // so we will do specified action for each node based on their 'RewriteState'. private static final String REWRITE_STATE_KEY = "rewrite_state"; - private final RewriteJobContext rewriteJobContext; private final List rules; + private final int batchId; enum RewriteState { // 'REWRITE_THIS' means the current plan node can be handled immediately. If the plan state is 'REWRITE_THIS', @@ -59,22 +59,15 @@ public PlanTreeRewriteBottomUpJob(RewriteJobContext rewriteJobContext, JobContex super(JobType.BOTTOM_UP_REWRITE, context); this.rewriteJobContext = Objects.requireNonNull(rewriteJobContext, "rewriteContext cannot be null"); this.rules = Objects.requireNonNull(rules, "rules cannot be null"); + this.batchId = rewriteJobContext.batchId; } @Override public void execute() { - // For the bottom-up rewrite job, we need to reset the state of its children - // if the plan has changed after the rewrite. So we use the 'childrenVisited' to check this situation. - boolean clearStatePhase = !rewriteJobContext.childrenVisited; - if (clearStatePhase) { - traverseClearState(); - return; - } - // We'll do different actions based on their different states. // You can check the comment in 'RewriteState' structure for more details. Plan plan = rewriteJobContext.plan; - RewriteState state = getState(plan); + RewriteState state = getState(plan, batchId); switch (state) { case REWRITE_THIS: rewriteThis(); @@ -90,33 +83,13 @@ public void execute() { } } - private void traverseClearState() { - // Reset the state for current node. - RewriteJobContext clearedStateContext = rewriteJobContext.withChildrenVisited(true); - setState(clearedStateContext.plan, RewriteState.REWRITE_THIS); - pushJob(new PlanTreeRewriteBottomUpJob(clearedStateContext, context, rules)); - - // Generate the new rewrite job for its children. Because the character of stack is 'first in, last out', - // so we can traverse reset the state for the plan node until the leaf node. - List children = clearedStateContext.plan.children(); - for (int i = children.size() - 1; i >= 0; i--) { - Plan child = children.get(i); - RewriteJobContext childRewriteJobContext = new RewriteJobContext( - child, clearedStateContext, i, false); - // NOTICE: this relay on pull up cte anchor - if (!(rewriteJobContext.plan instanceof LogicalCTEAnchor)) { - pushJob(new PlanTreeRewriteBottomUpJob(childRewriteJobContext, context, rules)); - } - } - } - private void rewriteThis() { // Link the current node with the sub-plan to get the current plan which is used in the rewrite phase later. Plan plan = linkChildren(rewriteJobContext.plan, rewriteJobContext.childrenContext); RewriteResult rewriteResult = rewrite(plan, rules, rewriteJobContext); if (rewriteResult.hasNewPlan) { RewriteJobContext newJobContext = rewriteJobContext.withPlan(rewriteResult.plan); - RewriteState state = getState(rewriteResult.plan); + RewriteState state = getState(rewriteResult.plan, batchId); // Some eliminate rule will return a rewritten plan, for example the current node is eliminated // and return the child plan. So we don't need to handle it again. if (state == RewriteState.REWRITTEN) { @@ -125,40 +98,82 @@ private void rewriteThis() { } // After the rewrite take effect, we should handle the children part again. pushJob(new PlanTreeRewriteBottomUpJob(newJobContext, context, rules)); - setState(rewriteResult.plan, RewriteState.ENSURE_CHILDREN_REWRITTEN); + setState(rewriteResult.plan, RewriteState.ENSURE_CHILDREN_REWRITTEN, batchId); } else { // No new plan is generated, so just set the state of the current plan to 'REWRITTEN'. - setState(rewriteResult.plan, RewriteState.REWRITTEN); + setState(rewriteResult.plan, RewriteState.REWRITTEN, batchId); rewriteJobContext.setResult(rewriteResult.plan); } } private void ensureChildrenRewritten() { - // Similar to the function 'traverseClearState'. Plan plan = rewriteJobContext.plan; - setState(plan, RewriteState.REWRITE_THIS); + int batchId = rewriteJobContext.batchId; + setState(plan, RewriteState.REWRITE_THIS, batchId); pushJob(new PlanTreeRewriteBottomUpJob(rewriteJobContext, context, rules)); + // some rule return new plan tree, which the number of new plan node > 1, + // we should transform this new plan nodes too. + // NOTICE: this relay on pull up cte anchor + if (!(rewriteJobContext.plan instanceof LogicalCTEAnchor)) { + pushChildrenJobs(plan); + } + } + + private void pushChildrenJobs(Plan plan) { List children = plan.children(); - for (int i = children.size() - 1; i >= 0; i--) { - Plan child = children.get(i); - // some rule return new plan tree, which the number of new plan node > 1, - // we should transform this new plan nodes too. - RewriteJobContext childRewriteJobContext = new RewriteJobContext( - child, rewriteJobContext, i, false); - // NOTICE: this relay on pull up cte anchor - if (!(rewriteJobContext.plan instanceof LogicalCTEAnchor)) { + switch (children.size()) { + case 0: return; + case 1: + Plan child = children.get(0); + RewriteJobContext childRewriteJobContext = new RewriteJobContext( + child, rewriteJobContext, 0, false, batchId); pushJob(new PlanTreeRewriteBottomUpJob(childRewriteJobContext, context, rules)); - } + return; + case 2: + Plan right = children.get(1); + RewriteJobContext rightRewriteJobContext = new RewriteJobContext( + right, rewriteJobContext, 1, false, batchId); + pushJob(new PlanTreeRewriteBottomUpJob(rightRewriteJobContext, context, rules)); + + Plan left = children.get(0); + RewriteJobContext leftRewriteJobContext = new RewriteJobContext( + left, rewriteJobContext, 0, false, batchId); + pushJob(new PlanTreeRewriteBottomUpJob(leftRewriteJobContext, context, rules)); + return; + default: + for (int i = children.size() - 1; i >= 0; i--) { + child = children.get(i); + childRewriteJobContext = new RewriteJobContext( + child, rewriteJobContext, i, false, batchId); + pushJob(new PlanTreeRewriteBottomUpJob(childRewriteJobContext, context, rules)); + } + } + } + + private static RewriteState getState(Plan plan, int currentBatchId) { + Optional state = plan.getMutableState(REWRITE_STATE_KEY); + if (!state.isPresent()) { + return RewriteState.ENSURE_CHILDREN_REWRITTEN; + } + RewriteStateContext context = state.get(); + if (context.batchId != currentBatchId) { + return RewriteState.ENSURE_CHILDREN_REWRITTEN; } + return context.rewriteState; } - private static final RewriteState getState(Plan plan) { - Optional state = plan.getMutableState(REWRITE_STATE_KEY); - return state.orElse(RewriteState.ENSURE_CHILDREN_REWRITTEN); + private static void setState(Plan plan, RewriteState state, int batchId) { + plan.setMutableState(REWRITE_STATE_KEY, new RewriteStateContext(state, batchId)); } - private static final void setState(Plan plan, RewriteState state) { - plan.setMutableState(REWRITE_STATE_KEY, state); + private static class RewriteStateContext { + private final RewriteState rewriteState; + private final int batchId; + + public RewriteStateContext(RewriteState rewriteState, int batchId) { + this.rewriteState = rewriteState; + this.batchId = batchId; + } } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/PlanTreeRewriteJob.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/PlanTreeRewriteJob.java index affbb9196cc3d5..5e5acc29f66edb 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/PlanTreeRewriteJob.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/PlanTreeRewriteJob.java @@ -28,6 +28,8 @@ import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.trees.plans.Plan; +import com.google.common.collect.ImmutableList; + import java.util.List; /** PlanTreeRewriteJob */ @@ -43,7 +45,7 @@ protected final RewriteResult rewrite(Plan plan, List rules, RewriteJobCon boolean showPlanProcess = cascadesContext.showPlanProcess(); for (Rule rule : rules) { - if (disableRules.contains(rule.getRuleType().type())) { + if (disableRules.get(rule.getRuleType().type())) { continue; } Pattern pattern = (Pattern) rule.getPattern(); @@ -76,26 +78,50 @@ protected final RewriteResult rewrite(Plan plan, List rules, RewriteJobCon return new RewriteResult(false, plan); } - protected final Plan linkChildrenAndParent(Plan plan, RewriteJobContext rewriteJobContext) { - Plan newPlan = linkChildren(plan, rewriteJobContext.childrenContext); - rewriteJobContext.setResult(newPlan); - return newPlan; - } - - protected final Plan linkChildren(Plan plan, RewriteJobContext[] childrenContext) { - boolean changed = false; - Plan[] newChildren = new Plan[childrenContext.length]; - for (int i = 0; i < childrenContext.length; ++i) { - Plan result = childrenContext[i].result; - Plan oldChild = plan.child(i); - if (result != null && result != oldChild) { - newChildren[i] = result; - changed = true; - } else { - newChildren[i] = oldChild; + protected static Plan linkChildren(Plan plan, RewriteJobContext[] childrenContext) { + List children = plan.children(); + // loop unrolling + switch (children.size()) { + case 0: { + return plan; + } + case 1: { + RewriteJobContext child = childrenContext[0]; + Plan firstResult = child == null ? plan.child(0) : child.result; + return firstResult == null || firstResult == children.get(0) + ? plan : plan.withChildren(ImmutableList.of(firstResult)); + } + case 2: { + RewriteJobContext left = childrenContext[0]; + Plan firstResult = left == null ? plan.child(0) : left.result; + RewriteJobContext right = childrenContext[1]; + Plan secondResult = right == null ? plan.child(1) : right.result; + Plan firstOrigin = children.get(0); + Plan secondOrigin = children.get(1); + boolean firstChanged = firstResult != null && firstResult != firstOrigin; + boolean secondChanged = secondResult != null && secondResult != secondOrigin; + if (firstChanged || secondChanged) { + ImmutableList.Builder newChildren = ImmutableList.builderWithExpectedSize(2); + newChildren.add(firstChanged ? firstResult : firstOrigin); + newChildren.add(secondChanged ? secondResult : secondOrigin); + return plan.withChildren(newChildren.build()); + } else { + return plan; + } + } + default: { + boolean changed = false; + int i = 0; + Plan[] newChildren = new Plan[childrenContext.length]; + for (Plan oldChild : children) { + Plan result = childrenContext[i].result; + changed = result != null && result != oldChild; + newChildren[i] = changed ? result : oldChild; + i++; + } + return changed ? plan.withChildren(newChildren) : plan; } } - return changed ? plan.withChildren(newChildren) : plan; } private String getCurrentPlanTreeString() { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/PlanTreeRewriteTopDownJob.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/PlanTreeRewriteTopDownJob.java index d8dba41b3788bd..14019bc885e0d0 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/PlanTreeRewriteTopDownJob.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/PlanTreeRewriteTopDownJob.java @@ -56,21 +56,44 @@ public void execute() { RewriteJobContext newRewriteJobContext = rewriteJobContext.withChildrenVisited(true); pushJob(new PlanTreeRewriteTopDownJob(newRewriteJobContext, context, rules)); - List children = newRewriteJobContext.plan.children(); - for (int i = children.size() - 1; i >= 0; i--) { - RewriteJobContext childRewriteJobContext = new RewriteJobContext( - children.get(i), newRewriteJobContext, i, false); - // NOTICE: this relay on pull up cte anchor - if (!(rewriteJobContext.plan instanceof LogicalCTEAnchor)) { - pushJob(new PlanTreeRewriteTopDownJob(childRewriteJobContext, context, rules)); - } + // NOTICE: this relay on pull up cte anchor + if (!(this.rewriteJobContext.plan instanceof LogicalCTEAnchor)) { + pushChildrenJobs(newRewriteJobContext); } } else { // All the children part are already visited. Just link the children plan to the current node. - Plan result = linkChildrenAndParent(rewriteJobContext.plan, rewriteJobContext); + Plan result = linkChildren(rewriteJobContext.plan, rewriteJobContext.childrenContext); + rewriteJobContext.setResult(result); if (rewriteJobContext.parentContext == null) { context.getCascadesContext().setRewritePlan(result); } } } + + private void pushChildrenJobs(RewriteJobContext rewriteJobContext) { + List children = rewriteJobContext.plan.children(); + switch (children.size()) { + case 0: return; + case 1: + RewriteJobContext childRewriteJobContext = new RewriteJobContext( + children.get(0), rewriteJobContext, 0, false, this.rewriteJobContext.batchId); + pushJob(new PlanTreeRewriteTopDownJob(childRewriteJobContext, context, rules)); + return; + case 2: + RewriteJobContext rightRewriteJobContext = new RewriteJobContext( + children.get(1), rewriteJobContext, 1, false, this.rewriteJobContext.batchId); + pushJob(new PlanTreeRewriteTopDownJob(rightRewriteJobContext, context, rules)); + + RewriteJobContext leftRewriteJobContext = new RewriteJobContext( + children.get(0), rewriteJobContext, 0, false, this.rewriteJobContext.batchId); + pushJob(new PlanTreeRewriteTopDownJob(leftRewriteJobContext, context, rules)); + return; + default: + for (int i = children.size() - 1; i >= 0; i--) { + childRewriteJobContext = new RewriteJobContext( + children.get(i), rewriteJobContext, i, false, this.rewriteJobContext.batchId); + pushJob(new PlanTreeRewriteTopDownJob(childRewriteJobContext, context, rules)); + } + } + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/RewriteJobContext.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/RewriteJobContext.java index fb0475f7a61a3b..060bb8edd62838 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/RewriteJobContext.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/RewriteJobContext.java @@ -25,6 +25,7 @@ public class RewriteJobContext { final boolean childrenVisited; + final int batchId; final RewriteJobContext parentContext; final int childIndexInParentContext; final Plan plan; @@ -33,7 +34,7 @@ public class RewriteJobContext { /** RewriteJobContext */ public RewriteJobContext(Plan plan, @Nullable RewriteJobContext parentContext, int childIndexInParentContext, - boolean childrenVisited) { + boolean childrenVisited, int batchId) { this.plan = plan; this.parentContext = parentContext; this.childIndexInParentContext = childIndexInParentContext; @@ -42,6 +43,7 @@ public RewriteJobContext(Plan plan, @Nullable RewriteJobContext parentContext, i if (parentContext != null) { parentContext.childrenContext[childIndexInParentContext] = this; } + this.batchId = batchId; } public void setResult(Plan result) { @@ -49,15 +51,15 @@ public void setResult(Plan result) { } public RewriteJobContext withChildrenVisited(boolean childrenVisited) { - return new RewriteJobContext(plan, parentContext, childIndexInParentContext, childrenVisited); + return new RewriteJobContext(plan, parentContext, childIndexInParentContext, childrenVisited, batchId); } public RewriteJobContext withPlan(Plan plan) { - return new RewriteJobContext(plan, parentContext, childIndexInParentContext, childrenVisited); + return new RewriteJobContext(plan, parentContext, childIndexInParentContext, childrenVisited, batchId); } public RewriteJobContext withPlanAndChildrenVisited(Plan plan, boolean childrenVisited) { - return new RewriteJobContext(plan, parentContext, childIndexInParentContext, childrenVisited); + return new RewriteJobContext(plan, parentContext, childIndexInParentContext, childrenVisited, batchId); } public boolean isRewriteRoot() { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/RootPlanTreeRewriteJob.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/RootPlanTreeRewriteJob.java index 6bc055a68aa976..d352dfee4a0b20 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/RootPlanTreeRewriteJob.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/RootPlanTreeRewriteJob.java @@ -27,9 +27,11 @@ import java.util.List; import java.util.Objects; +import java.util.concurrent.atomic.AtomicInteger; /** RootPlanTreeRewriteJob */ public class RootPlanTreeRewriteJob implements RewriteJob { + private static final AtomicInteger BATCH_ID = new AtomicInteger(); private final List rules; private final RewriteJobBuilder rewriteJobBuilder; @@ -47,7 +49,9 @@ public void execute(JobContext context) { // get plan from the cascades context Plan root = cascadesContext.getRewritePlan(); // write rewritten root plan to cascades context by the RootRewriteJobContext - RootRewriteJobContext rewriteJobContext = new RootRewriteJobContext(root, false, context); + int batchId = BATCH_ID.incrementAndGet(); + RootRewriteJobContext rewriteJobContext = new RootRewriteJobContext( + root, false, context, batchId); Job rewriteJob = rewriteJobBuilder.build(rewriteJobContext, context, rules); context.getScheduleContext().pushJob(rewriteJob); @@ -71,8 +75,8 @@ public static class RootRewriteJobContext extends RewriteJobContext { private final JobContext jobContext; - RootRewriteJobContext(Plan plan, boolean childrenVisited, JobContext jobContext) { - super(plan, null, -1, childrenVisited); + RootRewriteJobContext(Plan plan, boolean childrenVisited, JobContext jobContext, int batchId) { + super(plan, null, -1, childrenVisited, batchId); this.jobContext = Objects.requireNonNull(jobContext, "jobContext cannot be null"); jobContext.getCascadesContext().setCurrentRootRewriteJobContext(this); } @@ -89,17 +93,17 @@ public void setResult(Plan result) { @Override public RewriteJobContext withChildrenVisited(boolean childrenVisited) { - return new RootRewriteJobContext(plan, childrenVisited, jobContext); + return new RootRewriteJobContext(plan, childrenVisited, jobContext, batchId); } @Override public RewriteJobContext withPlan(Plan plan) { - return new RootRewriteJobContext(plan, childrenVisited, jobContext); + return new RootRewriteJobContext(plan, childrenVisited, jobContext, batchId); } @Override public RewriteJobContext withPlanAndChildrenVisited(Plan plan, boolean childrenVisited) { - return new RootRewriteJobContext(plan, childrenVisited, jobContext); + return new RootRewriteJobContext(plan, childrenVisited, jobContext, batchId); } /** linkChildren */ diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/ExpressionPatternRules.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/ExpressionPatternRules.java new file mode 100644 index 00000000000000..523540e6435d89 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/ExpressionPatternRules.java @@ -0,0 +1,112 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.pattern; + +import org.apache.doris.nereids.rules.expression.ExpressionMatchingContext; +import org.apache.doris.nereids.rules.expression.ExpressionPatternMatchRule; +import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; +import org.apache.doris.nereids.trees.expressions.Expression; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +import java.lang.reflect.Field; +import java.util.List; +import java.util.Optional; +import java.util.Set; + +/** ExpressionPatternMapping */ +public class ExpressionPatternRules extends TypeMappings { + private static final Logger LOG = LogManager.getLogger(ExpressionPatternRules.class); + + public ExpressionPatternRules(List typeMappings) { + super(typeMappings); + } + + @Override + protected Set> getChildrenClasses(Class clazz) { + return org.apache.doris.nereids.pattern.GeneratedExpressionRelations.CHILDREN_CLASS_MAP.get(clazz); + } + + /** matchesAndApply */ + public Optional matchesAndApply(Expression expr, ExpressionRewriteContext context, Expression parent) { + List rules = singleMappings.get(expr.getClass()); + ExpressionMatchingContext matchingContext + = new ExpressionMatchingContext<>(expr, parent, context); + switch (rules.size()) { + case 0: { + for (ExpressionPatternMatchRule multiMatchRule : multiMappings) { + if (multiMatchRule.matchesTypeAndPredicates(matchingContext)) { + Expression newExpr = multiMatchRule.apply(matchingContext); + if (!newExpr.equals(expr)) { + if (context.cascadesContext.isEnableExprTrace()) { + traceExprChanged(multiMatchRule, expr, newExpr); + } + return Optional.of(newExpr); + } + } + } + return Optional.empty(); + } + case 1: { + ExpressionPatternMatchRule rule = rules.get(0); + if (rule.matchesPredicates(matchingContext)) { + Expression newExpr = rule.apply(matchingContext); + if (!newExpr.equals(expr)) { + if (context.cascadesContext.isEnableExprTrace()) { + traceExprChanged(rule, expr, newExpr); + } + return Optional.of(newExpr); + } + } + return Optional.empty(); + } + default: { + for (ExpressionPatternMatchRule rule : rules) { + if (rule.matchesPredicates(matchingContext)) { + Expression newExpr = rule.apply(matchingContext); + if (!expr.equals(newExpr)) { + if (context.cascadesContext.isEnableExprTrace()) { + traceExprChanged(rule, expr, newExpr); + } + return Optional.of(newExpr); + } + } + } + return Optional.empty(); + } + } + } + + private static void traceExprChanged(ExpressionPatternMatchRule rule, Expression expr, Expression newExpr) { + try { + Field[] declaredFields = (rule.matchingAction).getClass().getDeclaredFields(); + Class ruleClass; + if (declaredFields.length == 0) { + ruleClass = rule.matchingAction.getClass(); + } else { + Field field = declaredFields[0]; + field.setAccessible(true); + ruleClass = field.get(rule.matchingAction).getClass(); + } + LOG.info("RULE: " + ruleClass + "\nbefore: " + expr + "\nafter: " + newExpr); + } catch (Throwable t) { + LOG.error(t.getMessage(), t); + } + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/ExpressionPatternTraverseListeners.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/ExpressionPatternTraverseListeners.java new file mode 100644 index 00000000000000..3f3640a43bf8b2 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/ExpressionPatternTraverseListeners.java @@ -0,0 +1,112 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.pattern; + +import org.apache.doris.nereids.rules.expression.ExpressionMatchingContext; +import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; +import org.apache.doris.nereids.rules.expression.ExpressionTraverseListener; +import org.apache.doris.nereids.rules.expression.ExpressionTraverseListenerMapping; +import org.apache.doris.nereids.trees.expressions.Expression; + +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Set; +import javax.annotation.Nullable; + +/** ExpressionPatternTraverseListeners */ +public class ExpressionPatternTraverseListeners + extends TypeMappings { + public ExpressionPatternTraverseListeners( + List typeMappings) { + super(typeMappings); + } + + @Override + protected Set> getChildrenClasses(Class clazz) { + return org.apache.doris.nereids.pattern.GeneratedExpressionRelations.CHILDREN_CLASS_MAP.get(clazz); + } + + /** matchesAndCombineListener */ + public @Nullable CombinedListener matchesAndCombineListeners( + Expression expr, ExpressionRewriteContext context, Expression parent) { + List listenerSingleMappings = singleMappings.get(expr.getClass()); + ExpressionMatchingContext matchingContext + = new ExpressionMatchingContext<>(expr, parent, context); + switch (listenerSingleMappings.size()) { + case 0: { + ImmutableList.Builder> matchedListeners + = ImmutableList.builder(); + for (ExpressionTraverseListenerMapping multiMapping : multiMappings) { + if (multiMapping.matchesTypeAndPredicates(matchingContext)) { + matchedListeners.add(multiMapping.listener); + } + } + return CombinedListener.tryCombine(matchedListeners.build(), matchingContext); + } + case 1: { + ExpressionTraverseListenerMapping listenerMapping = listenerSingleMappings.get(0); + if (listenerMapping.matchesPredicates(matchingContext)) { + return CombinedListener.tryCombine(ImmutableList.of(listenerMapping.listener), matchingContext); + } + return null; + } + default: { + ImmutableList.Builder> matchedListeners + = ImmutableList.builder(); + for (ExpressionTraverseListenerMapping singleMapping : listenerSingleMappings) { + if (singleMapping.matchesPredicates(matchingContext)) { + matchedListeners.add(singleMapping.listener); + } + } + return CombinedListener.tryCombine(matchedListeners.build(), matchingContext); + } + } + } + + /** CombinedListener */ + public static class CombinedListener { + private final ExpressionMatchingContext context; + private final List> listeners; + + /** CombinedListener */ + public CombinedListener(ExpressionMatchingContext context, + List> listeners) { + this.context = context; + this.listeners = listeners; + } + + public static @Nullable CombinedListener tryCombine( + List> listenerMappings, + ExpressionMatchingContext context) { + return listenerMappings.isEmpty() ? null : new CombinedListener(context, listenerMappings); + } + + public void onEnter() { + for (ExpressionTraverseListener listener : listeners) { + listener.onEnter(context); + } + } + + public void onExit(Expression rewritten) { + for (ExpressionTraverseListener listener : listeners) { + listener.onExit(context, rewritten); + } + } + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/ParentTypeIdMapping.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/ParentTypeIdMapping.java new file mode 100644 index 00000000000000..b4623e105238b7 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/ParentTypeIdMapping.java @@ -0,0 +1,59 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.pattern; + +import org.apache.doris.nereids.trees.expressions.LessThanEqual; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; + +/** ParentTypeIdMapping */ +public class ParentTypeIdMapping { + + private final AtomicInteger idGenerator = new AtomicInteger(); + private final Map, Integer> classId = new ConcurrentHashMap<>(8192); + + /** getId */ + public int getId(Class clazz) { + Integer id = classId.get(clazz); + if (id != null) { + return id; + } + return ensureClassHasId(clazz); + } + + private int ensureClassHasId(Class clazz) { + Class superClass = clazz.getSuperclass(); + if (superClass != null) { + ensureClassHasId(superClass); + } + + for (Class interfaceClass : clazz.getInterfaces()) { + ensureClassHasId(interfaceClass); + } + + return classId.computeIfAbsent(clazz, c -> idGenerator.incrementAndGet()); + } + + public static void main(String[] args) { + ParentTypeIdMapping mapping = new ParentTypeIdMapping(); + int id = mapping.getId(LessThanEqual.class); + System.out.println(id); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/Pattern.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/Pattern.java index c47dcd6a725be1..91dd87ba457837 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/Pattern.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/Pattern.java @@ -152,6 +152,10 @@ public boolean matchPlanTree(Plan plan) { if (this instanceof SubTreePattern) { return matchPredicates((TYPE) plan); } + return matchChildrenAndSelfPredicates(plan, childPatternNum); + } + + private boolean matchChildrenAndSelfPredicates(Plan plan, int childPatternNum) { List childrenPlan = plan.children(); for (int i = 0; i < childrenPlan.size(); i++) { Plan child = childrenPlan.get(i); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/TypeMappings.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/TypeMappings.java new file mode 100644 index 00000000000000..4eb5ffc76d22a2 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/TypeMappings.java @@ -0,0 +1,133 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.pattern; + +import org.apache.doris.nereids.pattern.TypeMappings.TypeMapping; +import org.apache.doris.nereids.util.Utils; + +import com.google.common.collect.ArrayListMultimap; +import com.google.common.collect.ListMultimap; +import com.google.common.collect.Lists; + +import java.lang.reflect.Modifier; +import java.util.List; +import java.util.Set; +import javax.annotation.Nullable; + +/** ExpressionPatternMappings */ +public abstract class TypeMappings> { + protected final ListMultimap, T> singleMappings; + protected final List multiMappings; + + /** ExpressionPatternMappings */ + public TypeMappings(List typeMappings) { + this.singleMappings = ArrayListMultimap.create(); + this.multiMappings = Lists.newArrayList(); + + for (T mapping : typeMappings) { + Set> childrenClasses = getChildrenClasses(mapping.getType()); + if (childrenClasses == null || childrenClasses.isEmpty()) { + // add some expressions which no child class + // e.g. LessThanEqual + addSimpleMapping(mapping); + } else if (childrenClasses.size() <= 100) { + // add some expressions which have children classes + // e.g. ComparisonPredicate will be expanded to + // ruleMappings.put(LessThanEqual.class, rule); + // ruleMappings.put(LessThan.class, rule); + // ruleMappings.put(GreaterThan.class, rule); + // ruleMappings.put(GreaterThanEquals.class, rule); + // ... + addThisAndChildrenMapping(mapping, childrenClasses); + } else { + // some expressions have lots of children classes, e.g. Expression, ExpressionTrait, BinaryExpression, + // we will not expand this types to child class, but also add this rules to other type matching. + // for example, if we have three rules to matches this types: LessThanEqual, Abs and Expression, + // then the ruleMappings would be: + // { + // LessThanEqual.class: [rule_of_LessThanEqual, rule_of_Expression], + // Abs.class: [rule_of_Abs, rule_of_Expression] + // } + // + // and the multiMatchRules would be: [rule_of_Expression] + // + // if we matches `a <= 1`, there have two rules would be applied because + // ruleMappings.get(LessThanEqual.class) return two rules; + // if we matches `a = 1`, ruleMappings.get(EqualTo.class) will return empty rules, so we use + // all the rules in multiMatchRules to matches and apply, the rule_of_Expression will be applied. + addMultiMapping(mapping); + } + } + } + + public @Nullable List get(Class clazz) { + return singleMappings.get(clazz); + } + + private void addSimpleMapping(T typeMapping) { + Class clazz = typeMapping.getType(); + int modifiers = clazz.getModifiers(); + if (!Modifier.isAbstract(modifiers)) { + addSingleMapping(clazz, typeMapping); + } + } + + private void addThisAndChildrenMapping( + T typeMapping, Set> childrenClasses) { + Class clazz = typeMapping.getType(); + if (!Modifier.isAbstract(clazz.getModifiers())) { + addSingleMapping(clazz, typeMapping); + } + + for (Class childrenClass : childrenClasses) { + if (!Modifier.isAbstract(childrenClass.getModifiers())) { + addSingleMapping(childrenClass, typeMapping); + } + } + } + + private void addMultiMapping(T multiMapping) { + multiMappings.add(multiMapping); + + Set> existSingleMappingTypes = Utils.fastToImmutableSet(singleMappings.keySet()); + for (Class existSingleType : existSingleMappingTypes) { + Class type = multiMapping.getType(); + if (type.isAssignableFrom(existSingleType)) { + singleMappings.put(existSingleType, multiMapping); + } + } + } + + private void addSingleMapping(Class clazz, T singleMapping) { + if (!singleMappings.containsKey(clazz) && !multiMappings.isEmpty()) { + for (T multiMapping : multiMappings) { + if (multiMapping.getType().isAssignableFrom(clazz)) { + singleMappings.put(clazz, multiMapping); + } + } + } + singleMappings.put(clazz, singleMapping); + } + + protected abstract Set> getChildrenClasses(Class clazz); + + /** TypeMapping */ + public interface TypeMapping { + Class getType(); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/ExpressionTypeMappingGenerator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/ExpressionTypeMappingGenerator.java new file mode 100644 index 00000000000000..c5a923153dfeea --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/ExpressionTypeMappingGenerator.java @@ -0,0 +1,159 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.pattern.generator; + +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; + +import java.io.BufferedWriter; +import java.io.File; +import java.io.FileWriter; +import java.io.IOException; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Set; +import java.util.stream.Collectors; +import javax.annotation.processing.ProcessingEnvironment; +import javax.tools.StandardLocation; + +/** ExpressionTypeMappingGenerator */ +public class ExpressionTypeMappingGenerator { + private final JavaAstAnalyzer analyzer; + + public ExpressionTypeMappingGenerator(JavaAstAnalyzer javaAstAnalyzer) { + this.analyzer = javaAstAnalyzer; + } + + public JavaAstAnalyzer getAnalyzer() { + return analyzer; + } + + /** generate */ + public void generate(ProcessingEnvironment processingEnv) throws IOException { + Set superExpressions = findSuperExpression(); + Map> childrenNameMap = analyzer.getChildrenNameMap(); + Map> parentNameMap = analyzer.getParentNameMap(); + String code = generateCode(childrenNameMap, parentNameMap, superExpressions); + generateFile(processingEnv, code); + } + + private void generateFile(ProcessingEnvironment processingEnv, String code) throws IOException { + File generatePatternFile = new File(processingEnv.getFiler() + .getResource(StandardLocation.SOURCE_OUTPUT, "org.apache.doris.nereids.pattern", + "GeneratedExpressionRelations.java").toUri()); + if (generatePatternFile.exists()) { + generatePatternFile.delete(); + } + if (!generatePatternFile.getParentFile().exists()) { + generatePatternFile.getParentFile().mkdirs(); + } + + // bypass create file for processingEnv.getFiler(), compile GeneratePatterns in next compile term + try (BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(generatePatternFile))) { + bufferedWriter.write(code); + } + } + + private Set findSuperExpression() { + Map> parentNameMap = analyzer.getParentNameMap(); + Map> childrenNameMap = analyzer.getChildrenNameMap(); + Set superExpressions = Sets.newLinkedHashSet(); + for (Entry> entry : childrenNameMap.entrySet()) { + String parentName = entry.getKey(); + Set childrenNames = entry.getValue(); + + if (parentName.startsWith("org.apache.doris.nereids.trees.expressions.")) { + for (String childrenName : childrenNames) { + Set parentNames = parentNameMap.get(childrenName); + if (parentNames != null + && parentNames.contains("org.apache.doris.nereids.trees.expressions.Expression")) { + superExpressions.add(parentName); + break; + } + } + } + } + return superExpressions; + } + + private String generateCode(Map> childrenNameMap, + Map> parentNameMap, Set superExpressions) { + String generateCode + = "// Licensed to the Apache Software Foundation (ASF) under one\n" + + "// or more contributor license agreements. See the NOTICE file\n" + + "// distributed with this work for additional information\n" + + "// regarding copyright ownership. The ASF licenses this file\n" + + "// to you under the Apache License, Version 2.0 (the\n" + + "// \"License\"); you may not use this file except in compliance\n" + + "// with the License. You may obtain a copy of the License at\n" + + "//\n" + + "// http://www.apache.org/licenses/LICENSE-2.0\n" + + "//\n" + + "// Unless required by applicable law or agreed to in writing,\n" + + "// software distributed under the License is distributed on an\n" + + "// \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n" + + "// KIND, either express or implied. See the License for the\n" + + "// specific language governing permissions and limitations\n" + + "// under the License.\n" + + "\n" + + "package org.apache.doris.nereids.pattern;\n" + + "\n" + + "import org.apache.doris.nereids.trees.expressions.Expression;\n" + + "\n" + + "import com.google.common.collect.ImmutableMap;\n" + + "import com.google.common.collect.ImmutableSet;\n" + + "\n" + + "import java.util.Map;\n" + + "import java.util.Set;\n" + + "\n"; + generateCode += "/** GeneratedExpressionRelations */\npublic class GeneratedExpressionRelations {\n"; + String childrenClassesGenericType = ", Set>>"; + generateCode += + " public static final Map" + childrenClassesGenericType + " CHILDREN_CLASS_MAP;\n\n"; + generateCode += + " static {\n" + + " ImmutableMap.Builder" + childrenClassesGenericType + " childrenClassesBuilder\n" + + " = ImmutableMap.builderWithExpectedSize(" + childrenNameMap.size() + ");\n"; + + for (String superExpression : superExpressions) { + Set childrenClasseSet = childrenNameMap.get(superExpression) + .stream() + .filter(childClass -> parentNameMap.get(childClass) + .contains("org.apache.doris.nereids.trees.expressions.Expression") + ) + .collect(Collectors.toSet()); + + List childrenClasses = Lists.newArrayList(childrenClasseSet); + Collections.sort(childrenClasses, Comparator.naturalOrder()); + + String childClassesString = childrenClasses.stream() + .map(childClass -> " " + childClass + ".class") + .collect(Collectors.joining(",\n")); + generateCode += " childrenClassesBuilder.put(\n " + superExpression + + ".class,\n ImmutableSet.>of(\n" + childClassesString + + "\n )\n );\n\n"; + } + + generateCode += " CHILDREN_CLASS_MAP = childrenClassesBuilder.build();\n"; + + return generateCode + " }\n}\n"; + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/PatternGeneratorAnalyzer.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/JavaAstAnalyzer.java similarity index 75% rename from fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/PatternGeneratorAnalyzer.java rename to fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/JavaAstAnalyzer.java index f4a9d128087ae8..cce69151ca2ab7 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/PatternGeneratorAnalyzer.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/JavaAstAnalyzer.java @@ -29,25 +29,24 @@ import com.google.common.base.Joiner; -import java.lang.reflect.Modifier; import java.util.ArrayList; import java.util.IdentityHashMap; import java.util.LinkedHashMap; import java.util.LinkedHashSet; import java.util.List; import java.util.Map; +import java.util.Map.Entry; import java.util.Optional; import java.util.Set; -import java.util.stream.Collectors; -/** - * used to analyze plan class extends hierarchy and then generated pattern builder methods. - */ -public class PatternGeneratorAnalyzer { - private final Map name2Ast = new LinkedHashMap<>(); - private final IdentityHashMap ast2Name = new IdentityHashMap<>(); - private final IdentityHashMap> ast2Import = new IdentityHashMap<>(); - private final IdentityHashMap> parentClassMap = new IdentityHashMap<>(); +/** JavaAstAnalyzer */ +public class JavaAstAnalyzer { + protected final Map name2Ast = new LinkedHashMap<>(); + protected final IdentityHashMap ast2Name = new IdentityHashMap<>(); + protected final IdentityHashMap> ast2Import = new IdentityHashMap<>(); + protected final IdentityHashMap> parentClassMap = new IdentityHashMap<>(); + protected final Map> parentNameMap = new LinkedHashMap<>(); + protected final Map> childrenNameMap = new LinkedHashMap<>(); /** add java AST. */ public void addAsts(List typeDeclarations) { @@ -56,14 +55,20 @@ public void addAsts(List typeDeclarations) { } } - /** generate pattern methods. */ - public String generatePatterns(String className, String parentClassName, boolean isMemoPattern) { - analyzeImport(); - analyzeParentClass(); - return doGenerate(className, parentClassName, isMemoPattern); + public IdentityHashMap> getParentClassMap() { + return parentClassMap; + } + + public Map> getParentNameMap() { + return parentNameMap; } - Optional getType(TypeDeclaration typeDeclaration, TypeType type) { + public Map> getChildrenNameMap() { + return childrenNameMap; + } + + /** getType */ + public Optional getType(TypeDeclaration typeDeclaration, TypeType type) { String typeName = analyzeClass(new LinkedHashSet<>(), typeDeclaration, type); if (typeName != null) { TypeDeclaration ast = name2Ast.get(typeName); @@ -73,34 +78,11 @@ Optional getType(TypeDeclaration typeDeclaration, TypeType type return Optional.empty(); } - private String doGenerate(String className, String parentClassName, boolean isMemoPattern) { - Map> planClassMap = parentClassMap.entrySet().stream() - .filter(kv -> kv.getValue().contains("org.apache.doris.nereids.trees.plans.Plan")) - .filter(kv -> !kv.getKey().name.equals("GroupPlan")) - .filter(kv -> !Modifier.isAbstract(kv.getKey().modifiers.mod) - && kv.getKey() instanceof ClassDeclaration) - .collect(Collectors.toMap(kv -> (ClassDeclaration) kv.getKey(), kv -> kv.getValue())); - - List generators = planClassMap.entrySet() - .stream() - .map(kv -> PatternGenerator.create(this, kv.getKey(), kv.getValue(), isMemoPattern)) - .filter(Optional::isPresent) - .map(Optional::get) - .sorted((g1, g2) -> { - // logical first - if (g1.isLogical() != g2.isLogical()) { - return g1.isLogical() ? -1 : 1; - } - // leaf first - if (g1.childrenNum() != g2.childrenNum()) { - return g1.childrenNum() - g2.childrenNum(); - } - // string dict sort - return g1.opType.name.compareTo(g2.opType.name); - }) - .collect(Collectors.toList()); - - return PatternGenerator.generateCode(className, parentClassName, generators, this, isMemoPattern); + protected void analyze() { + analyzeImport(); + analyzeParentClass(); + analyzeParentName(); + analyzeChildrenName(); } private void analyzeImport() { @@ -148,7 +130,28 @@ private void analyzeParentClass(Set parentClasses, TypeDeclaration typeD parentClasses.addAll(currentParentClasses); } - String analyzeClass(Set parentClasses, TypeDeclaration typeDeclaration, TypeType type) { + private void analyzeParentName() { + for (Entry> entry : parentClassMap.entrySet()) { + String parentName = entry.getKey().getFullQualifiedName(); + parentNameMap.put(parentName, entry.getValue()); + } + } + + private void analyzeChildrenName() { + for (Entry entry : name2Ast.entrySet()) { + Set parentNames = parentClassMap.get(entry.getValue()); + for (String parentName : parentNames) { + Set childrenNames = childrenNameMap.get(parentName); + if (childrenNames == null) { + childrenNames = new LinkedHashSet<>(); + childrenNameMap.put(parentName, childrenNames); + } + childrenNames.add(entry.getKey()); + } + } + } + + private String analyzeClass(Set parentClasses, TypeDeclaration typeDeclaration, TypeType type) { if (type.classOrInterfaceType.isPresent()) { List identifiers = new ArrayList<>(); ClassOrInterfaceType classOrInterfaceType = type.classOrInterfaceType.get(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/LogicalBinaryPatternGenerator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/LogicalBinaryPatternGenerator.java index bec3efa270a7f7..8e05a87ad7dc0e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/LogicalBinaryPatternGenerator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/LogicalBinaryPatternGenerator.java @@ -23,9 +23,9 @@ import java.util.TreeSet; /** used to generate pattern for LogicalBinary. */ -public class LogicalBinaryPatternGenerator extends PatternGenerator { +public class LogicalBinaryPatternGenerator extends PlanPatternGenerator { - public LogicalBinaryPatternGenerator(PatternGeneratorAnalyzer analyzer, + public LogicalBinaryPatternGenerator(PlanPatternGeneratorAnalyzer analyzer, ClassDeclaration opType, Set parentClass, boolean isMemoPattern) { super(analyzer, opType, parentClass, isMemoPattern); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/LogicalLeafPatternGenerator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/LogicalLeafPatternGenerator.java index fd7b30a8e6f112..b82ac81d42077a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/LogicalLeafPatternGenerator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/LogicalLeafPatternGenerator.java @@ -23,9 +23,9 @@ import java.util.TreeSet; /** used to generate pattern for LogicalLeaf. */ -public class LogicalLeafPatternGenerator extends PatternGenerator { +public class LogicalLeafPatternGenerator extends PlanPatternGenerator { - public LogicalLeafPatternGenerator(PatternGeneratorAnalyzer analyzer, + public LogicalLeafPatternGenerator(PlanPatternGeneratorAnalyzer analyzer, ClassDeclaration opType, Set parentClass, boolean isMemoPattern) { super(analyzer, opType, parentClass, isMemoPattern); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/LogicalUnaryPatternGenerator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/LogicalUnaryPatternGenerator.java index 8ecb7c14e1005c..d2f2b61bf96d71 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/LogicalUnaryPatternGenerator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/LogicalUnaryPatternGenerator.java @@ -23,9 +23,9 @@ import java.util.TreeSet; /** used to generate pattern for LogicalUnary. */ -public class LogicalUnaryPatternGenerator extends PatternGenerator { +public class LogicalUnaryPatternGenerator extends PlanPatternGenerator { - public LogicalUnaryPatternGenerator(PatternGeneratorAnalyzer analyzer, + public LogicalUnaryPatternGenerator(PlanPatternGeneratorAnalyzer analyzer, ClassDeclaration opType, Set parentClass, boolean isMemoPattern) { super(analyzer, opType, parentClass, isMemoPattern); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/PatternDescribableProcessor.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/PatternDescribableProcessor.java index 42cf82e3c01414..5ba81bbb96bc93 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/PatternDescribableProcessor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/PatternDescribableProcessor.java @@ -60,12 +60,12 @@ @SupportedSourceVersion(SourceVersion.RELEASE_8) @SupportedAnnotationTypes("org.apache.doris.nereids.pattern.generator.PatternDescribable") public class PatternDescribableProcessor extends AbstractProcessor { - private List planPaths; + private List paths; @Override public synchronized void init(ProcessingEnvironment processingEnv) { super.init(processingEnv); - this.planPaths = Arrays.stream(processingEnv.getOptions().get("planPath").split(",")) + this.paths = Arrays.stream(processingEnv.getOptions().get("path").split(",")) .map(path -> path.trim()) .filter(path -> !path.isEmpty()) .collect(Collectors.toSet()) @@ -80,15 +80,25 @@ public boolean process(Set annotations, RoundEnvironment return false; } try { - List planFiles = findJavaFiles(planPaths); - PatternGeneratorAnalyzer patternGeneratorAnalyzer = new PatternGeneratorAnalyzer(); - for (File file : planFiles) { + List javaFiles = findJavaFiles(paths); + JavaAstAnalyzer javaAstAnalyzer = new JavaAstAnalyzer(); + for (File file : javaFiles) { List asts = parseJavaFile(file); - patternGeneratorAnalyzer.addAsts(asts); + javaAstAnalyzer.addAsts(asts); } - doGenerate("GeneratedMemoPatterns", "MemoPatterns", true, patternGeneratorAnalyzer); - doGenerate("GeneratedPlanPatterns", "PlanPatterns", false, patternGeneratorAnalyzer); + javaAstAnalyzer.analyze(); + + ExpressionTypeMappingGenerator expressionTypeMappingGenerator + = new ExpressionTypeMappingGenerator(javaAstAnalyzer); + expressionTypeMappingGenerator.generate(processingEnv); + + PlanTypeMappingGenerator planTypeMappingGenerator = new PlanTypeMappingGenerator(javaAstAnalyzer); + planTypeMappingGenerator.generate(processingEnv); + + PlanPatternGeneratorAnalyzer patternGeneratorAnalyzer = new PlanPatternGeneratorAnalyzer(javaAstAnalyzer); + generatePlanPatterns("GeneratedMemoPatterns", "MemoPatterns", true, patternGeneratorAnalyzer); + generatePlanPatterns("GeneratedPlanPatterns", "PlanPatterns", false, patternGeneratorAnalyzer); } catch (Throwable t) { String exceptionMsg = Throwables.getStackTraceAsString(t); processingEnv.getMessager().printMessage(Kind.ERROR, @@ -97,8 +107,12 @@ public boolean process(Set annotations, RoundEnvironment return false; } - private void doGenerate(String className, String parentClassName, boolean isMemoPattern, - PatternGeneratorAnalyzer patternGeneratorAnalyzer) throws IOException { + private void generateExpressionTypeMapping() { + + } + + private void generatePlanPatterns(String className, String parentClassName, boolean isMemoPattern, + PlanPatternGeneratorAnalyzer patternGeneratorAnalyzer) throws IOException { String generatePatternCode = patternGeneratorAnalyzer.generatePatterns( className, parentClassName, isMemoPattern); File generatePatternFile = new File(processingEnv.getFiler() diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/PhysicalBinaryPatternGenerator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/PhysicalBinaryPatternGenerator.java index 72a315574952ac..08e639a924dad3 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/PhysicalBinaryPatternGenerator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/PhysicalBinaryPatternGenerator.java @@ -23,9 +23,9 @@ import java.util.TreeSet; /** used to generate pattern for PhysicalBinary. */ -public class PhysicalBinaryPatternGenerator extends PatternGenerator { +public class PhysicalBinaryPatternGenerator extends PlanPatternGenerator { - public PhysicalBinaryPatternGenerator(PatternGeneratorAnalyzer analyzer, + public PhysicalBinaryPatternGenerator(PlanPatternGeneratorAnalyzer analyzer, ClassDeclaration opType, Set parentClass, boolean isMemoPattern) { super(analyzer, opType, parentClass, isMemoPattern); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/PhysicalLeafPatternGenerator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/PhysicalLeafPatternGenerator.java index f75746b5142f20..27a94edacad2b9 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/PhysicalLeafPatternGenerator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/PhysicalLeafPatternGenerator.java @@ -23,9 +23,9 @@ import java.util.TreeSet; /** used to generate pattern for PhysicalLeaf. */ -public class PhysicalLeafPatternGenerator extends PatternGenerator { +public class PhysicalLeafPatternGenerator extends PlanPatternGenerator { - public PhysicalLeafPatternGenerator(PatternGeneratorAnalyzer analyzer, + public PhysicalLeafPatternGenerator(PlanPatternGeneratorAnalyzer analyzer, ClassDeclaration opType, Set parentClass, boolean isMemoPattern) { super(analyzer, opType, parentClass, isMemoPattern); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/PhysicalUnaryPatternGenerator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/PhysicalUnaryPatternGenerator.java index 4254e28ee4371f..f69de7e9d6a123 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/PhysicalUnaryPatternGenerator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/PhysicalUnaryPatternGenerator.java @@ -23,9 +23,9 @@ import java.util.TreeSet; /** used to generate pattern for PhysicalUnary. */ -public class PhysicalUnaryPatternGenerator extends PatternGenerator { +public class PhysicalUnaryPatternGenerator extends PlanPatternGenerator { - public PhysicalUnaryPatternGenerator(PatternGeneratorAnalyzer analyzer, + public PhysicalUnaryPatternGenerator(PlanPatternGeneratorAnalyzer analyzer, ClassDeclaration opType, Set parentClass, boolean isMemoPattern) { super(analyzer, opType, parentClass, isMemoPattern); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/PatternGenerator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/PlanPatternGenerator.java similarity index 96% rename from fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/PatternGenerator.java rename to fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/PlanPatternGenerator.java index 75c950f8c82bd4..b94c9f489e628c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/PatternGenerator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/PlanPatternGenerator.java @@ -43,8 +43,8 @@ import java.util.stream.Collectors; /** used to generate pattern by plan. */ -public abstract class PatternGenerator { - protected final PatternGeneratorAnalyzer analyzer; +public abstract class PlanPatternGenerator { + protected final JavaAstAnalyzer analyzer; protected final ClassDeclaration opType; protected final Set parentClass; protected final List enumFieldPatternInfos; @@ -52,9 +52,9 @@ public abstract class PatternGenerator { protected final boolean isMemoPattern; /** constructor. */ - public PatternGenerator(PatternGeneratorAnalyzer analyzer, ClassDeclaration opType, + public PlanPatternGenerator(PlanPatternGeneratorAnalyzer analyzer, ClassDeclaration opType, Set parentClass, boolean isMemoPattern) { - this.analyzer = analyzer; + this.analyzer = analyzer.getAnalyzer(); this.opType = opType; this.parentClass = parentClass; this.enumFieldPatternInfos = getEnumFieldPatternInfos(); @@ -76,8 +76,8 @@ public String getPatternMethodName() { } /** generate code by generators and analyzer. */ - public static String generateCode(String className, String parentClassName, List generators, - PatternGeneratorAnalyzer analyzer, boolean isMemoPattern) { + public static String generateCode(String className, String parentClassName, List generators, + PlanPatternGeneratorAnalyzer analyzer, boolean isMemoPattern) { String generateCode = "// Licensed to the Apache Software Foundation (ASF) under one\n" + "// or more contributor license agreements. See the NOTICE file\n" @@ -206,7 +206,7 @@ protected String childType() { } /** create generator by plan's type. */ - public static Optional create(PatternGeneratorAnalyzer analyzer, + public static Optional create(PlanPatternGeneratorAnalyzer analyzer, ClassDeclaration opType, Set parentClass, boolean isMemoPattern) { if (parentClass.contains("org.apache.doris.nereids.trees.plans.logical.LogicalLeaf")) { return Optional.of(new LogicalLeafPatternGenerator(analyzer, opType, parentClass, isMemoPattern)); @@ -225,9 +225,9 @@ public static Optional create(PatternGeneratorAnalyzer analyze } } - private static String generateImports(List generators) { + private static String generateImports(List generators) { Set imports = new HashSet<>(); - for (PatternGenerator generator : generators) { + for (PlanPatternGenerator generator : generators) { imports.addAll(generator.getImports()); } List sortedImports = new ArrayList<>(imports); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/PlanPatternGeneratorAnalyzer.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/PlanPatternGeneratorAnalyzer.java new file mode 100644 index 00000000000000..99d7c308dacf0d --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/PlanPatternGeneratorAnalyzer.java @@ -0,0 +1,73 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.pattern.generator; + +import org.apache.doris.nereids.pattern.generator.javaast.ClassDeclaration; + +import java.lang.reflect.Modifier; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * used to analyze plan class extends hierarchy and then generated pattern builder methods. + */ +public class PlanPatternGeneratorAnalyzer { + private final JavaAstAnalyzer analyzer; + + public PlanPatternGeneratorAnalyzer(JavaAstAnalyzer analyzer) { + this.analyzer = analyzer; + } + + public JavaAstAnalyzer getAnalyzer() { + return analyzer; + } + + /** generate pattern methods. */ + public String generatePatterns(String className, String parentClassName, boolean isMemoPattern) { + Map> planClassMap = analyzer.getParentClassMap().entrySet().stream() + .filter(kv -> kv.getValue().contains("org.apache.doris.nereids.trees.plans.Plan")) + .filter(kv -> !kv.getKey().name.equals("GroupPlan")) + .filter(kv -> !Modifier.isAbstract(kv.getKey().modifiers.mod) + && kv.getKey() instanceof ClassDeclaration) + .collect(Collectors.toMap(kv -> (ClassDeclaration) kv.getKey(), kv -> kv.getValue())); + + List generators = planClassMap.entrySet() + .stream() + .map(kv -> PlanPatternGenerator.create(this, kv.getKey(), kv.getValue(), isMemoPattern)) + .filter(Optional::isPresent) + .map(Optional::get) + .sorted((g1, g2) -> { + // logical first + if (g1.isLogical() != g2.isLogical()) { + return g1.isLogical() ? -1 : 1; + } + // leaf first + if (g1.childrenNum() != g2.childrenNum()) { + return g1.childrenNum() - g2.childrenNum(); + } + // string dict sort + return g1.opType.name.compareTo(g2.opType.name); + }) + .collect(Collectors.toList()); + + return PlanPatternGenerator.generateCode(className, parentClassName, generators, this, isMemoPattern); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/PlanTypeMappingGenerator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/PlanTypeMappingGenerator.java new file mode 100644 index 00000000000000..c3b6c765d49383 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/generator/PlanTypeMappingGenerator.java @@ -0,0 +1,159 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.pattern.generator; + +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; + +import java.io.BufferedWriter; +import java.io.File; +import java.io.FileWriter; +import java.io.IOException; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Set; +import java.util.stream.Collectors; +import javax.annotation.processing.ProcessingEnvironment; +import javax.tools.StandardLocation; + +/** PlanTypeMappingGenerator */ +public class PlanTypeMappingGenerator { + private final JavaAstAnalyzer analyzer; + + public PlanTypeMappingGenerator(JavaAstAnalyzer javaAstAnalyzer) { + this.analyzer = javaAstAnalyzer; + } + + public JavaAstAnalyzer getAnalyzer() { + return analyzer; + } + + /** generate */ + public void generate(ProcessingEnvironment processingEnv) throws IOException { + Set superPlans = findSuperPlan(); + Map> childrenNameMap = analyzer.getChildrenNameMap(); + Map> parentNameMap = analyzer.getParentNameMap(); + String code = generateCode(childrenNameMap, parentNameMap, superPlans); + generateFile(processingEnv, code); + } + + private void generateFile(ProcessingEnvironment processingEnv, String code) throws IOException { + File generatePatternFile = new File(processingEnv.getFiler() + .getResource(StandardLocation.SOURCE_OUTPUT, "org.apache.doris.nereids.pattern", + "GeneratedPlanRelations.java").toUri()); + if (generatePatternFile.exists()) { + generatePatternFile.delete(); + } + if (!generatePatternFile.getParentFile().exists()) { + generatePatternFile.getParentFile().mkdirs(); + } + + // bypass create file for processingEnv.getFiler(), compile GeneratePatterns in next compile term + try (BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(generatePatternFile))) { + bufferedWriter.write(code); + } + } + + private Set findSuperPlan() { + Map> parentNameMap = analyzer.getParentNameMap(); + Map> childrenNameMap = analyzer.getChildrenNameMap(); + Set superPlans = Sets.newLinkedHashSet(); + for (Entry> entry : childrenNameMap.entrySet()) { + String parentName = entry.getKey(); + Set childrenNames = entry.getValue(); + + if (parentName.startsWith("org.apache.doris.nereids.trees.plans.")) { + for (String childrenName : childrenNames) { + Set parentNames = parentNameMap.get(childrenName); + if (parentNames != null + && parentNames.contains("org.apache.doris.nereids.trees.plans.Plan")) { + superPlans.add(parentName); + break; + } + } + } + } + return superPlans; + } + + private String generateCode(Map> childrenNameMap, + Map> parentNameMap, Set superPlans) { + String generateCode + = "// Licensed to the Apache Software Foundation (ASF) under one\n" + + "// or more contributor license agreements. See the NOTICE file\n" + + "// distributed with this work for additional information\n" + + "// regarding copyright ownership. The ASF licenses this file\n" + + "// to you under the Apache License, Version 2.0 (the\n" + + "// \"License\"); you may not use this file except in compliance\n" + + "// with the License. You may obtain a copy of the License at\n" + + "//\n" + + "// http://www.apache.org/licenses/LICENSE-2.0\n" + + "//\n" + + "// Unless required by applicable law or agreed to in writing,\n" + + "// software distributed under the License is distributed on an\n" + + "// \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n" + + "// KIND, either express or implied. See the License for the\n" + + "// specific language governing permissions and limitations\n" + + "// under the License.\n" + + "\n" + + "package org.apache.doris.nereids.pattern;\n" + + "\n" + + "import org.apache.doris.nereids.trees.plans.Plan;\n" + + "\n" + + "import com.google.common.collect.ImmutableMap;\n" + + "import com.google.common.collect.ImmutableSet;\n" + + "\n" + + "import java.util.Map;\n" + + "import java.util.Set;\n" + + "\n"; + generateCode += "/** GeneratedPlanRelations */\npublic class GeneratedPlanRelations {\n"; + String childrenClassesGenericType = ", Set>>"; + generateCode += + " public static final Map" + childrenClassesGenericType + " CHILDREN_CLASS_MAP;\n\n"; + generateCode += + " static {\n" + + " ImmutableMap.Builder" + childrenClassesGenericType + " childrenClassesBuilder\n" + + " = ImmutableMap.builderWithExpectedSize(" + childrenNameMap.size() + ");\n"; + + for (String superPlan : superPlans) { + Set childrenClasseSet = childrenNameMap.get(superPlan) + .stream() + .filter(childClass -> parentNameMap.get(childClass) + .contains("org.apache.doris.nereids.trees.plans.Plan") + ) + .collect(Collectors.toSet()); + + List childrenClasses = Lists.newArrayList(childrenClasseSet); + Collections.sort(childrenClasses, Comparator.naturalOrder()); + + String childClassesString = childrenClasses.stream() + .map(childClass -> " " + childClass + ".class") + .collect(Collectors.joining(",\n")); + generateCode += " childrenClassesBuilder.put(\n " + superPlan + + ".class,\n ImmutableSet.>of(\n" + childClassesString + + "\n )\n );\n\n"; + } + + generateCode += " CHILDREN_CLASS_MAP = childrenClassesBuilder.build();\n"; + + return generateCode + " }\n}\n"; + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterPruner.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterPruner.java index 4efafe3af90f50..fb6e54e38a8545 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterPruner.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterPruner.java @@ -195,9 +195,20 @@ private boolean isVisibleColumn(Slot slot) { @Override public PhysicalFilter visitPhysicalFilter(PhysicalFilter filter, CascadesContext context) { filter.child().accept(this, context); - boolean visibleFilter = filter.getExpressions().stream() - .flatMap(expression -> expression.getInputSlots().stream()) - .anyMatch(slot -> isVisibleColumn(slot)); + + boolean visibleFilter = false; + + for (Expression expr : filter.getExpressions()) { + for (Slot inputSlot : expr.getInputSlots()) { + if (isVisibleColumn(inputSlot)) { + visibleFilter = true; + break; + } + } + if (visibleFilter) { + break; + } + } if (visibleFilter) { // skip filters like: __DORIS_DELETE_SIGN__ = 0 context.getRuntimeFilterContext().addEffectiveSrcNode(filter, RuntimeFilterContext.EffectiveSrcType.NATIVE); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/Validator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/Validator.java index e73039e9237980..561e09ed404ad2 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/Validator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/Validator.java @@ -26,6 +26,8 @@ import org.apache.doris.nereids.trees.plans.algebra.Aggregate; import org.apache.doris.nereids.trees.plans.physical.PhysicalFilter; import org.apache.doris.nereids.trees.plans.physical.PhysicalProject; +import org.apache.doris.nereids.util.PlanUtils; +import org.apache.doris.nereids.util.Utils; import com.google.common.base.Preconditions; @@ -69,7 +71,10 @@ public Plan visitPhysicalFilter(PhysicalFilter filter, CascadesC @Override public Plan visit(Plan plan, CascadesContext context) { - plan.children().forEach(child -> child.accept(this, context)); + for (Plan child : plan.children()) { + child.accept(this, context); + } + Optional opt = checkAllSlotFromChildren(plan); if (opt.isPresent()) { List childrenOutput = plan.children().stream().flatMap(p -> p.getOutput().stream()).collect( @@ -93,8 +98,7 @@ public static Optional checkAllSlotFromChildren(Plan plan) { if (plan instanceof Aggregate) { return Optional.empty(); } - Set childOutputSet = plan.children().stream().flatMap(child -> child.getOutputSet().stream()) - .collect(Collectors.toSet()); + Set childOutputSet = Utils.fastToImmutableSet(PlanUtils.fastGetChildrenOutputs(plan.children())); Set inputSlots = plan.getInputSlots(); for (Slot slot : inputSlots) { if (slot.getName().startsWith("mv") || slot instanceof SlotNotFromChildren) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/FunctionalDependencies.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/FunctionalDependencies.java index d7b4b3b1c9f34d..c7e6030e13794c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/FunctionalDependencies.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/FunctionalDependencies.java @@ -20,6 +20,7 @@ import org.apache.doris.nereids.trees.expressions.Slot; import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Sets; import java.util.HashSet; import java.util.Map; @@ -196,12 +197,23 @@ public boolean containsAnySub(Set slotSet) { } public void removeNotContain(Set slotSet) { - slots = slots.stream() - .filter(slotSet::contains) - .collect(Collectors.toSet()); - slotSets = slotSets.stream() - .filter(slotSet::containsAll) - .collect(Collectors.toSet()); + if (!slotSet.isEmpty()) { + Set newSlots = Sets.newLinkedHashSetWithExpectedSize(slots.size()); + for (Slot slot : slots) { + if (slotSet.contains(slot)) { + newSlots.add(slot); + } + } + this.slots = newSlots; + + Set> newSlotSets = Sets.newLinkedHashSetWithExpectedSize(slots.size()); + for (ImmutableSet set : slotSets) { + if (slotSet.containsAll(set)) { + newSlotSets.add(set); + } + } + this.slotSets = newSlotSets; + } } public void add(Slot slot) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/LogicalProperties.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/LogicalProperties.java index 07d2882894288c..7c86579980c8a5 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/LogicalProperties.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/LogicalProperties.java @@ -19,7 +19,6 @@ import org.apache.doris.common.Id; import org.apache.doris.nereids.trees.expressions.ExprId; -import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; import com.google.common.base.Supplier; @@ -62,21 +61,40 @@ public LogicalProperties(Supplier> outputSupplier, this.outputSupplier = Suppliers.memoize( Objects.requireNonNull(outputSupplier, "outputSupplier can not be null") ); - this.outputExprIdsSupplier = Suppliers.memoize( - () -> this.outputSupplier.get().stream().map(NamedExpression::getExprId).map(Id.class::cast) - .collect(ImmutableList.toImmutableList()) - ); - this.outputSetSupplier = Suppliers.memoize( - () -> ImmutableSet.copyOf(this.outputSupplier.get()) - ); - this.outputMapSupplier = Suppliers.memoize( - () -> this.outputSetSupplier.get().stream().collect(ImmutableMap.toImmutableMap(s -> s, s -> s)) - ); - this.outputExprIdSetSupplier = Suppliers.memoize( - () -> this.outputSupplier.get().stream() - .map(NamedExpression::getExprId) - .collect(ImmutableSet.toImmutableSet()) - ); + this.outputExprIdsSupplier = Suppliers.memoize(() -> { + List output = this.outputSupplier.get(); + ImmutableList.Builder exprIdSet + = ImmutableList.builderWithExpectedSize(output.size()); + for (Slot slot : output) { + exprIdSet.add(slot.getExprId()); + } + return exprIdSet.build(); + }); + this.outputSetSupplier = Suppliers.memoize(() -> { + List output = outputSupplier.get(); + ImmutableSet.Builder slots = ImmutableSet.builderWithExpectedSize(output.size()); + for (Slot slot : output) { + slots.add(slot); + } + return slots.build(); + }); + this.outputMapSupplier = Suppliers.memoize(() -> { + List output = outputSupplier.get(); + ImmutableMap.Builder map = ImmutableMap.builderWithExpectedSize(output.size()); + for (Slot slot : output) { + map.put(slot, slot); + } + return map.build(); + }); + this.outputExprIdSetSupplier = Suppliers.memoize(() -> { + List output = this.outputSupplier.get(); + ImmutableSet.Builder exprIdSet + = ImmutableSet.builderWithExpectedSize(output.size()); + for (Slot slot : output) { + exprIdSet.add(slot.getExprId()); + } + return exprIdSet.build(); + }); this.fdSupplier = Suppliers.memoize( Objects.requireNonNull(fdSupplier, "FunctionalDependencies can not be null") ); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/Rule.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/Rule.java index a9b4591ad4a0a8..207dd6458c9202 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/Rule.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/Rule.java @@ -24,8 +24,8 @@ import org.apache.doris.nereids.rules.RuleType.RuleTypeClass; import org.apache.doris.nereids.trees.plans.Plan; +import java.util.BitSet; import java.util.List; -import java.util.Set; /** * Abstract class for all rules. @@ -79,8 +79,8 @@ public void acceptPlan(Plan plan) { /** * Filter out already applied rules and rules that are not matched on root node. */ - public boolean isInvalid(Set disableRules, GroupExpression groupExpression) { - return disableRules.contains(this.getRuleType().type()) + public boolean isInvalid(BitSet disableRules, GroupExpression groupExpression) { + return disableRules.get(this.getRuleType().type()) || !groupExpression.notApplied(this) || !this.getPattern().matchRoot(groupExpression.getPlan()); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/AdjustAggregateNullableForEmptySet.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/AdjustAggregateNullableForEmptySet.java index 86a70d35ccc087..5543341ae277d4 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/AdjustAggregateNullableForEmptySet.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/AdjustAggregateNullableForEmptySet.java @@ -46,21 +46,30 @@ public List buildRules() { RuleType.ADJUST_NULLABLE_FOR_AGGREGATE_SLOT.build( logicalAggregate() .then(agg -> { - List output = agg.getOutputExpressions().stream() - .map(ne -> ((NamedExpression) FunctionReplacer.INSTANCE.replace(ne, - agg.getGroupByExpressions().isEmpty()))) - .collect(ImmutableList.toImmutableList()); - return agg.withAggOutput(output); + List outputExprs = agg.getOutputExpressions(); + boolean noGroupBy = agg.getGroupByExpressions().isEmpty(); + ImmutableList.Builder newOutput + = ImmutableList.builderWithExpectedSize(outputExprs.size()); + for (NamedExpression ne : outputExprs) { + NamedExpression newExpr = + ((NamedExpression) FunctionReplacer.INSTANCE.replace(ne, noGroupBy)); + newOutput.add(newExpr); + } + return agg.withAggOutput(newOutput.build()); }) ), RuleType.ADJUST_NULLABLE_FOR_HAVING_SLOT.build( logicalHaving(logicalAggregate()) .then(having -> { - Set newConjuncts = having.getConjuncts().stream() - .map(ne -> FunctionReplacer.INSTANCE.replace(ne, - having.child().getGroupByExpressions().isEmpty())) - .collect(ImmutableSet.toImmutableSet()); - return new LogicalHaving<>(newConjuncts, having.child()); + Set conjuncts = having.getConjuncts(); + boolean noGroupBy = having.child().getGroupByExpressions().isEmpty(); + ImmutableSet.Builder newConjuncts + = ImmutableSet.builderWithExpectedSize(conjuncts.size()); + for (Expression expr : conjuncts) { + Expression newExpr = FunctionReplacer.INSTANCE.replace(expr, noGroupBy); + newConjuncts.add(newExpr); + } + return new LogicalHaving<>(newConjuncts.build(), having.child()); }) ) ); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java index c2c7f5815d9288..6211f493eaf4e4 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java @@ -333,10 +333,11 @@ private LogicalPlan bindInlineTable(MatchingContext ctx) { List relations = Lists.newArrayListWithCapacity(logicalInlineTable.getConstantExprsList().size()); for (int i = 0; i < logicalInlineTable.getConstantExprsList().size(); i++) { - if (logicalInlineTable.getConstantExprsList().get(i).stream() - .anyMatch(DefaultValueSlot.class::isInstance)) { - throw new AnalysisException("Default expression" - + " can't exist in SELECT statement at row " + (i + 1)); + for (NamedExpression constantExpr : logicalInlineTable.getConstantExprsList().get(i)) { + if (constantExpr instanceof DefaultValueSlot) { + throw new AnalysisException("Default expression" + + " can't exist in SELECT statement at row " + (i + 1)); + } } relations.add(new UnboundOneRowRelation(StatementScopeIdGenerator.newRelationId(), logicalInlineTable.getConstantExprsList().get(i))); @@ -590,7 +591,7 @@ private Plan bindFilter(MatchingContext> ctx) { SimpleExprAnalyzer analyzer = buildSimpleExprAnalyzer( filter, cascadesContext, filter.children(), true, true); ImmutableSet.Builder boundConjuncts = ImmutableSet.builderWithExpectedSize( - filter.getConjuncts().size() * 2); + filter.getConjuncts().size()); for (Expression conjunct : filter.getConjuncts()) { Expression boundConjunct = analyzer.analyze(conjunct); boundConjunct = TypeCoercionUtils.castIfNotSameType(boundConjunct, BooleanType.INSTANCE); @@ -828,15 +829,22 @@ private void checkIfOutputAliasNameDuplicatedForGroupBy(Collection e if (output.stream().noneMatch(Alias.class::isInstance)) { return; } - List aliasList = output.stream().filter(Alias.class::isInstance) - .map(Alias.class::cast).collect(Collectors.toList()); + List aliasList = ExpressionUtils.filter(output, Alias.class); List exprAliasList = ExpressionUtils.collectAll(expressions, NamedExpression.class::isInstance); - boolean isGroupByContainAlias = exprAliasList.stream().anyMatch(ne -> - aliasList.stream().anyMatch(alias -> !alias.getExprId().equals(ne.getExprId()) - && alias.getName().equals(ne.getName()))); + boolean isGroupByContainAlias = false; + for (NamedExpression ne : exprAliasList) { + for (Alias alias : aliasList) { + if (!alias.getExprId().equals(ne.getExprId()) && alias.getName().equalsIgnoreCase(ne.getName())) { + isGroupByContainAlias = true; + } + } + if (isGroupByContainAlias) { + break; + } + } if (isGroupByContainAlias && ConnectContext.get() != null diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindSlotWithPaths.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindSlotWithPaths.java index 114e4c1d12051b..714e6e48794ba0 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindSlotWithPaths.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindSlotWithPaths.java @@ -22,7 +22,6 @@ import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.NamedExpression; -import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; @@ -34,7 +33,6 @@ import java.util.ArrayList; import java.util.List; import java.util.Set; -import java.util.stream.Collectors; /** * Rule to bind slot with path in query plan. @@ -60,21 +58,18 @@ public List buildRules() { Set pathsSlots = ctx.statementContext.getAllPathsSlots(); // With new logical properties that contains new slots with paths StatementContext stmtCtx = ConnectContext.get().getStatementContext(); - List olapScanPathSlots = pathsSlots.stream().filter( - slot -> { - Preconditions.checkNotNull(stmtCtx.getRelationBySlot(slot), - "[Not implemented] Slot not found in relation map, slot ", slot); - return stmtCtx.getRelationBySlot(slot).getRelationId() - == logicalOlapScan.getRelationId(); - }).collect( - Collectors.toList()); - List newExprs = olapScanPathSlots.stream() - .map(SlotReference.class::cast) - .map(slotReference -> - new Alias(slotReference.getExprId(), - stmtCtx.getOriginalExpr(slotReference), slotReference.getName())) - .collect( - Collectors.toList()); + ImmutableList.Builder newExprsBuilder + = ImmutableList.builderWithExpectedSize(pathsSlots.size()); + for (SlotReference slot : pathsSlots) { + Preconditions.checkNotNull(stmtCtx.getRelationBySlot(slot), + "[Not implemented] Slot not found in relation map, slot ", slot); + if (stmtCtx.getRelationBySlot(slot).getRelationId() + == logicalOlapScan.getRelationId()) { + newExprsBuilder.add(new Alias(slot.getExprId(), + stmtCtx.getOriginalExpr(slot), slot.getName())); + } + } + ImmutableList newExprs = newExprsBuilder.build(); if (newExprs.isEmpty()) { return ctx.root; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAfterRewrite.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAfterRewrite.java index 754d3efa583fa5..92052bc85ed100 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAfterRewrite.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAfterRewrite.java @@ -46,6 +46,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalTopN; import org.apache.doris.nereids.trees.plans.logical.LogicalWindow; +import com.google.common.collect.ImmutableSet; import org.apache.commons.lang3.StringUtils; import java.util.List; @@ -69,42 +70,43 @@ public Rule build() { } private void checkUnexpectedExpression(Plan plan) { - if (plan.getExpressions().stream().anyMatch(e -> e.anyMatch(SubqueryExpr.class::isInstance))) { - throw new AnalysisException("Subquery is not allowed in " + plan.getType()); - } - if (!(plan instanceof Generate)) { - if (plan.getExpressions().stream().anyMatch(e -> e.anyMatch(TableGeneratingFunction.class::isInstance))) { - throw new AnalysisException("table generating function is not allowed in " + plan.getType()); - } + boolean isGenerate = plan instanceof Generate; + boolean isAgg = plan instanceof LogicalAggregate; + boolean isWindow = plan instanceof LogicalWindow; + boolean notAggAndWindow = !isAgg && !isWindow; + + for (Expression expression : plan.getExpressions()) { + expression.foreach(expr -> { + if (expr instanceof SubqueryExpr) { + throw new AnalysisException("Subquery is not allowed in " + plan.getType()); + } else if (!isGenerate && expr instanceof TableGeneratingFunction) { + throw new AnalysisException("table generating function is not allowed in " + plan.getType()); + } else if (notAggAndWindow && expr instanceof AggregateFunction) { + throw new AnalysisException("aggregate function is not allowed in " + plan.getType()); + } else if (!isAgg && expr instanceof GroupingScalarFunction) { + throw new AnalysisException("grouping scalar function is not allowed in " + plan.getType()); + } else if (!isWindow && expr instanceof WindowExpression) { + throw new AnalysisException("analytic function is not allowed in " + plan.getType()); + } + }); } - if (!(plan instanceof LogicalAggregate || plan instanceof LogicalWindow)) { - if (plan.getExpressions().stream().anyMatch(e -> e.anyMatch(AggregateFunction.class::isInstance))) { - throw new AnalysisException("aggregate function is not allowed in " + plan.getType()); + } + + private void checkAllSlotReferenceFromChildren(Plan plan) { + Set inputSlots = plan.getInputSlots(); + Set childrenOutput = plan.getChildrenOutputExprIdSet(); + + ImmutableSet.Builder notFromChildrenBuilder = ImmutableSet.builderWithExpectedSize(inputSlots.size()); + for (Slot inputSlot : inputSlots) { + if (!childrenOutput.contains(inputSlot.getExprId())) { + notFromChildrenBuilder.add(inputSlot); } } - if (!(plan instanceof LogicalAggregate)) { - if (plan.getExpressions().stream().anyMatch(e -> e.anyMatch(GroupingScalarFunction.class::isInstance))) { - throw new AnalysisException("grouping scalar function is not allowed in " + plan.getType()); - } + Set notFromChildren = notFromChildrenBuilder.build(); + if (notFromChildren.isEmpty()) { + return; } - if (!(plan instanceof LogicalWindow)) { - if (plan.getExpressions().stream().anyMatch(e -> e.anyMatch(WindowExpression.class::isInstance))) { - throw new AnalysisException("analytic function is not allowed in " + plan.getType()); - } - } - } - private void checkAllSlotReferenceFromChildren(Plan plan) { - Set notFromChildren = plan.getExpressions().stream() - .flatMap(expr -> expr.getInputSlots().stream()) - .collect(Collectors.toSet()); - Set childrenOutput = plan.children().stream() - .flatMap(child -> child.getOutput().stream()) - .map(NamedExpression::getExprId) - .collect(Collectors.toSet()); - notFromChildren = notFromChildren.stream() - .filter(s -> !childrenOutput.contains(s.getExprId())) - .collect(Collectors.toSet()); notFromChildren = removeValidSlotsNotFromChildren(notFromChildren, childrenOutput); if (!notFromChildren.isEmpty()) { if (plan.arity() != 0 && plan.child(0) instanceof LogicalAggregate) { @@ -181,17 +183,18 @@ private void checkMetricTypeIsUsedCorrectly(Plan plan) { } private void checkMatchIsUsedCorrectly(Plan plan) { - if (plan.getExpressions().stream().anyMatch( - expression -> expression instanceof Match)) { - if (plan instanceof LogicalFilter && (plan.child(0) instanceof LogicalOlapScan - || plan.child(0) instanceof LogicalDeferMaterializeOlapScan - || plan.child(0) instanceof LogicalProject + for (Expression expression : plan.getExpressions()) { + if (expression instanceof Match) { + if (plan instanceof LogicalFilter && (plan.child(0) instanceof LogicalOlapScan + || plan.child(0) instanceof LogicalDeferMaterializeOlapScan + || plan.child(0) instanceof LogicalProject && ((LogicalProject) plan.child(0)).hasPushedDownToProjectionFunctions())) { - return; - } else { - throw new AnalysisException(String.format( - "Not support match in %s in plan: %s, only support in olapScan filter", - plan.child(0), plan)); + return; + } else { + throw new AnalysisException(String.format( + "Not support match in %s in plan: %s, only support in olapScan filter", + plan.child(0), plan)); + } } } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAnalysis.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAnalysis.java index 5a310d697ac798..64fd14019bbbc9 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAnalysis.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAnalysis.java @@ -45,7 +45,6 @@ import java.util.Collections; import java.util.List; import java.util.Map; -import java.util.Optional; import java.util.Set; /** @@ -117,14 +116,16 @@ private void checkUnexpectedExpressions(Plan plan) { if (unexpectedExpressionTypes.isEmpty()) { return; } - plan.getExpressions().forEach(c -> c.foreachUp(e -> { - for (Class type : unexpectedExpressionTypes) { - if (type.isInstance(e)) { - throw new AnalysisException(plan.getType() + " can not contains " - + type.getSimpleName() + " expression: " + ((Expression) e).toSql()); + for (Expression expr : plan.getExpressions()) { + expr.foreachUp(e -> { + for (Class type : unexpectedExpressionTypes) { + if (type.isInstance(e)) { + throw new AnalysisException(plan.getType() + " can not contains " + + type.getSimpleName() + " expression: " + ((Expression) e).toSql()); + } } - } - })); + }); + } } private void checkExpressionInputTypes(Plan plan) { @@ -157,20 +158,21 @@ private void checkAggregate(LogicalAggregate aggregate) { break; } } - long distinctFunctionNum = aggregateFunctions.stream() - .filter(AggregateFunction::isDistinct) - .count(); + + long distinctFunctionNum = 0; + for (AggregateFunction aggregateFunction : aggregateFunctions) { + distinctFunctionNum += aggregateFunction.isDistinct() ? 1 : 0; + } if (distinctMultiColumns && distinctFunctionNum > 1) { throw new AnalysisException( "The query contains multi count distinct or sum distinct, each can't have multi columns"); } - Optional expr = aggregate.getGroupByExpressions().stream() - .filter(expression -> expression.containsType(AggregateFunction.class)).findFirst(); - if (expr.isPresent()) { - throw new AnalysisException( - "GROUP BY expression must not contain aggregate functions: " - + expr.get().toSql()); + for (Expression expr : aggregate.getGroupByExpressions()) { + if (expr.anyMatch(AggregateFunction.class::isInstance)) { + throw new AnalysisException( + "GROUP BY expression must not contain aggregate functions: " + expr.toSql()); + } } } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/EliminateGroupByConstant.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/EliminateGroupByConstant.java index e683153e9a2966..6b0b1f58c3e773 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/EliminateGroupByConstant.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/EliminateGroupByConstant.java @@ -60,7 +60,7 @@ public Rule build() { // because we rely on expression matching to replace subtree that same as group by expr in output // if we do constant folding before normalize aggregate, the subtree will change and matching fail // such as: select a + 1 + 2 + 3, sum(b) from t group by a + 1 + 2 - Expression foldExpression = FoldConstantRule.INSTANCE.rewrite(expression, context); + Expression foldExpression = FoldConstantRule.evaluate(expression, context); if (!foldExpression.isConstant()) { slotGroupByExprs.add(expression); } else { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java index f4c1b428d4147b..56ca1b3a8c4822 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java @@ -297,7 +297,7 @@ public Expression visitUnboundFunction(UnboundFunction unboundFunction, Expressi if (unboundFunction.isHighOrder()) { unboundFunction = bindHighOrderFunction(unboundFunction, context); } else { - unboundFunction = (UnboundFunction) rewriteChildren(this, unboundFunction, context); + unboundFunction = (UnboundFunction) super.visit(unboundFunction, context); } // bind function diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlots.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlots.java index d6c783bbe946d8..82468978a8069a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlots.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlots.java @@ -316,13 +316,18 @@ private Plan createPlan(Resolver resolver, Aggregate aggregate, } private boolean checkSort(LogicalSort logicalSort) { - return logicalSort.getOrderKeys().stream() - .map(OrderKey::getExpr) - .map(Expression::getInputSlots) - .flatMap(Set::stream) - .anyMatch(s -> !logicalSort.child().getOutputSet().contains(s)) - || logicalSort.getOrderKeys().stream() - .map(OrderKey::getExpr) - .anyMatch(e -> e.containsType(AggregateFunction.class)); + Plan child = logicalSort.child(); + for (OrderKey orderKey : logicalSort.getOrderKeys()) { + Expression expr = orderKey.getExpr(); + if (expr.anyMatch(AggregateFunction.class::isInstance)) { + return true; + } + for (Slot inputSlot : expr.getInputSlots()) { + if (!child.getOutputSet().contains(inputSlot)) { + return true; + } + } + } + return false; } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java index 1105fe2da72be1..82aa38ff7fe423 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java @@ -228,11 +228,18 @@ private LogicalPlan normalizeAgg(LogicalAggregate aggregate, Optional normalizedAggOutput = ImmutableList.builder() - .addAll(pushedGroupByExprs.stream().map(NamedExpression::toSlot).iterator()) - .addAll(normalizedAggFuncsToSlotContext - .pushDownToNamedExpression(normalizedAggFuncs)) - .build(); + + ImmutableList.Builder normalizedAggOutputBuilder + = ImmutableList.builderWithExpectedSize(groupingByExprs.size() + normalizedAggFuncs.size()); + for (NamedExpression pushedGroupByExpr : pushedGroupByExprs) { + normalizedAggOutputBuilder.add(pushedGroupByExpr.toSlot()); + } + for (AggregateFunction normalizedAggFunc : normalizedAggFuncs) { + normalizedAggOutputBuilder.add( + normalizedAggFuncsToSlotContext.pushDownToNamedExpression(normalizedAggFunc) + ); + } + List normalizedAggOutput = normalizedAggOutputBuilder.build(); // create new agg node LogicalAggregate newAggregate = @@ -245,7 +252,7 @@ private LogicalPlan normalizeAgg(LogicalAggregate aggregate, Optional project = new LogicalProject<>(upperProjects, newAggregate); if (having.isPresent()) { - if (upperProjects.stream().anyMatch(expr -> expr.anyMatch(WindowExpression.class::isInstance))) { + if (ExpressionUtils.containsWindowExpression(upperProjects)) { // when project contains window functions, in order to get the correct result // push having through project to make it the parent node of logicalAgg return project.withChildren(ImmutableList.of( diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ReplaceExpressionByChildOutput.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ReplaceExpressionByChildOutput.java index b52e2f0218d04e..cd53086f96625d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ReplaceExpressionByChildOutput.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ReplaceExpressionByChildOutput.java @@ -22,6 +22,7 @@ import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; @@ -35,7 +36,6 @@ import java.util.List; import java.util.Map; -import java.util.concurrent.atomic.AtomicBoolean; /** * replace. @@ -47,52 +47,50 @@ public List buildRules() { .add(RuleType.REPLACE_SORT_EXPRESSION_BY_CHILD_OUTPUT.build( logicalSort(logicalProject()).then(sort -> { LogicalProject project = sort.child(); - Map sMap = Maps.newHashMap(); - project.getProjects().stream() - .filter(Alias.class::isInstance) - .map(Alias.class::cast) - .forEach(p -> sMap.put(p.child(), p.toSlot())); + Map sMap = buildOutputAliasMap(project.getProjects()); return replaceSortExpression(sort, sMap); }) )) .add(RuleType.REPLACE_SORT_EXPRESSION_BY_CHILD_OUTPUT.build( logicalSort(logicalAggregate()).then(sort -> { LogicalAggregate aggregate = sort.child(); - Map sMap = Maps.newHashMap(); - aggregate.getOutputExpressions().stream() - .filter(Alias.class::isInstance) - .map(Alias.class::cast) - .forEach(p -> sMap.put(p.child(), p.toSlot())); + Map sMap = buildOutputAliasMap(aggregate.getOutputExpressions()); return replaceSortExpression(sort, sMap); }) )).add(RuleType.REPLACE_SORT_EXPRESSION_BY_CHILD_OUTPUT.build( logicalSort(logicalHaving(logicalAggregate())).then(sort -> { LogicalAggregate aggregate = sort.child().child(); - Map sMap = Maps.newHashMap(); - aggregate.getOutputExpressions().stream() - .filter(Alias.class::isInstance) - .map(Alias.class::cast) - .forEach(p -> sMap.put(p.child(), p.toSlot())); + Map sMap = buildOutputAliasMap(aggregate.getOutputExpressions()); return replaceSortExpression(sort, sMap); }) )) .build(); } + private Map buildOutputAliasMap(List output) { + Map sMap = Maps.newHashMapWithExpectedSize(output.size()); + for (NamedExpression expr : output) { + if (expr instanceof Alias) { + Alias alias = (Alias) expr; + sMap.put(alias.child(), alias.toSlot()); + } + } + return sMap; + } + private LogicalPlan replaceSortExpression(LogicalSort sort, Map sMap) { List orderKeys = sort.getOrderKeys(); - AtomicBoolean changed = new AtomicBoolean(false); - List newKeys = orderKeys.stream().map(k -> { + + boolean changed = false; + ImmutableList.Builder newKeys = ImmutableList.builderWithExpectedSize(orderKeys.size()); + for (OrderKey k : orderKeys) { Expression newExpr = ExpressionUtils.replace(k.getExpr(), sMap); if (newExpr != k.getExpr()) { - changed.set(true); + changed = true; } - return new OrderKey(newExpr, k.isAsc(), k.isNullFirst()); - }).collect(ImmutableList.toImmutableList()); - if (changed.get()) { - return new LogicalSort<>(newKeys, sort.child()); - } else { - return sort; + newKeys.add(new OrderKey(newExpr, k.isAsc(), k.isNullFirst())); } + + return changed ? new LogicalSort<>(newKeys.build(), sort.child()) : sort; } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SubqueryToApply.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SubqueryToApply.java index b0f78be54a24aa..cfc5b2ba24a11b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SubqueryToApply.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SubqueryToApply.java @@ -21,6 +21,7 @@ import org.apache.doris.nereids.StatementContext; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; import org.apache.doris.nereids.rules.expression.rules.TrySimplifyPredicateWithMarkJoinSlot; import org.apache.doris.nereids.trees.TreeNode; import org.apache.doris.nereids.trees.expressions.Alias; @@ -51,6 +52,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import org.apache.doris.nereids.trees.plans.logical.LogicalSort; import org.apache.doris.nereids.util.ExpressionUtils; +import org.apache.doris.nereids.util.Utils; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; @@ -77,24 +79,21 @@ public List buildRules() { RuleType.FILTER_SUBQUERY_TO_APPLY.build( logicalFilter().thenApply(ctx -> { LogicalFilter filter = ctx.root; - ImmutableList> subqueryExprsList = filter.getConjuncts().stream() - .>map(e -> e.collect(SubqueryToApply::canConvertToSupply)) - .collect(ImmutableList.toImmutableList()); - if (subqueryExprsList.stream() - .flatMap(Collection::stream).noneMatch(SubqueryExpr.class::isInstance)) { + + Set conjuncts = filter.getConjuncts(); + CollectSubquerys collectSubquerys = collectSubquerys(conjuncts); + if (!collectSubquerys.hasSubquery) { return filter; } - ImmutableList shouldOutputMarkJoinSlot = - filter.getConjuncts().stream() - .map(expr -> !(expr instanceof SubqueryExpr) - && expr.containsType(SubqueryExpr.class)) - .collect(ImmutableList.toImmutableList()); - List oldConjuncts = ImmutableList.copyOf(filter.getConjuncts()); - ImmutableList.Builder newConjuncts = new ImmutableList.Builder<>(); + List shouldOutputMarkJoinSlot = shouldOutputMarkJoinSlot(conjuncts); + + List oldConjuncts = Utils.fastToImmutableList(conjuncts); + ImmutableSet.Builder newConjuncts = new ImmutableSet.Builder<>(); LogicalPlan applyPlan = null; LogicalPlan tmpPlan = (LogicalPlan) filter.child(); + List> subqueryExprsList = collectSubquerys.subqueies; // Subquery traversal with the conjunct of and as the granularity. for (int i = 0; i < subqueryExprsList.size(); ++i) { Set subqueryExprs = subqueryExprsList.get(i); @@ -119,9 +118,11 @@ public List buildRules() { * if it's semi join with non-null mark slot * we can safely change the mark conjunct to hash conjunct */ + ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(ctx.cascadesContext); boolean isMarkSlotNotNull = conjunct.containsType(MarkJoinSlotReference.class) ? ExpressionUtils.canInferNotNullForMarkSlot( - TrySimplifyPredicateWithMarkJoinSlot.INSTANCE.rewrite(conjunct, null)) + TrySimplifyPredicateWithMarkJoinSlot.INSTANCE.rewrite(conjunct, + rewriteContext), rewriteContext) : false; applyPlan = subqueryToApply(subqueryExprs.stream() @@ -132,21 +133,22 @@ public List buildRules() { tmpPlan = applyPlan; newConjuncts.add(conjunct); } - Set conjuncts = ImmutableSet.copyOf(newConjuncts.build()); - Plan newFilter = new LogicalFilter<>(conjuncts, applyPlan); + Plan newFilter = new LogicalFilter<>(newConjuncts.build(), applyPlan); return new LogicalProject<>(filter.getOutput().stream().collect(ImmutableList.toImmutableList()), newFilter); }) ), RuleType.PROJECT_SUBQUERY_TO_APPLY.build(logicalProject().thenApply(ctx -> { LogicalProject project = ctx.root; - ImmutableList> subqueryExprsList = project.getProjects().stream() - .>map(e -> e.collect(SubqueryToApply::canConvertToSupply)) - .collect(ImmutableList.toImmutableList()); - if (subqueryExprsList.stream().flatMap(Collection::stream).count() == 0) { + + List projects = project.getProjects(); + CollectSubquerys collectSubquerys = collectSubquerys(projects); + if (!collectSubquerys.hasSubquery) { return project; } - List oldProjects = ImmutableList.copyOf(project.getProjects()); + + List> subqueryExprsList = collectSubquerys.subqueies; + List oldProjects = ImmutableList.copyOf(projects); ImmutableList.Builder newProjects = new ImmutableList.Builder<>(); LogicalPlan childPlan = (LogicalPlan) project.child(); LogicalPlan applyPlan; @@ -166,7 +168,7 @@ public List buildRules() { replaceSubquery.replace(oldProjects.get(i), context); applyPlan = subqueryToApply( - subqueryExprs.stream().collect(ImmutableList.toImmutableList()), + Utils.fastToImmutableList(subqueryExprs), childPlan, context.getSubqueryToMarkJoinSlot(), ctx.cascadesContext, Optional.of(newProject), true, false); @@ -240,9 +242,11 @@ public List buildRules() { * if it's semi join with non-null mark slot * we can safely change the mark conjunct to hash conjunct */ + ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(ctx.cascadesContext); boolean isMarkSlotNotNull = conjunct.containsType(MarkJoinSlotReference.class) ? ExpressionUtils.canInferNotNullForMarkSlot( - TrySimplifyPredicateWithMarkJoinSlot.INSTANCE.rewrite(conjunct, null)) + TrySimplifyPredicateWithMarkJoinSlot.INSTANCE.rewrite(conjunct, rewriteContext), + rewriteContext) : false; applyPlan = subqueryToApply( subqueryExprs.stream().collect(ImmutableList.toImmutableList()), @@ -566,4 +570,33 @@ private boolean shouldOutputMarkJoinSlot(Expression expr, SearchState searchStat } return false; } + + private List shouldOutputMarkJoinSlot(Collection conjuncts) { + ImmutableList.Builder result = ImmutableList.builderWithExpectedSize(conjuncts.size()); + for (Expression expr : conjuncts) { + result.add(!(expr instanceof SubqueryExpr) && expr.containsType(SubqueryExpr.class)); + } + return result.build(); + } + + private CollectSubquerys collectSubquerys(Collection exprs) { + boolean hasSubqueryExpr = false; + ImmutableList.Builder> subqueryExprsListBuilder = ImmutableList.builder(); + for (Expression expression : exprs) { + Set subqueries = expression.collect(SubqueryToApply::canConvertToSupply); + hasSubqueryExpr |= !subqueries.isEmpty(); + subqueryExprsListBuilder.add(subqueries); + } + return new CollectSubquerys(subqueryExprsListBuilder.build(), hasSubqueryExpr); + } + + private static class CollectSubquerys { + final List> subqueies; + final boolean hasSubquery; + + public CollectSubquerys(List> subqueies, boolean hasSubquery) { + this.subqueies = subqueies; + this.hasSubquery = hasSubquery; + } + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionBottomUpRewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionBottomUpRewriter.java new file mode 100644 index 00000000000000..932446ce48b16d --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionBottomUpRewriter.java @@ -0,0 +1,124 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.expression; + +import org.apache.doris.nereids.pattern.ExpressionPatternRules; +import org.apache.doris.nereids.pattern.ExpressionPatternTraverseListeners; +import org.apache.doris.nereids.pattern.ExpressionPatternTraverseListeners.CombinedListener; +import org.apache.doris.nereids.trees.expressions.Expression; + +import com.google.common.collect.ImmutableList; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +import java.util.Optional; +import java.util.concurrent.atomic.AtomicInteger; +import javax.annotation.Nullable; + +/** ExpressionBottomUpRewriter */ +public class ExpressionBottomUpRewriter implements ExpressionRewriteRule { + public static final String BATCH_ID_KEY = "batch_id"; + private static final Logger LOG = LogManager.getLogger(ExpressionBottomUpRewriter.class); + private static final AtomicInteger rewriteBatchId = new AtomicInteger(); + private final ExpressionPatternRules rules; + private final ExpressionPatternTraverseListeners listeners; + + public ExpressionBottomUpRewriter(ExpressionPatternRules rules, ExpressionPatternTraverseListeners listeners) { + this.rules = rules; + this.listeners = listeners; + } + + // entrance + @Override + public Expression rewrite(Expression expr, ExpressionRewriteContext ctx) { + int currentBatch = rewriteBatchId.incrementAndGet(); + return rewriteBottomUp(expr, ctx, currentBatch, null, rules, listeners); + } + + private static Expression rewriteBottomUp( + Expression expression, ExpressionRewriteContext context, int currentBatch, @Nullable Expression parent, + ExpressionPatternRules rules, ExpressionPatternTraverseListeners listeners) { + + Optional rewriteState = expression.getMutableState(BATCH_ID_KEY); + if (!rewriteState.isPresent() || rewriteState.get() != currentBatch) { + CombinedListener listener = null; + boolean hasChildren = expression.arity() > 0; + if (hasChildren) { + listener = listeners.matchesAndCombineListeners(expression, context, parent); + if (listener != null) { + listener.onEnter(); + } + } + + Expression afterRewrite = expression; + try { + Expression beforeRewrite; + afterRewrite = rewriteChildren(expression, context, currentBatch, rules, listeners); + // use rewriteTimes to avoid dead loop + int rewriteTimes = 0; + boolean changed; + do { + beforeRewrite = afterRewrite; + + // rewrite this + Optional applied = rules.matchesAndApply(beforeRewrite, context, parent); + + changed = applied.isPresent(); + if (changed) { + afterRewrite = applied.get(); + // ensure children are rewritten + afterRewrite = rewriteChildren(afterRewrite, context, currentBatch, rules, listeners); + } + rewriteTimes++; + } while (changed && rewriteTimes < 100); + + // set rewritten + afterRewrite.setMutableState(BATCH_ID_KEY, currentBatch); + } finally { + if (hasChildren && listener != null) { + listener.onExit(afterRewrite); + } + } + + return afterRewrite; + } + + // already rewritten + return expression; + } + + private static Expression rewriteChildren(Expression parent, ExpressionRewriteContext context, int currentBatch, + ExpressionPatternRules rules, ExpressionPatternTraverseListeners listeners) { + boolean changed = false; + ImmutableList.Builder newChildren = ImmutableList.builderWithExpectedSize(parent.arity()); + for (Expression child : parent.children()) { + Expression newChild = rewriteBottomUp(child, context, currentBatch, parent, rules, listeners); + changed |= !child.equals(newChild); + newChildren.add(newChild); + } + + Expression result = parent; + if (changed) { + result = parent.withChildren(newChildren.build()); + } + if (changed && context.cascadesContext.isEnableExprTrace()) { + LOG.info("WithChildren: \nbefore: " + parent + "\nafter: " + result); + } + return result; + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionListenerMatcher.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionListenerMatcher.java new file mode 100644 index 00000000000000..ea67d14e8fe6ae --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionListenerMatcher.java @@ -0,0 +1,41 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.expression; + +import org.apache.doris.nereids.trees.expressions.Expression; + +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Objects; +import java.util.function.Predicate; + +/** ExpressionListenerMatcher */ +public class ExpressionListenerMatcher { + public final Class typePattern; + public final List>> predicates; + public final ExpressionTraverseListener listener; + + public ExpressionListenerMatcher(Class typePattern, + List>> predicates, + ExpressionTraverseListener listener) { + this.typePattern = Objects.requireNonNull(typePattern, "typePattern can not be null"); + this.predicates = predicates == null ? ImmutableList.of() : predicates; + this.listener = Objects.requireNonNull(listener, "listener can not be null"); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionMatchingAction.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionMatchingAction.java new file mode 100644 index 00000000000000..a28b96079b51a8 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionMatchingAction.java @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.expression; + +import org.apache.doris.nereids.trees.expressions.Expression; + +/** ExpressionMatchAction */ +public interface ExpressionMatchingAction { + Expression apply(ExpressionMatchingContext context); +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionMatchingContext.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionMatchingContext.java new file mode 100644 index 00000000000000..953815ad87c5c2 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionMatchingContext.java @@ -0,0 +1,46 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.expression; + +import org.apache.doris.nereids.CascadesContext; +import org.apache.doris.nereids.trees.expressions.Expression; + +import java.util.Optional; + +/** ExpressionMatchingContext */ +public class ExpressionMatchingContext { + public final E expr; + public final Optional parent; + public final ExpressionRewriteContext rewriteContext; + public final CascadesContext cascadesContext; + + public ExpressionMatchingContext(E expr, Expression parent, ExpressionRewriteContext context) { + this.expr = expr; + this.parent = Optional.ofNullable(parent); + this.rewriteContext = context; + this.cascadesContext = context.cascadesContext; + } + + public boolean isRoot() { + return !parent.isPresent(); + } + + public Expression parentOr(Expression defaultParent) { + return parent.orElse(defaultParent); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionNormalization.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionNormalization.java index 9886cb1787e9ed..adf0cb90a958c1 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionNormalization.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionNormalization.java @@ -42,20 +42,21 @@ public class ExpressionNormalization extends ExpressionRewrite { // we should run supportJavaDateFormatter before foldConstantRule or be will fold // from_unixtime(timestamp, 'yyyyMMdd') to 'yyyyMMdd' public static final List NORMALIZE_REWRITE_RULES = ImmutableList.of( - SupportJavaDateFormatter.INSTANCE, - ReplaceVariableByLiteral.INSTANCE, - NormalizeBinaryPredicatesRule.INSTANCE, - InPredicateDedup.INSTANCE, - InPredicateToEqualToRule.INSTANCE, - SimplifyNotExprRule.INSTANCE, - SimplifyArithmeticRule.INSTANCE, - FoldConstantRule.INSTANCE, - SimplifyCastRule.INSTANCE, - DigitalMaskingConvert.INSTANCE, - SimplifyArithmeticComparisonRule.INSTANCE, - SupportJavaDateFormatter.INSTANCE, - ConvertAggStateCast.INSTANCE, - CheckCast.INSTANCE + bottomUp( + ReplaceVariableByLiteral.INSTANCE, + SupportJavaDateFormatter.INSTANCE, + NormalizeBinaryPredicatesRule.INSTANCE, + InPredicateDedup.INSTANCE, + InPredicateToEqualToRule.INSTANCE, + SimplifyNotExprRule.INSTANCE, + SimplifyArithmeticRule.INSTANCE, + FoldConstantRule.INSTANCE, + SimplifyCastRule.INSTANCE, + DigitalMaskingConvert.INSTANCE, + SimplifyArithmeticComparisonRule.INSTANCE, + ConvertAggStateCast.INSTANCE, + CheckCast.INSTANCE + ) ); public ExpressionNormalization() { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionNormalizationAndOptimization.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionNormalizationAndOptimization.java new file mode 100644 index 00000000000000..d694062ef1f049 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionNormalizationAndOptimization.java @@ -0,0 +1,33 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.expression; + +import com.google.common.collect.ImmutableList; + +/** ExpressionNormalizationAndOptimization */ +public class ExpressionNormalizationAndOptimization extends ExpressionRewrite { + /** ExpressionNormalizationAndOptimization */ + public ExpressionNormalizationAndOptimization() { + super(new ExpressionRuleExecutor( + ImmutableList.builder() + .addAll(ExpressionNormalization.NORMALIZE_REWRITE_RULES) + .addAll(ExpressionOptimization.OPTIMIZE_REWRITE_RULES) + .build() + )); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionOptimization.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionOptimization.java index fdf9820c582f56..b3bb18163ea2eb 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionOptimization.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionOptimization.java @@ -39,18 +39,20 @@ */ public class ExpressionOptimization extends ExpressionRewrite { public static final List OPTIMIZE_REWRITE_RULES = ImmutableList.of( - ExtractCommonFactorRule.INSTANCE, - DistinctPredicatesRule.INSTANCE, - SimplifyComparisonPredicate.INSTANCE, - SimplifyInPredicate.INSTANCE, - SimplifyDecimalV3Comparison.INSTANCE, - SimplifyRange.INSTANCE, - DateFunctionRewrite.INSTANCE, - OrToIn.INSTANCE, - ArrayContainToArrayOverlap.INSTANCE, - CaseWhenToIf.INSTANCE, - TopnToMax.INSTANCE, - NullSafeEqualToEqual.INSTANCE + bottomUp( + ExtractCommonFactorRule.INSTANCE, + DistinctPredicatesRule.INSTANCE, + SimplifyComparisonPredicate.INSTANCE, + SimplifyInPredicate.INSTANCE, + SimplifyDecimalV3Comparison.INSTANCE, + OrToIn.INSTANCE, + SimplifyRange.INSTANCE, + DateFunctionRewrite.INSTANCE, + ArrayContainToArrayOverlap.INSTANCE, + CaseWhenToIf.INSTANCE, + TopnToMax.INSTANCE, + NullSafeEqualToEqual.INSTANCE + ) ); private static final ExpressionRuleExecutor EXECUTOR = new ExpressionRuleExecutor(OPTIMIZE_REWRITE_RULES); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionPatternMatchRule.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionPatternMatchRule.java new file mode 100644 index 00000000000000..dbf5c79c96d754 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionPatternMatchRule.java @@ -0,0 +1,64 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.expression; + +import org.apache.doris.nereids.pattern.TypeMappings.TypeMapping; +import org.apache.doris.nereids.trees.expressions.Expression; + +import java.util.List; +import java.util.function.Predicate; + +/** ExpressionPatternMatcherRule */ +public class ExpressionPatternMatchRule implements TypeMapping { + public final Class typePattern; + public final List>> predicates; + public final ExpressionMatchingAction matchingAction; + + public ExpressionPatternMatchRule(ExpressionPatternMatcher patternMatcher) { + this.typePattern = patternMatcher.typePattern; + this.predicates = patternMatcher.predicates; + this.matchingAction = patternMatcher.matchingAction; + } + + /** matches */ + public boolean matchesTypeAndPredicates(ExpressionMatchingContext context) { + return typePattern.isInstance(context.expr) && matchesPredicates(context); + } + + /** matchesPredicates */ + public boolean matchesPredicates(ExpressionMatchingContext context) { + if (!predicates.isEmpty()) { + for (Predicate> predicate : predicates) { + if (!predicate.test(context)) { + return false; + } + } + } + return true; + } + + public Expression apply(ExpressionMatchingContext context) { + Expression newResult = matchingAction.apply(context); + return newResult == null ? context.expr : newResult; + } + + @Override + public Class getType() { + return typePattern; + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionPatternMatcher.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionPatternMatcher.java new file mode 100644 index 00000000000000..058b1d60b1d013 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionPatternMatcher.java @@ -0,0 +1,41 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.expression; + +import org.apache.doris.nereids.trees.expressions.Expression; + +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Objects; +import java.util.function.Predicate; + +/** ExpressionPattern */ +public class ExpressionPatternMatcher { + public final Class typePattern; + public final List>> predicates; + public final ExpressionMatchingAction matchingAction; + + public ExpressionPatternMatcher(Class typePattern, + List>> predicates, + ExpressionMatchingAction matchingAction) { + this.typePattern = Objects.requireNonNull(typePattern, "typePattern can not be null"); + this.predicates = predicates == null ? ImmutableList.of() : predicates; + this.matchingAction = Objects.requireNonNull(matchingAction, "matchingAction can not be null"); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionPatternRuleFactory.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionPatternRuleFactory.java new file mode 100644 index 00000000000000..7fb18735ba5e46 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionPatternRuleFactory.java @@ -0,0 +1,84 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.expression; + +import org.apache.doris.nereids.trees.expressions.Expression; + +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Objects; +import java.util.function.Function; +import java.util.function.Predicate; + +/** ExpressionPatternRuleFactory */ +public interface ExpressionPatternRuleFactory { + List> buildRules(); + + default ExpressionPatternDescriptor matchesType(Class clazz) { + return new ExpressionPatternDescriptor<>(clazz); + } + + default ExpressionPatternDescriptor root(Class clazz) { + return new ExpressionPatternDescriptor<>(clazz) + .whenCtx(ctx -> ctx.isRoot()); + } + + default ExpressionPatternDescriptor matchesTopType(Class clazz) { + return new ExpressionPatternDescriptor<>(clazz) + .whenCtx(ctx -> ctx.isRoot() || !clazz.isInstance(ctx.parent.get())); + } + + /** ExpressionPatternDescriptor */ + class ExpressionPatternDescriptor { + private final Class typePattern; + private final ImmutableList>> predicates; + + public ExpressionPatternDescriptor(Class typePattern) { + this(typePattern, ImmutableList.of()); + } + + public ExpressionPatternDescriptor( + Class typePattern, ImmutableList>> predicates) { + this.typePattern = Objects.requireNonNull(typePattern, "typePattern can not be null"); + this.predicates = Objects.requireNonNull(predicates, "predicates can not be null"); + } + + public ExpressionPatternDescriptor when(Predicate predicate) { + return whenCtx(ctx -> predicate.test(ctx.expr)); + } + + public ExpressionPatternDescriptor whenCtx(Predicate> predicate) { + ImmutableList.Builder>> newPredicates + = ImmutableList.builderWithExpectedSize(predicates.size() + 1); + newPredicates.addAll(predicates); + newPredicates.add(predicate); + return new ExpressionPatternDescriptor<>(typePattern, newPredicates.build()); + } + + /** then */ + public ExpressionPatternMatcher then(Function rewriter) { + return new ExpressionPatternMatcher<>( + typePattern, predicates, (context) -> rewriter.apply(context.expr)); + } + + public ExpressionPatternMatcher thenApply(ExpressionMatchingAction action) { + return new ExpressionPatternMatcher<>(typePattern, predicates, action); + } + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRewrite.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRewrite.java index 912793e61d1b2e..b547f693a7cdb3 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRewrite.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRewrite.java @@ -18,6 +18,8 @@ package org.apache.doris.nereids.rules.expression; import org.apache.doris.common.Pair; +import org.apache.doris.nereids.pattern.ExpressionPatternRules; +import org.apache.doris.nereids.pattern.ExpressionPatternTraverseListeners; import org.apache.doris.nereids.properties.OrderKey; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; @@ -41,7 +43,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; -import java.util.ArrayList; +import java.util.Collection; import java.util.List; import java.util.Objects; import java.util.Set; @@ -123,9 +125,7 @@ public Rule build() { LogicalProject project = ctx.root; ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext); List projects = project.getProjects(); - List newProjects = projects.stream() - .map(expr -> (NamedExpression) rewriter.rewrite(expr, context)) - .collect(ImmutableList.toImmutableList()); + List newProjects = rewriteAll(projects, rewriter, context); if (projects.equals(newProjects)) { return project; } @@ -160,9 +160,7 @@ public Rule build() { List newGroupByExprs = rewriter.rewrite(groupByExprs, context); List outputExpressions = agg.getOutputExpressions(); - List newOutputExpressions = outputExpressions.stream() - .map(expr -> (NamedExpression) rewriter.rewrite(expr, context)) - .collect(ImmutableList.toImmutableList()); + List newOutputExpressions = rewriteAll(outputExpressions, rewriter, context); if (outputExpressions.equals(newOutputExpressions)) { return agg; } @@ -222,13 +220,16 @@ public Rule build() { return logicalSort().thenApply(ctx -> { LogicalSort sort = ctx.root; List orderKeys = sort.getOrderKeys(); - List rewrittenOrderKeys = new ArrayList<>(); + ImmutableList.Builder rewrittenOrderKeys + = ImmutableList.builderWithExpectedSize(orderKeys.size()); ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext); + boolean changed = false; for (OrderKey k : orderKeys) { Expression expression = rewriter.rewrite(k.getExpr(), context); + changed |= expression != k.getExpr(); rewrittenOrderKeys.add(new OrderKey(expression, k.isAsc(), k.isNullFirst())); } - return sort.withOrderKeys(rewrittenOrderKeys); + return changed ? sort.withOrderKeys(rewrittenOrderKeys.build()) : sort; }).toRule(RuleType.REWRITE_SORT_EXPRESSION); } } @@ -270,4 +271,36 @@ public Rule build() { }).toRule(RuleType.REWRITE_REPEAT_EXPRESSION); } } + + /** bottomUp */ + public static ExpressionBottomUpRewriter bottomUp(ExpressionPatternRuleFactory... ruleFactories) { + ImmutableList.Builder rules = ImmutableList.builder(); + ImmutableList.Builder listeners = ImmutableList.builder(); + for (ExpressionPatternRuleFactory ruleFactory : ruleFactories) { + if (ruleFactory instanceof ExpressionTraverseListenerFactory) { + List> listenersMatcher + = ((ExpressionTraverseListenerFactory) ruleFactory).buildListeners(); + for (ExpressionListenerMatcher listenerMatcher : listenersMatcher) { + listeners.add(new ExpressionTraverseListenerMapping(listenerMatcher)); + } + } + for (ExpressionPatternMatcher patternMatcher : ruleFactory.buildRules()) { + rules.add(new ExpressionPatternMatchRule(patternMatcher)); + } + } + + return new ExpressionBottomUpRewriter( + new ExpressionPatternRules(rules.build()), + new ExpressionPatternTraverseListeners(listeners.build()) + ); + } + + public static List rewriteAll( + Collection exprs, ExpressionRuleExecutor rewriter, ExpressionRewriteContext context) { + ImmutableList.Builder result = ImmutableList.builderWithExpectedSize(exprs.size()); + for (E expr : exprs) { + result.add((E) rewriter.rewrite(expr, context)); + } + return result.build(); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteContext.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteContext.java index cb50e0d2871e3b..35633e7594f717 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteContext.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteContext.java @@ -19,6 +19,8 @@ import org.apache.doris.nereids.CascadesContext; +import java.util.Objects; + /** * expression rewrite context. */ @@ -27,7 +29,7 @@ public class ExpressionRewriteContext { public final CascadesContext cascadesContext; public ExpressionRewriteContext(CascadesContext cascadesContext) { - this.cascadesContext = cascadesContext; + this.cascadesContext = Objects.requireNonNull(cascadesContext, "cascadesContext can not be null"); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRuleExecutor.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRuleExecutor.java index ac7e6dae6b282d..0f951448dd2582 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRuleExecutor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRuleExecutor.java @@ -18,6 +18,7 @@ package org.apache.doris.nereids.rules.expression; import org.apache.doris.nereids.rules.expression.rules.NormalizeBinaryPredicatesRule; +import org.apache.doris.nereids.trees.expressions.ComparisonPredicate; import org.apache.doris.nereids.trees.expressions.Expression; import com.google.common.collect.ImmutableList; @@ -36,7 +37,11 @@ public ExpressionRuleExecutor(List rules) { } public List rewrite(List exprs, ExpressionRewriteContext ctx) { - return exprs.stream().map(expr -> rewrite(expr, ctx)).collect(ImmutableList.toImmutableList()); + ImmutableList.Builder result = ImmutableList.builderWithExpectedSize(exprs.size()); + for (Expression expr : exprs) { + result.add(rewrite(expr, ctx)); + } + return result.build(); } /** @@ -61,8 +66,15 @@ private Expression applyRule(Expression expr, ExpressionRewriteRule rule, Expres return rule.rewrite(expr, ctx); } + /** normalize */ public static Expression normalize(Expression expression) { - return NormalizeBinaryPredicatesRule.INSTANCE.rewrite(expression, null); + return expression.rewriteUp(expr -> { + if (expr instanceof ComparisonPredicate) { + return NormalizeBinaryPredicatesRule.normalize((ComparisonPredicate) expression); + } else { + return expr; + } + }); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionTraverseListener.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionTraverseListener.java new file mode 100644 index 00000000000000..5df5a6d68185dd --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionTraverseListener.java @@ -0,0 +1,31 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.expression; + +import org.apache.doris.nereids.trees.expressions.Expression; + +/** ExpressionTraverseListener */ +public interface ExpressionTraverseListener { + default void onEnter(ExpressionMatchingContext context) {} + + default void onExit(ExpressionMatchingContext context, Expression rewritten) {} + + default ExpressionTraverseListener as() { + return (ExpressionTraverseListener) this; + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionTraverseListenerFactory.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionTraverseListenerFactory.java new file mode 100644 index 00000000000000..201362fed781b6 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionTraverseListenerFactory.java @@ -0,0 +1,79 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.expression; + +import org.apache.doris.nereids.trees.expressions.Expression; + +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Objects; +import java.util.function.Predicate; + +/** ExpressionTraverseListenerFactory */ +public interface ExpressionTraverseListenerFactory { + List> buildListeners(); + + default ListenerDescriptor listenerType(Class clazz) { + return new ListenerDescriptor<>(clazz); + } + + /** listenerTypes */ + default List> listenerTypes(Class... classes) { + ImmutableList.Builder> listeners + = ImmutableList.builderWithExpectedSize(classes.length); + for (Class clazz : classes) { + listeners.add((ListenerDescriptor) listenerType(clazz)); + } + return listeners.build(); + } + + /** ListenerDescriptor */ + class ListenerDescriptor { + + private final Class typePattern; + private final ImmutableList>> predicates; + + public ListenerDescriptor(Class typePattern) { + this(typePattern, ImmutableList.of()); + } + + public ListenerDescriptor( + Class typePattern, ImmutableList>> predicates) { + this.typePattern = Objects.requireNonNull(typePattern, "typePattern can not be null"); + this.predicates = Objects.requireNonNull(predicates, "predicates can not be null"); + } + + public ListenerDescriptor when(Predicate predicate) { + return whenCtx(ctx -> predicate.test(ctx.expr)); + } + + public ListenerDescriptor whenCtx(Predicate> predicate) { + ImmutableList.Builder>> newPredicates + = ImmutableList.builderWithExpectedSize(predicates.size() + 1); + newPredicates.addAll(predicates); + newPredicates.add(predicate); + return new ListenerDescriptor<>(typePattern, newPredicates.build()); + } + + /** then */ + public ExpressionListenerMatcher then(ExpressionTraverseListener listener) { + return new ExpressionListenerMatcher<>(typePattern, predicates, listener); + } + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionTraverseListenerMapping.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionTraverseListenerMapping.java new file mode 100644 index 00000000000000..d99c231110f175 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionTraverseListenerMapping.java @@ -0,0 +1,59 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.expression; + +import org.apache.doris.nereids.pattern.TypeMappings.TypeMapping; +import org.apache.doris.nereids.trees.expressions.Expression; + +import java.util.List; +import java.util.function.Predicate; + +/** ExpressionTraverseListener */ +public class ExpressionTraverseListenerMapping implements TypeMapping { + public final Class typePattern; + public final List>> predicates; + public final ExpressionTraverseListener listener; + + public ExpressionTraverseListenerMapping(ExpressionListenerMatcher listenerMatcher) { + this.typePattern = listenerMatcher.typePattern; + this.predicates = listenerMatcher.predicates; + this.listener = listenerMatcher.listener; + } + + @Override + public Class getType() { + return typePattern; + } + + /** matches */ + public boolean matchesTypeAndPredicates(ExpressionMatchingContext context) { + return typePattern.isInstance(context.expr) && matchesPredicates(context); + } + + /** matchesPredicates */ + public boolean matchesPredicates(ExpressionMatchingContext context) { + if (!predicates.isEmpty()) { + for (Predicate> predicate : predicates) { + if (!predicate.test(context)) { + return false; + } + } + } + return true; + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/check/CheckCast.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/check/CheckCast.java index d7a6085dcab550..69a9105d653d81 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/check/CheckCast.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/check/CheckCast.java @@ -18,8 +18,8 @@ package org.apache.doris.nereids.rules.expression.check; import org.apache.doris.nereids.exceptions.AnalysisException; -import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule; -import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; +import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; +import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.types.ArrayType; @@ -31,18 +31,24 @@ import org.apache.doris.nereids.types.coercion.CharacterType; import org.apache.doris.nereids.types.coercion.PrimitiveType; +import com.google.common.collect.ImmutableList; + import java.util.List; /** * check cast valid */ -public class CheckCast extends AbstractExpressionRewriteRule { - - public static final CheckCast INSTANCE = new CheckCast(); +public class CheckCast implements ExpressionPatternRuleFactory { + public static CheckCast INSTANCE = new CheckCast(); @Override - public Expression visitCast(Cast cast, ExpressionRewriteContext context) { - rewrite(cast.child(), context); + public List> buildRules() { + return ImmutableList.of( + matchesType(Cast.class).then(CheckCast::check) + ); + } + + private static Expression check(Cast cast) { DataType originalType = cast.child().getDataType(); DataType targetType = cast.getDataType(); if (!check(originalType, targetType)) { @@ -51,7 +57,7 @@ public Expression visitCast(Cast cast, ExpressionRewriteContext context) { return cast; } - private boolean check(DataType originalType, DataType targetType) { + private static boolean check(DataType originalType, DataType targetType) { if (originalType.isVariantType() && (targetType instanceof PrimitiveType || targetType.isArrayType())) { // variant could cast to primitive types and array return true; @@ -99,7 +105,7 @@ private boolean check(DataType originalType, DataType targetType) { * 3. original type is same with target type * 4. target type is null type */ - private boolean checkPrimitiveType(DataType originalType, DataType targetType) { + private static boolean checkPrimitiveType(DataType originalType, DataType targetType) { if (!originalType.isPrimitive() || !targetType.isPrimitive()) { return false; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/ArrayContainToArrayOverlap.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/ArrayContainToArrayOverlap.java index 7309ef111c925d..f32d76062aaf7c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/ArrayContainToArrayOverlap.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/ArrayContainToArrayOverlap.java @@ -17,26 +17,29 @@ package org.apache.doris.nereids.rules.expression.rules; -import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; -import org.apache.doris.nereids.rules.expression.ExpressionRewriteRule; +import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; +import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.Or; import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayContains; import org.apache.doris.nereids.trees.expressions.functions.scalar.ArraysOverlap; import org.apache.doris.nereids.trees.expressions.literal.ArrayLiteral; import org.apache.doris.nereids.trees.expressions.literal.Literal; -import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter; import org.apache.doris.nereids.util.ExpressionUtils; +import org.apache.doris.nereids.util.Utils; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList.Builder; +import com.google.common.collect.Lists; +import com.google.common.collect.Multimaps; +import com.google.common.collect.SetMultimap; -import java.util.HashMap; -import java.util.HashSet; +import java.util.Collection; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; import java.util.List; -import java.util.Map; +import java.util.Map.Entry; import java.util.Set; -import java.util.stream.Collectors; /** * array_contains ( c_array, '1' ) @@ -44,56 +47,73 @@ * =========================================> * array_overlap(c_array, ['1', '2']) */ -public class ArrayContainToArrayOverlap extends DefaultExpressionRewriter implements - ExpressionRewriteRule { +public class ArrayContainToArrayOverlap implements ExpressionPatternRuleFactory { public static final ArrayContainToArrayOverlap INSTANCE = new ArrayContainToArrayOverlap(); private static final int REWRITE_PREDICATE_THRESHOLD = 2; @Override - public Expression rewrite(Expression expr, ExpressionRewriteContext ctx) { - return expr.accept(this, ctx); + public List> buildRules() { + return ImmutableList.of( + matchesTopType(Or.class).then(ArrayContainToArrayOverlap::rewrite) + ); } - @Override - public Expression visitOr(Or or, ExpressionRewriteContext ctx) { + private static Expression rewrite(Or or) { List disjuncts = ExpressionUtils.extractDisjunction(or); - Map> containFuncAndOtherFunc = disjuncts.stream() - .collect(Collectors.partitioningBy(this::isValidArrayContains)); - Map> containLiteralSet = new HashMap<>(); - List contains = containFuncAndOtherFunc.get(true); - List others = containFuncAndOtherFunc.get(false); - contains.forEach(containFunc -> - containLiteralSet.computeIfAbsent(containFunc.child(0), k -> new HashSet<>()) - .add((Literal) containFunc.child(1))); + List contains = Lists.newArrayList(); + List others = Lists.newArrayList(); + for (Expression expr : disjuncts) { + if (ArrayContainToArrayOverlap.isValidArrayContains(expr)) { + contains.add(expr); + } else { + others.add(expr); + } + } + + if (contains.size() <= 1) { + return or; + } + + SetMultimap containLiteralSet = Multimaps.newSetMultimap( + new LinkedHashMap<>(), LinkedHashSet::new + ); + for (Expression contain : contains) { + containLiteralSet.put(contain.child(0), (Literal) contain.child(1)); + } Builder newDisjunctsBuilder = new ImmutableList.Builder<>(); - containLiteralSet.forEach((left, literalSet) -> { + for (Entry> kv : containLiteralSet.asMap().entrySet()) { + Expression left = kv.getKey(); + Collection literalSet = kv.getValue(); if (literalSet.size() > REWRITE_PREDICATE_THRESHOLD) { newDisjunctsBuilder.add( - new ArraysOverlap(left, - new ArrayLiteral(ImmutableList.copyOf(literalSet)))); + new ArraysOverlap(left, new ArrayLiteral(Utils.fastToImmutableList(literalSet))) + ); + } + } + + for (Expression contain : contains) { + if (!canCovertToArrayOverlap(contain, containLiteralSet)) { + newDisjunctsBuilder.add(contain); } - }); - - contains.stream() - .filter(e -> !canCovertToArrayOverlap(e, containLiteralSet)) - .forEach(newDisjunctsBuilder::add); - others.stream() - .map(e -> e.accept(this, null)) - .forEach(newDisjunctsBuilder::add); + } + newDisjunctsBuilder.addAll(others); return ExpressionUtils.or(newDisjunctsBuilder.build()); } - private boolean isValidArrayContains(Expression expression) { + private static boolean isValidArrayContains(Expression expression) { return expression instanceof ArrayContains && expression.child(1) instanceof Literal; } - private boolean canCovertToArrayOverlap(Expression expression, Map> containLiteralSet) { - return expression instanceof ArrayContains - && containLiteralSet.getOrDefault(expression.child(0), - new HashSet<>()).size() > REWRITE_PREDICATE_THRESHOLD; + private static boolean canCovertToArrayOverlap( + Expression expression, SetMultimap containLiteralSet) { + if (!(expression instanceof ArrayContains)) { + return false; + } + Set containLiteral = containLiteralSet.get(expression.child(0)); + return containLiteral.size() > REWRITE_PREDICATE_THRESHOLD; } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/CaseWhenToIf.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/CaseWhenToIf.java index 6372338406dd1d..cafb0ecd068ddd 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/CaseWhenToIf.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/CaseWhenToIf.java @@ -17,25 +17,35 @@ package org.apache.doris.nereids.rules.expression.rules; -import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule; -import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; +import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; +import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; import org.apache.doris.nereids.trees.expressions.CaseWhen; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.WhenClause; import org.apache.doris.nereids.trees.expressions.functions.scalar.If; import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; +import com.google.common.collect.ImmutableList; + +import java.util.List; + /** * Rewrite rule to convert CASE WHEN to IF. * For example: * CASE WHEN a > 1 THEN 1 ELSE 0 END -> IF(a > 1, 1, 0) */ -public class CaseWhenToIf extends AbstractExpressionRewriteRule { +public class CaseWhenToIf implements ExpressionPatternRuleFactory { public static CaseWhenToIf INSTANCE = new CaseWhenToIf(); @Override - public Expression visitCaseWhen(CaseWhen caseWhen, ExpressionRewriteContext context) { + public List> buildRules() { + return ImmutableList.of( + matchesTopType(CaseWhen.class).then(CaseWhenToIf::rewrite) + ); + } + + private static Expression rewrite(CaseWhen caseWhen) { Expression expr = caseWhen; if (caseWhen.getWhenClauses().size() == 1) { WhenClause whenClause = caseWhen.getWhenClauses().get(0); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/ConvertAggStateCast.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/ConvertAggStateCast.java index e5748eb1d59e2c..239007015531eb 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/ConvertAggStateCast.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/ConvertAggStateCast.java @@ -17,8 +17,8 @@ package org.apache.doris.nereids.rules.expression.rules; -import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule; -import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; +import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; +import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.functions.combinator.StateCombinator; @@ -30,29 +30,30 @@ import com.google.common.collect.ImmutableList; +import java.util.List; + /** * Follow legacy planner cast agg_state combinator's children if we need cast it to another agg_state type when insert */ -public class ConvertAggStateCast extends AbstractExpressionRewriteRule { +public class ConvertAggStateCast implements ExpressionPatternRuleFactory { public static ConvertAggStateCast INSTANCE = new ConvertAggStateCast(); @Override - public Expression visitCast(Cast cast, ExpressionRewriteContext context) { + public List> buildRules() { + return ImmutableList.of( + matchesTopType(Cast.class).then(ConvertAggStateCast::convert) + ); + } + + private static Expression convert(Cast cast) { Expression child = cast.child(); DataType originalType = child.getDataType(); DataType targetType = cast.getDataType(); if (originalType instanceof AggStateType && targetType instanceof AggStateType && child instanceof StateCombinator) { - AggStateType original = (AggStateType) originalType; AggStateType target = (AggStateType) targetType; - if (original.getSubTypes().size() != target.getSubTypes().size()) { - return processCastChild(cast, context); - } - if (!original.getFunctionName().equalsIgnoreCase(target.getFunctionName())) { - return processCastChild(cast, context); - } ImmutableList.Builder newChildren = ImmutableList.builderWithExpectedSize(child.arity()); for (int i = 0; i < child.arity(); i++) { Expression newChild = TypeCoercionUtils.castIfNotSameType(child.child(i), target.getSubTypes().get(i)); @@ -66,15 +67,7 @@ public Expression visitCast(Cast cast, ExpressionRewriteContext context) { newChildren.add(newChild); } child = child.withChildren(newChildren.build()); - return processCastChild(cast.withChildren(ImmutableList.of(child)), context); - } - return processCastChild(cast, context); - } - - private Expression processCastChild(Cast cast, ExpressionRewriteContext context) { - Expression child = visit(cast.child(), context); - if (child != cast.child()) { - cast = cast.withChildren(ImmutableList.of(child)); + return cast.withChildren(ImmutableList.of(child)); } return cast; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/DateFunctionRewrite.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/DateFunctionRewrite.java index e78eeecff0d105..07ec0c3de71d24 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/DateFunctionRewrite.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/DateFunctionRewrite.java @@ -17,8 +17,8 @@ package org.apache.doris.nereids.rules.expression.rules; -import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule; -import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; +import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; +import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; import org.apache.doris.nereids.trees.expressions.And; import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.Expression; @@ -34,17 +34,31 @@ import org.apache.doris.nereids.types.DateTimeType; import org.apache.doris.nereids.types.DateTimeV2Type; +import com.google.common.collect.ImmutableList; + +import java.util.List; + /** * F: a DateTime or DateTimeV2 column * Date(F) > 2020-01-01 => F > 2020-01-02 00:00:00 * Date(F) >= 2020-01-01 => F > 2020-01-01 00:00:00 * */ -public class DateFunctionRewrite extends AbstractExpressionRewriteRule { +public class DateFunctionRewrite implements ExpressionPatternRuleFactory { public static DateFunctionRewrite INSTANCE = new DateFunctionRewrite(); @Override - public Expression visitEqualTo(EqualTo equalTo, ExpressionRewriteContext context) { + public List> buildRules() { + return ImmutableList.of( + matchesType(EqualTo.class).then(DateFunctionRewrite::rewriteEqualTo), + matchesType(GreaterThan.class).then(DateFunctionRewrite::rewriteGreaterThan), + matchesType(GreaterThanEqual.class).then(DateFunctionRewrite::rewriteGreaterThanEqual), + matchesType(LessThan.class).then(DateFunctionRewrite::rewriteLessThan), + matchesType(LessThanEqual.class).then(DateFunctionRewrite::rewriteLessThanEqual) + ); + } + + private static Expression rewriteEqualTo(EqualTo equalTo) { if (equalTo.left() instanceof Date) { // V1 if (equalTo.left().child(0).getDataType() instanceof DateTimeType @@ -70,8 +84,7 @@ public Expression visitEqualTo(EqualTo equalTo, ExpressionRewriteContext context return equalTo; } - @Override - public Expression visitGreaterThan(GreaterThan greaterThan, ExpressionRewriteContext context) { + private static Expression rewriteGreaterThan(GreaterThan greaterThan) { if (greaterThan.left() instanceof Date) { // V1 if (greaterThan.left().child(0).getDataType() instanceof DateTimeType @@ -91,8 +104,7 @@ public Expression visitGreaterThan(GreaterThan greaterThan, ExpressionRewriteCon return greaterThan; } - @Override - public Expression visitGreaterThanEqual(GreaterThanEqual greaterThanEqual, ExpressionRewriteContext context) { + private static Expression rewriteGreaterThanEqual(GreaterThanEqual greaterThanEqual) { if (greaterThanEqual.left() instanceof Date) { // V1 if (greaterThanEqual.left().child(0).getDataType() instanceof DateTimeType @@ -111,8 +123,7 @@ public Expression visitGreaterThanEqual(GreaterThanEqual greaterThanEqual, Expre return greaterThanEqual; } - @Override - public Expression visitLessThan(LessThan lessThan, ExpressionRewriteContext context) { + private static Expression rewriteLessThan(LessThan lessThan) { if (lessThan.left() instanceof Date) { // V1 if (lessThan.left().child(0).getDataType() instanceof DateTimeType @@ -131,8 +142,7 @@ public Expression visitLessThan(LessThan lessThan, ExpressionRewriteContext cont return lessThan; } - @Override - public Expression visitLessThanEqual(LessThanEqual lessThanEqual, ExpressionRewriteContext context) { + private static Expression rewriteLessThanEqual(LessThanEqual lessThanEqual) { if (lessThanEqual.left() instanceof Date) { // V1 if (lessThanEqual.left().child(0).getDataType() instanceof DateTimeType diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/DigitalMaskingConvert.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/DigitalMaskingConvert.java index 5e38c0390b6c93..95d25e3c592454 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/DigitalMaskingConvert.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/DigitalMaskingConvert.java @@ -17,8 +17,8 @@ package org.apache.doris.nereids.rules.expression.rules; -import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule; -import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; +import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; +import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.functions.scalar.Concat; import org.apache.doris.nereids.trees.expressions.functions.scalar.DigitalMasking; @@ -26,16 +26,25 @@ import org.apache.doris.nereids.trees.expressions.functions.scalar.Right; import org.apache.doris.nereids.trees.expressions.literal.Literal; +import com.google.common.collect.ImmutableList; + +import java.util.List; + /** * Convert DigitalMasking to Concat */ -public class DigitalMaskingConvert extends AbstractExpressionRewriteRule { - +public class DigitalMaskingConvert implements ExpressionPatternRuleFactory { public static DigitalMaskingConvert INSTANCE = new DigitalMaskingConvert(); @Override - public Expression visitDigitalMasking(DigitalMasking digitalMasking, ExpressionRewriteContext context) { - return new Concat(new Left(digitalMasking.child(), Literal.of(3)), Literal.of("****"), - new Right(digitalMasking.child(), Literal.of(4))); + public List> buildRules() { + return ImmutableList.of( + matchesType(DigitalMasking.class).then(digitalMasking -> + new Concat( + new Left(digitalMasking.child(), Literal.of(3)), + Literal.of("****"), + new Right(digitalMasking.child(), Literal.of(4))) + ) + ); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/DistinctPredicatesRule.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/DistinctPredicatesRule.java index a3466d395d56e0..cf18886cd85fc3 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/DistinctPredicatesRule.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/DistinctPredicatesRule.java @@ -17,12 +17,13 @@ package org.apache.doris.nereids.rules.expression.rules; -import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule; -import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; +import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; +import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; import org.apache.doris.nereids.trees.expressions.CompoundPredicate; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.util.ExpressionUtils; +import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; import java.util.LinkedHashSet; @@ -35,16 +36,21 @@ * transform (a = 1) and (b > 2) and (a = 1) to (a = 1) and (b > 2) * transform (a = 1) or (a = 1) to (a = 1) */ -public class DistinctPredicatesRule extends AbstractExpressionRewriteRule { - +public class DistinctPredicatesRule implements ExpressionPatternRuleFactory { public static final DistinctPredicatesRule INSTANCE = new DistinctPredicatesRule(); @Override - public Expression visitCompoundPredicate(CompoundPredicate expr, ExpressionRewriteContext context) { + public List> buildRules() { + return ImmutableList.of( + matchesTopType(CompoundPredicate.class).then(DistinctPredicatesRule::distinct) + ); + } + + private static Expression distinct(CompoundPredicate expr) { List extractExpressions = ExpressionUtils.extract(expr); Set distinctExpressions = new LinkedHashSet<>(extractExpressions); if (distinctExpressions.size() != extractExpressions.size()) { - return ExpressionUtils.combine(expr.getClass(), Lists.newArrayList(distinctExpressions)); + return ExpressionUtils.combineAsLeftDeepTree(expr.getClass(), Lists.newArrayList(distinctExpressions)); } return expr; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/ExtractCommonFactorRule.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/ExtractCommonFactorRule.java index dd457e01d8d9cc..4032db4aadf550 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/ExtractCommonFactorRule.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/ExtractCommonFactorRule.java @@ -18,21 +18,28 @@ package org.apache.doris.nereids.rules.expression.rules; import org.apache.doris.nereids.annotation.Developing; -import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule; -import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; +import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; +import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; import org.apache.doris.nereids.trees.expressions.CompoundPredicate; import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral; import org.apache.doris.nereids.util.ExpressionUtils; +import org.apache.doris.nereids.util.Utils; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import com.google.common.collect.Multimaps; +import com.google.common.collect.SetMultimap; import com.google.common.collect.Sets; -import java.util.Collections; -import java.util.HashSet; +import java.util.Collection; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; import java.util.List; +import java.util.Map.Entry; import java.util.Set; -import java.util.stream.Collectors; /** * Extract common expr for `CompoundPredicate`. @@ -41,42 +48,197 @@ * transform (a and b) or (a and c) to a and (b or c) */ @Developing -public class ExtractCommonFactorRule extends AbstractExpressionRewriteRule { - +public class ExtractCommonFactorRule implements ExpressionPatternRuleFactory { public static final ExtractCommonFactorRule INSTANCE = new ExtractCommonFactorRule(); @Override - public Expression visitCompoundPredicate(CompoundPredicate expr, ExpressionRewriteContext context) { + public List> buildRules() { + return ImmutableList.of( + matchesTopType(CompoundPredicate.class).then(ExtractCommonFactorRule::extractCommonFactor) + ); + } + + private static Expression extractCommonFactor(CompoundPredicate originExpr) { + // fast return + if (!(originExpr.left() instanceof CompoundPredicate || originExpr.left() instanceof BooleanLiteral) + && !(originExpr.right() instanceof CompoundPredicate || originExpr.right() instanceof BooleanLiteral)) { + return originExpr; + } - Expression rewrittenChildren = ExpressionUtils.combine(expr.getClass(), ExpressionUtils.extract(expr).stream() - .map(predicate -> rewrite(predicate, context)).collect(ImmutableList.toImmutableList())); - if (!(rewrittenChildren instanceof CompoundPredicate)) { - return rewrittenChildren; + // flatten same type to a list + // e.g. ((a and (b or c)) and c) -> [a, (b or c), c] + List flatten = ExpressionUtils.extract(originExpr); + + // combine and delete some boolean literal predicate + // e.g. (a and true) -> true + Expression simplified = ExpressionUtils.combineAsLeftDeepTree(originExpr.getClass(), flatten); + if (!(simplified instanceof CompoundPredicate)) { + return simplified; } - CompoundPredicate compoundPredicate = (CompoundPredicate) rewrittenChildren; + // separate two levels CompoundPredicate to partitions + // e.g. ((a and (b or c)) and c) -> [[a], [b, c], c] + CompoundPredicate leftDeapTree = (CompoundPredicate) simplified; + ImmutableSet.Builder> partitionsBuilder + = ImmutableSet.builderWithExpectedSize(flatten.size()); + for (Expression onPartition : ExpressionUtils.extract(leftDeapTree)) { + if (onPartition instanceof CompoundPredicate) { + partitionsBuilder.add(ExpressionUtils.extract((CompoundPredicate) onPartition)); + } else { + partitionsBuilder.add(ImmutableList.of(onPartition)); + } + } + Set> partitions = partitionsBuilder.build(); - List> partitions = ExpressionUtils.extract(compoundPredicate).stream() - .map(predicate -> predicate instanceof CompoundPredicate ? ExpressionUtils.extract( - (CompoundPredicate) predicate) : Lists.newArrayList(predicate)).collect(Collectors.toList()); + Expression result = extractCommonFactors(originExpr, leftDeapTree, Utils.fastToImmutableList(partitions)); + return result; + } - Set commons = partitions.stream() - .>map(HashSet::new) - .reduce(Sets::intersection) - .orElse(Collections.emptySet()); + private static Expression extractCommonFactors(CompoundPredicate originPredicate, + CompoundPredicate leftDeapTreePredicate, List> initPartitions) { + // extract factor and fill into commonFactorToPartIds + // e.g. + // originPredicate: (a and (b and c)) and (b or c) + // leftDeapTreePredicate: ((a and b) and c) and (b or c) + // initPartitions: [[a], [b], [c], [b, c]] + // + // -> commonFactorToPartIds = {a: [0], b: [1, 3], c: [2, 3]}. + // so we can know `b` and `c` is a common factors + SetMultimap commonFactorToPartIds = Multimaps.newSetMultimap( + Maps.newLinkedHashMap(), LinkedHashSet::new + ); + int originExpressionNum = 0; + int partId = 0; + for (List partition : initPartitions) { + for (Expression expression : partition) { + commonFactorToPartIds.put(expression, partId); + originExpressionNum++; + } + partId++; + } - List> uncorrelated = partitions.stream() - .map(predicates -> predicates.stream().filter(p -> !commons.contains(p)).collect(Collectors.toList())) - .collect(Collectors.toList()); + // commonFactorToPartIds = {a: [0], b: [1, 3], c: [2, 3]} + // + // -> reverse key value of commonFactorToPartIds and remove intersecting partition + // + // -> 1. reverse: {[0]: [a], [1, 3]: [b], [2, 3]: [c]} + // -> 2. sort by key size desc: {[1, 3]: [b], [2, 3]: [c], [0]: [a]} + // -> 3. remove intersection partition: {[1, 3]: [b], [2]: [c], [0]: [a]}, + // because first part and second part intersect by partition 3 + LinkedHashMap, Set> commonFactorPartitions + = partitionByMostCommonFactors(commonFactorToPartIds); - Expression combineUncorrelated = ExpressionUtils.combine(compoundPredicate.getClass(), - uncorrelated.stream() - .map(predicates -> ExpressionUtils.combine(compoundPredicate.flipType(), predicates)) - .collect(Collectors.toList())); + int extractedExpressionNum = 0; + for (Set exprs : commonFactorPartitions.values()) { + extractedExpressionNum += exprs.size(); + } + + // no any common factor + if (commonFactorPartitions.entrySet().iterator().next().getKey().size() <= 1 + && !(originPredicate.getWidth() > leftDeapTreePredicate.getWidth()) + && originExpressionNum <= extractedExpressionNum) { + // this condition is important because it can avoid deap loop: + // origin originExpr: A = 1 and (B > 0 and B < 10) + // after ExtractCommonFactorRule: (A = 1 and B > 0) and (B < 10) (left deap tree) + // after SimplifyRange: A = 1 and (B > 0 and B < 10) (right deap tree) + return originPredicate; + } + + // now we can do extract common factors for each part: + // originPredicate: (a and (b and c)) and (b or c) + // leftDeapTreePredicate: ((a and b) and c) and (b or c) + // initPartitions: [[a], [b], [c], [b, c]] + // commonFactorPartitions: {[1, 3]: [b], [0]: [a]} + // + // -> extractedExprs: [ + // b or (false and c) = b, + // a, + // c + // ] + // + // -> result: (b or c) and a and c + ImmutableList.Builder extractedExprs + = ImmutableList.builderWithExpectedSize(commonFactorPartitions.size()); + for (Entry, Set> kv : commonFactorPartitions.entrySet()) { + Expression extracted = doExtractCommonFactors( + leftDeapTreePredicate, initPartitions, kv.getKey(), kv.getValue() + ); + extractedExprs.add(extracted); + } + + // combine and eliminate some boolean literal predicate + return ExpressionUtils.combineAsLeftDeepTree(leftDeapTreePredicate.getClass(), extractedExprs.build()); + } - List finalCompound = Lists.newArrayList(commons); - finalCompound.add(combineUncorrelated); + private static Expression doExtractCommonFactors( + CompoundPredicate originPredicate, + List> partitions, Set partitionIds, Set commonFactors) { + ImmutableList.Builder uncorrelatedExprPartitionsBuilder + = ImmutableList.builderWithExpectedSize(partitionIds.size()); + for (Integer partitionId : partitionIds) { + List partition = partitions.get(partitionId); + ImmutableSet.Builder uncorrelatedBuilder + = ImmutableSet.builderWithExpectedSize(partition.size()); + for (Expression exprOfPart : partition) { + if (!commonFactors.contains(exprOfPart)) { + uncorrelatedBuilder.add(exprOfPart); + } + } + + Set uncorrelated = uncorrelatedBuilder.build(); + Expression partitionWithoutCommonFactor + = ExpressionUtils.combineAsLeftDeepTree(originPredicate.flipType(), uncorrelated); + if (partitionWithoutCommonFactor instanceof CompoundPredicate) { + partitionWithoutCommonFactor = extractCommonFactor((CompoundPredicate) partitionWithoutCommonFactor); + } + uncorrelatedExprPartitionsBuilder.add(partitionWithoutCommonFactor); + } + + ImmutableList uncorrelatedExprPartitions = uncorrelatedExprPartitionsBuilder.build(); + ImmutableList.Builder allExprs = ImmutableList.builderWithExpectedSize(commonFactors.size() + 1); + allExprs.addAll(commonFactors); + + Expression combineUncorrelatedExpr = ExpressionUtils.combineAsLeftDeepTree( + originPredicate.getClass(), uncorrelatedExprPartitions); + allExprs.add(combineUncorrelatedExpr); + + Expression result = ExpressionUtils.combineAsLeftDeepTree(originPredicate.flipType(), allExprs.build()); + return result; + } + + private static LinkedHashMap, Set> partitionByMostCommonFactors( + SetMultimap commonFactorToPartIds) { + SetMultimap, Expression> partWithCommonFactors = Multimaps.newSetMultimap( + Maps.newLinkedHashMap(), LinkedHashSet::new + ); + + for (Entry> factorToId : commonFactorToPartIds.asMap().entrySet()) { + partWithCommonFactors.put((Set) factorToId.getValue(), factorToId.getKey()); + } + + List> sortedPartitionIdHasCommonFactor = Lists.newArrayList(partWithCommonFactors.keySet()); + // place the most common factor at the head of this list + sortedPartitionIdHasCommonFactor.sort((p1, p2) -> p2.size() - p1.size()); + + LinkedHashMap, Set> shouldExtractFactors = Maps.newLinkedHashMap(); + + Set allocatedPartitions = Sets.newLinkedHashSet(); + for (Set originMostCommonFactorPartitions : sortedPartitionIdHasCommonFactor) { + ImmutableSet.Builder notAllocatePartitions = ImmutableSet.builderWithExpectedSize( + originMostCommonFactorPartitions.size()); + for (Integer partId : originMostCommonFactorPartitions) { + if (allocatedPartitions.add(partId)) { + notAllocatePartitions.add(partId); + } + } + + Set mostCommonFactorPartitions = notAllocatePartitions.build(); + if (!mostCommonFactorPartitions.isEmpty()) { + Set commonFactors = partWithCommonFactors.get(originMostCommonFactorPartitions); + shouldExtractFactors.put(mostCommonFactorPartitions, commonFactors); + } + } - return ExpressionUtils.combine(compoundPredicate.flipType(), finalCompound); + return shouldExtractFactors; } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRule.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRule.java index c801f749ee0ef0..04acb91d9e2d39 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRule.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRule.java @@ -17,24 +17,46 @@ package org.apache.doris.nereids.rules.expression.rules; -import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule; +import org.apache.doris.nereids.rules.expression.ExpressionBottomUpRewriter; +import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; +import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; +import org.apache.doris.nereids.rules.expression.ExpressionRewrite; import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; import org.apache.doris.nereids.trees.expressions.Expression; +import com.google.common.collect.ImmutableList; + +import java.util.List; + /** * Constant evaluation of an expression. */ -public class FoldConstantRule extends AbstractExpressionRewriteRule { +public class FoldConstantRule implements ExpressionPatternRuleFactory { public static final FoldConstantRule INSTANCE = new FoldConstantRule(); + private static final ExpressionBottomUpRewriter FULL_FOLD_REWRITER = ExpressionRewrite.bottomUp( + FoldConstantRuleOnFE.VISITOR_INSTANCE, + FoldConstantRuleOnBE.INSTANCE + ); + + /** evaluate by pattern match */ @Override - public Expression rewrite(Expression expr, ExpressionRewriteContext ctx) { + public List> buildRules() { + return ImmutableList.>builder() + .addAll(FoldConstantRuleOnFE.PATTERN_MATCH_INSTANCE.buildRules()) + .addAll(FoldConstantRuleOnBE.INSTANCE.buildRules()) + .build(); + } + + /** evaluate by visitor */ + public static Expression evaluate(Expression expr, ExpressionRewriteContext ctx) { if (ctx.cascadesContext != null && ctx.cascadesContext.getConnectContext() != null && ctx.cascadesContext.getConnectContext().getSessionVariable().isEnableFoldConstantByBe()) { - return new FoldConstantRuleOnBE().rewrite(expr, ctx); + return FULL_FOLD_REWRITER.rewrite(expr, ctx); + } else { + return FoldConstantRuleOnFE.VISITOR_INSTANCE.rewrite(expr, ctx); } - return FoldConstantRuleOnFE.INSTANCE.rewrite(expr, ctx); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnBE.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnBE.java index 38c6a483c9f777..09e9bbe0b91e37 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnBE.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnBE.java @@ -27,8 +27,9 @@ import org.apache.doris.common.util.DebugUtil; import org.apache.doris.common.util.TimeUtils; import org.apache.doris.nereids.glue.translator.ExpressionTranslator; -import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule; -import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; +import org.apache.doris.nereids.rules.expression.ExpressionMatchingContext; +import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; +import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.Expression; @@ -55,6 +56,7 @@ import org.apache.doris.thrift.TQueryOptions; import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; import com.google.common.collect.Maps; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -73,24 +75,38 @@ /** * Constant evaluation of an expression. */ -public class FoldConstantRuleOnBE extends AbstractExpressionRewriteRule { +public class FoldConstantRuleOnBE implements ExpressionPatternRuleFactory { + + public static final FoldConstantRuleOnBE INSTANCE = new FoldConstantRuleOnBE(); private static final Logger LOG = LogManager.getLogger(FoldConstantRuleOnBE.class); - private final IdGenerator idGenerator = ExprId.createGenerator(); @Override - public Expression rewrite(Expression expression, ExpressionRewriteContext ctx) { - expression = FoldConstantRuleOnFE.INSTANCE.rewrite(expression, ctx); - return foldByBE(expression, ctx); + public List> buildRules() { + return ImmutableList.of( + root(Expression.class) + .whenCtx(FoldConstantRuleOnBE::isEnableFoldByBe) + .thenApply(FoldConstantRuleOnBE::foldByBE) + ); + } + + public static boolean isEnableFoldByBe(ExpressionMatchingContext ctx) { + return ctx.cascadesContext != null + && ctx.cascadesContext.getConnectContext() != null + && ctx.cascadesContext.getConnectContext().getSessionVariable().isEnableFoldConstantByBe(); } - private Expression foldByBE(Expression root, ExpressionRewriteContext context) { + /** foldByBE */ + public static Expression foldByBE(ExpressionMatchingContext context) { + IdGenerator idGenerator = ExprId.createGenerator(); + + Expression root = context.expr; Map constMap = Maps.newHashMap(); Map staleConstTExprMap = Maps.newHashMap(); Expression rootWithoutAlias = root; if (root instanceof Alias) { rootWithoutAlias = ((Alias) root).child(); } - collectConst(rootWithoutAlias, constMap, staleConstTExprMap); + collectConst(rootWithoutAlias, constMap, staleConstTExprMap, idGenerator); if (constMap.isEmpty()) { return root; } @@ -103,7 +119,8 @@ private Expression foldByBE(Expression root, ExpressionRewriteContext context) { return root; } - private Expression replace(Expression root, Map constMap, Map resultMap) { + private static Expression replace( + Expression root, Map constMap, Map resultMap) { for (Entry entry : constMap.entrySet()) { if (entry.getValue().equals(root)) { return resultMap.get(entry.getKey()); @@ -121,7 +138,8 @@ private Expression replace(Expression root, Map constMap, Ma return hasNewChildren ? root.withChildren(newChildren) : root; } - private void collectConst(Expression expr, Map constMap, Map tExprMap) { + private static void collectConst(Expression expr, Map constMap, + Map tExprMap, IdGenerator idGenerator) { if (expr.isConstant()) { // Do not constant fold cast(null as dataType) because we cannot preserve the // cast-to-types and that can lead to query failures, e.g., CTAS @@ -157,13 +175,13 @@ private void collectConst(Expression expr, Map constMap, Map } else { for (int i = 0; i < expr.children().size(); i++) { final Expression child = expr.children().get(i); - collectConst(child, constMap, tExprMap); + collectConst(child, constMap, tExprMap, idGenerator); } } } // if sleep(5) will cause rpc timeout - private boolean skipSleepFunction(Expression expr) { + private static boolean skipSleepFunction(Expression expr) { if (expr instanceof Sleep) { Expression param = expr.child(0); if (param instanceof Cast) { @@ -176,7 +194,7 @@ private boolean skipSleepFunction(Expression expr) { return false; } - private Map evalOnBE(Map> paramMap, + private static Map evalOnBE(Map> paramMap, Map constMap, ConnectContext context) { Map resultMap = new HashMap<>(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnFE.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnFE.java index 05165f6c312c56..cf3d1a88d8cff1 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnFE.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnFE.java @@ -22,7 +22,13 @@ import org.apache.doris.cluster.ClusterNamespace; import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule; +import org.apache.doris.nereids.rules.expression.ExpressionListenerMatcher; +import org.apache.doris.nereids.rules.expression.ExpressionMatchingContext; +import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; +import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; +import org.apache.doris.nereids.rules.expression.ExpressionTraverseListener; +import org.apache.doris.nereids.rules.expression.ExpressionTraverseListenerFactory; import org.apache.doris.nereids.trees.expressions.AggregateExpression; import org.apache.doris.nereids.trees.expressions.And; import org.apache.doris.nereids.trees.expressions.BinaryArithmetic; @@ -80,6 +86,8 @@ import com.google.common.base.Preconditions; import com.google.common.base.Strings; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableList.Builder; import com.google.common.collect.Lists; import org.apache.commons.codec.digest.DigestUtils; @@ -87,13 +95,78 @@ import java.util.List; import java.util.Objects; import java.util.Optional; +import java.util.function.BiFunction; +import java.util.function.Predicate; /** * evaluate an expression on fe. */ -public class FoldConstantRuleOnFE extends AbstractExpressionRewriteRule { +public class FoldConstantRuleOnFE extends AbstractExpressionRewriteRule + implements ExpressionPatternRuleFactory, ExpressionTraverseListenerFactory { - public static final FoldConstantRuleOnFE INSTANCE = new FoldConstantRuleOnFE(); + public static final FoldConstantRuleOnFE VISITOR_INSTANCE = new FoldConstantRuleOnFE(true); + public static final FoldConstantRuleOnFE PATTERN_MATCH_INSTANCE = new FoldConstantRuleOnFE(false); + + // record whether current expression is in an aggregate function with distinct, + // if is, we will skip to fold constant + private static final ListenAggDistinct LISTEN_AGG_DISTINCT = new ListenAggDistinct(); + private static final CheckWhetherUnderAggDistinct NOT_UNDER_AGG_DISTINCT = new CheckWhetherUnderAggDistinct(); + + private final boolean deepRewrite; + + public FoldConstantRuleOnFE(boolean deepRewrite) { + this.deepRewrite = deepRewrite; + } + + public static Expression evaluate(Expression expression, ExpressionRewriteContext expressionRewriteContext) { + return VISITOR_INSTANCE.rewrite(expression, expressionRewriteContext); + } + + @Override + public List> buildListeners() { + return ImmutableList.of( + listenerType(AggregateFunction.class) + .when(AggregateFunction::isDistinct) + .then(LISTEN_AGG_DISTINCT.as()), + + listenerType(AggregateExpression.class) + .when(AggregateExpression::isDistinct) + .then(LISTEN_AGG_DISTINCT.as()) + ); + } + + @Override + public List> buildRules() { + return ImmutableList.of( + matches(EncryptKeyRef.class, this::visitEncryptKeyRef), + matches(EqualTo.class, this::visitEqualTo), + matches(GreaterThan.class, this::visitGreaterThan), + matches(GreaterThanEqual.class, this::visitGreaterThanEqual), + matches(LessThan.class, this::visitLessThan), + matches(LessThanEqual.class, this::visitLessThanEqual), + matches(NullSafeEqual.class, this::visitNullSafeEqual), + matches(Not.class, this::visitNot), + matches(Database.class, this::visitDatabase), + matches(CurrentUser.class, this::visitCurrentUser), + matches(CurrentCatalog.class, this::visitCurrentCatalog), + matches(User.class, this::visitUser), + matches(ConnectionId.class, this::visitConnectionId), + matches(And.class, this::visitAnd), + matches(Or.class, this::visitOr), + matches(Cast.class, this::visitCast), + matches(BoundFunction.class, this::visitBoundFunction), + matches(BinaryArithmetic.class, this::visitBinaryArithmetic), + matches(CaseWhen.class, this::visitCaseWhen), + matches(If.class, this::visitIf), + matches(InPredicate.class, this::visitInPredicate), + matches(IsNull.class, this::visitIsNull), + matches(TimestampArithmetic.class, this::visitTimestampArithmetic), + matches(Password.class, this::visitPassword), + matches(Array.class, this::visitArray), + matches(Date.class, this::visitDate), + matches(Version.class, this::visitVersion) + ); + } @Override public Expression rewrite(Expression expr, ExpressionRewriteContext ctx) { @@ -253,7 +326,7 @@ public Expression visitAnd(And and, ExpressionRewriteContext context) { List nonTrueLiteral = Lists.newArrayList(); int nullCount = 0; for (Expression e : and.children()) { - e = e.accept(this, context); + e = deepRewrite ? e.accept(this, context) : e; if (BooleanLiteral.FALSE.equals(e)) { return BooleanLiteral.FALSE; } else if (e instanceof NullLiteral) { @@ -294,7 +367,7 @@ public Expression visitOr(Or or, ExpressionRewriteContext context) { List nonFalseLiteral = Lists.newArrayList(); int nullCount = 0; for (Expression e : or.children()) { - e = e.accept(this, context); + e = deepRewrite ? e.accept(this, context) : e; if (BooleanLiteral.TRUE.equals(e)) { return BooleanLiteral.TRUE; } else if (e instanceof NullLiteral) { @@ -412,9 +485,13 @@ public Expression visitCaseWhen(CaseWhen caseWhen, ExpressionRewriteContext cont } } - Expression defaultResult = caseWhen.getDefaultValue().isPresent() - ? rewrite(caseWhen.getDefaultValue().get(), context) - : null; + Expression defaultResult = null; + if (caseWhen.getDefaultValue().isPresent()) { + defaultResult = caseWhen.getDefaultValue().get(); + if (deepRewrite) { + defaultResult = rewrite(defaultResult, context); + } + } if (foundNewDefault) { defaultResult = newDefault; } @@ -537,28 +614,83 @@ public Expression visitVersion(Version version, ExpressionRewriteContext context return new StringLiteral(GlobalVariable.version); } - private E rewriteChildren(Expression expr, ExpressionRewriteContext ctx) { - return (E) super.visit(expr, ctx); - } - - private boolean allArgsIsAllLiteral(Expression expression) { - return ExpressionUtils.isAllLiteral(expression.getArguments()); - } - - private boolean argsHasNullLiteral(Expression expression) { - return ExpressionUtils.hasNullLiteral(expression.getArguments()); + private E rewriteChildren(E expr, ExpressionRewriteContext context) { + if (!deepRewrite) { + return expr; + } + switch (expr.arity()) { + case 1: { + Expression originChild = expr.child(0); + Expression newChild = originChild.accept(this, context); + return (originChild != newChild) ? (E) expr.withChildren(ImmutableList.of(newChild)) : expr; + } + case 2: { + Expression originLeft = expr.child(0); + Expression newLeft = originLeft.accept(this, context); + Expression originRight = expr.child(1); + Expression newRight = originRight.accept(this, context); + return (originLeft != newLeft || originRight != newRight) + ? (E) expr.withChildren(ImmutableList.of(newLeft, newRight)) + : expr; + } + case 0: { + return expr; + } + default: { + boolean hasNewChildren = false; + Builder newChildren = ImmutableList.builderWithExpectedSize(expr.arity()); + for (Expression child : expr.children()) { + Expression newChild = child.accept(this, context); + if (newChild != child) { + hasNewChildren = true; + } + newChildren.add(newChild); + } + return hasNewChildren ? (E) expr.withChildren(newChildren.build()) : expr; + } + } } private Optional preProcess(Expression expression) { if (expression instanceof AggregateFunction || expression instanceof TableGeneratingFunction) { return Optional.of(expression); } - if (expression instanceof PropagateNullable && argsHasNullLiteral(expression)) { + if (expression instanceof PropagateNullable && ExpressionUtils.hasNullLiteral(expression.getArguments())) { return Optional.of(new NullLiteral(expression.getDataType())); } - if (!allArgsIsAllLiteral(expression)) { + if (!ExpressionUtils.isAllLiteral(expression.getArguments())) { return Optional.of(expression); } return Optional.empty(); } + + private static class ListenAggDistinct implements ExpressionTraverseListener { + @Override + public void onEnter(ExpressionMatchingContext context) { + context.cascadesContext.incrementDistinctAggLevel(); + } + + @Override + public void onExit(ExpressionMatchingContext context, Expression rewritten) { + context.cascadesContext.decrementDistinctAggLevel(); + } + } + + private static class CheckWhetherUnderAggDistinct implements Predicate> { + @Override + public boolean test(ExpressionMatchingContext context) { + return context.cascadesContext.getDistinctAggLevel() == 0; + } + + public Predicate> as() { + return (Predicate) this; + } + } + + private ExpressionPatternMatcher matches( + Class clazz, BiFunction visitMethod) { + return matchesType(clazz) + .whenCtx(NOT_UNDER_AGG_DISTINCT.as()) + .thenApply(ctx -> visitMethod.apply(ctx.expr, ctx.rewriteContext)); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/InPredicateDedup.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/InPredicateDedup.java index 32f8e46da7553f..3760dcf0e72420 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/InPredicateDedup.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/InPredicateDedup.java @@ -17,13 +17,14 @@ package org.apache.doris.nereids.rules.expression.rules; -import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule; -import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; +import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; +import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.InPredicate; -import java.util.ArrayList; -import java.util.HashSet; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; + import java.util.List; import java.util.Set; @@ -31,25 +32,32 @@ * Deduplicate InPredicate, For example: * where A in (x, x) ==> where A in (x) */ -public class InPredicateDedup extends AbstractExpressionRewriteRule { - - public static InPredicateDedup INSTANCE = new InPredicateDedup(); +public class InPredicateDedup implements ExpressionPatternRuleFactory { + public static final InPredicateDedup INSTANCE = new InPredicateDedup(); @Override - public Expression visitInPredicate(InPredicate inPredicate, ExpressionRewriteContext context) { + public List> buildRules() { + return ImmutableList.of( + matchesType(InPredicate.class).then(InPredicateDedup::dedup) + ); + } + + /** dedup */ + public static Expression dedup(InPredicate inPredicate) { // In many BI scenarios, the sql is auto-generated, and hence there may be thousands of options. // It takes a long time to apply this rule. So set a threshold for the max number. - if (inPredicate.getOptions().size() > 200) { + int optionSize = inPredicate.getOptions().size(); + if (optionSize > 200) { return inPredicate; } - Set dedupExpr = new HashSet<>(); - List newOptions = new ArrayList<>(); + ImmutableSet.Builder newOptionsBuilder = ImmutableSet.builderWithExpectedSize(inPredicate.arity()); for (Expression option : inPredicate.getOptions()) { - if (dedupExpr.contains(option)) { - continue; - } - dedupExpr.add(option); - newOptions.add(option); + newOptionsBuilder.add(option); + } + + Set newOptions = newOptionsBuilder.build(); + if (newOptions.size() == optionSize) { + return inPredicate; } return new InPredicate(inPredicate.getCompareExpr(), newOptions); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/InPredicateToEqualToRule.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/InPredicateToEqualToRule.java index b076cadd53358d..353de7f41f62a1 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/InPredicateToEqualToRule.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/InPredicateToEqualToRule.java @@ -17,12 +17,14 @@ package org.apache.doris.nereids.rules.expression.rules; -import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule; -import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; +import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; +import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.InPredicate; +import com.google.common.collect.ImmutableList; + import java.util.List; /** @@ -36,17 +38,16 @@ * NOTICE: it's related with `SimplifyRange`. * They are same processes, so must change synchronously. */ -public class InPredicateToEqualToRule extends AbstractExpressionRewriteRule { - - public static InPredicateToEqualToRule INSTANCE = new InPredicateToEqualToRule(); +public class InPredicateToEqualToRule implements ExpressionPatternRuleFactory { + public static final InPredicateToEqualToRule INSTANCE = new InPredicateToEqualToRule(); @Override - public Expression visitInPredicate(InPredicate inPredicate, ExpressionRewriteContext context) { - Expression left = inPredicate.getCompareExpr(); - List right = inPredicate.getOptions(); - if (right.size() != 1) { - return new InPredicate(left.accept(this, context), right); - } - return new EqualTo(left.accept(this, context), right.get(0).accept(this, context)); + public List> buildRules() { + return ImmutableList.of( + matchesType(InPredicate.class) + .when(in -> in.getOptions().size() == 1) + .then(in -> new EqualTo(in.getCompareExpr(), in.getOptions().get(0)) + ) + ); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/NormalizeBinaryPredicatesRule.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/NormalizeBinaryPredicatesRule.java index 9b1c88b930ba22..e73104793cd916 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/NormalizeBinaryPredicatesRule.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/NormalizeBinaryPredicatesRule.java @@ -17,22 +17,31 @@ package org.apache.doris.nereids.rules.expression.rules; -import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule; -import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; +import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; +import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; import org.apache.doris.nereids.trees.expressions.ComparisonPredicate; import org.apache.doris.nereids.trees.expressions.Expression; +import com.google.common.collect.ImmutableList; + +import java.util.List; + /** * Normalizes binary predicates of the form 'expr' op 'slot' so that the slot is on the left-hand side. * For example: * 5 > id -> id < 5 */ -public class NormalizeBinaryPredicatesRule extends AbstractExpressionRewriteRule { - - public static NormalizeBinaryPredicatesRule INSTANCE = new NormalizeBinaryPredicatesRule(); +public class NormalizeBinaryPredicatesRule implements ExpressionPatternRuleFactory { + public static final NormalizeBinaryPredicatesRule INSTANCE = new NormalizeBinaryPredicatesRule(); @Override - public Expression visitComparisonPredicate(ComparisonPredicate expr, ExpressionRewriteContext context) { + public List> buildRules() { + return ImmutableList.of( + matchesType(ComparisonPredicate.class).then(NormalizeBinaryPredicatesRule::normalize) + ); + } + + public static Expression normalize(ComparisonPredicate expr) { return expr.left().isConstant() && !expr.right().isConstant() ? expr.commute() : expr; } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/NullSafeEqualToEqual.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/NullSafeEqualToEqual.java index 6507f49825c7c5..e8eedb1e1980ff 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/NullSafeEqualToEqual.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/NullSafeEqualToEqual.java @@ -17,31 +17,34 @@ package org.apache.doris.nereids.rules.expression.rules; -import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; -import org.apache.doris.nereids.rules.expression.ExpressionRewriteRule; +import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; +import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.IsNull; import org.apache.doris.nereids.trees.expressions.NullSafeEqual; import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral; import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; -import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter; + +import com.google.common.collect.ImmutableList; + +import java.util.List; /** * convert "<=>" to "=", if any side is not nullable * convert "A <=> null" to "A is null" */ -public class NullSafeEqualToEqual extends DefaultExpressionRewriter implements - ExpressionRewriteRule { +public class NullSafeEqualToEqual implements ExpressionPatternRuleFactory { public static final NullSafeEqualToEqual INSTANCE = new NullSafeEqualToEqual(); @Override - public Expression rewrite(Expression expr, ExpressionRewriteContext ctx) { - return expr.accept(this, null); + public List> buildRules() { + return ImmutableList.of( + matchesType(NullSafeEqual.class).then(NullSafeEqualToEqual::rewrite) + ); } - @Override - public Expression visitNullSafeEqual(NullSafeEqual nullSafeEqual, ExpressionRewriteContext ctx) { + private static Expression rewrite(NullSafeEqual nullSafeEqual) { if (nullSafeEqual.left() instanceof NullLiteral) { if (nullSafeEqual.right().nullable()) { return new IsNull(nullSafeEqual.right()); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/OneListPartitionEvaluator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/OneListPartitionEvaluator.java index dd71ed8e99ff76..b9bdf520e3d6d4 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/OneListPartitionEvaluator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/OneListPartitionEvaluator.java @@ -82,7 +82,7 @@ public Expression visit(Expression expr, Map context) expr = super.visit(expr, context); if (!(expr instanceof Literal)) { // just forward to fold constant rule - return expr.accept(FoldConstantRuleOnFE.INSTANCE, expressionRewriteContext); + return FoldConstantRuleOnFE.evaluate(expr, expressionRewriteContext); } return expr; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/OneRangePartitionEvaluator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/OneRangePartitionEvaluator.java index deccd6cc1a3a54..32e39ef6264937 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/OneRangePartitionEvaluator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/OneRangePartitionEvaluator.java @@ -90,7 +90,7 @@ public class OneRangePartitionEvaluator /** OneRangePartitionEvaluator */ public OneRangePartitionEvaluator(long partitionId, List partitionSlots, - RangePartitionItem partitionItem, CascadesContext cascadesContext) { + RangePartitionItem partitionItem, CascadesContext cascadesContext, int expandThreshold) { this.partitionId = partitionId; this.partitionSlots = Objects.requireNonNull(partitionSlots, "partitionSlots cannot be null"); this.partitionItem = Objects.requireNonNull(partitionItem, "partitionItem cannot be null"); @@ -101,41 +101,46 @@ public OneRangePartitionEvaluator(long partitionId, List partitionSlots, this.lowers = toNereidsLiterals(range.lowerEndpoint()); this.uppers = toNereidsLiterals(range.upperEndpoint()); - PartitionRangeExpander expander = new PartitionRangeExpander(); - this.partitionSlotTypes = expander.computePartitionSlotTypes(lowers, uppers); - this.slotToType = Maps.newHashMapWithExpectedSize(16); - for (int i = 0; i < partitionSlots.size(); i++) { - slotToType.put(partitionSlots.get(i), partitionSlotTypes.get(i)); - } + this.partitionSlotTypes = PartitionRangeExpander.computePartitionSlotTypes(lowers, uppers); - this.partitionSlotContainsNull = Maps.newHashMapWithExpectedSize(16); - for (int i = 0; i < partitionSlots.size(); i++) { - Slot slot = partitionSlots.get(i); - if (!slot.nullable()) { - partitionSlotContainsNull.put(slot, false); - continue; + if (partitionSlots.size() == 1) { + // fast path + Slot partSlot = partitionSlots.get(0); + this.slotToType = ImmutableMap.of(partSlot, partitionSlotTypes.get(0)); + this.partitionSlotContainsNull + = ImmutableMap.of(partSlot, range.lowerEndpoint().getKeys().get(0).isMinValue()); + } else { + // slow path + this.slotToType = Maps.newHashMap(); + for (int i = 0; i < partitionSlots.size(); i++) { + slotToType.put(partitionSlots.get(i), partitionSlotTypes.get(i)); } - PartitionSlotType partitionSlotType = partitionSlotTypes.get(i); - boolean maybeNull = false; - switch (partitionSlotType) { - case CONST: - case RANGE: - maybeNull = range.lowerEndpoint().getKeys().get(i).isMinValue(); - break; - case OTHER: - maybeNull = true; - break; - default: - throw new AnalysisException("Unknown partition slot type: " + partitionSlotType); + + this.partitionSlotContainsNull = Maps.newHashMap(); + for (int i = 0; i < partitionSlots.size(); i++) { + Slot slot = partitionSlots.get(i); + if (!slot.nullable()) { + partitionSlotContainsNull.put(slot, false); + continue; + } + PartitionSlotType partitionSlotType = partitionSlotTypes.get(i); + boolean maybeNull; + switch (partitionSlotType) { + case CONST: + case RANGE: + maybeNull = range.lowerEndpoint().getKeys().get(i).isMinValue(); + break; + case OTHER: + maybeNull = true; + break; + default: + throw new AnalysisException("Unknown partition slot type: " + partitionSlotType); + } + partitionSlotContainsNull.put(slot, maybeNull); } - partitionSlotContainsNull.put(slot, maybeNull); } - int expandThreshold = cascadesContext.getAndCacheSessionVariable( - "partitionPruningExpandThreshold", - 10, sessionVariable -> sessionVariable.partitionPruningExpandThreshold); - - List> expandInputs = expander.tryExpandRange( + List> expandInputs = PartitionRangeExpander.tryExpandRange( partitionSlots, lowers, uppers, partitionSlotTypes, expandThreshold); // after expand range, we will get 2 dimension list like list: // part_col1: [1], part_col2:[4, 5, 6], we should combine it to @@ -428,10 +433,13 @@ public EvaluateRangeResult visitNot(Not not, EvaluateRangeInput context) { private EvaluateRangeResult evaluateChildrenThenThis(Expression expr, EvaluateRangeInput context) { // evaluate children - List newChildren = new ArrayList<>(); - List childrenResults = new ArrayList<>(); + List children = expr.children(); + ImmutableList.Builder newChildren = ImmutableList.builderWithExpectedSize(children.size()); + List childrenResults = new ArrayList<>(children.size()); boolean hasNewChildren = false; - for (Expression child : expr.children()) { + + for (int i = 0; i < children.size(); i++) { + Expression child = children.get(i); EvaluateRangeResult childResult = child.accept(this, context); if (!childResult.result.equals(child)) { hasNewChildren = true; @@ -440,11 +448,11 @@ private EvaluateRangeResult evaluateChildrenThenThis(Expression expr, EvaluateRa newChildren.add(childResult.result); } if (hasNewChildren) { - expr = expr.withChildren(newChildren); + expr = expr.withChildren(newChildren.build()); } // evaluate this - expr = expr.accept(FoldConstantRuleOnFE.INSTANCE, expressionRewriteContext); + expr = FoldConstantRuleOnFE.evaluate(expr, expressionRewriteContext); return new EvaluateRangeResult(expr, context.defaultColumnRanges, childrenResults); } @@ -552,9 +560,28 @@ private EvaluateRangeResult mergeRanges( } private List toNereidsLiterals(PartitionKey partitionKey) { - List literals = Lists.newArrayListWithCapacity(partitionKey.getKeys().size()); - for (int i = 0; i < partitionKey.getKeys().size(); i++) { - LiteralExpr literalExpr = partitionKey.getKeys().get(i); + if (partitionKey.getKeys().size() == 1) { + // fast path + return toSingleNereidsLiteral(partitionKey); + } + + // slow path + return toMultiNereidsLiterals(partitionKey); + } + + private List toSingleNereidsLiteral(PartitionKey partitionKey) { + List keys = partitionKey.getKeys(); + LiteralExpr literalExpr = keys.get(0); + PrimitiveType primitiveType = partitionKey.getTypes().get(0); + Type type = Type.fromPrimitiveType(primitiveType); + return ImmutableList.of(Literal.fromLegacyLiteral(literalExpr, type)); + } + + private List toMultiNereidsLiterals(PartitionKey partitionKey) { + List keys = partitionKey.getKeys(); + List literals = Lists.newArrayListWithCapacity(keys.size()); + for (int i = 0; i < keys.size(); i++) { + LiteralExpr literalExpr = keys.get(i); PrimitiveType primitiveType = partitionKey.getTypes().get(i); Type type = Type.fromPrimitiveType(primitiveType); literals.add(Literal.fromLegacyLiteral(literalExpr, type)); @@ -590,8 +617,8 @@ public EvaluateRangeResult visitDate(Date date, EvaluateRangeInput context) { Literal lower = span.lowerEndpoint().getValue(); Literal upper = span.upperEndpoint().getValue(); - Expression lowerDate = new Date(lower).accept(FoldConstantRuleOnFE.INSTANCE, expressionRewriteContext); - Expression upperDate = new Date(upper).accept(FoldConstantRuleOnFE.INSTANCE, expressionRewriteContext); + Expression lowerDate = FoldConstantRuleOnFE.evaluate(new Date(lower), expressionRewriteContext); + Expression upperDate = FoldConstantRuleOnFE.evaluate(new Date(upper), expressionRewriteContext); if (lowerDate instanceof Literal && upperDate instanceof Literal && lowerDate.equals(upperDate)) { return new EvaluateRangeResult(lowerDate, result.columnRanges, result.childrenResult); @@ -673,7 +700,7 @@ public EvaluateRangeResult(Expression result, Map columnRange public EvaluateRangeResult(Expression result, Map columnRanges, List childrenResult) { - this(result, columnRanges, childrenResult, childrenResult.stream().allMatch(r -> r.isRejectNot())); + this(result, columnRanges, childrenResult, allIsRejectNot(childrenResult)); } public EvaluateRangeResult withRejectNot(boolean rejectNot) { @@ -683,6 +710,15 @@ public EvaluateRangeResult withRejectNot(boolean rejectNot) { public boolean isRejectNot() { return rejectNot; } + + private static boolean allIsRejectNot(List childrenResult) { + for (EvaluateRangeResult evaluateRangeResult : childrenResult) { + if (!evaluateRangeResult.isRejectNot()) { + return false; + } + } + return true; + } } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/OrToIn.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/OrToIn.java index b085f70da6e9c9..83da8055037242 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/OrToIn.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/OrToIn.java @@ -17,15 +17,17 @@ package org.apache.doris.nereids.rules.expression.rules; +import org.apache.doris.nereids.rules.expression.ExpressionBottomUpRewriter; +import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; +import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; +import org.apache.doris.nereids.rules.expression.ExpressionRewrite; import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; -import org.apache.doris.nereids.rules.expression.ExpressionRewriteRule; import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.InPredicate; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Or; import org.apache.doris.nereids.trees.expressions.literal.Literal; -import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter; import org.apache.doris.nereids.util.ExpressionUtils; import com.google.common.collect.ImmutableList; @@ -54,20 +56,25 @@ * adding any additional rule-specific fields to the default ExpressionRewriteContext. However, the entire expression * rewrite framework always passes an ExpressionRewriteContext of type context to all rules. */ -public class OrToIn extends DefaultExpressionRewriter implements - ExpressionRewriteRule { +public class OrToIn implements ExpressionPatternRuleFactory { public static final OrToIn INSTANCE = new OrToIn(); public static final int REWRITE_OR_TO_IN_PREDICATE_THRESHOLD = 2; @Override - public Expression rewrite(Expression expr, ExpressionRewriteContext ctx) { - return expr.accept(this, null); + public List> buildRules() { + return ImmutableList.of( + matchesTopType(Or.class).then(OrToIn::rewrite) + ); } - @Override - public Expression visitOr(Or or, ExpressionRewriteContext ctx) { + public Expression rewriteTree(Expression expr, ExpressionRewriteContext context) { + ExpressionBottomUpRewriter bottomUpRewriter = ExpressionRewrite.bottomUp(this); + return bottomUpRewriter.rewrite(expr, context); + } + + private static Expression rewrite(Or or) { // NOTICE: use linked hash map to avoid unstable order or entry. // unstable order entry lead to dead loop since return expression always un-equals to original one. Map> slotNameToLiteral = Maps.newLinkedHashMap(); @@ -80,6 +87,10 @@ public Expression visitOr(Or or, ExpressionRewriteContext ctx) { handleInPredicate((InPredicate) expression, slotNameToLiteral, disConjunctToSlot); } } + if (disConjunctToSlot.isEmpty()) { + return or; + } + List rewrittenOr = new ArrayList<>(); for (Map.Entry> entry : slotNameToLiteral.entrySet()) { Set literals = entry.getValue(); @@ -90,7 +101,7 @@ public Expression visitOr(Or or, ExpressionRewriteContext ctx) { } for (Expression expression : expressions) { if (disConjunctToSlot.get(expression) == null) { - rewrittenOr.add(expression.accept(this, null)); + rewrittenOr.add(expression); } else { Set literals = slotNameToLiteral.get(disConjunctToSlot.get(expression)); if (literals.size() < REWRITE_OR_TO_IN_PREDICATE_THRESHOLD) { @@ -102,7 +113,7 @@ public Expression visitOr(Or or, ExpressionRewriteContext ctx) { return ExpressionUtils.or(rewrittenOr); } - private void handleEqualTo(EqualTo equal, Map> slotNameToLiteral, + private static void handleEqualTo(EqualTo equal, Map> slotNameToLiteral, Map disConjunctToSlot) { Expression left = equal.left(); Expression right = equal.right(); @@ -115,7 +126,7 @@ private void handleEqualTo(EqualTo equal, Map> slo } } - private void handleInPredicate(InPredicate inPredicate, Map> slotNameToLiteral, + private static void handleInPredicate(InPredicate inPredicate, Map> slotNameToLiteral, Map disConjunctToSlot) { // TODO a+b in (1,2,3...) is not supported now if (inPredicate.getCompareExpr() instanceof NamedExpression @@ -127,10 +138,9 @@ private void handleInPredicate(InPredicate inPredicate, Map> slotNameToLiteral) { Set literals = slotNameToLiteral.computeIfAbsent(namedExpression, k -> new LinkedHashSet<>()); literals.add(literal); } - } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/PartitionPruner.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/PartitionPruner.java index 9d6f420b47d79b..4a825d7956b839 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/PartitionPruner.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/PartitionPruner.java @@ -21,6 +21,7 @@ import org.apache.doris.catalog.PartitionItem; import org.apache.doris.catalog.RangePartitionItem; import org.apache.doris.nereids.CascadesContext; +import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.ComparisonPredicate; import org.apache.doris.nereids.trees.expressions.Expression; @@ -81,14 +82,19 @@ public Expression visitComparisonPredicate(ComparisonPredicate cp, Void context) && ((Cast) right).child().getDataType().isDateType()) { DateTimeLiteral dt = (DateTimeLiteral) left; Cast cast = (Cast) right; - return cp.withChildren(new DateLiteral(dt.getYear(), dt.getMonth(), dt.getDay()), cast.child()); + return cp.withChildren( + ImmutableList.of(new DateLiteral(dt.getYear(), dt.getMonth(), dt.getDay()), cast.child()) + ); } else if (right instanceof DateTimeLiteral && ((DateTimeLiteral) right).isMidnight() && left instanceof Cast && ((Cast) left).child() instanceof SlotReference && ((Cast) left).child().getDataType().isDateType()) { DateTimeLiteral dt = (DateTimeLiteral) right; Cast cast = (Cast) left; - return cp.withChildren(cast.child(), new DateLiteral(dt.getYear(), dt.getMonth(), dt.getDay())); + return cp.withChildren(ImmutableList.of( + cast.child(), + new DateLiteral(dt.getYear(), dt.getMonth(), dt.getDay())) + ); } else { return cp; } @@ -115,13 +121,18 @@ public static List prune(List partitionSlots, Expression partitionPr partitionPredicate, ImmutableSet.copyOf(partitionSlots), cascadesContext); partitionPredicate = PredicateRewriteForPartitionPrune.rewrite(partitionPredicate, cascadesContext); + int expandThreshold = cascadesContext.getAndCacheSessionVariable( + "partitionPruningExpandThreshold", + 10, sessionVariable -> sessionVariable.partitionPruningExpandThreshold); + List evaluators = Lists.newArrayListWithCapacity(idToPartitions.size()); for (Entry kv : idToPartitions.entrySet()) { evaluators.add(toPartitionEvaluator( - kv.getKey(), kv.getValue(), partitionSlots, cascadesContext, partitionTableType)); + kv.getKey(), kv.getValue(), partitionSlots, cascadesContext, expandThreshold)); } - partitionPredicate = OrToIn.INSTANCE.rewrite(partitionPredicate, null); + partitionPredicate = OrToIn.INSTANCE.rewriteTree( + partitionPredicate, new ExpressionRewriteContext(cascadesContext)); PartitionPruner partitionPruner = new PartitionPruner(evaluators, partitionPredicate); //TODO: we keep default partition because it's too hard to prune it, we return false in canPrune(). return partitionPruner.prune(); @@ -131,13 +142,13 @@ public static List prune(List partitionSlots, Expression partitionPr * convert partition item to partition evaluator */ public static final OnePartitionEvaluator toPartitionEvaluator(long id, PartitionItem partitionItem, - List partitionSlots, CascadesContext cascadesContext, PartitionTableType partitionTableType) { + List partitionSlots, CascadesContext cascadesContext, int expandThreshold) { if (partitionItem instanceof ListPartitionItem) { return new OneListPartitionEvaluator( id, partitionSlots, (ListPartitionItem) partitionItem, cascadesContext); } else if (partitionItem instanceof RangePartitionItem) { return new OneRangePartitionEvaluator( - id, partitionSlots, (RangePartitionItem) partitionItem, cascadesContext); + id, partitionSlots, (RangePartitionItem) partitionItem, cascadesContext, expandThreshold); } else { return new UnknownPartitionEvaluator(id, partitionItem); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/PartitionRangeExpander.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/PartitionRangeExpander.java index 071ab8f11572c6..01a674488da50a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/PartitionRangeExpander.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/PartitionRangeExpander.java @@ -41,7 +41,6 @@ import java.time.ZoneOffset; import java.util.Iterator; import java.util.List; -import java.util.NoSuchElementException; import java.util.function.Function; /** @@ -74,10 +73,44 @@ public enum PartitionSlotType { } /** expandRangeLiterals */ - public final List> tryExpandRange( + public static final List> tryExpandRange( List partitionSlots, List lowers, List uppers, List partitionSlotTypes, int expandThreshold) { + if (partitionSlots.size() == 1) { + return tryExpandSingleColumnRange(partitionSlots.get(0), lowers.get(0), + uppers.get(0), expandThreshold); + } else { + // slow path + return commonTryExpandRange(partitionSlots, lowers, uppers, partitionSlotTypes, expandThreshold); + } + } + + private static List> tryExpandSingleColumnRange(Slot partitionSlot, Literal lower, + Literal upper, int expandThreshold) { + // must be range slot + try { + if (canExpandRange(partitionSlot, lower, upper, 1, expandThreshold)) { + Iterator iterator = enumerableIterator( + partitionSlot, lower, upper, true); + if (iterator instanceof SingletonIterator) { + return ImmutableList.of(ImmutableList.of(iterator.next())); + } else { + return ImmutableList.of( + ImmutableList.copyOf(iterator) + ); + } + } else { + return ImmutableList.of(ImmutableList.of(partitionSlot)); + } + } catch (Throwable t) { + // catch for safety, should not invoke here + return ImmutableList.of(ImmutableList.of(partitionSlot)); + } + } + private static List> commonTryExpandRange( + List partitionSlots, List lowers, List uppers, + List partitionSlotTypes, int expandThreshold) { long expandedCount = 1; List> expandedLists = Lists.newArrayListWithCapacity(lowers.size()); for (int i = 0; i < partitionSlotTypes.size(); i++) { @@ -126,7 +159,7 @@ public final List> tryExpandRange( return expandedLists; } - private boolean canExpandRange(Slot slot, Literal lower, Literal upper, + private static boolean canExpandRange(Slot slot, Literal lower, Literal upper, long expandedCount, int expandThreshold) { DataType type = slot.getDataType(); if (!type.isIntegerLikeType() && !type.isDateType() && !type.isDateV2Type()) { @@ -139,7 +172,7 @@ private boolean canExpandRange(Slot slot, Literal lower, Literal upper, } // too much expanded will consuming resources of frontend, // e.g. [1, 100000000), we should skip expand it - return (expandedCount * count) <= expandThreshold; + return count == 1 || (expandedCount * count) <= expandThreshold; } catch (Throwable t) { // e.g. max_value can not expand return false; @@ -147,7 +180,7 @@ private boolean canExpandRange(Slot slot, Literal lower, Literal upper, } /** the types will like this: [CONST, CONST, ..., RANGE, OTHER, OTHER, ...] */ - public List computePartitionSlotTypes(List lowers, List uppers) { + public static List computePartitionSlotTypes(List lowers, List uppers) { PartitionSlotType previousType = PartitionSlotType.CONST; List types = Lists.newArrayListWithCapacity(lowers.size()); for (int i = 0; i < lowers.size(); ++i) { @@ -167,7 +200,7 @@ public List computePartitionSlotTypes(List lowers, L return types; } - private long enumerableCount(DataType dataType, Literal startInclusive, Literal endExclusive) { + private static long enumerableCount(DataType dataType, Literal startInclusive, Literal endExclusive) { if (dataType.isIntegerLikeType()) { BigInteger start = new BigInteger(startInclusive.getStringValue()); BigInteger end = new BigInteger(endExclusive.getStringValue()); @@ -175,6 +208,12 @@ private long enumerableCount(DataType dataType, Literal startInclusive, Literal } else if (dataType.isDateType()) { DateLiteral startInclusiveDate = (DateLiteral) startInclusive; DateLiteral endExclusiveDate = (DateLiteral) endExclusive; + + if (startInclusiveDate.getYear() == endExclusiveDate.getYear() + && startInclusiveDate.getMonth() == endExclusiveDate.getMonth()) { + return endExclusiveDate.getDay() - startInclusiveDate.getDay(); + } + LocalDate startDate = LocalDate.of( (int) startInclusiveDate.getYear(), (int) startInclusiveDate.getMonth(), @@ -192,6 +231,12 @@ private long enumerableCount(DataType dataType, Literal startInclusive, Literal } else if (dataType.isDateV2Type()) { DateV2Literal startInclusiveDate = (DateV2Literal) startInclusive; DateV2Literal endExclusiveDate = (DateV2Literal) endExclusive; + + if (startInclusiveDate.getYear() == endExclusiveDate.getYear() + && startInclusiveDate.getMonth() == endExclusiveDate.getMonth()) { + return endExclusiveDate.getDay() - startInclusiveDate.getDay(); + } + LocalDate startDate = LocalDate.of( (int) startInclusiveDate.getYear(), (int) startInclusiveDate.getMonth(), @@ -212,7 +257,7 @@ private long enumerableCount(DataType dataType, Literal startInclusive, Literal return -1; } - private Iterator enumerableIterator( + private static Iterator enumerableIterator( Slot slot, Literal startInclusive, Literal endLiteral, boolean endExclusive) { DataType dataType = slot.getDataType(); if (dataType.isIntegerLikeType()) { @@ -237,6 +282,12 @@ private Iterator enumerableIterator( } else if (dataType.isDateType()) { DateLiteral startInclusiveDate = (DateLiteral) startInclusive; DateLiteral endLiteralDate = (DateLiteral) endLiteral; + if (endExclusive && startInclusiveDate.getYear() == endLiteralDate.getYear() + && startInclusiveDate.getMonth() == endLiteralDate.getMonth() + && startInclusiveDate.getDay() + 1 == endLiteralDate.getDay()) { + return new SingletonIterator(startInclusive); + } + LocalDate startDate = LocalDate.of( (int) startInclusiveDate.getYear(), (int) startInclusiveDate.getMonth(), @@ -258,6 +309,13 @@ private Iterator enumerableIterator( } else if (dataType.isDateV2Type()) { DateV2Literal startInclusiveDate = (DateV2Literal) startInclusive; DateV2Literal endLiteralDate = (DateV2Literal) endLiteral; + + if (endExclusive && startInclusiveDate.getYear() == endLiteralDate.getYear() + && startInclusiveDate.getMonth() == endLiteralDate.getMonth() + && startInclusiveDate.getDay() + 1 == endLiteralDate.getDay()) { + return new SingletonIterator(startInclusive); + } + LocalDate startDate = LocalDate.of( (int) startInclusiveDate.getYear(), (int) startInclusiveDate.getMonth(), @@ -282,7 +340,7 @@ private Iterator enumerableIterator( return Iterators.singletonIterator(slot); } - private class IntegerLikeRangePartitionValueIterator + private static class IntegerLikeRangePartitionValueIterator extends RangePartitionValueIterator { public IntegerLikeRangePartitionValueIterator(BigInteger startInclusive, BigInteger end, @@ -296,7 +354,7 @@ protected BigInteger doGetNext(BigInteger current) { } } - private class DateLikeRangePartitionValueIterator + private static class DateLikeRangePartitionValueIterator extends RangePartitionValueIterator { public DateLikeRangePartitionValueIterator( @@ -309,43 +367,4 @@ protected LocalDate doGetNext(LocalDate current) { return current.plusDays(1); } } - - private abstract class RangePartitionValueIterator - implements Iterator { - private final C startInclusive; - private final C end; - private final boolean endExclusive; - private C current; - - private final Function toLiteral; - - public RangePartitionValueIterator(C startInclusive, C end, boolean endExclusive, Function toLiteral) { - this.startInclusive = startInclusive; - this.end = end; - this.endExclusive = endExclusive; - this.current = this.startInclusive; - this.toLiteral = toLiteral; - } - - @Override - public boolean hasNext() { - if (endExclusive) { - return current.compareTo(end) < 0; - } else { - return current.compareTo(end) <= 0; - } - } - - @Override - public L next() { - if (hasNext()) { - C value = current; - current = doGetNext(current); - return toLiteral.apply(value); - } - throw new NoSuchElementException(); - } - - protected abstract C doGetNext(C current); - } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/PredicateRewriteForPartitionPrune.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/PredicateRewriteForPartitionPrune.java index c227c89b939188..87646fbd582d3c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/PredicateRewriteForPartitionPrune.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/PredicateRewriteForPartitionPrune.java @@ -70,7 +70,7 @@ public Expression visitInPredicate(InPredicate in, CascadesContext context) { } } if (convertable) { - Expression or = ExpressionUtils.combine(Or.class, splitIn); + Expression or = ExpressionUtils.combineAsLeftDeepTree(Or.class, splitIn); return or; } } else if (dateChild.getDataType() instanceof DateTimeV2Type) { @@ -87,7 +87,7 @@ public Expression visitInPredicate(InPredicate in, CascadesContext context) { } } if (convertable) { - Expression or = ExpressionUtils.combine(Or.class, splitIn); + Expression or = ExpressionUtils.combineAsLeftDeepTree(Or.class, splitIn); return or; } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/RangePartitionValueIterator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/RangePartitionValueIterator.java new file mode 100644 index 00000000000000..79ee33d1ebb815 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/RangePartitionValueIterator.java @@ -0,0 +1,64 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.expression.rules; + +import org.apache.doris.nereids.trees.expressions.literal.Literal; + +import java.util.Iterator; +import java.util.NoSuchElementException; +import java.util.function.Function; + +/** RangePartitionValueIterator */ +public abstract class RangePartitionValueIterator + implements Iterator { + private final C startInclusive; + private final C end; + private final boolean endExclusive; + private C current; + + private final Function toLiteral; + + public RangePartitionValueIterator(C startInclusive, C end, boolean endExclusive, Function toLiteral) { + this.startInclusive = startInclusive; + this.end = end; + this.endExclusive = endExclusive; + this.current = this.startInclusive; + this.toLiteral = toLiteral; + } + + @Override + public boolean hasNext() { + if (endExclusive) { + return current.compareTo(end) < 0; + } else { + return current.compareTo(end) <= 0; + } + } + + @Override + public L next() { + if (hasNext()) { + C value = current; + current = doGetNext(current); + return toLiteral.apply(value); + } + throw new NoSuchElementException(); + } + + protected abstract C doGetNext(C current); +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/ReplaceVariableByLiteral.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/ReplaceVariableByLiteral.java index 3fd5330395e7fc..b4c5552706c589 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/ReplaceVariableByLiteral.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/ReplaceVariableByLiteral.java @@ -17,20 +17,25 @@ package org.apache.doris.nereids.rules.expression.rules; -import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule; -import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; +import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; +import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.Variable; +import com.google.common.collect.ImmutableList; + +import java.util.List; + /** * replace varaible to real expression */ -public class ReplaceVariableByLiteral extends AbstractExpressionRewriteRule { - +public class ReplaceVariableByLiteral implements ExpressionPatternRuleFactory { public static ReplaceVariableByLiteral INSTANCE = new ReplaceVariableByLiteral(); @Override - public Expression visitVariable(Variable variable, ExpressionRewriteContext context) { - return variable.getRealExpression(); + public List> buildRules() { + return ImmutableList.of( + matchesType(Variable.class).then(Variable::getRealExpression) + ); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyArithmeticComparisonRule.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyArithmeticComparisonRule.java index 7606d082479227..6d18bc7b3807a6 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyArithmeticComparisonRule.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyArithmeticComparisonRule.java @@ -17,7 +17,8 @@ package org.apache.doris.nereids.rules.expression.rules; -import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule; +import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; +import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; import org.apache.doris.nereids.trees.expressions.Add; import org.apache.doris.nereids.trees.expressions.ComparisonPredicate; @@ -43,6 +44,7 @@ import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.util.TypeCoercionUtils; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import java.util.Arrays; @@ -55,11 +57,11 @@ * a + 1 > 1 => a > 0 * a / -2 > 1 => a < -2 */ -public class SimplifyArithmeticComparisonRule extends AbstractExpressionRewriteRule { - public static final SimplifyArithmeticComparisonRule INSTANCE = new SimplifyArithmeticComparisonRule(); +public class SimplifyArithmeticComparisonRule implements ExpressionPatternRuleFactory { + public static SimplifyArithmeticComparisonRule INSTANCE = new SimplifyArithmeticComparisonRule(); // don't rearrange multiplication because divide may loss precision - final Map, Class> rearrangementMap = ImmutableMap + private static final Map, Class> REARRANGEMENT_MAP = ImmutableMap ., Class>builder() .put(Add.class, Subtract.class) .put(Subtract.class, Add.class) @@ -81,41 +83,54 @@ public class SimplifyArithmeticComparisonRule extends AbstractExpressionRewriteR .build(); @Override - public Expression visitComparisonPredicate(ComparisonPredicate comparison, ExpressionRewriteContext context) { - if (couldRearrange(comparison)) { - ComparisonPredicate newComparison = normalize(comparison); - if (newComparison == null) { - return comparison; - } - try { - List children = - tryRearrangeChildren(newComparison.left(), newComparison.right(), context); - newComparison = (ComparisonPredicate) visitComparisonPredicate( - (ComparisonPredicate) newComparison.withChildren(children), context); - } catch (Exception e) { - return comparison; - } + public List> buildRules() { + return ImmutableList.of( + matchesType(ComparisonPredicate.class) + .thenApply(ctx -> simplify(ctx.expr, new ExpressionRewriteContext(ctx.cascadesContext))) + ); + } + + /** simplify */ + public static Expression simplify(ComparisonPredicate comparison, ExpressionRewriteContext context) { + if (!couldRearrange(comparison)) { + return comparison; + } + ComparisonPredicate newComparison = normalize(comparison); + if (newComparison == null) { + return comparison; + } + try { + List children = tryRearrangeChildren(newComparison.left(), newComparison.right(), context); + newComparison = (ComparisonPredicate) simplify( + (ComparisonPredicate) newComparison.withChildren(children), context); return TypeCoercionUtils.processComparisonPredicate(newComparison); - } else { + } catch (Exception e) { return comparison; } } - private boolean couldRearrange(ComparisonPredicate cmp) { - return rearrangementMap.containsKey(cmp.left().getClass()) - && !cmp.left().isConstant() - && cmp.left().children().stream().anyMatch(Expression::isConstant); + private static boolean couldRearrange(ComparisonPredicate cmp) { + if (!REARRANGEMENT_MAP.containsKey(cmp.left().getClass()) || cmp.left().isConstant()) { + return false; + } + + for (Expression child : cmp.left().children()) { + if (child.isConstant()) { + return true; + } + } + return false; } - private List tryRearrangeChildren(Expression left, Expression right, + private static List tryRearrangeChildren(Expression left, Expression right, ExpressionRewriteContext context) throws Exception { if (!left.child(1).isConstant()) { throw new RuntimeException(String.format("Expected literal when arranging children for Expr %s", left)); } - Literal leftLiteral = (Literal) FoldConstantRule.INSTANCE.rewrite(left.child(1), context); + Literal leftLiteral = (Literal) FoldConstantRule.evaluate(left.child(1), context); Expression leftExpr = left.child(0); - Class oppositeOperator = rearrangementMap.get(left.getClass()); + Class oppositeOperator = REARRANGEMENT_MAP.get(left.getClass()); Expression newChild = oppositeOperator.getConstructor(Expression.class, Expression.class) .newInstance(right, leftLiteral); @@ -127,25 +142,25 @@ private List tryRearrangeChildren(Expression left, Expression right, } // Ensure that the second child must be Literal, such as - private @Nullable ComparisonPredicate normalize(ComparisonPredicate comparison) { - if (!(comparison.left().child(1) instanceof Literal)) { - Expression left = comparison.left(); - if (comparison.left() instanceof Add) { - // 1 + a > 1 => a + 1 > 1 - Expression newLeft = left.withChildren(left.child(1), left.child(0)); - comparison = (ComparisonPredicate) comparison.withChildren(newLeft, comparison.right()); - } else if (comparison.left() instanceof Subtract) { - // 1 - a > 1 => a + 1 < 1 - Expression newLeft = left.child(0); - Expression newRight = new Add(left.child(1), comparison.right()); - comparison = (ComparisonPredicate) comparison.withChildren(newLeft, newRight); - comparison = comparison.commute(); - } else { - // Don't normalize division/multiplication because the slot sign is undecided. - return null; - } + private static @Nullable ComparisonPredicate normalize(ComparisonPredicate comparison) { + Expression left = comparison.left(); + Expression leftRight = left.child(1); + if (leftRight instanceof Literal) { + return comparison; + } + if (left instanceof Add) { + // 1 + a > 1 => a + 1 > 1 + Expression newLeft = left.withChildren(leftRight, left.child(0)); + return (ComparisonPredicate) comparison.withChildren(newLeft, comparison.right()); + } else if (left instanceof Subtract) { + // 1 - a > 1 => a + 1 < 1 + Expression newLeft = left.child(0); + Expression newRight = new Add(leftRight, comparison.right()); + comparison = (ComparisonPredicate) comparison.withChildren(newLeft, newRight); + return comparison.commute(); + } else { + // Don't normalize division/multiplication because the slot sign is undecided. + return null; } - return comparison; } - } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyArithmeticRule.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyArithmeticRule.java index fc7431a9994d98..b9fd91f64387ef 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyArithmeticRule.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyArithmeticRule.java @@ -17,8 +17,8 @@ package org.apache.doris.nereids.rules.expression.rules; -import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule; -import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; +import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; +import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; import org.apache.doris.nereids.trees.expressions.Add; import org.apache.doris.nereids.trees.expressions.BinaryArithmetic; import org.apache.doris.nereids.trees.expressions.Divide; @@ -27,7 +27,9 @@ import org.apache.doris.nereids.trees.expressions.Subtract; import org.apache.doris.nereids.util.TypeCoercionUtils; import org.apache.doris.nereids.util.TypeUtils; +import org.apache.doris.nereids.util.Utils; +import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; import java.util.List; @@ -43,27 +45,24 @@ * * TODO: handle cases like: '1 - IA < 1' to 'IA > 0' */ -public class SimplifyArithmeticRule extends AbstractExpressionRewriteRule { +public class SimplifyArithmeticRule implements ExpressionPatternRuleFactory { public static final SimplifyArithmeticRule INSTANCE = new SimplifyArithmeticRule(); @Override - public Expression visitAdd(Add add, ExpressionRewriteContext context) { - return process(add, true); + public List> buildRules() { + return ImmutableList.of( + matchesTopType(BinaryArithmetic.class).then(SimplifyArithmeticRule::simplify) + ); } - @Override - public Expression visitSubtract(Subtract subtract, ExpressionRewriteContext context) { - return process(subtract, true); - } - - @Override - public Expression visitDivide(Divide divide, ExpressionRewriteContext context) { - return process(divide, false); - } - - @Override - public Expression visitMultiply(Multiply multiply, ExpressionRewriteContext context) { - return process(multiply, false); + /** simplify */ + public static Expression simplify(BinaryArithmetic binaryArithmetic) { + if (binaryArithmetic instanceof Add || binaryArithmetic instanceof Subtract) { + return process(binaryArithmetic, true); + } else if (binaryArithmetic instanceof Multiply || binaryArithmetic instanceof Divide) { + return process(binaryArithmetic, false); + } + return binaryArithmetic; } /** @@ -75,7 +74,7 @@ public Expression visitMultiply(Multiply multiply, ExpressionRewriteContext cont * 3.build new arithmetic expression. * (a + b - c + d) + (1 - 2 - 1) */ - private Expression process(BinaryArithmetic arithmetic, boolean isAddOrSub) { + private static Expression process(BinaryArithmetic arithmetic, boolean isAddOrSub) { // 1. flatten the arithmetic expression. List flattedExpressions = flatten(arithmetic, isAddOrSub); @@ -83,22 +82,24 @@ private Expression process(BinaryArithmetic arithmetic, boolean isAddOrSub) { List constants = Lists.newArrayList(); // TODO currently we don't process decimal for simplicity. - if (flattedExpressions.stream().anyMatch(operand -> operand.expression.getDataType().isDecimalLikeType())) { - return arithmetic; + for (Operand operand : flattedExpressions) { + if (operand.expression.getDataType().isDecimalLikeType()) { + return arithmetic; + } } // 2. move variables to left side and move constants to right sid. - flattedExpressions.forEach(operand -> { + for (Operand operand : flattedExpressions) { if (operand.expression.isConstant()) { constants.add(operand); } else { variables.add(operand); } - }); + } // 3. build new arithmetic expression. if (!constants.isEmpty()) { boolean isOpposite = !constants.get(0).flag; - Optional c = constants.stream().reduce((x, y) -> { + Optional c = Utils.fastReduce(constants, (x, y) -> { Expression expr; if (isOpposite && y.flag || !isOpposite && !y.flag) { expr = getSubOrDivide(isAddOrSub, x, y); @@ -115,10 +116,10 @@ private Expression process(BinaryArithmetic arithmetic, boolean isAddOrSub) { } } - Optional result = variables.stream().reduce((x, y) -> !y.flag + Optional result = Utils.fastReduce(variables, (x, y) -> !y.flag ? Operand.of(true, getSubOrDivide(isAddOrSub, x, y)) - : Operand.of(true, getAddOrMultiply(isAddOrSub, x, y))); - + : Operand.of(true, getAddOrMultiply(isAddOrSub, x, y)) + ); if (result.isPresent()) { return TypeCoercionUtils.castIfNotSameType(result.get().expression, arithmetic.getDataType()); } else { @@ -126,7 +127,7 @@ private Expression process(BinaryArithmetic arithmetic, boolean isAddOrSub) { } } - private List flatten(Expression expr, boolean isAddOrSub) { + private static List flatten(Expression expr, boolean isAddOrSub) { List result = Lists.newArrayList(); if (isAddOrSub) { flattenAddSubtract(true, expr, result); @@ -136,7 +137,7 @@ private List flatten(Expression expr, boolean isAddOrSub) { return result; } - private void flattenAddSubtract(boolean flag, Expression expr, List result) { + private static void flattenAddSubtract(boolean flag, Expression expr, List result) { if (TypeUtils.isAddOrSubtract(expr)) { BinaryArithmetic arithmetic = (BinaryArithmetic) expr; flattenAddSubtract(flag, arithmetic.left(), result); @@ -152,7 +153,7 @@ private void flattenAddSubtract(boolean flag, Expression expr, List res } } - private void flattenMultiplyDivide(boolean flag, Expression expr, List result) { + private static void flattenMultiplyDivide(boolean flag, Expression expr, List result) { if (TypeUtils.isMultiplyOrDivide(expr)) { BinaryArithmetic arithmetic = (BinaryArithmetic) expr; flattenMultiplyDivide(flag, arithmetic.left(), result); @@ -168,13 +169,13 @@ private void flattenMultiplyDivide(boolean flag, Expression expr, List } } - private Expression getSubOrDivide(boolean isAddOrSub, Operand x, Operand y) { - return isAddOrSub ? new Subtract(x.expression, y.expression) + private static Expression getSubOrDivide(boolean isSubOrDivide, Operand x, Operand y) { + return isSubOrDivide ? new Subtract(x.expression, y.expression) : new Divide(x.expression, y.expression); } - private Expression getAddOrMultiply(boolean isAddOrSub, Operand x, Operand y) { - return isAddOrSub ? new Add(x.expression, y.expression) + private static Expression getAddOrMultiply(boolean isAddOrMultiply, Operand x, Operand y) { + return isAddOrMultiply ? new Add(x.expression, y.expression) : new Multiply(x.expression, y.expression); } @@ -204,3 +205,4 @@ public String toString() { } } } + diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyCastRule.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyCastRule.java index 34143043a07022..ded0a2f558f8d4 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyCastRule.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyCastRule.java @@ -17,8 +17,8 @@ package org.apache.doris.nereids.rules.expression.rules; -import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule; -import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; +import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; +import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral; @@ -37,7 +37,10 @@ import org.apache.doris.nereids.types.StringType; import org.apache.doris.nereids.types.VarcharType; +import com.google.common.collect.ImmutableList; + import java.math.BigDecimal; +import java.util.List; /** * Rewrite rule of simplify CAST expression. @@ -46,17 +49,19 @@ * Merge cast like * - cast(cast(1 as bigint) as string) -> cast(1 as string). */ -public class SimplifyCastRule extends AbstractExpressionRewriteRule { - +public class SimplifyCastRule implements ExpressionPatternRuleFactory { public static SimplifyCastRule INSTANCE = new SimplifyCastRule(); @Override - public Expression visitCast(Cast origin, ExpressionRewriteContext context) { - return simplify(origin, context); + public List> buildRules() { + return ImmutableList.of( + matchesType(Cast.class).then(SimplifyCastRule::simplifyCast) + ); } - private Expression simplify(Cast cast, ExpressionRewriteContext context) { - Expression child = rewrite(cast.child(), context); + /** simplifyCast */ + public static Expression simplifyCast(Cast cast) { + Expression child = cast.child(); // remove redundant cast // CAST(value as type), value is type diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java index 03958f3d55f6f8..d26b5a53036897 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java @@ -18,6 +18,8 @@ package org.apache.doris.nereids.rules.expression.rules; import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule; +import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; +import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; import org.apache.doris.nereids.trees.expressions.And; import org.apache.doris.nereids.trees.expressions.Cast; @@ -55,17 +57,18 @@ import org.apache.doris.nereids.util.TypeCoercionUtils; import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; import java.math.BigDecimal; import java.math.RoundingMode; +import java.util.List; /** * simplify comparison * such as: cast(c1 as DateV2) >= DateV2Literal --> c1 >= DateLiteral * cast(c1 AS double) > 2.0 --> c1 >= 2 (c1 is integer like type) */ -public class SimplifyComparisonPredicate extends AbstractExpressionRewriteRule { - +public class SimplifyComparisonPredicate extends AbstractExpressionRewriteRule implements ExpressionPatternRuleFactory { public static SimplifyComparisonPredicate INSTANCE = new SimplifyComparisonPredicate(); enum AdjustType { @@ -74,10 +77,20 @@ enum AdjustType { NONE } + @Override + public List> buildRules() { + return ImmutableList.of( + matchesType(ComparisonPredicate.class).then(SimplifyComparisonPredicate::simplify) + ); + } + @Override public Expression visitComparisonPredicate(ComparisonPredicate cp, ExpressionRewriteContext context) { - cp = (ComparisonPredicate) visit(cp, context); + return simplify(cp); + } + /** simplify */ + public static Expression simplify(ComparisonPredicate cp) { if (cp.left() instanceof Literal && !(cp.right() instanceof Literal)) { cp = cp.commute(); } @@ -146,7 +159,7 @@ private static Expression processComparisonPredicateDateTimeV2Literal( return comparisonPredicate; } - private Expression processDateLikeTypeCoercion(ComparisonPredicate cp, Expression left, Expression right) { + private static Expression processDateLikeTypeCoercion(ComparisonPredicate cp, Expression left, Expression right) { if (left instanceof Cast && right instanceof DateLiteral) { Cast cast = (Cast) left; if (cast.child().getDataType() instanceof DateTimeType) { @@ -196,7 +209,7 @@ private Expression processDateLikeTypeCoercion(ComparisonPredicate cp, Expressio } } - private Expression processFloatLikeTypeCoercion(ComparisonPredicate comparisonPredicate, + private static Expression processFloatLikeTypeCoercion(ComparisonPredicate comparisonPredicate, Expression left, Expression right) { if (left instanceof Cast && left.child(0).getDataType().isIntegerLikeType() && (right instanceof DoubleLiteral || right instanceof FloatLiteral)) { @@ -209,7 +222,7 @@ private Expression processFloatLikeTypeCoercion(ComparisonPredicate comparisonPr } } - private Expression processDecimalV3TypeCoercion(ComparisonPredicate comparisonPredicate, + private static Expression processDecimalV3TypeCoercion(ComparisonPredicate comparisonPredicate, Expression left, Expression right) { if (left instanceof Cast && right instanceof DecimalV3Literal) { Cast cast = (Cast) left; @@ -264,7 +277,7 @@ private Expression processDecimalV3TypeCoercion(ComparisonPredicate comparisonPr return comparisonPredicate; } - private Expression processIntegerDecimalLiteralComparison( + private static Expression processIntegerDecimalLiteralComparison( ComparisonPredicate comparisonPredicate, Expression left, BigDecimal literal) { // we only process isIntegerLikeType, which are tinyint, smallint, int, bigint if (literal.compareTo(new BigDecimal(Long.MAX_VALUE)) <= 0) { @@ -306,7 +319,7 @@ private Expression processIntegerDecimalLiteralComparison( return comparisonPredicate; } - private IntegerLikeLiteral convertDecimalToIntegerLikeLiteral(BigDecimal decimal) { + private static IntegerLikeLiteral convertDecimalToIntegerLikeLiteral(BigDecimal decimal) { Preconditions.checkArgument( decimal.scale() <= 0 && decimal.compareTo(new BigDecimal(Long.MAX_VALUE)) <= 0, "decimal literal must have 0 scale and smaller than Long.MAX_VALUE"); @@ -322,15 +335,15 @@ private IntegerLikeLiteral convertDecimalToIntegerLikeLiteral(BigDecimal decimal } } - private Expression migrateToDateTime(DateTimeV2Literal l) { + private static Expression migrateToDateTime(DateTimeV2Literal l) { return new DateTimeLiteral(l.getYear(), l.getMonth(), l.getDay(), l.getHour(), l.getMinute(), l.getSecond()); } - private boolean cannotAdjust(DateTimeLiteral l, ComparisonPredicate cp) { + private static boolean cannotAdjust(DateTimeLiteral l, ComparisonPredicate cp) { return cp instanceof EqualTo && (l.getHour() != 0 || l.getMinute() != 0 || l.getSecond() != 0); } - private Expression migrateToDateV2(DateTimeLiteral l, AdjustType type) { + private static Expression migrateToDateV2(DateTimeLiteral l, AdjustType type) { DateV2Literal d = new DateV2Literal(l.getYear(), l.getMonth(), l.getDay()); if (type == AdjustType.UPPER && (l.getHour() != 0 || l.getMinute() != 0 || l.getSecond() != 0)) { d = ((DateV2Literal) d.plusDays(1)); @@ -338,7 +351,7 @@ private Expression migrateToDateV2(DateTimeLiteral l, AdjustType type) { return d; } - private Expression migrateToDate(DateV2Literal l) { + private static Expression migrateToDate(DateV2Literal l) { return new DateLiteral(l.getYear(), l.getMonth(), l.getDay()); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyDecimalV3Comparison.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyDecimalV3Comparison.java index 6b0426adaad5e9..c3c3c17dd55f42 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyDecimalV3Comparison.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyDecimalV3Comparison.java @@ -17,8 +17,8 @@ package org.apache.doris.nereids.rules.expression.rules; -import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule; -import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; +import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; +import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.ComparisonPredicate; import org.apache.doris.nereids.trees.expressions.Expression; @@ -26,8 +26,10 @@ import org.apache.doris.nereids.types.DecimalV3Type; import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; import java.math.BigDecimal; +import java.util.List; /** * if we have a column with decimalv3 type and set enable_decimal_conversion = false. @@ -37,14 +39,20 @@ * and the col1 need to convert to decimalv3(27, 9) to match the precision of right hand * this rule simplify it from cast(col1 as decimalv3(27, 9)) > 0.6 to col1 > 0.6 */ -public class SimplifyDecimalV3Comparison extends AbstractExpressionRewriteRule { - +public class SimplifyDecimalV3Comparison implements ExpressionPatternRuleFactory { public static SimplifyDecimalV3Comparison INSTANCE = new SimplifyDecimalV3Comparison(); @Override - public Expression visitComparisonPredicate(ComparisonPredicate cp, ExpressionRewriteContext context) { - Expression left = rewrite(cp.left(), context); - Expression right = rewrite(cp.right(), context); + public List> buildRules() { + return ImmutableList.of( + matchesType(ComparisonPredicate.class).then(SimplifyDecimalV3Comparison::simplify) + ); + } + + /** simplify */ + public static Expression simplify(ComparisonPredicate cp) { + Expression left = cp.left(); + Expression right = cp.right(); if (left.getDataType() instanceof DecimalV3Type && left instanceof Cast @@ -60,7 +68,7 @@ public Expression visitComparisonPredicate(ComparisonPredicate cp, ExpressionRew } } - private Expression doProcess(ComparisonPredicate cp, Cast left, DecimalV3Literal right) { + private static Expression doProcess(ComparisonPredicate cp, Cast left, DecimalV3Literal right) { BigDecimal trailingZerosValue = right.getValue().stripTrailingZeros(); int scale = org.apache.doris.analysis.DecimalLiteral.getBigDecimalScale(trailingZerosValue); int precision = org.apache.doris.analysis.DecimalLiteral.getBigDecimalPrecision(trailingZerosValue); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyInPredicate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyInPredicate.java index 3e194a4edde398..bf1b194a6ac7f7 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyInPredicate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyInPredicate.java @@ -17,8 +17,8 @@ package org.apache.doris.nereids.rules.expression.rules; -import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule; -import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; +import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; +import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.InPredicate; @@ -33,12 +33,18 @@ /** * SimplifyInPredicate */ -public class SimplifyInPredicate extends AbstractExpressionRewriteRule { - +public class SimplifyInPredicate implements ExpressionPatternRuleFactory { public static final SimplifyInPredicate INSTANCE = new SimplifyInPredicate(); @Override - public Expression visitInPredicate(InPredicate expr, ExpressionRewriteContext context) { + public List> buildRules() { + return ImmutableList.of( + matchesType(InPredicate.class).then(SimplifyInPredicate::simplify) + ); + } + + /** simplify */ + public static Expression simplify(InPredicate expr) { if (expr.children().size() > 1) { if (expr.getCompareExpr() instanceof Cast) { Cast cast = (Cast) expr.getCompareExpr(); @@ -58,7 +64,7 @@ && canLosslessConvertToDateV2Literal((DateTimeV2Literal) literal))) { DateTimeV2Type compareType = (DateTimeV2Type) cast.child().getDataType(); if (literals.stream().allMatch(literal -> literal instanceof DateTimeV2Literal && canLosslessConvertToLowScaleLiteral( - (DateTimeV2Literal) literal, compareType.getScale()))) { + (DateTimeV2Literal) literal, compareType.getScale()))) { ImmutableList.Builder children = ImmutableList.builder(); children.add(cast.child()); literals.forEach(l -> children.add(new DateTimeV2Literal(compareType, @@ -86,7 +92,7 @@ private static boolean canLosslessConvertToDateV2Literal(DateTimeV2Literal liter | literal.getMicroSecond()) == 0L; } - private DateV2Literal convertToDateV2Literal(DateTimeV2Literal literal) { + private static DateV2Literal convertToDateV2Literal(DateTimeV2Literal literal) { return new DateV2Literal(literal.getYear(), literal.getMonth(), literal.getDay()); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyNotExprRule.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyNotExprRule.java index 7268d6e8328a9c..484d68f0d7317d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyNotExprRule.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyNotExprRule.java @@ -17,8 +17,8 @@ package org.apache.doris.nereids.rules.expression.rules; -import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule; -import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; +import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; +import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; import org.apache.doris.nereids.trees.expressions.ComparisonPredicate; import org.apache.doris.nereids.trees.expressions.CompoundPredicate; import org.apache.doris.nereids.trees.expressions.Expression; @@ -28,6 +28,10 @@ import org.apache.doris.nereids.trees.expressions.LessThanEqual; import org.apache.doris.nereids.trees.expressions.Not; +import com.google.common.collect.ImmutableList; + +import java.util.List; + /** * Rewrite rule of NOT expression. * For example: @@ -42,12 +46,19 @@ * not and(a >= b, a <= c) -> or(a < b, a > c) * not or(a >= b, a <= c) -> and(a < b, a > c) */ -public class SimplifyNotExprRule extends AbstractExpressionRewriteRule { +public class SimplifyNotExprRule implements ExpressionPatternRuleFactory { public static SimplifyNotExprRule INSTANCE = new SimplifyNotExprRule(); @Override - public Expression visitNot(Not not, ExpressionRewriteContext context) { + public List> buildRules() { + return ImmutableList.of( + matchesType(Not.class).then(SimplifyNotExprRule::simplify) + ); + } + + /** simplifyNot */ + public static Expression simplify(Not not) { Expression child = not.child(); if (child instanceof ComparisonPredicate) { ComparisonPredicate cp = (ComparisonPredicate) not.child(); @@ -55,23 +66,22 @@ public Expression visitNot(Not not, ExpressionRewriteContext context) { Expression right = cp.right(); if (child instanceof GreaterThan) { - return new LessThanEqual(left, right).accept(this, context); + return new LessThanEqual(left, right); } else if (child instanceof GreaterThanEqual) { - return new LessThan(left, right).accept(this, context); + return new LessThan(left, right); } else if (child instanceof LessThan) { - return new GreaterThanEqual(left, right).accept(this, context); + return new GreaterThanEqual(left, right); } else if (child instanceof LessThanEqual) { - return new GreaterThan(left, right).accept(this, context); + return new GreaterThan(left, right); } } else if (child instanceof CompoundPredicate) { CompoundPredicate cp = (CompoundPredicate) child; Not left = new Not(cp.left()); Not right = new Not(cp.right()); - return cp.flip(left, right).accept(this, context); + return cp.flip(left, right); } else if (child instanceof Not) { - return child.child(0).accept(this, context); + return child.child(0); } - - return super.visitNot(not, context); + return not; } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyRange.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyRange.java index 6c4fcc3edd11d5..35437bf836117a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyRange.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyRange.java @@ -17,9 +17,8 @@ package org.apache.doris.nereids.rules.expression.rules; -import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule; -import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; -import org.apache.doris.nereids.rules.expression.ExpressionRuleExecutor; +import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; +import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; import org.apache.doris.nereids.trees.expressions.And; import org.apache.doris.nereids.trees.expressions.ComparisonPredicate; import org.apache.doris.nereids.trees.expressions.CompoundPredicate; @@ -41,15 +40,17 @@ import com.google.common.collect.BoundType; import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; +import com.google.common.collect.Multimap; +import com.google.common.collect.Multimaps; import com.google.common.collect.Range; import com.google.common.collect.Sets; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Iterator; import java.util.LinkedHashMap; import java.util.List; -import java.util.Map; import java.util.Map.Entry; import java.util.Set; import java.util.function.BinaryOperator; @@ -74,18 +75,21 @@ * 2. for `Or` expression (similar to `And`). * todo: support a > 10 and (a < 10 or a > 20 ) => a > 20 */ -public class SimplifyRange extends AbstractExpressionRewriteRule { - +public class SimplifyRange implements ExpressionPatternRuleFactory { public static final SimplifyRange INSTANCE = new SimplifyRange(); @Override - public Expression rewrite(Expression expr, ExpressionRewriteContext ctx) { - if (expr instanceof CompoundPredicate) { - ValueDesc valueDesc = expr.accept(new RangeInference(), null); - Expression simplifiedExpr = valueDesc.toExpression(); - return simplifiedExpr == null ? valueDesc.expr : simplifiedExpr; - } - return expr; + public List> buildRules() { + return ImmutableList.of( + matchesTopType(CompoundPredicate.class).then(SimplifyRange::rewrite) + ); + } + + /** rewrite */ + public static Expression rewrite(CompoundPredicate expr) { + ValueDesc valueDesc = expr.accept(new RangeInference(), null); + Expression simplifiedExpr = valueDesc.toExpression(); + return simplifiedExpr == null ? valueDesc.expr : simplifiedExpr; } private static class RangeInference extends ExpressionVisitor { @@ -96,11 +100,10 @@ public ValueDesc visit(Expression expr, Void context) { } private ValueDesc buildRange(ComparisonPredicate predicate) { - Expression rewrite = ExpressionRuleExecutor.normalize(predicate); - Expression right = rewrite.child(1); + Expression right = predicate.child(1); // only handle `NumericType` and `DateLikeType` if (right.isLiteral() && (right.getDataType().isNumericType() || right.getDataType().isDateLikeType())) { - return ValueDesc.range((ComparisonPredicate) rewrite); + return ValueDesc.range(predicate); } return new UnknownValue(predicate); } @@ -154,18 +157,23 @@ public ValueDesc visitOr(Or or, Void context) { private ValueDesc simplify(Expression originExpr, List predicates, BinaryOperator op, BinaryOperator exprOp) { - Map> groupByReference = predicates.stream() - .map(predicate -> predicate.accept(this, null)) - .collect(Collectors.groupingBy(p -> p.reference, LinkedHashMap::new, Collectors.toList())); + Multimap groupByReference + = Multimaps.newListMultimap(new LinkedHashMap<>(), ArrayList::new); + for (Expression predicate : predicates) { + ValueDesc valueDesc = predicate.accept(this, null); + List valueDescs = (List) groupByReference.get(valueDesc.reference); + valueDescs.add(valueDesc); + } List valuePerRefs = Lists.newArrayList(); - for (Entry> referenceValues : groupByReference.entrySet()) { - List valuePerReference = referenceValues.getValue(); + for (Entry> referenceValues : groupByReference.asMap().entrySet()) { + List valuePerReference = (List) referenceValues.getValue(); // merge per reference - ValueDesc simplifiedValue = valuePerReference.stream() - .reduce(op) - .get(); + ValueDesc simplifiedValue = valuePerReference.get(0); + for (int i = 1; i < valuePerReference.size(); i++) { + simplifiedValue = op.apply(simplifiedValue, valuePerReference.get(i)); + } valuePerRefs.add(simplifiedValue); } @@ -235,6 +243,7 @@ public static ValueDesc range(ComparisonPredicate predicate) { } public static ValueDesc discrete(InPredicate in) { + // Set literals = (Set) Utils.fastToImmutableSet(in.getOptions()); Set literals = in.getOptions().stream().map(Literal.class::cast).collect(Collectors.toSet()); return new DiscreteValue(in.getCompareExpr(), in, literals); } @@ -417,7 +426,9 @@ public Expression toExpression() { // They are same processes, so must change synchronously. if (values.size() == 1) { return new EqualTo(reference, values.iterator().next()); - } else if (values.size() <= OrToIn.REWRITE_OR_TO_IN_PREDICATE_THRESHOLD) { + + // this condition should as same as OrToIn, or else meet dead loop + } else if (values.size() < OrToIn.REWRITE_OR_TO_IN_PREDICATE_THRESHOLD) { Iterator iterator = values.iterator(); return new Or(new EqualTo(reference, iterator.next()), new EqualTo(reference, iterator.next())); } else { @@ -468,10 +479,12 @@ public Expression toExpression() { if (sourceValues.isEmpty()) { return expr; } - return sourceValues.stream() - .map(ValueDesc::toExpression) - .reduce(mergeExprOp) - .get(); + + Expression result = sourceValues.get(0).toExpression(); + for (int i = 1; i < sourceValues.size(); i++) { + result = mergeExprOp.apply(result, sourceValues.get(i).toExpression()); + } + return result; } } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SupportJavaDateFormatter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SupportJavaDateFormatter.java index 17f4b7d239a237..27b929a2b9f865 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SupportJavaDateFormatter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SupportJavaDateFormatter.java @@ -17,8 +17,8 @@ package org.apache.doris.nereids.rules.expression.rules; -import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule; -import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; +import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; +import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.functions.scalar.DateFormat; import org.apache.doris.nereids.trees.expressions.functions.scalar.FromUnixtime; @@ -26,54 +26,46 @@ import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.trees.expressions.literal.VarcharLiteral; +import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; import java.util.List; -/** SupportJavaDateFormatter */ -public class SupportJavaDateFormatter extends AbstractExpressionRewriteRule { +/** SupportJavaDateFormatter2 */ +public class SupportJavaDateFormatter implements ExpressionPatternRuleFactory { public static final SupportJavaDateFormatter INSTANCE = new SupportJavaDateFormatter(); @Override - public Expression visitDateFormat(DateFormat dateFormat, ExpressionRewriteContext context) { - Expression expr = super.visitDateFormat(dateFormat, context); - if (!(expr instanceof DateFormat)) { - return expr; - } - dateFormat = (DateFormat) expr; + public List> buildRules() { + return ImmutableList.of( + matchesType(DateFormat.class).then(SupportJavaDateFormatter::rewriteDateFormat), + matchesType(FromUnixtime.class).then(SupportJavaDateFormatter::rewriteFromUnixtime), + matchesType(UnixTimestamp.class).then(SupportJavaDateFormatter::rewriteUnixTimestamp) + ); + } + + public static Expression rewriteDateFormat(DateFormat dateFormat) { if (dateFormat.arity() > 1) { return translateJavaFormatter(dateFormat, 1); } return dateFormat; } - @Override - public Expression visitFromUnixtime(FromUnixtime fromUnixtime, ExpressionRewriteContext context) { - Expression expr = super.visitFromUnixtime(fromUnixtime, context); - if (!(expr instanceof FromUnixtime)) { - return expr; - } - fromUnixtime = (FromUnixtime) expr; + public static Expression rewriteFromUnixtime(FromUnixtime fromUnixtime) { if (fromUnixtime.arity() > 1) { return translateJavaFormatter(fromUnixtime, 1); } return fromUnixtime; } - @Override - public Expression visitUnixTimestamp(UnixTimestamp unixTimestamp, ExpressionRewriteContext context) { - Expression expr = super.visitUnixTimestamp(unixTimestamp, context); - if (!(expr instanceof UnixTimestamp)) { - return expr; - } - unixTimestamp = (UnixTimestamp) expr; + public static Expression rewriteUnixTimestamp(UnixTimestamp unixTimestamp) { if (unixTimestamp.arity() > 1) { return translateJavaFormatter(unixTimestamp, 1); } return unixTimestamp; } - private Expression translateJavaFormatter(Expression function, int formatterIndex) { + private static Expression translateJavaFormatter(Expression function, int formatterIndex) { Expression formatterExpr = function.getArgument(formatterIndex); Expression newFormatterExpr = translateJavaFormatter(formatterExpr); if (newFormatterExpr != formatterExpr) { @@ -84,7 +76,7 @@ private Expression translateJavaFormatter(Expression function, int formatterInde return function; } - private Expression translateJavaFormatter(Expression formatterExpr) { + private static Expression translateJavaFormatter(Expression formatterExpr) { if (formatterExpr.isLiteral() && formatterExpr.getDataType().isStringLikeType()) { Literal literal = (Literal) formatterExpr; String originFormatter = literal.getStringValue(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/TopnToMax.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/TopnToMax.java index 30e76cfe226f5b..318cb6ec6031af 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/TopnToMax.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/TopnToMax.java @@ -17,39 +17,38 @@ package org.apache.doris.nereids.rules.expression.rules; -import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; -import org.apache.doris.nereids.rules.expression.ExpressionRewriteRule; +import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; +import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; import org.apache.doris.nereids.trees.expressions.functions.agg.Max; import org.apache.doris.nereids.trees.expressions.functions.agg.TopN; import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral; -import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter; + +import com.google.common.collect.ImmutableList; + +import java.util.List; /** * Convert topn(x, 1) to max(x) */ -public class TopnToMax extends DefaultExpressionRewriter implements - ExpressionRewriteRule { +public class TopnToMax implements ExpressionPatternRuleFactory { public static final TopnToMax INSTANCE = new TopnToMax(); @Override - public Expression rewrite(Expression expr, ExpressionRewriteContext ctx) { - return expr.accept(this, null); + public List> buildRules() { + return ImmutableList.of( + matchesTopType(TopN.class).then(TopnToMax::rewrite) + ); } - @Override - public Expression visitAggregateFunction(AggregateFunction aggregateFunction, ExpressionRewriteContext context) { - if (!(aggregateFunction instanceof TopN)) { - return aggregateFunction; - } - TopN topN = (TopN) aggregateFunction; + /** rewrite */ + public static Expression rewrite(TopN topN) { if (topN.arity() == 2 && topN.child(1) instanceof IntegerLikeLiteral && ((IntegerLikeLiteral) topN.child(1)).getIntValue() == 1) { return new Max(topN.child(0)); } else { - return aggregateFunction; + return topN; } } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/TryEliminateUninterestedPredicates.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/TryEliminateUninterestedPredicates.java index 3faf56f0f3829e..ce23219bcc93e2 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/TryEliminateUninterestedPredicates.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/TryEliminateUninterestedPredicates.java @@ -22,6 +22,7 @@ import org.apache.doris.nereids.rules.expression.rules.TryEliminateUninterestedPredicates.Context; import org.apache.doris.nereids.trees.expressions.And; import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.Not; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral; import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter; @@ -51,10 +52,17 @@ private TryEliminateUninterestedPredicates(Set interestedSlots, CascadesCo this.expressionRewriteContext = new ExpressionRewriteContext(cascadesContext); } + /** rewrite */ public static Expression rewrite(Expression expression, Set interestedSlots, CascadesContext cascadesContext) { // before eliminate uninterested predicate, we must push down `Not` under CompoundPredicate - expression = expression.accept(new SimplifyNotExprRule(), null); + expression = expression.rewriteUp(expr -> { + if (expr instanceof Not) { + return SimplifyNotExprRule.simplify((Not) expr); + } else { + return expr; + } + }); TryEliminateUninterestedPredicates rewriter = new TryEliminateUninterestedPredicates( interestedSlots, cascadesContext); return expression.accept(rewriter, new Context()); @@ -89,7 +97,7 @@ public Expression visit(Expression originExpr, Context parentContext) { // -> ((interested slot a) and true) or true // -> (interested slot a) or true // -> true - expr = expr.accept(FoldConstantRuleOnFE.INSTANCE, expressionRewriteContext); + expr = FoldConstantRuleOnFE.evaluate(expr, expressionRewriteContext); } } else { // ((uninterested slot b > 0) + 1) > 1 @@ -122,7 +130,7 @@ public Expression visitAnd(And and, Context parentContext) { if (rightContext.childrenContainsNonInterestedSlots) { newRight = BooleanLiteral.TRUE; } - Expression expr = new And(newLeft, newRight).accept(FoldConstantRuleOnFE.INSTANCE, expressionRewriteContext); + Expression expr = FoldConstantRuleOnFE.evaluate(new And(newLeft, newRight), expressionRewriteContext); parentContext.childrenContainsInterestedSlots = rightContext.childrenContainsInterestedSlots || leftContext.childrenContainsInterestedSlots; return expr; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java index 10b21d0b979ae0..61aac4d2407462 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java @@ -1611,7 +1611,7 @@ private Pair, List> countDistinctMultiEx } private boolean containsCountDistinctMultiExpr(LogicalAggregate aggregate) { - return ExpressionUtils.anyMatch(aggregate.getOutputExpressions(), expr -> + return ExpressionUtils.deapAnyMatch(aggregate.getOutputExpressions(), expr -> expr instanceof Count && ((Count) expr).isDistinct() && expr.arity() > 1); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AdjustNullable.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AdjustNullable.java index f30d55ad0fc294..a608448e023f07 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AdjustNullable.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AdjustNullable.java @@ -268,11 +268,19 @@ private T updateExpression(T input, Map rep } private List updateExpressions(List inputs, Map replaceMap) { - return inputs.stream().map(i -> updateExpression(i, replaceMap)).collect(ImmutableList.toImmutableList()); + ImmutableList.Builder result = ImmutableList.builderWithExpectedSize(inputs.size()); + for (T input : inputs) { + result.add(updateExpression(input, replaceMap)); + } + return result.build(); } private Set updateExpressions(Set inputs, Map replaceMap) { - return inputs.stream().map(i -> updateExpression(i, replaceMap)).collect(ImmutableSet.toImmutableSet()); + ImmutableSet.Builder result = ImmutableSet.builderWithExpectedSize(inputs.size()); + for (T input : inputs) { + result.add(updateExpression(input, replaceMap)); + } + return result.build(); } private Map collectChildrenOutputMap(LogicalPlan plan) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CheckMatchExpression.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CheckMatchExpression.java index 907d34c07c0a12..8c73991f3638aa 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CheckMatchExpression.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CheckMatchExpression.java @@ -27,6 +27,7 @@ import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; +import org.apache.doris.nereids.util.ExpressionUtils; import java.util.List; @@ -38,7 +39,7 @@ public class CheckMatchExpression extends OneRewriteRuleFactory { @Override public Rule build() { return logicalFilter(logicalOlapScan()) - .when(filter -> containsMatchExpression(filter.getExpressions())) + .when(filter -> ExpressionUtils.containsType(filter.getExpressions(), Match.class)) .then(this::checkChildren) .toRule(RuleType.CHECK_MATCH_EXPRESSION); } @@ -60,8 +61,4 @@ private Plan checkChildren(LogicalFilter filter) { } return filter; } - - private boolean containsMatchExpression(List expressions) { - return expressions.stream().anyMatch(expr -> expr.anyMatch(Match.class::isInstance)); - } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CheckPrivileges.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CheckPrivileges.java index b4d7b8005132ff..70a5c593ee3dc8 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CheckPrivileges.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CheckPrivileges.java @@ -22,7 +22,6 @@ import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.jobs.JobContext; import org.apache.doris.nereids.rules.analysis.UserAuthentication; -import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalCatalogRelation; @@ -30,9 +29,12 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalView; import org.apache.doris.qe.ConnectContext; +import com.google.common.collect.Sets; + +import java.util.LinkedHashMap; +import java.util.List; import java.util.Map; import java.util.Set; -import java.util.stream.Collectors; /** CheckPrivileges */ public class CheckPrivileges extends ColumnPruning { @@ -65,15 +67,20 @@ public Plan visitLogicalRelation(LogicalRelation relation, PruneContext context) } private Set computeUsedColumns(Plan plan, Set requiredSlots) { - Map idToSlot = plan.getOutputSet() - .stream() - .collect(Collectors.toMap(slot -> slot.getExprId().asInt(), slot -> slot)); - return requiredSlots - .stream() - .map(slot -> idToSlot.get(slot.getExprId().asInt())) - .filter(slot -> slot != null) - .map(NamedExpression::getName) - .collect(Collectors.toSet()); + List outputs = plan.getOutput(); + Map idToSlot = new LinkedHashMap<>(outputs.size()); + for (Slot output : outputs) { + idToSlot.putIfAbsent(output.getExprId().asInt(), output); + } + + Set usedColumns = Sets.newLinkedHashSetWithExpectedSize(requiredSlots.size()); + for (Slot requiredSlot : requiredSlots) { + Slot slot = idToSlot.get(requiredSlot.getExprId().asInt()); + if (slot != null) { + usedColumns.add(slot.getName()); + } + } + return usedColumns; } private void checkColumnPrivileges(TableIf table, Set usedColumns) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ColumnPruning.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ColumnPruning.java index f33f1658c32e29..e36c0f5172ad70 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ColumnPruning.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ColumnPruning.java @@ -24,6 +24,7 @@ import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.algebra.Aggregate; import org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier; @@ -39,18 +40,17 @@ import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter; import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter; import org.apache.doris.nereids.util.ExpressionUtils; +import org.apache.doris.nereids.util.Utils; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Lists; import com.google.common.collect.Sets; -import java.util.ArrayList; import java.util.List; import java.util.Optional; import java.util.Set; import java.util.function.Function; -import java.util.stream.Collectors; import java.util.stream.IntStream; /** @@ -97,13 +97,11 @@ public Plan visit(Plan plan, JobContext jobContext) { for (Plan child : plan.children()) { child.accept(this, jobContext); } - plan.getExpressions().stream().filter( - expression -> !(expression instanceof SlotReference) - ).forEach( - expression -> { - keys.addAll(expression.getInputSlots()); - } - ); + for (Expression expression : plan.getExpressions()) { + if (!(expression instanceof SlotReference)) { + keys.addAll(expression.getInputSlots()); + } + } return plan; } } @@ -212,39 +210,42 @@ private Plan pruneAggregate(Aggregate agg, PruneContext context) { } private Plan skipPruneThisAndFirstLevelChildren(Plan plan) { - Set requireAllOutputOfChildren = plan.children() - .stream() - .flatMap(child -> child.getOutputSet().stream()) - .collect(Collectors.toSet()); - return pruneChildren(plan, requireAllOutputOfChildren); + ImmutableSet.Builder requireAllOutputOfChildren = ImmutableSet.builder(); + for (Plan child : plan.children()) { + requireAllOutputOfChildren.addAll(child.getOutput()); + } + return pruneChildren(plan, requireAllOutputOfChildren.build()); } - private static Aggregate fillUpGroupByAndOutput(Aggregate prunedOutputAgg) { + private static Aggregate fillUpGroupByAndOutput(Aggregate prunedOutputAgg) { List groupBy = prunedOutputAgg.getGroupByExpressions(); List output = prunedOutputAgg.getOutputExpressions(); if (!(prunedOutputAgg instanceof LogicalAggregate)) { return prunedOutputAgg; } - // add back group by keys which eliminated by rule ELIMINATE_GROUP_BY_KEY - // if related output expressions are not in pruned output list. - List remainedOutputExprs = Lists.newArrayList(output); - remainedOutputExprs.removeAll(groupBy); - List newOutputList = Lists.newArrayList(); - newOutputList.addAll((List) groupBy); - newOutputList.addAll(remainedOutputExprs); + ImmutableList.Builder newOutputListBuilder + = ImmutableList.builderWithExpectedSize(output.size()); + newOutputListBuilder.addAll((List) groupBy); + for (NamedExpression ne : output) { + if (!groupBy.contains(ne)) { + newOutputListBuilder.add(ne); + } + } - if (!(prunedOutputAgg instanceof LogicalAggregate)) { - return prunedOutputAgg.withAggOutput(newOutputList); - } else { - List newGroupByExprList = newOutputList.stream().filter(e -> - !(prunedOutputAgg.getAggregateFunctions().contains(e) - || e instanceof Alias && prunedOutputAgg.getAggregateFunctions() - .contains(((Alias) e).child())) - ).collect(Collectors.toList()); - return ((LogicalAggregate) prunedOutputAgg).withGroupByAndOutput(newGroupByExprList, newOutputList); + List newOutputList = newOutputListBuilder.build(); + Set aggregateFunctions = prunedOutputAgg.getAggregateFunctions(); + ImmutableList.Builder newGroupByExprList + = ImmutableList.builderWithExpectedSize(newOutputList.size()); + for (NamedExpression e : newOutputList) { + if (!(aggregateFunctions.contains(e) + || (e instanceof Alias && aggregateFunctions.contains(e.child(0))))) { + newGroupByExprList.add(e); + } } + return ((LogicalAggregate) prunedOutputAgg).withGroupByAndOutput( + newGroupByExprList.build(), newOutputList); } /** prune output */ @@ -253,9 +254,8 @@ public

P pruneOutput(P plan, List originOutput if (originOutput.isEmpty()) { return plan; } - List prunedOutputs = originOutput.stream() - .filter(output -> context.requiredSlots.contains(output.toSlot())) - .collect(ImmutableList.toImmutableList()); + List prunedOutputs = + Utils.filterImmutableList(originOutput, output -> context.requiredSlots.contains(output.toSlot())); if (prunedOutputs.isEmpty()) { List candidates = Lists.newArrayList(originOutput); @@ -281,7 +281,6 @@ private LogicalUnion pruneUnionOutput(LogicalUnion union, PruneContext context) } List prunedOutputs = Lists.newArrayList(); List> constantExprsList = union.getConstantExprsList(); - List> prunedConstantExprsList = Lists.newArrayList(); List extractColumnIndex = Lists.newArrayList(); for (int i = 0; i < originOutput.size(); i++) { NamedExpression output = originOutput.get(i); @@ -291,12 +290,14 @@ private LogicalUnion pruneUnionOutput(LogicalUnion union, PruneContext context) } } int len = extractColumnIndex.size(); + ImmutableList.Builder> prunedConstantExprsList + = ImmutableList.builderWithExpectedSize(constantExprsList.size()); for (List row : constantExprsList) { - ArrayList newRow = new ArrayList<>(len); + ImmutableList.Builder newRow = ImmutableList.builderWithExpectedSize(len); for (int idx : extractColumnIndex) { newRow.add(row.get(idx)); } - prunedConstantExprsList.add(newRow); + prunedConstantExprsList.add(newRow.build()); } if (prunedOutputs.isEmpty()) { @@ -312,7 +313,7 @@ private LogicalUnion pruneUnionOutput(LogicalUnion union, PruneContext context) if (prunedOutputs.equals(originOutput)) { return union; } else { - return union.withNewOutputsAndConstExprsList(prunedOutputs, prunedConstantExprsList); + return union.withNewOutputsAndConstExprsList(prunedOutputs, prunedConstantExprsList.build()); } } @@ -329,24 +330,31 @@ private

P pruneChildren(P plan, Set parentRequiredSlots) Set currentUsedSlots = plan.getInputSlots(); Set childrenRequiredSlots = parentRequiredSlots.isEmpty() ? currentUsedSlots - : ImmutableSet.builder() + : ImmutableSet.builderWithExpectedSize(parentRequiredSlots.size() + currentUsedSlots.size()) .addAll(parentRequiredSlots) .addAll(currentUsedSlots) .build(); - List newChildren = new ArrayList<>(); + ImmutableList.Builder newChildren = ImmutableList.builderWithExpectedSize(plan.arity()); boolean hasNewChildren = false; for (Plan child : plan.children()) { - Set childOutputSet = child.getOutputSet(); - Set childRequiredSlots = childOutputSet.stream() - .filter(childrenRequiredSlots::contains).collect(Collectors.toSet()); + Set childRequiredSlots; + List childOutputs = child.getOutput(); + ImmutableSet.Builder childRequiredSlotBuilder + = ImmutableSet.builderWithExpectedSize(childOutputs.size()); + for (Slot childOutput : childOutputs) { + if (childrenRequiredSlots.contains(childOutput)) { + childRequiredSlotBuilder.add(childOutput); + } + } + childRequiredSlots = childRequiredSlotBuilder.build(); Plan prunedChild = doPruneChild(plan, child, childRequiredSlots); if (prunedChild != child) { hasNewChildren = true; } newChildren.add(prunedChild); } - return hasNewChildren ? (P) plan.withChildren(newChildren) : plan; + return hasNewChildren ? (P) plan.withChildren(newChildren.build()) : plan; } private Plan doPruneChild(Plan plan, Plan child, Set childRequiredSlots) { @@ -358,7 +366,7 @@ private Plan doPruneChild(Plan plan, Plan child, Set childRequiredSlots) { // the case 2 in the class comment, prune child's output failed if (!isProject && !Sets.difference(prunedChild.getOutputSet(), childRequiredSlots).isEmpty()) { - prunedChild = new LogicalProject<>(ImmutableList.copyOf(childRequiredSlots), prunedChild); + prunedChild = new LogicalProject<>(Utils.fastToImmutableList(childRequiredSlots), prunedChild); } return prunedChild; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CountDistinctRewrite.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CountDistinctRewrite.java index f2ccf55ac50731..480ff57638458a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CountDistinctRewrite.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CountDistinctRewrite.java @@ -24,9 +24,12 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnionCount; import org.apache.doris.nereids.trees.expressions.functions.agg.Count; import org.apache.doris.nereids.trees.expressions.functions.agg.HllUnionAgg; -import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; +import org.apache.doris.nereids.types.DataType; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableList.Builder; import java.util.List; @@ -38,35 +41,42 @@ public class CountDistinctRewrite extends OneRewriteRuleFactory { @Override public Rule build() { - return logicalAggregate().then(agg -> { - List output = agg.getOutputExpressions() - .stream() - .map(CountDistinctRewriter::rewrite) - .map(NamedExpression.class::cast) - .collect(ImmutableList.toImmutableList()); - return agg.withAggOutput(output); + return logicalAggregate().when(CountDistinctRewrite::containsCountObject).then(agg -> { + List outputExpressions = agg.getOutputExpressions(); + Builder newOutputs + = ImmutableList.builderWithExpectedSize(outputExpressions.size()); + for (NamedExpression outputExpression : outputExpressions) { + NamedExpression newOutput = (NamedExpression) outputExpression.rewriteUp(expr -> { + if (expr instanceof Count && expr.arity() == 1) { + Expression child = expr.child(0); + if (child.getDataType().isBitmapType()) { + return new BitmapUnionCount(child); + } + if (child.getDataType().isHllType()) { + return new HllUnionAgg(child); + } + } + return expr; + }); + newOutputs.add(newOutput); + } + return agg.withAggOutput(newOutputs.build()); }).toRule(RuleType.COUNT_DISTINCT_REWRITE); } - private static class CountDistinctRewriter extends DefaultExpressionRewriter { - private static final CountDistinctRewriter INSTANCE = new CountDistinctRewriter(); - - public static Expression rewrite(Expression expr) { - return expr.accept(INSTANCE, null); - } - - @Override - public Expression visitCount(Count count, Void context) { - if (count.isDistinct() && count.arity() == 1) { - Expression child = count.child(0); - if (child.getDataType().isBitmapType()) { - return new BitmapUnionCount(child); - } - if (child.getDataType().isHllType()) { - return new HllUnionAgg(child); + private static boolean containsCountObject(LogicalAggregate agg) { + for (NamedExpression ne : agg.getOutputExpressions()) { + boolean needRewrite = ne.anyMatch(expr -> { + if (expr instanceof Count && expr.arity() == 1) { + DataType dataType = expr.child(0).getDataType(); + return dataType.isBitmapType() || dataType.isHllType(); } + return false; + }); + if (needRewrite) { + return true; } - return count; } + return false; } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CountLiteralRewrite.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CountLiteralRewrite.java index dfe13b388f5b56..bfbd6599cf8acf 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CountLiteralRewrite.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CountLiteralRewrite.java @@ -27,13 +27,14 @@ import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableList.Builder; import com.google.common.collect.Lists; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Set; -import java.util.stream.Collectors; /** * count(1) ==> count(*) @@ -50,21 +51,31 @@ public Rule build() { return agg; } - Map> projectsAndAggFunc = newExprs.stream() - .collect(Collectors.partitioningBy(Expression::isConstant)); + List projectFuncs = Lists.newArrayListWithCapacity(newExprs.size()); + Builder aggFuncsBuilder + = ImmutableList.builderWithExpectedSize(newExprs.size()); + for (NamedExpression newExpr : newExprs) { + if (newExpr.isConstant()) { + projectFuncs.add(newExpr); + } else { + aggFuncsBuilder.add(newExpr); + } + } - if (projectsAndAggFunc.get(false).isEmpty()) { + List aggFuncs = aggFuncsBuilder.build(); + if (aggFuncs.isEmpty()) { // if there is no group by keys and other agg func, don't rewrite return null; } else { // if there is group by keys, put count(null) in projects, such as // project(0 as count(null)) // --Aggregate(k1, group by k1) - Plan plan = agg.withAggOutput(projectsAndAggFunc.get(false)); - if (!projectsAndAggFunc.get(true).isEmpty()) { - projectsAndAggFunc.get(false).stream().map(NamedExpression::toSlot) - .forEach(projectsAndAggFunc.get(true)::add); - plan = new LogicalProject<>(projectsAndAggFunc.get(true), plan); + Plan plan = agg.withAggOutput(aggFuncs); + if (!projectFuncs.isEmpty()) { + for (NamedExpression aggFunc : aggFuncs) { + projectFuncs.add(aggFunc.toSlot()); + } + plan = new LogicalProject<>(projectFuncs, plan); } return plan; } @@ -77,9 +88,11 @@ private boolean rewriteCountLiteral(List oldExprs, List replaced = new HashMap<>(); Set oldAggFuncSet = expr.collect(AggregateFunction.class::isInstance); - oldAggFuncSet.stream() - .filter(this::isCountLiteral) - .forEach(c -> replaced.put(c, rewrite((Count) c))); + for (AggregateFunction aggFun : oldAggFuncSet) { + if (isCountLiteral(aggFun)) { + replaced.put(aggFun, rewrite((Count) aggFun)); + } + } expr = expr.rewriteUp(s -> replaced.getOrDefault(s, s)); changed |= !replaced.isEmpty(); newExprs.add((NamedExpression) expr); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateFilter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateFilter.java index a3de71a770e30f..ef9e418f58d185 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateFilter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateFilter.java @@ -43,10 +43,10 @@ public class EliminateFilter implements RewriteRuleFactory { @Override public List buildRules() { return ImmutableList.of(logicalFilter().when( - filter -> filter.getConjuncts().stream().anyMatch(BooleanLiteral.class::isInstance)) + filter -> ExpressionUtils.containsType(filter.getConjuncts(), BooleanLiteral.class)) .thenApply(ctx -> { LogicalFilter filter = ctx.root; - ImmutableSet.Builder newConjuncts = ImmutableSet.builder(); + ImmutableSet.Builder newConjuncts = ImmutableSet.builder(); for (Expression expression : filter.getConjuncts()) { if (expression == BooleanLiteral.FALSE) { return new LogicalEmptyRelation(ctx.statementContext.getNextRelationId(), @@ -73,8 +73,7 @@ public List buildRules() { new ExpressionRewriteContext(ctx.cascadesContext); for (Expression expression : filter.getConjuncts()) { Expression newExpr = ExpressionUtils.replace(expression, replaceMap); - Expression foldExpression = - FoldConstantRule.INSTANCE.rewrite(newExpr, context); + Expression foldExpression = FoldConstantRule.evaluate(newExpr, context); if (foldExpression == BooleanLiteral.FALSE) { return new LogicalEmptyRelation( diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupBy.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupBy.java index 3b95e9b44e06f0..109cff192f22f6 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupBy.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupBy.java @@ -20,6 +20,7 @@ import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.IsNull; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; @@ -31,11 +32,14 @@ import org.apache.doris.nereids.trees.expressions.functions.scalar.If; import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.util.ExpressionUtils; import org.apache.doris.nereids.util.PlanUtils; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.ImmutableSet.Builder; + import java.util.List; -import java.util.Set; -import java.util.stream.Collectors; /** * Eliminate GroupBy. @@ -45,39 +49,53 @@ public class EliminateGroupBy extends OneRewriteRuleFactory { @Override public Rule build() { return logicalAggregate() - .when(agg -> agg.getGroupByExpressions().stream().allMatch(expr -> expr instanceof Slot)) + .when(agg -> ExpressionUtils.allMatch(agg.getGroupByExpressions(), Slot.class::isInstance)) .then(agg -> { - Set groupby = agg.getGroupByExpressions().stream().map(e -> (Slot) e) - .collect(Collectors.toSet()); + List groupByExpressions = agg.getGroupByExpressions(); + Builder groupBySlots + = ImmutableSet.builderWithExpectedSize(groupByExpressions.size()); + for (Expression groupByExpression : groupByExpressions) { + groupBySlots.add((Slot) groupByExpression); + } Plan child = agg.child(); - boolean unique = child.getLogicalProperties().getFunctionalDependencies() - .isUniqueAndNotNull(groupby); + boolean unique = child.getLogicalProperties() + .getFunctionalDependencies() + .isUniqueAndNotNull(groupBySlots.build()); if (!unique) { return null; } - Set aggregateFunctions = agg.getAggregateFunctions(); - if (!aggregateFunctions.stream().allMatch( - f -> (f instanceof Sum || f instanceof Count || f instanceof Min || f instanceof Max) - && (f.arity() == 1 && f.child(0) instanceof Slot))) { - return null; + for (AggregateFunction f : agg.getAggregateFunctions()) { + if (!((f instanceof Sum || f instanceof Count || f instanceof Min || f instanceof Max) + && (f.arity() == 1 && f.child(0) instanceof Slot))) { + return null; + } } + List outputExpressions = agg.getOutputExpressions(); + + ImmutableList.Builder newOutput + = ImmutableList.builderWithExpectedSize(outputExpressions.size()); - List newOutput = agg.getOutputExpressions().stream().map(ne -> { + for (NamedExpression ne : outputExpressions) { if (ne instanceof Alias && ne.child(0) instanceof AggregateFunction) { AggregateFunction f = (AggregateFunction) ne.child(0); if (f instanceof Sum || f instanceof Min || f instanceof Max) { - return new Alias(ne.getExprId(), f.child(0), ne.getName()); + newOutput.add(new Alias(ne.getExprId(), f.child(0), ne.getName())); } else if (f instanceof Count) { - return (NamedExpression) ne.withChildren( - new If(new IsNull(f.child(0)), Literal.of(0), Literal.of(1))); + newOutput.add((NamedExpression) ne.withChildren( + new If( + new IsNull(f.child(0)), + Literal.of(0), + Literal.of(1) + ) + )); } else { throw new IllegalStateException("Unexpected aggregate function: " + f); } } else { - return ne; + newOutput.add(ne); } - }).collect(Collectors.toList()); - return PlanUtils.projectOrSelf(newOutput, child); + } + return PlanUtils.projectOrSelf(newOutput.build(), child); }).toRule(RuleType.ELIMINATE_GROUP_BY); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateMarkJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateMarkJoin.java index 2c5a4bbdd14e61..2e426beae46537 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateMarkJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateMarkJoin.java @@ -19,9 +19,11 @@ import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; import org.apache.doris.nereids.rules.expression.rules.TrySimplifyPredicateWithMarkJoinSlot; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; import org.apache.doris.nereids.util.ExpressionUtils; @@ -38,15 +40,22 @@ public class EliminateMarkJoin extends OneRewriteRuleFactory { public Rule build() { return logicalFilter(logicalJoin().when( join -> join.getJoinType().isSemiJoin() && !join.getMarkJoinConjuncts().isEmpty())) - .when(filter -> canSimplifyMarkJoin(filter.getConjuncts())) - .then(filter -> filter.withChildren(eliminateMarkJoin(filter.child()))) + .when(filter -> canSimplifyMarkJoin(filter.getConjuncts(), null)) + .thenApply(ctx -> { + LogicalFilter> filter = ctx.root; + ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(ctx.cascadesContext); + if (canSimplifyMarkJoin(filter.getConjuncts(), rewriteContext)) { + return filter.withChildren(eliminateMarkJoin(filter.child())); + } + return filter; + }) .toRule(RuleType.ELIMINATE_MARK_JOIN); } - private boolean canSimplifyMarkJoin(Set predicates) { + private boolean canSimplifyMarkJoin(Set predicates, ExpressionRewriteContext rewriteContext) { return ExpressionUtils .canInferNotNullForMarkSlot(TrySimplifyPredicateWithMarkJoinSlot.INSTANCE - .rewrite(ExpressionUtils.and(predicates), null)); + .rewrite(ExpressionUtils.and(predicates), rewriteContext), rewriteContext); } private LogicalJoin eliminateMarkJoin(LogicalJoin join) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateNotNull.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateNotNull.java index db95d1fefa03be..22393cb55f8335 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateNotNull.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateNotNull.java @@ -24,7 +24,6 @@ import org.apache.doris.nereids.trees.expressions.IsNull; import org.apache.doris.nereids.trees.expressions.Not; import org.apache.doris.nereids.trees.expressions.Slot; -import org.apache.doris.nereids.trees.expressions.functions.ExpressionTrait; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; @@ -41,7 +40,6 @@ import java.util.List; import java.util.Optional; import java.util.Set; -import java.util.stream.Collectors; /** * Eliminate Predicate `is not null`, like @@ -85,29 +83,34 @@ private List removeGeneratedNotNull(Collection exprs, Ca // remove `name` (it's generated), remove `id` (because `id > 0` already contains it) Set predicatesNotContainIsNotNull = Sets.newHashSet(); List slotsFromIsNotNull = Lists.newArrayList(); - exprs.stream() - .filter(expr -> !(expr instanceof Not) - || !((Not) expr).isGeneratedIsNotNull()) // remove generated `is not null` - .forEach(expr -> { - Optional notNullSlot = TypeUtils.isNotNull(expr); - if (notNullSlot.isPresent()) { - slotsFromIsNotNull.add(notNullSlot.get()); - } else { - predicatesNotContainIsNotNull.add(expr); - } - }); + + for (Expression expr : exprs) { + // remove generated `is not null` + if (!(expr instanceof Not) || !((Not) expr).isGeneratedIsNotNull()) { + Optional notNullSlot = TypeUtils.isNotNull(expr); + if (notNullSlot.isPresent()) { + slotsFromIsNotNull.add(notNullSlot.get()); + } else { + predicatesNotContainIsNotNull.add(expr); + } + } + } + Set inferNonNotSlots = ExpressionUtils.inferNotNullSlots( predicatesNotContainIsNotNull, ctx); - Set keepIsNotNull = slotsFromIsNotNull.stream() - .filter(ExpressionTrait::nullable) - .filter(slot -> !inferNonNotSlots.contains(slot)) - .map(slot -> new Not(new IsNull(slot))).collect(Collectors.toSet()); + ImmutableSet.Builder keepIsNotNull + = ImmutableSet.builderWithExpectedSize(slotsFromIsNotNull.size()); + for (Slot slot : slotsFromIsNotNull) { + if (slot.nullable() && !inferNonNotSlots.contains(slot)) { + keepIsNotNull.add(new Not(new IsNull(slot))); + } + } // merge predicatesNotContainIsNotNull and keepIsNotNull into a new List return ImmutableList.builder() .addAll(predicatesNotContainIsNotNull) - .addAll(keepIsNotNull) + .addAll(keepIsNotNull.build()) .build(); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateOrderByConstant.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateOrderByConstant.java index 969d6e6b045b9b..021cae2d6533f5 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateOrderByConstant.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateOrderByConstant.java @@ -33,13 +33,19 @@ public class EliminateOrderByConstant extends OneRewriteRuleFactory { @Override public Rule build() { return logicalSort().then(sort -> { - List orderKeysWithoutConst = sort - .getOrderKeys() - .stream() - .filter(k -> !(k.getExpr().isConstant())) - .collect(ImmutableList.toImmutableList()); + List orderKeys = sort.getOrderKeys(); + ImmutableList.Builder orderKeysWithoutConstBuilder + = ImmutableList.builderWithExpectedSize(orderKeys.size()); + for (OrderKey orderKey : orderKeys) { + if (!orderKey.getExpr().isConstant()) { + orderKeysWithoutConstBuilder.add(orderKey); + } + } + List orderKeysWithoutConst = orderKeysWithoutConstBuilder.build(); if (orderKeysWithoutConst.isEmpty()) { return sort.child(); + } else if (orderKeysWithoutConst.size() == orderKeys.size()) { + return sort; } return sort.withOrderKeys(orderKeysWithoutConst); }).toRule(RuleType.ELIMINATE_ORDER_BY_CONSTANT); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ExtractAndNormalizeWindowExpression.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ExtractAndNormalizeWindowExpression.java index 697eb8fa5a3fa9..5ec0f0cd698d5e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ExtractAndNormalizeWindowExpression.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ExtractAndNormalizeWindowExpression.java @@ -22,6 +22,7 @@ import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.OrderExpression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.WindowExpression; import org.apache.doris.nereids.trees.expressions.functions.agg.Avg; @@ -30,6 +31,7 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.NullableAggregateFunction; import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import org.apache.doris.nereids.trees.plans.logical.LogicalWindow; import org.apache.doris.nereids.util.ExpressionUtils; @@ -50,75 +52,94 @@ public class ExtractAndNormalizeWindowExpression extends OneRewriteRuleFactory i @Override public Rule build() { - return logicalProject().when(project -> containsWindowExpression(project.getProjects())).then(project -> { - List outputs = - ExpressionUtils.rewriteDownShortCircuit(project.getProjects(), output -> { - if (output instanceof WindowExpression) { - WindowExpression windowExpression = (WindowExpression) output; - Expression expression = ((WindowExpression) output).getFunction(); - if (expression instanceof Sum || expression instanceof Max - || expression instanceof Min || expression instanceof Avg) { - // sum, max, min and avg in window function should be always nullable - windowExpression = ((WindowExpression) output) - .withFunction(((NullableAggregateFunction) expression) - .withAlwaysNullable(true)); + return logicalProject() + .when(project -> ExpressionUtils.containsWindowExpression(project.getProjects())) + .then(this::normalize) + .toRule(RuleType.EXTRACT_AND_NORMALIZE_WINDOW_EXPRESSIONS); + } + + private Plan normalize(LogicalProject project) { + List outputs = + ExpressionUtils.rewriteDownShortCircuit(project.getProjects(), output -> { + if (output instanceof WindowExpression) { + WindowExpression windowExpression = (WindowExpression) output; + Expression expression = ((WindowExpression) output).getFunction(); + if (expression instanceof Sum || expression instanceof Max + || expression instanceof Min || expression instanceof Avg) { + // sum, max, min and avg in window function should be always nullable + windowExpression = ((WindowExpression) output) + .withFunction( + ((NullableAggregateFunction) expression).withAlwaysNullable(true) + ); + } + + ImmutableList.Builder nonLiteralPartitionKeys = + ImmutableList.builderWithExpectedSize(windowExpression.getPartitionKeys().size()); + for (Expression partitionKey : windowExpression.getPartitionKeys()) { + if (!partitionKey.isConstant()) { + nonLiteralPartitionKeys.add(partitionKey); } - // remove literal partition by and order by keys - return windowExpression.withPartitionKeysOrderKeys( - windowExpression.getPartitionKeys().stream() - .filter(partitionExpr -> !partitionExpr.isConstant()) - .collect(Collectors.toList()), - windowExpression.getOrderKeys().stream() - .filter(orderExpression -> !orderExpression - .getOrderKey().getExpr().isConstant()) - .collect(Collectors.toList())); } - return output; - }); - - // 1. handle bottom projects - Set existedAlias = ExpressionUtils.collect(outputs, Alias.class::isInstance); - Set toBePushedDown = collectExpressionsToBePushedDown(outputs); - NormalizeToSlotContext context = NormalizeToSlotContext.buildContext(existedAlias, toBePushedDown); - // set toBePushedDown exprs as NamedExpression, e.g. (a+1) -> Alias(a+1) - Set bottomProjects = context.pushDownToNamedExpression(toBePushedDown); - Plan normalizedChild; - if (bottomProjects.isEmpty()) { - normalizedChild = project.child(); - } else { - normalizedChild = project.withProjectsAndChild( - ImmutableList.copyOf(bottomProjects), project.child()); - } - - // 2. handle window's outputs and windowExprs - // need to replace exprs with SlotReference in WindowSpec, due to LogicalWindow.getExpressions() - - // because alias is pushed down to bottom project - // we need replace alias's child expr with corresponding alias's slot in output - // so create a customNormalizeMap alias's child -> alias.toSlot to do it - Map customNormalizeMap = toBePushedDown.stream() - .filter(expr -> expr instanceof Alias) - .collect(Collectors.toMap(expr -> ((Alias) expr).child(), expr -> ((Alias) expr).toSlot(), - (oldExpr, newExpr) -> oldExpr)); - - List normalizedOutputs = context.normalizeToUseSlotRef(outputs, - (ctx, expr) -> customNormalizeMap.getOrDefault(expr, null)); - Set normalizedWindows = - ExpressionUtils.collect(normalizedOutputs, WindowExpression.class::isInstance); - - existedAlias = ExpressionUtils.collect(normalizedOutputs, Alias.class::isInstance); - NormalizeToSlotContext ctxForWindows = NormalizeToSlotContext.buildContext( - existedAlias, Sets.newHashSet(normalizedWindows)); - - Set normalizedWindowWithAlias = ctxForWindows.pushDownToNamedExpression(normalizedWindows); - // only need normalized windowExpressions - LogicalWindow normalizedLogicalWindow = - new LogicalWindow<>(ImmutableList.copyOf(normalizedWindowWithAlias), normalizedChild); - - // 3. handle top projects - List topProjects = ctxForWindows.normalizeToUseSlotRef(normalizedOutputs); - return project.withProjectsAndChild(topProjects, normalizedLogicalWindow); - }).toRule(RuleType.EXTRACT_AND_NORMALIZE_WINDOW_EXPRESSIONS); + + ImmutableList.Builder nonLiteralOrderExpressions = + ImmutableList.builderWithExpectedSize(windowExpression.getOrderKeys().size()); + for (OrderExpression orderExpr : windowExpression.getOrderKeys()) { + if (!orderExpr.getOrderKey().getExpr().isConstant()) { + nonLiteralOrderExpressions.add(orderExpr); + } + } + + // remove literal partition by and order by keys + return windowExpression.withPartitionKeysOrderKeys( + nonLiteralPartitionKeys.build(), + nonLiteralOrderExpressions.build() + ); + } + return output; + }); + + // 1. handle bottom projects + Set existedAlias = ExpressionUtils.collect(outputs, Alias.class::isInstance); + Set toBePushedDown = collectExpressionsToBePushedDown(outputs); + NormalizeToSlotContext context = NormalizeToSlotContext.buildContext(existedAlias, toBePushedDown); + // set toBePushedDown exprs as NamedExpression, e.g. (a+1) -> Alias(a+1) + Set bottomProjects = context.pushDownToNamedExpression(toBePushedDown); + Plan normalizedChild; + if (bottomProjects.isEmpty()) { + normalizedChild = project.child(); + } else { + normalizedChild = project.withProjectsAndChild( + ImmutableList.copyOf(bottomProjects), project.child()); + } + + // 2. handle window's outputs and windowExprs + // need to replace exprs with SlotReference in WindowSpec, due to LogicalWindow.getExpressions() + + // because alias is pushed down to bottom project + // we need replace alias's child expr with corresponding alias's slot in output + // so create a customNormalizeMap alias's child -> alias.toSlot to do it + Map customNormalizeMap = toBePushedDown.stream() + .filter(expr -> expr instanceof Alias) + .collect(Collectors.toMap(expr -> ((Alias) expr).child(), expr -> ((Alias) expr).toSlot(), + (oldExpr, newExpr) -> oldExpr)); + + List normalizedOutputs = context.normalizeToUseSlotRef(outputs, + (ctx, expr) -> customNormalizeMap.getOrDefault(expr, null)); + Set normalizedWindows = + ExpressionUtils.collect(normalizedOutputs, WindowExpression.class::isInstance); + + existedAlias = ExpressionUtils.collect(normalizedOutputs, Alias.class::isInstance); + NormalizeToSlotContext ctxForWindows = NormalizeToSlotContext.buildContext( + existedAlias, Sets.newHashSet(normalizedWindows)); + + Set normalizedWindowWithAlias = ctxForWindows.pushDownToNamedExpression(normalizedWindows); + // only need normalized windowExpressions + LogicalWindow normalizedLogicalWindow = + new LogicalWindow<>(ImmutableList.copyOf(normalizedWindowWithAlias), normalizedChild); + + // 3. handle top projects + List topProjects = ctxForWindows.normalizeToUseSlotRef(normalizedOutputs); + return project.withProjectsAndChild(topProjects, normalizedLogicalWindow); } private Set collectExpressionsToBePushedDown(List expressions) { @@ -161,10 +182,4 @@ private Set collectExpressionsToBePushedDown(List e }) .collect(ImmutableSet.toImmutableSet()); } - - private boolean containsWindowExpression(List expressions) { - // WindowExpression in top LogicalProject will be normalized as Alias(SlotReference) after this rule, - // so it will not be normalized infinitely - return expressions.stream().anyMatch(expr -> expr.anyMatch(WindowExpression.class::isInstance)); - } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ExtractSingleTableExpressionFromDisjunction.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ExtractSingleTableExpressionFromDisjunction.java index 4ecc79ae94e7b0..2f8e1404b7199e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ExtractSingleTableExpressionFromDisjunction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ExtractSingleTableExpressionFromDisjunction.java @@ -30,7 +30,6 @@ import java.util.List; import java.util.Optional; import java.util.Set; -import java.util.stream.Collectors; /** * Paper: Quantifying TPC-H Choke Points and Their Optimizations @@ -84,13 +83,9 @@ private List extractDependentConjuncts(Set conjuncts) { } // only check table in first disjunct. // In our example, qualifiers = { n1, n2 } - Expression first = disjuncts.get(0); - Set qualifiers = first.getInputSlots() - .stream() - .map(slot -> String.join(".", slot.getQualifier())) - .collect(Collectors.toSet()); // try to extract - for (String qualifier : qualifiers) { + for (Slot inputSlot : disjuncts.get(0).getInputSlots()) { + String qualifier = String.join(".", inputSlot.getQualifier()); List extractForAll = Lists.newArrayList(); boolean success = true; for (Expression expr : ExpressionUtils.extractDisjunction(conjunct)) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java index 3bdfbc582acc99..9a0b9f8b5e0353 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java @@ -201,7 +201,7 @@ private boolean canMergeAggregateWithProject(LogicalAggregate !(expr instanceof SlotReference) && !(expr instanceof Alias))) { return false; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeProjects.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeProjects.java index d152178b5238de..3ea903f8565928 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeProjects.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeProjects.java @@ -20,9 +20,9 @@ import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.trees.expressions.NamedExpression; -import org.apache.doris.nereids.trees.expressions.WindowExpression; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import org.apache.doris.nereids.util.ExpressionUtils; import java.util.List; @@ -43,8 +43,8 @@ public Rule build() { // TODO modify ExtractAndNormalizeWindowExpression to handle nested window functions // here we just don't merge two projects if there is any window function return logicalProject(logicalProject()) - .whenNot(project -> containsWindowExpression(project.getProjects()) - && containsWindowExpression(project.child().getProjects())) + .whenNot(project -> ExpressionUtils.containsWindowExpression(project.getProjects()) + && ExpressionUtils.containsWindowExpression(project.child().getProjects())) .then(MergeProjects::mergeProjects).toRule(RuleType.MERGE_PROJECTS); } @@ -54,8 +54,4 @@ public static Plan mergeProjects(LogicalProject project) { LogicalProject newProject = childProject.canEliminate() ? project : childProject; return newProject.withProjectsAndChild(projectExpressions, childProject.child(0)); } - - private boolean containsWindowExpression(List expressions) { - return expressions.stream().anyMatch(expr -> expr.anyMatch(WindowExpression.class::isInstance)); - } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/NormalizeSort.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/NormalizeSort.java index b36d0e63b85423..b7554582885c0c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/NormalizeSort.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/NormalizeSort.java @@ -24,13 +24,15 @@ import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import org.apache.doris.nereids.trees.plans.logical.LogicalSort; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableList.Builder; import com.google.common.collect.Lists; import java.util.List; -import java.util.stream.Stream; /** * SortNode on BE always output order keys because BE needs them to do merge sort. So we normalize LogicalSort as BE @@ -40,29 +42,44 @@ public class NormalizeSort extends OneRewriteRuleFactory { @Override public Rule build() { - return logicalSort().whenNot(sort -> sort.getOrderKeys().stream() - .map(OrderKey::getExpr).allMatch(Slot.class::isInstance)) + return logicalSort().whenNot(this::allOrderKeyIsSlot) .then(sort -> { List newProjects = Lists.newArrayList(); - List newOrderKeys = sort.getOrderKeys().stream() - .map(orderKey -> { - Expression expr = orderKey.getExpr(); - if (!(expr instanceof Slot)) { - Alias alias = new Alias(expr); - newProjects.add(alias); - expr = alias.toSlot(); - } - return orderKey.withExpression(expr); - }).collect(ImmutableList.toImmutableList()); - List bottomProjections = Stream.concat( - sort.child().getOutput().stream(), - newProjects.stream() - ).collect(ImmutableList.toImmutableList()); - List topProjections = sort.getOutput().stream() - .map(NamedExpression.class::cast) - .collect(ImmutableList.toImmutableList()); - return new LogicalProject<>(topProjections, sort.withOrderKeysAndChild(newOrderKeys, + + Builder newOrderKeys = + ImmutableList.builderWithExpectedSize(sort.getOrderKeys().size()); + for (OrderKey orderKey : sort.getOrderKeys()) { + Expression expr = orderKey.getExpr(); + if (!(expr instanceof Slot)) { + Alias alias = new Alias(expr); + newProjects.add(alias); + expr = alias.toSlot(); + newOrderKeys.add(orderKey.withExpression(expr)); + } else { + newOrderKeys.add(orderKey); + } + } + + List childOutput = sort.child().getOutput(); + List bottomProjections = ImmutableList.builderWithExpectedSize( + childOutput.size() + newProjects.size()) + .addAll(childOutput) + .addAll(newProjects) + .build(); + + List topProjections = (List) sort.getOutput(); + return new LogicalProject<>(topProjections, sort.withOrderKeysAndChild( + newOrderKeys.build(), new LogicalProject<>(bottomProjections, sort.child()))); }).toRule(RuleType.NORMALIZE_SORT); } + + private boolean allOrderKeyIsSlot(LogicalSort sort) { + for (OrderKey orderKey : sort.getOrderKeys()) { + if (!(orderKey.getExpr() instanceof Slot)) { + return false; + } + } + return true; + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/NormalizeToSlot.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/NormalizeToSlot.java index 683841a5f8fb2c..ea2fb8f4beb538 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/NormalizeToSlot.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/NormalizeToSlot.java @@ -104,15 +104,19 @@ public List normalizeToUseSlotRef(Collection expres */ public List normalizeToUseSlotRef(Collection expressions, BiFunction customNormalize) { - return expressions.stream() - .map(expr -> (E) expr.rewriteDownShortCircuit(child -> { - Expression newChild = customNormalize.apply(this, child); - if (newChild != null && newChild != child) { - return newChild; - } - NormalizeToSlotTriplet normalizeToSlotTriplet = normalizeToSlotMap.get(child); - return normalizeToSlotTriplet == null ? child : normalizeToSlotTriplet.remainExpr; - })).collect(ImmutableList.toImmutableList()); + ImmutableList.Builder result = ImmutableList.builderWithExpectedSize(expressions.size()); + for (E expr : expressions) { + Expression rewriteExpr = expr.rewriteDownShortCircuit(child -> { + Expression newChild = customNormalize.apply(this, child); + if (newChild != null && newChild != child) { + return newChild; + } + NormalizeToSlotTriplet normalizeToSlotTriplet = normalizeToSlotMap.get(child); + return normalizeToSlotTriplet == null ? child : normalizeToSlotTriplet.remainExpr; + }); + result.add((E) rewriteExpr); + } + return result.build(); } public List normalizeToUseSlotRefWithoutWindowFunction( @@ -131,13 +135,20 @@ public List normalizeToUseSlotRefWithoutWindowFunction * bottom: k1#0, (k2#1 + 1) AS (k2 + 1)#2; */ public Set pushDownToNamedExpression(Collection needToPushExpressions) { - return needToPushExpressions.stream() - .map(expr -> { - NormalizeToSlotTriplet normalizeToSlotTriplet = normalizeToSlotMap.get(expr); - return normalizeToSlotTriplet == null - ? (NamedExpression) expr - : normalizeToSlotTriplet.pushedExpr; - }).collect(ImmutableSet.toImmutableSet()); + ImmutableSet.Builder result + = ImmutableSet.builderWithExpectedSize(needToPushExpressions.size()); + for (Expression expr : needToPushExpressions) { + NormalizeToSlotTriplet normalizeToSlotTriplet = normalizeToSlotMap.get(expr); + result.add(normalizeToSlotTriplet == null + ? (NamedExpression) expr + : normalizeToSlotTriplet.pushedExpr); + } + return result.build(); + } + + public NamedExpression pushDownToNamedExpression(Expression expr) { + NormalizeToSlotTriplet normalizeToSlotTriplet = normalizeToSlotMap.get(expr); + return normalizeToSlotTriplet == null ? (NamedExpression) expr : normalizeToSlotTriplet.pushedExpr; } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PruneOlapScanPartition.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PruneOlapScanPartition.java index 60df874f2a1004..0aacde1cc1984c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PruneOlapScanPartition.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PruneOlapScanPartition.java @@ -36,7 +36,6 @@ import org.apache.doris.qe.ConnectContext; import com.google.common.collect.ImmutableList; -import com.google.common.collect.Maps; import java.util.ArrayList; import java.util.List; @@ -50,6 +49,23 @@ * MergeConsecutiveProjects and all predicate push down related rules. */ public class PruneOlapScanPartition implements RewriteRuleFactory { + + @Override + public List buildRules() { + return ImmutableList.of( + logicalFilter(logicalOlapScan()) + .when(p -> !p.child().isPartitionPruned()) + .thenApply(ctx -> prunePartitions(ctx.cascadesContext, ctx.root.child(), ctx.root)) + .toRule(RuleType.OLAP_SCAN_PARTITION_PRUNE), + + logicalFilter(logicalProject(logicalOlapScan())) + .when(p -> !p.child().child().isPartitionPruned()) + .when(p -> p.child().hasPushedDownToProjectionFunctions()) + .thenApply(ctx -> prunePartitions(ctx.cascadesContext, ctx.root.child().child(), ctx.root)) + .toRule(RuleType.OLAP_SCAN_WITH_PROJECT_PARTITION_PRUNE) + ); + } + private Plan prunePartitions(CascadesContext ctx, LogicalOlapScan scan, LogicalFilter originalFilter) { OlapTable table = scan.getTable(); @@ -59,20 +75,22 @@ private Plan prunePartitions(CascadesContext ctx, } List output = scan.getOutput(); - Map scanOutput = Maps.newHashMapWithExpectedSize(output.size() * 2); - for (Slot slot : output) { - scanOutput.put(slot.getName().toLowerCase(), slot); - } - PartitionInfo partitionInfo = table.getPartitionInfo(); List partitionColumns = partitionInfo.getPartitionColumns(); List partitionSlots = new ArrayList<>(partitionColumns.size()); for (Column column : partitionColumns) { - Slot slot = scanOutput.get(column.getName().toLowerCase()); - if (slot == null) { + Slot partitionSlot = null; + // loop search is faster than build a map + for (Slot slot : output) { + if (slot.getName().equalsIgnoreCase(column.getName())) { + partitionSlot = slot; + break; + } + } + if (partitionSlot == null) { return originalFilter; } else { - partitionSlots.add(slot); + partitionSlots.add(partitionSlot); } } @@ -105,19 +123,4 @@ private Plan prunePartitions(CascadesContext ctx, } return originalFilter.withChildren(ImmutableList.of(rewrittenScan)); } - - @Override - public List buildRules() { - return ImmutableList.of( - logicalFilter(logicalOlapScan()).when(p -> !p.child().isPartitionPruned()).thenApply(ctx -> { - return prunePartitions(ctx.cascadesContext, ctx.root.child(), ctx.root); - }).toRule(RuleType.OLAP_SCAN_PARTITION_PRUNE), - - logicalFilter(logicalProject(logicalOlapScan())) - .when(p -> !p.child().child().isPartitionPruned()) - .when(p -> p.child().hasPushedDownToProjectionFunctions()).thenApply(ctx -> { - return prunePartitions(ctx.cascadesContext, ctx.root.child().child(), ctx.root); - }).toRule(RuleType.OLAP_SCAN_WITH_PROJECT_PARTITION_PRUNE) - ); - } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java index 26e1358c2e5e11..b02c51b1fe906e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PullUpPredicates.java @@ -31,16 +31,17 @@ import org.apache.doris.nereids.util.ExpressionUtils; import com.google.common.collect.ImmutableSet; +import com.google.common.collect.ImmutableSet.Builder; import com.google.common.collect.Lists; +import com.google.common.collect.Maps; import com.google.common.collect.Sets; -import java.util.Collection; import java.util.IdentityHashMap; import java.util.List; import java.util.Map; +import java.util.Map.Entry; import java.util.Set; import java.util.function.Supplier; -import java.util.stream.Collectors; /** * poll up effective predicates from operator's children. @@ -60,7 +61,7 @@ public ImmutableSet visit(Plan plan, Void context) { @Override public ImmutableSet visitLogicalFilter(LogicalFilter filter, Void context) { return cacheOrElse(filter, () -> { - List predicates = Lists.newArrayList(filter.getConjuncts()); + Set predicates = Sets.newLinkedHashSet(filter.getConjuncts()); predicates.addAll(filter.child().accept(this, context)); return getAvailableExpressions(predicates, filter); }); @@ -82,14 +83,14 @@ public ImmutableSet visitLogicalJoin(LogicalJoin visitLogicalProject(LogicalProject project, Void context) { return cacheOrElse(project, () -> { ImmutableSet childPredicates = project.child().accept(this, context); - - Set allPredicates = Sets.newHashSet(childPredicates); - project.getAliasToProducer().forEach((k, v) -> { - Set expressions = childPredicates.stream() - .map(e -> e.rewriteDownShortCircuit(c -> c.equals(v) ? k : c)).collect(Collectors.toSet()); - allPredicates.addAll(expressions); - }); - + Set allPredicates = Sets.newLinkedHashSet(childPredicates); + for (Entry kv : project.getAliasToProducer().entrySet()) { + Slot k = kv.getKey(); + Expression v = kv.getValue(); + for (Expression childPredicate : childPredicates) { + allPredicates.add(childPredicate.rewriteDownShortCircuit(c -> c.equals(v) ? k : c)); + } + } return getAvailableExpressions(allPredicates, project); }); } @@ -99,21 +100,22 @@ public ImmutableSet visitLogicalAggregate(LogicalAggregate { ImmutableSet childPredicates = aggregate.child().accept(this, context); // TODO - Map expressionSlotMap = aggregate.getOutputExpressions() - .stream() - .filter(this::hasAgg) - .collect(Collectors.toMap( - namedExpr -> { - if (namedExpr instanceof Alias) { - return ((Alias) namedExpr).child(); - } else { - return namedExpr; - } - }, NamedExpression::toSlot) + List outputExpressions = aggregate.getOutputExpressions(); + + Map expressionSlotMap + = Maps.newLinkedHashMapWithExpectedSize(outputExpressions.size()); + for (NamedExpression output : outputExpressions) { + if (hasAgg(output)) { + expressionSlotMap.putIfAbsent( + output instanceof Alias ? output.child(0) : output, output.toSlot() ); - Expression expression = ExpressionUtils.replace(ExpressionUtils.and(Lists.newArrayList(childPredicates)), - expressionSlotMap); - List predicates = ExpressionUtils.extractConjunction(expression); + } + } + Expression expression = ExpressionUtils.replace( + ExpressionUtils.and(Lists.newArrayList(childPredicates)), + expressionSlotMap + ); + Set predicates = Sets.newLinkedHashSet(ExpressionUtils.extractConjunction(expression)); return getAvailableExpressions(predicates, aggregate); }); } @@ -128,12 +130,23 @@ private ImmutableSet cacheOrElse(Plan plan, Supplier getAvailableExpressions(Collection predicates, Plan plan) { - Set expressions = Sets.newHashSet(predicates); - expressions.addAll(PredicatePropagation.infer(expressions)); - return expressions.stream() - .filter(p -> plan.getOutputSet().containsAll(p.getInputSlots())) - .collect(ImmutableSet.toImmutableSet()); + private ImmutableSet getAvailableExpressions(Set predicates, Plan plan) { + Set inferPredicates = PredicatePropagation.infer(predicates); + Builder newPredicates = ImmutableSet.builderWithExpectedSize(predicates.size() + 10); + Set outputSet = plan.getOutputSet(); + + for (Expression predicate : predicates) { + if (outputSet.containsAll(predicate.getInputSlots())) { + newPredicates.add(predicate); + } + } + + for (Expression inferPredicate : inferPredicates) { + if (outputSet.containsAll(inferPredicate.getInputSlots())) { + newPredicates.add(inferPredicate); + } + } + return newPredicates.build(); } private boolean hasAgg(Expression expression) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownFilterThroughAggregation.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownFilterThroughAggregation.java index f3a54fd8eeaa8b..798a41b37643dc 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownFilterThroughAggregation.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownFilterThroughAggregation.java @@ -29,7 +29,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.Sets; -import java.util.HashSet; +import java.util.LinkedHashSet; import java.util.Set; /** @@ -60,9 +60,9 @@ public Rule build() { LogicalAggregate aggregate = filter.child(); Set canPushDownSlots = getCanPushDownSlots(aggregate); - Set pushDownPredicates = Sets.newHashSet(); - Set filterPredicates = Sets.newHashSet(); - filter.getConjuncts().forEach(conjunct -> { + Set pushDownPredicates = Sets.newLinkedHashSet(); + Set filterPredicates = Sets.newLinkedHashSet(); + for (Expression conjunct : filter.getConjuncts()) { Set conjunctSlots = conjunct.getInputSlots(); // NOTICE: filter not contain slot should not be pushed. e.g. 'a' = 'b' if (!conjunctSlots.isEmpty() && canPushDownSlots.containsAll(conjunctSlots)) { @@ -70,7 +70,7 @@ public Rule build() { } else { filterPredicates.add(conjunct); } - }); + } if (pushDownPredicates.isEmpty()) { return null; } @@ -84,7 +84,7 @@ public Rule build() { * get the slots that can be pushed down */ public static Set getCanPushDownSlots(LogicalAggregate aggregate) { - Set canPushDownSlots = new HashSet<>(); + Set canPushDownSlots = new LinkedHashSet<>(); if (aggregate.getSourceRepeat().isPresent()) { // When there is a repeat, the push-down condition is consistent with the repeat aggregate.getSourceRepeat().get().getCommonGroupingSetExpressions().stream() diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownFilterThroughProject.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownFilterThroughProject.java index 77c90820a258c4..71834a66b19a2f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownFilterThroughProject.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownFilterThroughProject.java @@ -22,7 +22,6 @@ import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.Slot; -import org.apache.doris.nereids.trees.expressions.WindowExpression; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; import org.apache.doris.nereids.trees.plans.logical.LogicalLimit; @@ -49,16 +48,16 @@ public class PushDownFilterThroughProject implements RewriteRuleFactory { public List buildRules() { return ImmutableList.of( logicalFilter(logicalProject()) - .whenNot(filter -> filter.child().getProjects().stream().anyMatch( - expr -> expr.anyMatch(WindowExpression.class::isInstance))) + .whenNot(filter -> ExpressionUtils.containsWindowExpression(filter.child().getProjects())) .whenNot(filter -> filter.child().hasPushedDownToProjectionFunctions()) .then(PushDownFilterThroughProject::pushdownFilterThroughProject) .toRule(RuleType.PUSH_DOWN_FILTER_THROUGH_PROJECT), // filter(project(limit)) will change to filter(limit(project)) by PushdownProjectThroughLimit, // then we should change filter(limit(project)) to project(filter(limit)) logicalFilter(logicalLimit(logicalProject())) - .whenNot(filter -> filter.child().child().getProjects().stream() - .anyMatch(expr -> expr.anyMatch(WindowExpression.class::isInstance))) + .whenNot(filter -> + ExpressionUtils.containsWindowExpression(filter.child().child().getProjects()) + ) .whenNot(filter -> filter.child().child().hasPushedDownToProjectionFunctions()) .then(PushDownFilterThroughProject::pushdownFilterThroughLimitProject) .toRule(RuleType.PUSH_DOWN_FILTER_THROUGH_PROJECT_UNDER_LIMIT) @@ -111,14 +110,14 @@ private static Pair, Set> splitConjunctsByChildOutpu Set conjuncts, Set childOutputs) { Set pushDownPredicates = Sets.newLinkedHashSet(); Set remainPredicates = Sets.newLinkedHashSet(); - conjuncts.forEach(conjunct -> { + for (Expression conjunct : conjuncts) { Set conjunctSlots = conjunct.getInputSlots(); if (childOutputs.containsAll(conjunctSlots)) { pushDownPredicates.add(conjunct); } else { remainPredicates.add(conjunct); } - }); + } return Pair.of(remainPredicates, pushDownPredicates); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SimplifyAggGroupBy.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SimplifyAggGroupBy.java index cc0c7f12f33cbc..6dc446d88ca882 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SimplifyAggGroupBy.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SimplifyAggGroupBy.java @@ -19,16 +19,18 @@ import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.trees.TreeNode; import org.apache.doris.nereids.trees.expressions.BinaryArithmetic; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.util.ExpressionUtils; +import org.apache.doris.nereids.util.Utils; -import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import java.util.List; import java.util.Set; -import java.util.stream.Collectors; /** * Simplify Aggregate group by Multiple to One. For example @@ -41,20 +43,25 @@ public class SimplifyAggGroupBy extends OneRewriteRuleFactory { @Override public Rule build() { return logicalAggregate() - .when(agg -> agg.getGroupByExpressions().size() > 1) - .when(agg -> agg.getGroupByExpressions().stream().allMatch(this::isBinaryArithmeticSlot)) + .when(agg -> agg.getGroupByExpressions().size() > 1 + && ExpressionUtils.allMatch(agg.getGroupByExpressions(), this::isBinaryArithmeticSlot)) .then(agg -> { - Set slots = agg.getGroupByExpressions().stream() - .flatMap(e -> e.getInputSlots().stream()).collect(Collectors.toSet()); + List groupByExpressions = agg.getGroupByExpressions(); + ImmutableSet.Builder inputSlots + = ImmutableSet.builderWithExpectedSize(groupByExpressions.size()); + for (Expression groupByExpression : groupByExpressions) { + inputSlots.addAll(groupByExpression.getInputSlots()); + } + Set slots = inputSlots.build(); if (slots.size() != 1) { return null; } - return agg.withGroupByAndOutput(ImmutableList.copyOf(slots), agg.getOutputExpressions()); + return agg.withGroupByAndOutput(Utils.fastToImmutableList(slots), agg.getOutputExpressions()); }) .toRule(RuleType.SIMPLIFY_AGG_GROUP_BY); } - private boolean isBinaryArithmeticSlot(Expression expr) { + private boolean isBinaryArithmeticSlot(TreeNode expr) { if (expr instanceof Slot) { return true; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/AbstractSelectMaterializedIndexRule.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/AbstractSelectMaterializedIndexRule.java index 8ecabcd8918c3c..9b34a62ca02941 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/AbstractSelectMaterializedIndexRule.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/AbstractSelectMaterializedIndexRule.java @@ -372,7 +372,12 @@ private static int indexKeyPrefixMatchCount( } protected static boolean preAggEnabledByHint(LogicalOlapScan olapScan) { - return olapScan.getHints().stream().anyMatch("PREAGGOPEN"::equalsIgnoreCase); + for (String hint : olapScan.getHints()) { + if ("PREAGGOPEN".equalsIgnoreCase(hint)) { + return true; + } + } + return false; } public static String normalizeName(String name) { @@ -385,11 +390,11 @@ public static Expression slotToCaseWhen(Expression expression) { } protected SlotContext generateBaseScanExprToMvExpr(LogicalOlapScan mvPlan) { - Map baseSlotToMvSlot = new HashMap<>(); - Map mvNameToMvSlot = new HashMap<>(); if (mvPlan.getSelectedIndexId() == mvPlan.getTable().getBaseIndexId()) { - return new SlotContext(baseSlotToMvSlot, mvNameToMvSlot, new TreeSet()); + return SlotContext.EMPTY; } + Map baseSlotToMvSlot = new HashMap<>(); + Map mvNameToMvSlot = new HashMap<>(); for (Slot mvSlot : mvPlan.getOutputByIndex(mvPlan.getSelectedIndexId())) { boolean isPushed = false; for (Slot baseSlot : mvPlan.getOutput()) { @@ -428,6 +433,8 @@ protected SlotContext generateBaseScanExprToMvExpr(LogicalOlapScan mvPlan) { /** SlotContext */ protected static class SlotContext { + public static final SlotContext EMPTY + = new SlotContext(ImmutableMap.of(), ImmutableMap.of(), ImmutableSet.of()); // base index Slot to selected mv Slot public final Map baseSlotToMvSlot; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/SelectMaterializedIndexWithAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/SelectMaterializedIndexWithAggregate.java index ea09b25ba6f493..135cdb2a95700f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/SelectMaterializedIndexWithAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/SelectMaterializedIndexWithAggregate.java @@ -73,8 +73,10 @@ import org.apache.doris.nereids.types.DataType; import org.apache.doris.nereids.types.VarcharType; import org.apache.doris.nereids.util.ExpressionUtils; +import org.apache.doris.nereids.util.Utils; import org.apache.doris.planner.PlanNode; +import com.google.common.base.Suppliers; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; @@ -251,9 +253,10 @@ public List buildRules() { if (result.indexId == scan.getTable().getBaseIndexId()) { LogicalOlapScan mvPlanWithoutAgg = SelectMaterializedIndexWithoutAggregate.select(scan, project::getInputSlots, filter::getConjuncts, - Stream.concat(filter.getExpressions().stream(), - project.getExpressions().stream()) - .collect(ImmutableSet.toImmutableSet())); + Suppliers.memoize(() -> Utils.concatToSet( + filter.getExpressions(), project.getExpressions() + )) + ); SlotContext slotContextWithoutAgg = generateBaseScanExprToMvExpr(mvPlanWithoutAgg); return agg.withChildren(new LogicalProject( diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/SelectMaterializedIndexWithoutAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/SelectMaterializedIndexWithoutAggregate.java index 7960dd73df9a8c..e05a1eda3e63fb 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/SelectMaterializedIndexWithoutAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/SelectMaterializedIndexWithoutAggregate.java @@ -32,7 +32,9 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import org.apache.doris.nereids.util.ExpressionUtils; +import org.apache.doris.nereids.util.Utils; +import com.google.common.base.Suppliers; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; @@ -42,7 +44,6 @@ import java.util.Set; import java.util.function.Supplier; import java.util.stream.Collectors; -import java.util.stream.Stream; /** * Select materialized index, i.e., both for rollup and materialized view when aggregate is not present. @@ -70,11 +71,13 @@ public List buildRules() { LogicalOlapScan mvPlan = select( scan, project::getInputSlots, filter::getConjuncts, - Stream.concat(filter.getExpressions().stream(), - project.getExpressions().stream()).collect(ImmutableSet.toImmutableSet())); + Suppliers.memoize(() -> + Utils.concatToSet(filter.getExpressions(), project.getExpressions()) + ) + ); SlotContext slotContext = generateBaseScanExprToMvExpr(mvPlan); - return new LogicalProject( + return new LogicalProject<>( generateProjectsAlias(project.getOutput(), slotContext), new ReplaceExpressions(slotContext).replace( project.withChildren(filter.withChildren(mvPlan)), mvPlan)); @@ -90,7 +93,7 @@ public List buildRules() { LogicalOlapScan mvPlan = select( scan, project::getInputSlots, ImmutableSet::of, - new HashSet<>(project.getExpressions())); + () -> Utils.fastToImmutableSet(project.getExpressions())); SlotContext slotContext = generateBaseScanExprToMvExpr(mvPlan); return new LogicalProject( @@ -107,8 +110,10 @@ public List buildRules() { LogicalOlapScan scan = filter.child(); LogicalOlapScan mvPlan = select( scan, filter::getOutputSet, filter::getConjuncts, - Stream.concat(filter.getExpressions().stream(), - filter.getOutputSet().stream()).collect(ImmutableSet.toImmutableSet())); + Suppliers.memoize(() -> + Utils.concatToSet(filter.getExpressions(), filter.getOutputSet()) + ) + ); SlotContext slotContext = generateBaseScanExprToMvExpr(mvPlan); return new LogicalProject( @@ -127,7 +132,8 @@ public List buildRules() { LogicalOlapScan mvPlan = select( scan, project::getInputSlots, ImmutableSet::of, - new HashSet<>(project.getExpressions())); + Suppliers.memoize(() -> Utils.fastToImmutableSet(project.getExpressions())) + ); SlotContext slotContext = generateBaseScanExprToMvExpr(mvPlan); return new LogicalProject( @@ -145,7 +151,7 @@ public List buildRules() { LogicalOlapScan mvPlan = select( scan, scan::getOutputSet, ImmutableSet::of, - scan.getOutputSet()); + () -> (Set) scan.getOutputSet()); SlotContext slotContext = generateBaseScanExprToMvExpr(mvPlan); return new LogicalProject( @@ -169,7 +175,7 @@ public static LogicalOlapScan select( LogicalOlapScan scan, Supplier> requiredScanOutputSupplier, Supplier> predicatesSupplier, - Set requiredExpr) { + Supplier> requiredExpr) { OlapTable table = scan.getTable(); long baseIndexId = table.getBaseIndexId(); KeysType keysType = scan.getTable().getKeysType(); @@ -186,21 +192,24 @@ public static LogicalOlapScan select( throw new RuntimeException("Not supported keys type: " + keysType); } - Set requiredSlots = new HashSet<>(); - requiredSlots.addAll(requiredScanOutputSupplier.get()); - requiredSlots.addAll(ExpressionUtils.getInputSlotSet(requiredExpr)); - requiredSlots.addAll(ExpressionUtils.getInputSlotSet(predicatesSupplier.get())); + Supplier> requiredSlots = Suppliers.memoize(() -> { + Set set = new HashSet<>(); + set.addAll(requiredScanOutputSupplier.get()); + set.addAll(ExpressionUtils.getInputSlotSet(requiredExpr.get())); + set.addAll(ExpressionUtils.getInputSlotSet(predicatesSupplier.get())); + return set; + }); if (scan.getTable().isDupKeysOrMergeOnWrite()) { // Set pre-aggregation to `on` to keep consistency with legacy logic. List candidates = scan .getTable().getVisibleIndex().stream().filter(index -> index.getId() != baseIndexId) .filter(index -> !indexHasAggregate(index, scan)).filter(index -> containAllRequiredColumns(index, - scan, requiredScanOutputSupplier.get(), requiredExpr, predicatesSupplier.get())) + scan, requiredScanOutputSupplier.get(), requiredExpr.get(), predicatesSupplier.get())) .collect(Collectors.toList()); long bestIndex = selectBestIndex(candidates, scan, predicatesSupplier.get()); // this is fail-safe for select mv // select baseIndex if bestIndex's slots' data types are different from baseIndex - bestIndex = isSameDataType(scan, bestIndex, requiredSlots) ? bestIndex : baseIndexId; + bestIndex = isSameDataType(scan, bestIndex, requiredSlots.get()) ? bestIndex : baseIndexId; return scan.withMaterializedIndexSelected(PreAggStatus.on(), bestIndex); } else { final PreAggStatus preAggStatus; @@ -221,7 +230,7 @@ public static LogicalOlapScan select( List candidates = table.getVisibleIndex().stream() .filter(index -> table.getKeyColumnsByIndexId(index.getId()).size() == baseIndexKeySize) .filter(index -> containAllRequiredColumns(index, scan, requiredScanOutputSupplier.get(), - requiredExpr, predicatesSupplier.get())) + requiredExpr.get(), predicatesSupplier.get())) .collect(Collectors.toList()); if (candidates.size() == 1) { @@ -231,7 +240,7 @@ public static LogicalOlapScan select( long bestIndex = selectBestIndex(candidates, scan, predicatesSupplier.get()); // this is fail-safe for select mv // select baseIndex if bestIndex's slots' data types are different from baseIndex - bestIndex = isSameDataType(scan, bestIndex, requiredSlots) ? bestIndex : baseIndexId; + bestIndex = isSameDataType(scan, bestIndex, requiredSlots.get()) ? bestIndex : baseIndexId; return scan.withMaterializedIndexSelected(preAggStatus, bestIndex); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsCalculator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsCalculator.java index 2ed1afc56772c2..57a79037d801d3 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsCalculator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsCalculator.java @@ -132,6 +132,7 @@ import org.apache.doris.statistics.StatisticsBuilder; import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableSet; import com.google.common.collect.Lists; import com.google.common.collect.Maps; import org.apache.commons.collections.CollectionUtils; @@ -753,8 +754,14 @@ private ColumnStatistic getColumnStatistic(TableIf table, String colName, long i // 2. Consider the influence of runtime filter // 3. Get NDV and column data size from StatisticManger, StatisticManager doesn't support it now. private Statistics computeCatalogRelation(CatalogRelation catalogRelation) { - Set slotSet = catalogRelation.getOutput().stream().filter(SlotReference.class::isInstance) - .map(s -> (SlotReference) s).collect(Collectors.toSet()); + List output = catalogRelation.getOutput(); + ImmutableSet.Builder slotSetBuilder = ImmutableSet.builderWithExpectedSize(output.size()); + for (Slot slot : output) { + if (slot instanceof SlotReference) { + slotSetBuilder.add((SlotReference) slot); + } + } + Set slotSet = slotSetBuilder.build(); Map columnStatisticMap = new HashMap<>(); TableIf table = catalogRelation.getTable(); double rowCount = catalogRelation.getTable().getRowCountForNereids(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/AbstractTreeNode.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/AbstractTreeNode.java index 59d2acbe22bd6b..92bbcdb9b38fd0 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/AbstractTreeNode.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/AbstractTreeNode.java @@ -17,9 +17,12 @@ package org.apache.doris.nereids.trees; +import org.apache.doris.nereids.util.MutableState; +import org.apache.doris.nereids.util.MutableState.EmptyMutableState; import org.apache.doris.nereids.util.Utils; import java.util.List; +import java.util.Optional; /** * Abstract class for plan node in Nereids, include plan node and expression. @@ -30,8 +33,13 @@ public abstract class AbstractTreeNode> implements TreeNode { protected final List children; - // TODO: Maybe we should use a GroupPlan to avoid TreeNode hold the GroupExpression. - // https://github.com/apache/doris/pull/9807#discussion_r884829067 + + // this field is special, because other fields in tree node is immutable, but in some scenes, mutable + // state is necessary. e.g. the rewrite framework need distinguish whether the plan is created by + // rules, the framework can set this field to a state variable to quickly judge without new big plan. + // we should avoid using it as much as possible, because mutable state is easy to cause bugs and + // difficult to locate. + private MutableState mutableState = EmptyMutableState.INSTANCE; protected AbstractTreeNode(NODE_TYPE... children) { // NOTE: ImmutableList.copyOf has additional clone of the list, so here we @@ -55,6 +63,16 @@ public List children() { return children; } + @Override + public Optional getMutableState(String key) { + return mutableState.get(key); + } + + @Override + public void setMutableState(String key, Object state) { + this.mutableState = this.mutableState.set(key, state); + } + public int arity() { return children.size(); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java index d37070865e22eb..6d1a298eb79fe2 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java @@ -28,11 +28,13 @@ import java.util.Deque; import java.util.LinkedList; import java.util.List; +import java.util.Optional; import java.util.Set; import java.util.function.BiFunction; import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Predicate; +import java.util.function.Supplier; /** * interface for all node in Nereids, include plan node and expression. @@ -48,6 +50,21 @@ public interface TreeNode> { int arity(); + Optional getMutableState(String key); + + /** getOrInitMutableState */ + default T getOrInitMutableState(String key, Supplier initState) { + Optional mutableState = getMutableState(key); + if (!mutableState.isPresent()) { + T state = initState.get(); + setMutableState(key, state); + return state; + } + return mutableState.get(); + } + + void setMutableState(String key, Object value); + default NODE_TYPE withChildren(NODE_TYPE... children) { return withChildren(Utils.fastToImmutableList(children)); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/BinaryOperator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/BinaryOperator.java index 01a61d576d25aa..750f3a77881430 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/BinaryOperator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/BinaryOperator.java @@ -24,7 +24,6 @@ import com.google.common.collect.ImmutableList; import java.util.List; -import java.util.Objects; /** * Abstract for all binary operator, include binary arithmetic, compound predicate, comparison predicate. @@ -63,9 +62,4 @@ public String toString() { public String shapeInfo() { return "(" + left().shapeInfo() + " " + symbol + " " + right().shapeInfo() + ")"; } - - @Override - public int hashCode() { - return Objects.hash(symbol, left(), right()); - } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ComparisonPredicate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ComparisonPredicate.java index c9d10bde36d3c1..d343f6f93566cd 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ComparisonPredicate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ComparisonPredicate.java @@ -61,10 +61,10 @@ public DataType inputType() { @Override public void checkLegalityBeforeTypeCoercion() { - children().forEach(c -> { + for (Expression c : children) { if (c.getDataType().isComplexType() && !c.getDataType().isArrayType()) { throw new AnalysisException("comparison predicate could not contains complex type: " + this.toSql()); } - }); + } } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java index a7947c82a565bf..75cef0fc94677b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java @@ -39,6 +39,7 @@ import org.apache.doris.nereids.types.MapType; import org.apache.doris.nereids.types.StructField; import org.apache.doris.nereids.types.StructType; +import org.apache.doris.nereids.util.ExpressionUtils; import org.apache.doris.nereids.util.Utils; import com.google.common.base.Preconditions; @@ -70,20 +71,43 @@ public abstract class Expression extends AbstractTreeNode implements protected Expression(Expression... children) { super(children); - int maxChildDepth = 0; - int sumChildWidth = 0; + boolean hasUnbound = false; - boolean compareWidthAndDepth = true; - for (int i = 0; i < children.length; ++i) { - Expression child = children[i]; - maxChildDepth = Math.max(child.depth, maxChildDepth); - sumChildWidth += child.width; - hasUnbound |= child.hasUnbound; - compareWidthAndDepth &= (child.compareWidthAndDepth & child.supportCompareWidthAndDepth()); + switch (children.length) { + case 0: + this.depth = 1; + this.width = 1; + this.compareWidthAndDepth = supportCompareWidthAndDepth(); + break; + case 1: + Expression child = children[0]; + this.depth = child.depth + 1; + this.width = child.width; + this.compareWidthAndDepth = child.compareWidthAndDepth && supportCompareWidthAndDepth(); + break; + case 2: + Expression left = children[0]; + Expression right = children[1]; + this.depth = Math.max(left.depth, right.depth) + 1; + this.width = left.width + right.width; + this.compareWidthAndDepth = + left.compareWidthAndDepth && right.compareWidthAndDepth && supportCompareWidthAndDepth(); + break; + default: + int maxChildDepth = 0; + int sumChildWidth = 0; + boolean compareWidthAndDepth = true; + for (Expression expression : children) { + child = expression; + maxChildDepth = Math.max(child.depth, maxChildDepth); + sumChildWidth += child.width; + hasUnbound |= child.hasUnbound; + compareWidthAndDepth &= child.compareWidthAndDepth; + } + this.depth = maxChildDepth + 1; + this.width = sumChildWidth; + this.compareWidthAndDepth = compareWidthAndDepth; } - this.depth = maxChildDepth + 1; - this.width = sumChildWidth + ((children.length == 0) ? 1 : 0); - this.compareWidthAndDepth = compareWidthAndDepth; checkLimit(); this.inferred = false; @@ -96,20 +120,43 @@ protected Expression(List children) { protected Expression(List children, boolean inferred) { super(children); - int maxChildDepth = 0; - int sumChildWidth = 0; + boolean hasUnbound = false; - boolean compareWidthAndDepth = true; - for (int i = 0; i < children.size(); ++i) { - Expression child = children.get(i); - maxChildDepth = Math.max(child.depth, maxChildDepth); - sumChildWidth += child.width; - hasUnbound |= child.hasUnbound; - compareWidthAndDepth &= (child.compareWidthAndDepth & child.supportCompareWidthAndDepth()); + switch (children.size()) { + case 0: + this.depth = 1; + this.width = 1; + this.compareWidthAndDepth = supportCompareWidthAndDepth(); + break; + case 1: + Expression child = children.get(0); + this.depth = child.depth + 1; + this.width = child.width; + this.compareWidthAndDepth = child.compareWidthAndDepth && supportCompareWidthAndDepth(); + break; + case 2: + Expression left = children.get(0); + Expression right = children.get(1); + this.depth = Math.max(left.depth, right.depth) + 1; + this.width = left.width + right.width; + this.compareWidthAndDepth = + left.compareWidthAndDepth && right.compareWidthAndDepth && supportCompareWidthAndDepth(); + break; + default: + int maxChildDepth = 0; + int sumChildWidth = 0; + boolean compareWidthAndDepth = true; + for (Expression expression : children) { + child = expression; + maxChildDepth = Math.max(child.depth, maxChildDepth); + sumChildWidth += child.width; + hasUnbound |= child.hasUnbound; + compareWidthAndDepth &= child.compareWidthAndDepth; + } + this.depth = maxChildDepth + 1; + this.width = sumChildWidth; + this.compareWidthAndDepth = compareWidthAndDepth && supportCompareWidthAndDepth(); } - this.depth = maxChildDepth + 1; - this.width = sumChildWidth + ((children.isEmpty()) ? 1 : 0); - this.compareWidthAndDepth = compareWidthAndDepth; checkLimit(); this.inferred = inferred; @@ -284,7 +331,7 @@ public boolean isConstant() { if (this instanceof LeafExpression) { return this instanceof Literal; } else { - return !(this instanceof Nondeterministic) && children().stream().allMatch(Expression::isConstant); + return !(this instanceof Nondeterministic) && ExpressionUtils.allMatch(children(), Expression::isConstant); } } @@ -376,7 +423,7 @@ protected boolean extraEquals(Expression that) { @Override public int hashCode() { - return 0; + return getClass().hashCode(); } /** diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/InPredicate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/InPredicate.java index bcebdca4f5b8a5..53a753c4535dd1 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/InPredicate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/InPredicate.java @@ -28,6 +28,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList.Builder; +import java.util.Collection; import java.util.List; import java.util.Objects; import java.util.stream.Collectors; @@ -42,13 +43,13 @@ public class InPredicate extends Expression { private final Expression compareExpr; private final List options; - public InPredicate(Expression compareExpr, List options) { + public InPredicate(Expression compareExpr, Collection options) { super(new Builder().add(compareExpr).addAll(options).build()); this.compareExpr = Objects.requireNonNull(compareExpr, "Compare Expr cannot be null"); this.options = ImmutableList.copyOf(Objects.requireNonNull(options, "In list cannot be null")); } - public InPredicate(Expression compareExpr, List options, boolean inferred) { + public InPredicate(Expression compareExpr, Collection options, boolean inferred) { super(new Builder().add(compareExpr).addAll(options).build(), inferred); this.compareExpr = Objects.requireNonNull(compareExpr, "Compare Expr cannot be null"); this.options = ImmutableList.copyOf(Objects.requireNonNull(options, "In list cannot be null")); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/SlotReference.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/SlotReference.java index 7cfaad72a2c546..28cb20ea1cdfa7 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/SlotReference.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/SlotReference.java @@ -132,8 +132,8 @@ public static SlotReference of(String name, DataType type) { */ public static SlotReference fromColumn(TableIf table, Column column, List qualifier, Relation relation) { DataType dataType = DataType.fromCatalogType(column.getType()); - SlotReference slot = new SlotReference(StatementScopeIdGenerator.newExprId(), column.getName(), dataType, - column.isAllowNull(), qualifier, table, column, Optional.empty(), null); + SlotReference slot = new SlotReference(StatementScopeIdGenerator.newExprId(), () -> column.getName(), dataType, + column.isAllowNull(), qualifier, table, column, () -> Optional.of(column.getName()), null); if (relation != null && ConnectContext.get() != null && ConnectContext.get().getStatementContext() != null) { ConnectContext.get().getStatementContext().addSlotToRelation(slot, relation); @@ -260,6 +260,9 @@ public SlotReference withQualifier(List qualifier) { @Override public SlotReference withName(String name) { + if (this.name.get().equalsIgnoreCase(name)) { + return this; + } return new SlotReference( exprId, () -> name, dataType, nullable, qualifier, table, column, internalName, subColPath); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ComputeSignatureHelper.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ComputeSignatureHelper.java index 6d0a5d85de5557..2cdbe43c12ecb5 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ComputeSignatureHelper.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ComputeSignatureHelper.java @@ -38,6 +38,8 @@ import org.apache.doris.nereids.util.TypeCoercionUtils; import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableList.Builder; import com.google.common.collect.Lists; import com.google.common.collect.Maps; import com.google.common.collect.Sets; @@ -413,9 +415,12 @@ private static FunctionSignature defaultDateTimeV2PrecisionPromotion( return signature; } DateTimeV2Type argType = finalType; - List newArgTypes = signature.argumentsTypes.stream() - .map(at -> TypeCoercionUtils.replaceDateTimeV2WithTarget(at, argType)) - .collect(Collectors.toList()); + + ImmutableList.Builder newArgTypesBuilder = ImmutableList.builderWithExpectedSize(signature.arity); + for (DataType at : signature.argumentsTypes) { + newArgTypesBuilder.add(TypeCoercionUtils.replaceDateTimeV2WithTarget(at, argType)); + } + List newArgTypes = newArgTypesBuilder.build(); signature = signature.withArgumentTypes(signature.hasVarArgs, newArgTypes); signature = signature.withArgumentTypes(signature.hasVarArgs, newArgTypes); if (signature.returnType instanceof DateTimeV2Type) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateFunction.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateFunction.java index 4f53b383d244eb..e45d3fb4da8b0b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateFunction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateFunction.java @@ -108,11 +108,18 @@ public boolean hasVarArguments() { @Override public String toSql() throws UnboundException { - String args = children() - .stream() - .map(Expression::toSql) - .collect(Collectors.joining(", ")); - return getName() + "(" + (distinct ? "DISTINCT " : "") + args + ")"; + StringBuilder sql = new StringBuilder(getName()).append("("); + if (distinct) { + sql.append("DISTINCT "); + } + int arity = arity(); + for (int i = 0; i < arity; i++) { + sql.append(child(i).toSql()); + if (i + 1 < arity) { + sql.append(", "); + } + } + return sql.append(")").toString(); } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/PushDownToProjectionFunction.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/PushDownToProjectionFunction.java index 81678153cd6206..d8e3642c36ffdc 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/PushDownToProjectionFunction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/PushDownToProjectionFunction.java @@ -43,10 +43,9 @@ public PushDownToProjectionFunction(String name, Expression... arguments) { */ public static boolean validToPushDown(Expression pushDownExpr) { // Currently only element at for variant type could be pushed down - return pushDownExpr != null && !pushDownExpr.collectToList( - PushDownToProjectionFunction.class::isInstance).stream().filter( - x -> ((Expression) x).getDataType().isVariantType()).collect( - Collectors.toList()).isEmpty(); + return pushDownExpr != null && pushDownExpr.anyMatch(expr -> + expr instanceof PushDownToProjectionFunction && ((Expression) expr).getDataType().isVariantType() + ); } /** diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/DateLiteral.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/DateLiteral.java index a33cc32c16f2a6..38951ea9e453b4 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/DateLiteral.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/DateLiteral.java @@ -31,6 +31,7 @@ import com.google.common.collect.ImmutableSet; +import java.time.LocalDate; import java.time.LocalDateTime; import java.time.Year; import java.time.temporal.ChronoField; @@ -158,7 +159,9 @@ private static void replacePunctuation(String s, StringBuilder sb, char c, int i static String normalize(String s) { // merge consecutive space - s = s.replaceAll(" +", " "); + if (s.contains(" ")) { + s = s.replaceAll(" +", " "); + } StringBuilder sb = new StringBuilder(); @@ -261,6 +264,14 @@ static String normalize(String s) { } protected static TemporalAccessor parse(String s) { + // fast parse '2022-01-01' + if (s.length() == 10 && s.charAt(4) == '-' && s.charAt(7) == '-') { + TemporalAccessor date = fastParseDate(s); + if (date != null) { + return date; + } + } + String originalString = s; try { TemporalAccessor dateTime; @@ -477,4 +488,30 @@ public DateTimeLiteral toBeginOfTomorrow() { return toEndOfTheDay(); } } + + private static TemporalAccessor fastParseDate(String date) { + Integer year = readNextInt(date, 0, 4); + Integer month = readNextInt(date, 5, 2); + Integer day = readNextInt(date, 8, 2); + if (year != null && month != null && day != null) { + return LocalDate.of(year, month, day); + } else { + return null; + } + } + + private static Integer readNextInt(String str, int offset, int readLength) { + int value = 0; + int realReadLength = 0; + for (int i = offset; i < str.length(); i++) { + char c = str.charAt(i); + if ('0' <= c && c <= '9') { + realReadLength++; + value = value * 10 + (c - '0'); + } else { + break; + } + } + return readLength == realReadLength ? value : null; + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/DefaultExpressionRewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/DefaultExpressionRewriter.java index 2248666dbca12f..fd25f9368ef0b5 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/DefaultExpressionRewriter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/DefaultExpressionRewriter.java @@ -34,13 +34,13 @@ public Expression visit(Expression expr, C context) { } /** rewriteChildren */ - public static final Expression rewriteChildren( - ExpressionVisitor rewriter, Expression expr, C context) { + public static final E rewriteChildren( + ExpressionVisitor rewriter, E expr, C context) { switch (expr.arity()) { case 1: { Expression originChild = expr.child(0); Expression newChild = originChild.accept(rewriter, context); - return (originChild != newChild) ? expr.withChildren(ImmutableList.of(newChild)) : expr; + return (originChild != newChild) ? (E) expr.withChildren(ImmutableList.of(newChild)) : expr; } case 2: { Expression originLeft = expr.child(0); @@ -48,7 +48,7 @@ public static final Expression rewriteChildren( Expression originRight = expr.child(1); Expression newRight = originRight.accept(rewriter, context); return (originLeft != newLeft || originRight != newRight) - ? expr.withChildren(ImmutableList.of(newLeft, newRight)) + ? (E) expr.withChildren(ImmutableList.of(newLeft, newRight)) : expr; } case 0: { @@ -64,7 +64,7 @@ public static final Expression rewriteChildren( } newChildren.add(newChild); } - return hasNewChildren ? expr.withChildren(newChildren.build()) : expr; + return hasNewChildren ? (E) expr.withChildren(newChildren.build()) : expr; } } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/AbstractPlan.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/AbstractPlan.java index 4be6d35dc94692..286a92aab768f1 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/AbstractPlan.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/AbstractPlan.java @@ -28,7 +28,6 @@ import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator; import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; import org.apache.doris.nereids.util.MutableState; -import org.apache.doris.nereids.util.MutableState.EmptyMutableState; import org.apache.doris.nereids.util.TreeStringUtils; import org.apache.doris.statistics.Statistics; @@ -58,13 +57,6 @@ public abstract class AbstractPlan extends AbstractTreeNode implements Pla protected final Optional groupExpression; protected final Supplier logicalPropertiesSupplier; - // this field is special, because other fields in tree node is immutable, but in some scenes, mutable - // state is necessary. e.g. the rewrite framework need distinguish whether the plan is created by - // rules, the framework can set this field to a state variable to quickly judge without new big plan. - // we should avoid using it as much as possible, because mutable state is easy to cause bugs and - // difficult to locate. - private MutableState mutableState = EmptyMutableState.INSTANCE; - /** * all parameter constructor. */ @@ -108,7 +100,15 @@ public Statistics getStats() { @Override public boolean canBind() { - return !bound() && children().stream().allMatch(Plan::bound); + if (bound()) { + return false; + } + for (Plan child : children()) { + if (!child.bound()) { + return false; + } + } + return true; } /** @@ -185,16 +185,6 @@ public LogicalProperties computeLogicalProperties() { } } - @Override - public Optional getMutableState(String key) { - return mutableState.get(key); - } - - @Override - public void setMutableState(String key, Object state) { - this.mutableState = this.mutableState.set(key, state); - } - public int getId() { return id.asInt(); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/Plan.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/Plan.java index 1b237c72fdc207..d73b7390ce8d59 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/Plan.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/Plan.java @@ -23,10 +23,10 @@ import org.apache.doris.nereids.trees.TreeNode; import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor; import org.apache.doris.nereids.util.MutableState; +import org.apache.doris.nereids.util.PlanUtils; import org.apache.doris.qe.ConnectContext; import com.google.common.collect.ImmutableList; @@ -36,8 +36,6 @@ import java.util.List; import java.util.Optional; import java.util.Set; -import java.util.function.Supplier; -import java.util.stream.Collectors; /** * Abstract class for all plan node. @@ -104,12 +102,46 @@ default Set getOutputSet() { return ImmutableSet.copyOf(getOutput()); } + /** getOutputExprIds */ default List getOutputExprIds() { - return getOutput().stream().map(NamedExpression::getExprId).collect(Collectors.toList()); + List output = getOutput(); + ImmutableList.Builder exprIds = ImmutableList.builderWithExpectedSize(output.size()); + for (Slot slot : output) { + exprIds.add(slot.getExprId()); + } + return exprIds.build(); } + /** getOutputExprIdSet */ default Set getOutputExprIdSet() { - return getOutput().stream().map(NamedExpression::getExprId).collect(Collectors.toSet()); + List output = getOutput(); + ImmutableSet.Builder exprIds = ImmutableSet.builderWithExpectedSize(output.size()); + for (Slot slot : output) { + exprIds.add(slot.getExprId()); + } + return exprIds.build(); + } + + /** getChildrenOutputExprIdSet */ + default Set getChildrenOutputExprIdSet() { + switch (arity()) { + case 0: return ImmutableSet.of(); + case 1: return child(0).getOutputExprIdSet(); + default: { + int exprIdSize = 0; + for (Plan child : children()) { + exprIdSize += child.getOutput().size(); + } + + ImmutableSet.Builder exprIds = ImmutableSet.builderWithExpectedSize(exprIdSize); + for (Plan child : children()) { + for (Slot slot : child.getOutput()) { + exprIds.add(slot.getExprId()); + } + } + return exprIds.build(); + } + } } /** @@ -119,9 +151,7 @@ default Set getOutputExprIdSet() { * Note that the input slots of subquery's inner plan are not included. */ default Set getInputSlots() { - return getExpressions().stream() - .flatMap(expr -> expr.getInputSlots().stream()) - .collect(ImmutableSet.toImmutableSet()); + return PlanUtils.fastGetInputSlots(this.getExpressions()); } default List computeOutput() { @@ -147,21 +177,6 @@ default Set getInputRelations() { Plan withGroupExprLogicalPropChildren(Optional groupExpression, Optional logicalProperties, List children); - Optional getMutableState(String key); - - /** getOrInitMutableState */ - default T getOrInitMutableState(String key, Supplier initState) { - Optional mutableState = getMutableState(key); - if (!mutableState.isPresent()) { - T state = initState.get(); - setMutableState(key, state); - return state; - } - return mutableState.get(); - } - - void setMutableState(String key, Object value); - /** * a simple version of explain, used to verify plan shape * @param prefix " " diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Aggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Aggregate.java index 15fd5bec868eeb..e7d09b8cf8b9ce 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Aggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Aggregate.java @@ -53,10 +53,19 @@ default Set getAggregateFunctions() { return ExpressionUtils.collect(getOutputExpressions(), AggregateFunction.class::isInstance); } + /** getDistinctArguments */ default Set getDistinctArguments() { - return getAggregateFunctions().stream() - .filter(AggregateFunction::isDistinct) - .flatMap(aggregateFunction -> aggregateFunction.getDistinctArguments().stream()) - .collect(ImmutableSet.toImmutableSet()); + ImmutableSet.Builder distinctArguments = ImmutableSet.builder(); + for (NamedExpression outputExpression : getOutputExpressions()) { + outputExpression.foreach(expr -> { + if (expr instanceof AggregateFunction) { + AggregateFunction aggFun = (AggregateFunction) expr; + if (aggFun.isDistinct()) { + distinctArguments.addAll(aggFun.getDistinctArguments()); + } + } + }); + } + return distinctArguments.build(); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Project.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Project.java index b734bba576df26..7fa62f7628fc2d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Project.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Project.java @@ -74,15 +74,24 @@ default List mergeProjections(Project childProject) { * And check if contains PushDownToProjectionFunction that can pushed down to project */ default boolean hasPushedDownToProjectionFunctions() { - return ConnectContext.get() != null - && ConnectContext.get().getSessionVariable() != null - && ConnectContext.get().getSessionVariable().isEnableRewriteElementAtToSlot() - && getProjects().stream().allMatch(namedExpr -> - namedExpr instanceof SlotReference - || (namedExpr instanceof Alias - && PushDownToProjectionFunction.validToPushDown(((Alias) namedExpr).child()))) - && getProjects().stream().anyMatch((namedExpr -> namedExpr instanceof Alias - && PushDownToProjectionFunction.validToPushDown(((Alias) namedExpr).child()))); + if ((ConnectContext.get() == null + || ConnectContext.get().getSessionVariable() == null + || !ConnectContext.get().getSessionVariable().isEnableRewriteElementAtToSlot())) { + return false; + } + + boolean hasValidAlias = false; + for (NamedExpression namedExpr : getProjects()) { + if (namedExpr instanceof Alias) { + if (!PushDownToProjectionFunction.validToPushDown(((Alias) namedExpr).child())) { + return false; + } + hasValidAlias = true; + } else if (!(namedExpr instanceof SlotReference)) { + return false; + } + } + return hasValidAlias; } /** diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java index 20647a3808ebe7..fa4f891e7a20b3 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java @@ -200,9 +200,11 @@ public String toString() { @Override public List computeOutput() { - return outputExpressions.stream() - .map(NamedExpression::toSlot) - .collect(ImmutableList.toImmutableList()); + ImmutableList.Builder outputSlots = ImmutableList.builderWithExpectedSize(outputExpressions.size()); + for (NamedExpression outputExpression : outputExpressions) { + outputSlots.add(outputExpression.toSlot()); + } + return outputSlots.build(); } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalCatalogRelation.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalCatalogRelation.java index 4076e8348e2208..b4dbc9444604da 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalCatalogRelation.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalCatalogRelation.java @@ -22,6 +22,8 @@ import org.apache.doris.catalog.Env; import org.apache.doris.catalog.OlapTable; import org.apache.doris.catalog.TableIf; +import org.apache.doris.catalog.constraint.PrimaryKeyConstraint; +import org.apache.doris.catalog.constraint.UniqueConstraint; import org.apache.doris.datasource.CatalogIf; import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.memo.GroupExpression; @@ -31,7 +33,6 @@ import org.apache.doris.nereids.properties.FunctionalDependencies.Builder; import org.apache.doris.nereids.properties.LogicalProperties; import org.apache.doris.nereids.properties.TableFdItem; -import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.plans.PlanType; @@ -128,85 +129,80 @@ public String qualifiedName() { @Override public FunctionalDependencies computeFuncDeps(Supplier> outputSupplier) { Builder fdBuilder = new Builder(); - Set output = ImmutableSet.copyOf(outputSupplier.get()); + Set outputSet = Utils.fastToImmutableSet(outputSupplier.get()); if (table instanceof OlapTable && ((OlapTable) table).getKeysType().isAggregationFamily()) { - ImmutableSet slotSet = output.stream() - .filter(SlotReference.class::isInstance) - .map(SlotReference.class::cast) - .filter(s -> s.getColumn().isPresent() - && s.getColumn().get().isKey()) - .collect(ImmutableSet.toImmutableSet()); - fdBuilder.addUniqueSlot(slotSet); + ImmutableSet.Builder uniqSlots = ImmutableSet.builderWithExpectedSize(outputSet.size()); + for (Slot slot : outputSet) { + if (!(slot instanceof SlotReference)) { + continue; + } + SlotReference slotRef = (SlotReference) slot; + if (slotRef.getColumn().isPresent() && slotRef.getColumn().get().isKey()) { + uniqSlots.add(slot); + } + } + fdBuilder.addUniqueSlot(uniqSlots.build()); } - table.getPrimaryKeyConstraints().forEach(c -> { - Set columns = c.getPrimaryKeys(this.getTable()); - ImmutableSet slotSet = output.stream() - .filter(SlotReference.class::isInstance) - .map(SlotReference.class::cast) - .filter(s -> s.getColumn().isPresent() - && columns.contains(s.getColumn().get())) - .collect(ImmutableSet.toImmutableSet()); - fdBuilder.addUniqueSlot(slotSet); - }); - table.getUniqueConstraints().forEach(c -> { - Set columns = c.getUniqueKeys(this.getTable()); - ImmutableSet slotSet = output.stream() - .filter(SlotReference.class::isInstance) - .map(SlotReference.class::cast) - .filter(s -> s.getColumn().isPresent() - && columns.contains(s.getColumn().get())) - .collect(ImmutableSet.toImmutableSet()); - fdBuilder.addUniqueSlot(slotSet); - }); - ImmutableSet fdItems = computeFdItems(outputSupplier); - fdBuilder.addFdItems(fdItems); + + for (PrimaryKeyConstraint c : table.getPrimaryKeyConstraints()) { + Set columns = c.getPrimaryKeys(table); + fdBuilder.addUniqueSlot((ImmutableSet) findSlotsByColumn(outputSet, columns)); + } + + for (UniqueConstraint c : table.getUniqueConstraints()) { + Set columns = c.getUniqueKeys(table); + fdBuilder.addUniqueSlot((ImmutableSet) findSlotsByColumn(outputSet, columns)); + } + fdBuilder.addFdItems(computeFdItems(outputSet)); return fdBuilder.build(); } @Override public ImmutableSet computeFdItems(Supplier> outputSupplier) { - Set output = ImmutableSet.copyOf(outputSupplier.get()); + return computeFdItems(Utils.fastToImmutableSet(outputSupplier.get())); + } + + private ImmutableSet computeFdItems(Set outputSet) { ImmutableSet.Builder builder = ImmutableSet.builder(); - table.getPrimaryKeyConstraints().forEach(c -> { + + for (PrimaryKeyConstraint c : table.getPrimaryKeyConstraints()) { Set columns = c.getPrimaryKeys(this.getTable()); - ImmutableSet slotSet = output.stream() - .filter(SlotReference.class::isInstance) - .map(SlotReference.class::cast) - .filter(s -> s.getColumn().isPresent() - && columns.contains(s.getColumn().get())) - .collect(ImmutableSet.toImmutableSet()); - TableFdItem tableFdItem = FdFactory.INSTANCE.createTableFdItem(slotSet, true, - false, ImmutableSet.of(table)); + ImmutableSet slotSet = findSlotsByColumn(outputSet, columns); + TableFdItem tableFdItem = FdFactory.INSTANCE.createTableFdItem( + slotSet, true, false, ImmutableSet.of(table)); builder.add(tableFdItem); - }); - table.getUniqueConstraints().forEach(c -> { + } + + for (UniqueConstraint c : table.getUniqueConstraints()) { Set columns = c.getUniqueKeys(this.getTable()); - boolean allNotNull = columns.stream() - .filter(SlotReference.class::isInstance) - .map(SlotReference.class::cast) - .allMatch(s -> !s.nullable()); - if (allNotNull) { - ImmutableSet slotSet = output.stream() - .filter(SlotReference.class::isInstance) - .map(SlotReference.class::cast) - .filter(s -> s.getColumn().isPresent() - && columns.contains(s.getColumn().get())) - .collect(ImmutableSet.toImmutableSet()); - TableFdItem tableFdItem = FdFactory.INSTANCE.createTableFdItem(slotSet, - true, false, ImmutableSet.of(table)); - builder.add(tableFdItem); - } else { - ImmutableSet slotSet = output.stream() - .filter(SlotReference.class::isInstance) - .map(SlotReference.class::cast) - .filter(s -> s.getColumn().isPresent() - && columns.contains(s.getColumn().get())) - .collect(ImmutableSet.toImmutableSet()); - TableFdItem tableFdItem = FdFactory.INSTANCE.createTableFdItem(slotSet, - true, true, ImmutableSet.of(table)); - builder.add(tableFdItem); + boolean allNotNull = true; + + for (Column column : columns) { + if (column.isAllowNull()) { + allNotNull = false; + break; + } } - }); + + ImmutableSet slotSet = findSlotsByColumn(outputSet, columns); + TableFdItem tableFdItem = FdFactory.INSTANCE.createTableFdItem( + slotSet, true, !allNotNull, ImmutableSet.of(table)); + builder.add(tableFdItem); + } return builder.build(); } + + private ImmutableSet findSlotsByColumn(Set outputSet, Set columns) { + ImmutableSet.Builder slotSet = ImmutableSet.builderWithExpectedSize(columns.size()); + for (Slot slot : outputSet) { + if (!(slot instanceof SlotReference)) { + continue; + } + SlotReference slotRef = (SlotReference) slot; + if (slotRef.getColumn().isPresent() && columns.contains(slotRef.getColumn().get())) { + slotSet.add(slotRef); + } + } + return slotSet.build(); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalOlapScan.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalOlapScan.java index cb5d1847fef2a5..d83a2f59f7fb79 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalOlapScan.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalOlapScan.java @@ -38,16 +38,17 @@ import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList.Builder; +import com.google.common.collect.Lists; import com.google.common.collect.Maps; import com.google.common.collect.Sets; import org.apache.commons.lang3.tuple.Pair; +import java.util.Arrays; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Optional; import java.util.Set; -import java.util.stream.Collectors; /** * Logical OlapScan. @@ -174,9 +175,23 @@ public LogicalOlapScan(RelationId id, Table table, List qualifier, this.indexSelected = indexSelected; this.preAggStatus = preAggStatus; this.manuallySpecifiedPartitions = ImmutableList.copyOf(specifiedPartitions); - this.selectedPartitionIds = selectedPartitionIds.stream() - .filter(partitionId -> this.getTable().getPartition(partitionId) != null) - .collect(Collectors.toList()); + + switch (selectedPartitionIds.size()) { + case 0: { + this.selectedPartitionIds = ImmutableList.of(); + break; + } + default: { + ImmutableList.Builder existPartitions + = ImmutableList.builderWithExpectedSize(selectedPartitionIds.size()); + for (Long partitionId : selectedPartitionIds) { + if (((OlapTable) table).getPartition(partitionId) != null) { + existPartitions.add(partitionId); + } + } + this.selectedPartitionIds = existPartitions.build(); + } + } this.hints = Objects.requireNonNull(hints, "hints can not be null"); this.cacheSlotWithSlotName = Objects.requireNonNull(cacheSlotWithSlotName, "mvNameToSlot can not be null"); @@ -333,14 +348,17 @@ public List computeOutput() { return getOutputByIndex(selectedIndexId); } List baseSchema = table.getBaseSchema(true); + List slotFromColumn = createSlotsVectorized(baseSchema); + Builder slots = ImmutableList.builder(); - for (Column col : baseSchema) { + for (int i = 0; i < baseSchema.size(); i++) { + Column col = baseSchema.get(i); Pair key = Pair.of(selectedIndexId, col.getName()); Slot slot = cacheSlotWithSlotName.get(key); if (slot != null) { slots.add(slot); } else { - slot = SlotReference.fromColumn(table, col, qualified(), this); + slot = slotFromColumn.get(i); cacheSlotWithSlotName.put(key, slot); slots.add(slot); } @@ -363,27 +381,27 @@ public List getOutputByIndex(long indexId) { OlapTable olapTable = (OlapTable) table; // PhysicalStorageLayerAggregateTest has no visible index // when we have a partitioned table without any partition, visible index is empty - if (-1 == indexId || olapTable.getIndexMetaByIndexId(indexId) == null) { - return olapTable.getIndexMetaByIndexId(indexId).getSchema().stream() - .map(c -> generateUniqueSlot(olapTable, c, - indexId == ((OlapTable) table).getBaseIndexId(), indexId)) - .collect(Collectors.toList()); + List schema = olapTable.getIndexMetaByIndexId(indexId).getSchema(); + List slots = Lists.newArrayListWithCapacity(schema.size()); + for (Column c : schema) { + Slot slot = generateUniqueSlot( + olapTable, c, indexId == ((OlapTable) table).getBaseIndexId(), indexId); + slots.add(slot); } - return olapTable.getIndexMetaByIndexId(indexId).getSchema().stream() - .map(s -> generateUniqueSlot(olapTable, s, - indexId == ((OlapTable) table).getBaseIndexId(), indexId)) - .collect(ImmutableList.toImmutableList()); + return slots; } private Slot generateUniqueSlot(OlapTable table, Column column, boolean isBaseIndex, long indexId) { String name = isBaseIndex || directMvScan ? column.getName() : AbstractSelectMaterializedIndexRule.parseMvColumnToMvName(column.getName(), column.isAggregated() ? Optional.of(column.getAggregationType().toSql()) : Optional.empty()); - if (cacheSlotWithSlotName.containsKey(Pair.of(indexId, name))) { - return cacheSlotWithSlotName.get(Pair.of(indexId, name)); + Pair key = Pair.of(indexId, name); + Slot slot = cacheSlotWithSlotName.get(key); + if (slot != null) { + return slot; } - Slot slot = SlotReference.fromColumn(table, column, name, qualified()); - cacheSlotWithSlotName.put(Pair.of(indexId, name), slot); + slot = SlotReference.fromColumn(table, column, name, qualified()); + cacheSlotWithSlotName.put(key, slot); return slot; } @@ -402,4 +420,13 @@ public Optional getTableSample() { public boolean isDirectMvScan() { return directMvScan; } + + private List createSlotsVectorized(List columns) { + List qualified = qualified(); + Object[] slots = new Object[columns.size()]; + for (int i = 0; i < columns.size(); i++) { + slots[i] = SlotReference.fromColumn(table, columns.get(i), qualified, this); + } + return (List) Arrays.asList(slots); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalProject.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalProject.java index d899d228fb66bc..89dd7d49677d3e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalProject.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalProject.java @@ -114,8 +114,14 @@ public List getExcepts() { return excepts; } + /** isAllSlots */ public boolean isAllSlots() { - return projects.stream().allMatch(NamedExpression::isSlot); + for (NamedExpression project : projects) { + if (!project.isSlot()) { + return false; + } + } + return true; } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalSort.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalSort.java index 9d9d321e659636..607fcf25bca7fe 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalSort.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalSort.java @@ -30,11 +30,13 @@ import org.apache.doris.nereids.util.Utils; import com.google.common.base.Preconditions; +import com.google.common.base.Suppliers; import com.google.common.collect.ImmutableList; import java.util.List; import java.util.Objects; import java.util.Optional; +import java.util.function.Supplier; /** * Logical Sort plan. @@ -47,6 +49,7 @@ public class LogicalSort extends LogicalUnary orderKeys; + private final Supplier> expressions; public LogicalSort(List orderKeys, CHILD_TYPE child) { this(orderKeys, Optional.empty(), Optional.empty(), child); @@ -58,7 +61,17 @@ public LogicalSort(List orderKeys, CHILD_TYPE child) { public LogicalSort(List orderKeys, Optional groupExpression, Optional logicalProperties, CHILD_TYPE child) { super(PlanType.LOGICAL_SORT, groupExpression, logicalProperties, child); - this.orderKeys = ImmutableList.copyOf(Objects.requireNonNull(orderKeys, "orderKeys can not be null")); + this.orderKeys = Utils.fastToImmutableList( + Objects.requireNonNull(orderKeys, "orderKeys can not be null") + ); + this.expressions = Suppliers.memoize(() -> { + ImmutableList.Builder exprs + = ImmutableList.builderWithExpectedSize(orderKeys.size()); + for (OrderKey orderKey : orderKeys) { + exprs.add(orderKey.getExpr()); + } + return exprs.build(); + }); } @Override @@ -100,9 +113,7 @@ public R accept(PlanVisitor visitor, C context) { @Override public List getExpressions() { - return orderKeys.stream() - .map(OrderKey::getExpr) - .collect(ImmutableList.toImmutableList()); + return expressions.get(); } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalTopN.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalTopN.java index 63e3dc9c0b9717..1fb5dbaab7271c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalTopN.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalTopN.java @@ -35,6 +35,7 @@ import org.apache.doris.nereids.util.Utils; import com.google.common.base.Preconditions; +import com.google.common.base.Suppliers; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; @@ -51,6 +52,7 @@ public class LogicalTopN extends LogicalUnary orderKeys; private final long limit; private final long offset; + private final Supplier> expressions; public LogicalTopN(List orderKeys, long limit, long offset, CHILD_TYPE child) { this(orderKeys, limit, offset, Optional.empty(), Optional.empty(), child); @@ -65,6 +67,13 @@ public LogicalTopN(List orderKeys, long limit, long offset, Optional { + ImmutableList.Builder exprs = ImmutableList.builderWithExpectedSize(orderKeys.size()); + for (OrderKey orderKey : orderKeys) { + exprs.add(orderKey.getExpr()); + } + return exprs.build(); + }); } @Override @@ -120,9 +129,7 @@ public R accept(PlanVisitor visitor, C context) { @Override public List getExpressions() { - return orderKeys.stream() - .map(OrderKey::getExpr) - .collect(ImmutableList.toImmutableList()); + return expressions.get(); } public LogicalTopN withOrderKeys(List orderKeys) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java index 1f44d128b23be7..7c9d4a5324fba9 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java @@ -40,6 +40,7 @@ import org.apache.doris.nereids.trees.expressions.Or; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.WindowExpression; import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; import org.apache.doris.nereids.trees.expressions.functions.agg.Avg; import org.apache.doris.nereids.trees.expressions.functions.agg.Max; @@ -162,11 +163,11 @@ public static Optional optionalAnd(Collection collection } public static Expression and(Collection expressions) { - return combine(And.class, expressions); + return combineAsLeftDeepTree(And.class, expressions); } public static Expression and(Expression... expressions) { - return combine(And.class, Lists.newArrayList(expressions)); + return combineAsLeftDeepTree(And.class, Lists.newArrayList(expressions)); } public static Optional optionalOr(List expressions) { @@ -178,17 +179,18 @@ public static Optional optionalOr(List expressions) { } public static Expression or(Expression... expressions) { - return combine(Or.class, Lists.newArrayList(expressions)); + return combineAsLeftDeepTree(Or.class, Lists.newArrayList(expressions)); } public static Expression or(Collection expressions) { - return combine(Or.class, expressions); + return combineAsLeftDeepTree(Or.class, expressions); } /** * Use AND/OR to combine expressions together. */ - public static Expression combine(Class type, Collection expressions) { + public static Expression combineAsLeftDeepTree( + Class type, Collection expressions) { /* * (AB) (CD) E ((AB)(CD)) E (((AB)(CD))E) * â–² â–² â–² â–² â–² â–² @@ -209,9 +211,20 @@ public static Expression combine(Class type, Collection replace(List exprs, Map replaceMap) { - return exprs.stream() - .map(expr -> replace(expr, replaceMap)) - .collect(ImmutableList.toImmutableList()); + ImmutableList.Builder result = ImmutableList.builderWithExpectedSize(exprs.size()); + for (Expression expr : exprs) { + result.add(replace(expr, replaceMap)); + } + return result.build(); } public static Set replace(Set exprs, Map replaceMap) { - return exprs.stream() - .map(expr -> replace(expr, replaceMap)) - .collect(ImmutableSet.toImmutableSet()); + ImmutableSet.Builder result = ImmutableSet.builderWithExpectedSize(exprs.size()); + for (Expression expr : exprs) { + result.add(replace(expr, replaceMap)); + } + return result.build(); } /** @@ -456,34 +473,60 @@ public static List mergeArguments(Object... arguments) { return builder.build(); } + /** isAllLiteral */ public static boolean isAllLiteral(List children) { - return children.stream().allMatch(c -> c instanceof Literal); + for (Expression child : children) { + if (!(child instanceof Literal)) { + return false; + } + } + return true; } + /** matchNumericType */ public static boolean matchNumericType(List children) { - return children.stream().allMatch(c -> c.getDataType().isNumericType()); + for (Expression child : children) { + if (!child.getDataType().isNumericType()) { + return false; + } + } + return true; } + /** matchDateLikeType */ public static boolean matchDateLikeType(List children) { - return children.stream().allMatch(c -> c.getDataType().isDateLikeType()); + for (Expression child : children) { + if (!child.getDataType().isDateLikeType()) { + return false; + } + } + return true; } + /** hasNullLiteral */ public static boolean hasNullLiteral(List children) { - return children.stream().anyMatch(c -> c instanceof NullLiteral); + for (Expression child : children) { + if (child instanceof NullLiteral) { + return true; + } + } + return false; } + /** hasOnlyMetricType */ public static boolean hasOnlyMetricType(List children) { - return children.stream().anyMatch(c -> c.getDataType().isOnlyMetricType()); - } - - public static boolean isAllNullLiteral(List children) { - return children.stream().allMatch(c -> c instanceof NullLiteral); + for (Expression child : children) { + if (child.getDataType().isOnlyMetricType()) { + return true; + } + } + return false; } /** * canInferNotNullForMarkSlot */ - public static boolean canInferNotNullForMarkSlot(Expression predicate) { + public static boolean canInferNotNullForMarkSlot(Expression predicate, ExpressionRewriteContext ctx) { /* * assume predicate is from LogicalFilter * the idea is replacing each mark join slot with null and false literal then run FoldConstant rule @@ -523,9 +566,10 @@ public static boolean canInferNotNullForMarkSlot(Expression predicate) { for (int j = 0; j < markSlotSize; ++j) { replaceMap.put(markJoinSlotReferenceList.get(j), literals.get((i >> j) & 1)); } - Expression evalResult = FoldConstantRule.INSTANCE.rewrite( + Expression evalResult = FoldConstantRule.evaluate( ExpressionUtils.replace(predicate, replaceMap), - new ExpressionRewriteContext(null)); + ctx + ); if (evalResult.equals(BooleanLiteral.TRUE)) { if (meetNullOrFalse) { @@ -553,30 +597,33 @@ private static boolean isNullOrFalse(Expression expression) { * infer notNulls slot from predicate */ public static Set inferNotNullSlots(Set predicates, CascadesContext cascadesContext) { - Set notNullSlots = Sets.newHashSet(); + ImmutableSet.Builder notNullSlots = ImmutableSet.builderWithExpectedSize(predicates.size()); for (Expression predicate : predicates) { for (Slot slot : predicate.getInputSlots()) { Map replaceMap = new HashMap<>(); Literal nullLiteral = new NullLiteral(slot.getDataType()); replaceMap.put(slot, nullLiteral); - Expression evalExpr = FoldConstantRule.INSTANCE.rewrite( + Expression evalExpr = FoldConstantRule.evaluate( ExpressionUtils.replace(predicate, replaceMap), - new ExpressionRewriteContext(cascadesContext)); + new ExpressionRewriteContext(cascadesContext) + ); if (evalExpr.isNullLiteral() || BooleanLiteral.FALSE.equals(evalExpr)) { notNullSlots.add(slot); } } } - return notNullSlots; + return notNullSlots.build(); } /** * infer notNulls slot from predicate */ public static Set inferNotNull(Set predicates, CascadesContext cascadesContext) { - return inferNotNullSlots(predicates, cascadesContext).stream() - .map(slot -> new Not(new IsNull(slot), false)) - .collect(Collectors.toSet()); + ImmutableSet.Builder newPredicates = ImmutableSet.builderWithExpectedSize(predicates.size()); + for (Slot slot : inferNotNullSlots(predicates, cascadesContext)) { + newPredicates.add(new Not(new IsNull(slot), false)); + } + return newPredicates.build(); } /** @@ -584,37 +631,90 @@ public static Set inferNotNull(Set predicates, CascadesC */ public static Set inferNotNull(Set predicates, Set slots, CascadesContext cascadesContext) { - return inferNotNullSlots(predicates, cascadesContext).stream() - .filter(slots::contains) - .map(slot -> new Not(new IsNull(slot), true)) - .collect(Collectors.toSet()); + ImmutableSet.Builder newPredicates = ImmutableSet.builderWithExpectedSize(predicates.size()); + for (Slot slot : inferNotNullSlots(predicates, cascadesContext)) { + if (slots.contains(slot)) { + newPredicates.add(new Not(new IsNull(slot), true)); + } + } + return newPredicates.build(); } - public static List flatExpressions(List> expressions) { - return expressions.stream() - .flatMap(List::stream) - .collect(ImmutableList.toImmutableList()); + /** flatExpressions */ + public static List flatExpressions(List> expressionLists) { + int num = 0; + for (List expressionList : expressionLists) { + num += expressionList.size(); + } + + ImmutableList.Builder flatten = ImmutableList.builderWithExpectedSize(num); + for (List expressionList : expressionLists) { + flatten.addAll(expressionList); + } + return flatten.build(); + } + + /** containsType */ + public static boolean containsType(Collection expressions, Class type) { + for (Expression expression : expressions) { + if (expression.anyMatch(expr -> expr.anyMatch(type::isInstance))) { + return true; + } + } + return false; } - public static boolean anyMatch(List expressions, Predicate> predicate) { - return expressions.stream() - .anyMatch(expr -> expr.anyMatch(predicate)); + /** allMatch */ + public static boolean allMatch( + Collection expressions, Predicate predicate) { + for (Expression expression : expressions) { + if (!predicate.test(expression)) { + return false; + } + } + return true; } - public static boolean noneMatch(List expressions, Predicate> predicate) { - return expressions.stream() - .noneMatch(expr -> expr.anyMatch(predicate)); + /** anyMatch */ + public static boolean anyMatch( + Collection expressions, Predicate predicate) { + for (Expression expression : expressions) { + if (predicate.test(expression)) { + return true; + } + } + return false; } - public static boolean containsType(List expressions, Class type) { - return anyMatch(expressions, type::isInstance); + /** deapAnyMatch */ + public static boolean deapAnyMatch( + Collection expressions, Predicate> predicate) { + for (Expression expression : expressions) { + if (expression.anyMatch(expr -> expr.anyMatch(predicate))) { + return true; + } + } + return false; + } + + /** deapNoneMatch */ + public static boolean deapNoneMatch( + Collection expressions, Predicate> predicate) { + for (Expression expression : expressions) { + if (expression.anyMatch(expr -> expr.anyMatch(predicate))) { + return false; + } + } + return true; } public static Set collect(Collection expressions, Predicate> predicate) { - return expressions.stream() - .flatMap(expr -> expr.>collect(predicate).stream()) - .collect(ImmutableSet.toImmutableSet()); + ImmutableSet.Builder set = ImmutableSet.builder(); + for (Expression expr : expressions) { + set.addAll(expr.collectToList(predicate)); + } + return set.build(); } /** @@ -652,11 +752,19 @@ public static Set mutableCollect(List expressions, return set; } + /** collectAll */ public static List collectAll(Collection expressions, Predicate> predicate) { - return expressions.stream() - .flatMap(expr -> expr.>collect(predicate).stream()) - .collect(ImmutableList.toImmutableList()); + switch (expressions.size()) { + case 0: return ImmutableList.of(); + default: { + ImmutableList.Builder result = ImmutableList.builder(); + for (Expression expr : expressions) { + result.addAll((Set) expr.collect(predicate)); + } + return result.build(); + } + } } public static List> rollupToGroupingSets(List rollupExpressions) { @@ -718,7 +826,7 @@ private static void cubeToGroupingSets(List cubeExpressions, int act /** * Get input slot set from list of expressions. */ - public static Set getInputSlotSet(Collection exprs) { + public static Set getInputSlotSet(Collection exprs) { Set set = new HashSet<>(); for (Expression expr : exprs) { set.addAll(expr.getInputSlots()); @@ -807,4 +915,25 @@ public static List distinctSlotByName(List slots) { } return distinctSlots.build(); } + + /** containsWindowExpression */ + public static boolean containsWindowExpression(List expressions) { + for (NamedExpression expression : expressions) { + if (expression.anyMatch(WindowExpression.class::isInstance)) { + return true; + } + } + return false; + } + + /** filter */ + public static List filter(List expressions, Class clazz) { + ImmutableList.Builder result = ImmutableList.builderWithExpectedSize(expressions.size()); + for (Expression expression : expressions) { + if (clazz.isInstance(expression)) { + result.add((E) expression); + } + } + return result.build(); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/PlanUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/PlanUtils.java index 759b96c5b73047..3955b2d0f0c6af 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/PlanUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/PlanUtils.java @@ -155,6 +155,30 @@ public static List fastGetChildrenOutputs(List children) { return output.build(); } + /** fastGetInputSlots */ + public static Set fastGetInputSlots(List expressions) { + switch (expressions.size()) { + case 1: return expressions.get(0).getInputSlots(); + case 0: return ImmutableSet.of(); + default: { + } + } + + int inputSlotsNum = 0; + // child.inputSlots is cached by Expression.inputSlots, + // we can compute output num without the overhead of re-compute output + for (Expression expr : expressions) { + Set output = expr.getInputSlots(); + inputSlotsNum += output.size(); + } + // generate output list only copy once and without resize the list + ImmutableSet.Builder inputSlots = ImmutableSet.builderWithExpectedSize(inputSlotsNum); + for (Expression expr : expressions) { + inputSlots.addAll(expr.getInputSlots()); + } + return inputSlots.build(); + } + /** * Check if slot is from the plan. */ diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java index 4ec437055a3e67..afcdb30f2dcf03 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java @@ -1047,11 +1047,20 @@ public static Expression processCaseWhen(CaseWhen caseWhen) { public static Expression processCompoundPredicate(CompoundPredicate compoundPredicate) { // check compoundPredicate.checkLegalityBeforeTypeCoercion(); - List children = compoundPredicate.children().stream() - .map(e -> e.getDataType().isNullType() ? new NullLiteral(BooleanType.INSTANCE) - : castIfNotSameType(e, BooleanType.INSTANCE)) - .collect(Collectors.toList()); - return compoundPredicate.withChildren(children); + ImmutableList.Builder newChildren + = ImmutableList.builderWithExpectedSize(compoundPredicate.arity()); + boolean changed = false; + for (Expression child : compoundPredicate.children()) { + Expression newChild; + if (child.getDataType().isNullType()) { + newChild = new NullLiteral(BooleanType.INSTANCE); + } else { + newChild = castIfNotSameType(child, BooleanType.INSTANCE); + } + changed |= child != newChild; + newChildren.add(newChild); + } + return changed ? compoundPredicate.withChildren(newChildren.build()) : compoundPredicate; } private static boolean canCompareDate(DataType t1, DataType t2) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/Utils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/Utils.java index df9528cc49e2d6..2c90eefdde00e5 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/Utils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/Utils.java @@ -26,6 +26,7 @@ import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList.Builder; +import com.google.common.collect.ImmutableSet; import com.google.common.collect.Lists; import org.apache.commons.lang3.StringUtils; @@ -34,7 +35,9 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Optional; import java.util.Set; +import java.util.function.BiFunction; import java.util.function.Function; import java.util.function.Predicate; import java.util.stream.Collectors; @@ -325,24 +328,50 @@ public static ImmutableList fastToImmutableList(E[] array) { } /** fastToImmutableList */ - public static ImmutableList fastToImmutableList(List originList) { - if (originList instanceof ImmutableList) { - return (ImmutableList) originList; + public static ImmutableList fastToImmutableList(Collection collection) { + if (collection instanceof ImmutableList) { + return (ImmutableList) collection; } - switch (originList.size()) { + switch (collection.size()) { case 0: return ImmutableList.of(); - case 1: return ImmutableList.of(originList.get(0)); + case 1: + return collection instanceof List + ? ImmutableList.of(((List) collection).get(0)) + : ImmutableList.of(collection.iterator().next()); default: { // NOTE: ImmutableList.copyOf(list) has additional clone of the list, so here we // direct generate a ImmutableList - Builder copyChildren = ImmutableList.builderWithExpectedSize(originList.size()); - copyChildren.addAll(originList); + Builder copyChildren = ImmutableList.builderWithExpectedSize(collection.size()); + copyChildren.addAll(collection); return copyChildren.build(); } } } + /** fastToImmutableSet */ + public static ImmutableSet fastToImmutableSet(Collection collection) { + if (collection instanceof ImmutableSet) { + return (ImmutableSet) collection; + } + switch (collection.size()) { + case 0: + return ImmutableSet.of(); + case 1: + return collection instanceof List + ? ImmutableSet.of(((List) collection).get(0)) + : ImmutableSet.of(collection.iterator().next()); + default: + // NOTE: ImmutableList.copyOf(array) has additional clone of the array, so here we + // direct generate a ImmutableList + ImmutableSet.Builder copyChildren = ImmutableSet.builderWithExpectedSize(collection.size()); + for (E child : collection) { + copyChildren.add(child); + } + return copyChildren.build(); + } + } + /** reverseImmutableList */ public static ImmutableList reverseImmutableList(List list) { Builder reverseList = ImmutableList.builderWithExpectedSize(list.size()); @@ -363,4 +392,25 @@ public static ImmutableList filterImmutableList(List list, P } return newList.build(); } + + public static Set concatToSet(Collection left, Collection right) { + ImmutableSet.Builder required = ImmutableSet.builderWithExpectedSize( + left.size() + right.size() + ); + required.addAll(left); + required.addAll(right); + return required.build(); + } + + /** fastReduce */ + public static Optional fastReduce(List list, BiFunction reduceOp) { + if (list.isEmpty()) { + return Optional.empty(); + } + M merge = list.get(0); + for (int i = 1; i < list.size(); i++) { + merge = reduceOp.apply(merge, list.get(i)); + } + return Optional.of(merge); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java b/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java index 67f3569091418a..66789e330573ca 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java +++ b/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java @@ -58,6 +58,7 @@ import java.time.format.DateTimeFormatter; import java.time.format.DateTimeParseException; import java.util.Arrays; +import java.util.BitSet; import java.util.HashMap; import java.util.List; import java.util.Locale; @@ -300,6 +301,7 @@ public class SessionVariable implements Serializable, Writable { public static final String NEREIDS_CBO_PENALTY_FACTOR = "nereids_cbo_penalty_factor"; public static final String ENABLE_NEREIDS_TRACE = "enable_nereids_trace"; + public static final String ENABLE_EXPR_TRACE = "enable_expr_trace"; public static final String ENABLE_DPHYP_TRACE = "enable_dphyp_trace"; @@ -1141,6 +1143,9 @@ public void setEnableLeftZigZag(boolean enableLeftZigZag) { @VariableMgr.VarAttr(name = ENABLE_NEREIDS_TRACE) private boolean enableNereidsTrace = false; + @VariableMgr.VarAttr(name = ENABLE_EXPR_TRACE) + private boolean enableExprTrace = false; + @VariableMgr.VarAttr(name = ENABLE_DPHYP_TRACE, needForward = true) public boolean enableDpHypTrace = false; @@ -2764,15 +2769,20 @@ public Set getDisableNereidsRuleNames() { .collect(ImmutableSet.toImmutableSet()); } - public Set getDisableNereidsRules() { - return Arrays.stream(disableNereidsRules.split(",[\\s]*")) - .filter(rule -> !rule.isEmpty()) - .map(rule -> rule.toUpperCase(Locale.ROOT)) - .map(rule -> RuleType.valueOf(rule)) - .filter(ruleType -> ruleType != RuleType.CHECK_PRIVILEGES - && ruleType != RuleType.CHECK_ROW_POLICY) - .map(RuleType::type) - .collect(ImmutableSet.toImmutableSet()); + public BitSet getDisableNereidsRules() { + BitSet bitSet = new BitSet(); + for (String ruleName : disableNereidsRules.split(",[\\s]*")) { + if (ruleName.isEmpty()) { + continue; + } + ruleName = ruleName.toUpperCase(Locale.ROOT); + RuleType ruleType = RuleType.valueOf(ruleName); + if (ruleType == RuleType.CHECK_PRIVILEGES || ruleType == RuleType.CHECK_ROW_POLICY) { + continue; + } + bitSet.set(ruleType.type()); + } + return bitSet; } public Set getEnableNereidsRules() { @@ -2807,6 +2817,10 @@ public boolean isEnableNereidsTrace() { return isEnableNereidsPlanner() && enableNereidsTrace; } + public boolean isEnableExprTrace() { + return enableExprTrace; + } + public boolean isEnableSingleReplicaInsert() { return enableSingleReplicaInsert; } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTest.java index de859058ecc7d4..1894b2c7a9b5ef 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTest.java @@ -25,6 +25,7 @@ import org.apache.doris.nereids.rules.expression.rules.SimplifyCastRule; import org.apache.doris.nereids.rules.expression.rules.SimplifyDecimalV3Comparison; import org.apache.doris.nereids.rules.expression.rules.SimplifyNotExprRule; +import org.apache.doris.nereids.rules.expression.rules.SimplifyRange; import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.Expression; @@ -54,7 +55,9 @@ class ExpressionRewriteTest extends ExpressionRewriteTestHelper { @Test void testNotRewrite() { - executor = new ExpressionRuleExecutor(ImmutableList.of(SimplifyNotExprRule.INSTANCE)); + executor = new ExpressionRuleExecutor(ImmutableList.of( + ExpressionRewrite.bottomUp(SimplifyNotExprRule.INSTANCE) + )); assertRewrite("not x", "not x"); assertRewrite("not not x", "x"); @@ -79,7 +82,9 @@ void testNotRewrite() { @Test void testNormalizeExpressionRewrite() { - executor = new ExpressionRuleExecutor(ImmutableList.of(NormalizeBinaryPredicatesRule.INSTANCE)); + executor = new ExpressionRuleExecutor(ImmutableList.of( + ExpressionRewrite.bottomUp(NormalizeBinaryPredicatesRule.INSTANCE) + )); assertRewrite("1 = 1", "1 = 1"); assertRewrite("2 > x", "x < 2"); @@ -91,7 +96,9 @@ void testNormalizeExpressionRewrite() { @Test void testDistinctPredicatesRewrite() { - executor = new ExpressionRuleExecutor(ImmutableList.of(DistinctPredicatesRule.INSTANCE)); + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(DistinctPredicatesRule.INSTANCE) + )); assertRewrite("a = 1", "a = 1"); assertRewrite("a = 1 and a = 1", "a = 1"); @@ -103,7 +110,9 @@ void testDistinctPredicatesRewrite() { @Test void testExtractCommonFactorRewrite() { - executor = new ExpressionRuleExecutor(ImmutableList.of(ExtractCommonFactorRule.INSTANCE)); + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(ExtractCommonFactorRule.INSTANCE) + )); assertRewrite("a", "a"); @@ -112,22 +121,24 @@ void testExtractCommonFactorRewrite() { assertRewrite("a = 1 and b > 2", "a = 1 and b > 2"); assertRewrite("(a and b) or (c and d)", "(a and b) or (c and d)"); - assertRewrite("(a and b) and (c and d)", "((a and b) and c) and d"); + assertRewrite("(a and b) and (c and d)", "((a and b) and (c and d))"); + assertRewrite("(a and (b and c)) and (b or c)", "((b and c) and a)"); assertRewrite("(a or b) and (a or c)", "a or (b and c)"); assertRewrite("(a and b) or (a and c)", "a and (b or c)"); assertRewrite("(a or b) and (a or c) and (a or d)", "a or (b and c and d)"); assertRewrite("(a and b) or (a and c) or (a and d)", "a and (b or c or d)"); - assertRewrite("(a and b) or (a or c) or (a and d)", "((((a and b) or a) or c) or (a and d))"); - assertRewrite("(a and b) or (a and c) or (a or d)", "(((a and b) or (a and c) or a) or d))"); - assertRewrite("(a or b) or (a and c) or (a and d)", "(a or b) or (a and c) or (a and d)"); - assertRewrite("(a or b) or (a and c) or (a or d)", "(((a or b) or (a and c)) or d)"); - assertRewrite("(a or b) or (a or c) or (a and d)", "((a or b) or c) or (a and d)"); + assertRewrite("(a or b) and (a or d)", "a or (b and d)"); + assertRewrite("(a and b) or (a or c) or (a and d)", "a or c"); + assertRewrite("(a and b) or (a and c) or (a or d)", "(a or d)"); + assertRewrite("(a or b) or (a and c) or (a and d)", "(a or b)"); + assertRewrite("(a or b) or (a and c) or (a or d)", "((a or b) or d)"); + assertRewrite("(a or b) or (a or c) or (a and d)", "((a or b) or c)"); assertRewrite("(a or b) or (a or c) or (a or d)", "(((a or b) or c) or d)"); - assertRewrite("(a and b) or (d and c) or (d and e)", "(a and b) or (d and c) or (d and e)"); - assertRewrite("(a or b) and (d or c) and (d or e)", "(a or b) and (d or c) and (d or e)"); + assertRewrite("(a and b) or (d and c) or (d and e)", "((d and (c or e)) or (a and b))"); + assertRewrite("(a or b) and (d or c) and (d or e)", "((d or (c and e)) and (a or b))"); assertRewrite("(a and b) or ((d and c) and (d and e))", "(a and b) or (d and c and e)"); assertRewrite("(a or b) and ((d or c) or (d or e))", "(a or b) and (d or c or e)"); @@ -152,11 +163,14 @@ void testExtractCommonFactorRewrite() { assertRewrite("(a or b) and (a or true)", "a or b"); + assertRewrite("a and (b or ((a and e) or (a and f))) and (b or d)", "(b or ((a and (e or f)) and d)) and a"); } @Test void testInPredicateToEqualToRule() { - executor = new ExpressionRuleExecutor(ImmutableList.of(InPredicateToEqualToRule.INSTANCE)); + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(InPredicateToEqualToRule.INSTANCE) + )); assertRewrite("a in (1)", "a = 1"); assertRewrite("a not in (1)", "not a = 1"); @@ -172,14 +186,18 @@ void testInPredicateToEqualToRule() { @Test void testInPredicateDedup() { - executor = new ExpressionRuleExecutor(ImmutableList.of(InPredicateDedup.INSTANCE)); + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(InPredicateDedup.INSTANCE) + )); assertRewrite("a in (1, 2, 1, 2)", "a in (1, 2)"); } @Test void testSimplifyCastRule() { - executor = new ExpressionRuleExecutor(ImmutableList.of(SimplifyCastRule.INSTANCE)); + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(SimplifyCastRule.INSTANCE) + )); // deduplicate assertRewrite("CAST(1 AS tinyint)", "1"); @@ -211,7 +229,9 @@ void testSimplifyCastRule() { @Test void testSimplifyDecimalV3Comparison() { - executor = new ExpressionRuleExecutor(ImmutableList.of(SimplifyDecimalV3Comparison.INSTANCE)); + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(SimplifyDecimalV3Comparison.INSTANCE) + )); // do rewrite Expression left = new DecimalV3Literal(new BigDecimal("12345.67")); @@ -226,4 +246,16 @@ void testSimplifyDecimalV3Comparison() { comparison = new EqualTo(new DecimalV3Literal(new BigDecimal("12345.67")), new DecimalV3Literal(new BigDecimal("76543.21"))); assertRewrite(comparison, comparison); } + + @Test + void testDeadLoop() { + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp( + SimplifyRange.INSTANCE, + ExtractCommonFactorRule.INSTANCE + ) + )); + + assertRewrite("a and (b > 0 and b < 10)", "a and (b > 0 and b < 10)"); + } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTestHelper.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTestHelper.java index 60d4384207f90f..b252b4650f7315 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTestHelper.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTestHelper.java @@ -46,7 +46,7 @@ import java.util.List; import java.util.Map; -public abstract class ExpressionRewriteTestHelper { +public abstract class ExpressionRewriteTestHelper extends ExpressionRewrite { protected static final NereidsParser PARSER = new NereidsParser(); protected ExpressionRuleExecutor executor; diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/FoldConstantTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/FoldConstantTest.java index 3b8fbc8526b356..747e72b0a9167c 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/FoldConstantTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/FoldConstantTest.java @@ -58,7 +58,9 @@ class FoldConstantTest extends ExpressionRewriteTestHelper { @Test void testCaseWhenFold() { - executor = new ExpressionRuleExecutor(ImmutableList.of(FoldConstantRuleOnFE.INSTANCE)); + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(FoldConstantRuleOnFE.VISITOR_INSTANCE) + )); // assertRewriteAfterTypeCoercion("case when 1 = 2 then 1 when '1' < 2 then 2 else 3 end", "2"); // assertRewriteAfterTypeCoercion("case when 1 = 2 then 1 when '1' > 2 then 2 end", "null"); assertRewriteAfterTypeCoercion("case when (1 + 5) / 2 > 2 then 4 when '1' < 2 then 2 else 3 end", "4"); @@ -75,7 +77,9 @@ void testCaseWhenFold() { @Test void testInFold() { - executor = new ExpressionRuleExecutor(ImmutableList.of(FoldConstantRuleOnFE.INSTANCE)); + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(FoldConstantRuleOnFE.VISITOR_INSTANCE) + )); assertRewriteAfterTypeCoercion("1 in (1,2,3,4)", "true"); // Type Coercion trans all to string. assertRewriteAfterTypeCoercion("3 in ('1', 2 + 8 / 2, 3, 4)", "true"); @@ -88,7 +92,9 @@ void testInFold() { @Test void testLogicalFold() { - executor = new ExpressionRuleExecutor(ImmutableList.of(FoldConstantRuleOnFE.INSTANCE)); + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(FoldConstantRuleOnFE.VISITOR_INSTANCE) + )); assertRewriteAfterTypeCoercion("10 + 1 > 1 and 1 > 2", "false"); assertRewriteAfterTypeCoercion("10 + 1 > 1 and 1 < 2", "true"); assertRewriteAfterTypeCoercion("null + 1 > 1 and 1 < 2", "null"); @@ -126,7 +132,9 @@ void testLogicalFold() { @Test void testIsNullFold() { - executor = new ExpressionRuleExecutor(ImmutableList.of(FoldConstantRuleOnFE.INSTANCE)); + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(FoldConstantRuleOnFE.VISITOR_INSTANCE) + )); assertRewriteAfterTypeCoercion("100 is null", "false"); assertRewriteAfterTypeCoercion("null is null", "true"); assertRewriteAfterTypeCoercion("null is not null", "false"); @@ -137,7 +145,9 @@ void testIsNullFold() { @Test void testNotPredicateFold() { - executor = new ExpressionRuleExecutor(ImmutableList.of(FoldConstantRuleOnFE.INSTANCE)); + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(FoldConstantRuleOnFE.VISITOR_INSTANCE) + )); assertRewriteAfterTypeCoercion("not 1 > 2", "true"); assertRewriteAfterTypeCoercion("not null + 1 > 2", "null"); assertRewriteAfterTypeCoercion("not (1 + 5) / 2 + (10 - 1) * 3 > 3 * 5 + 1", "false"); @@ -145,7 +155,9 @@ void testNotPredicateFold() { @Test void testCastFold() { - executor = new ExpressionRuleExecutor(ImmutableList.of(FoldConstantRuleOnFE.INSTANCE)); + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(FoldConstantRuleOnFE.VISITOR_INSTANCE) + )); // cast '1' as tinyint Cast c = new Cast(Literal.of("1"), TinyIntType.INSTANCE); @@ -156,7 +168,9 @@ void testCastFold() { @Test void testCompareFold() { - executor = new ExpressionRuleExecutor(ImmutableList.of(FoldConstantRuleOnFE.INSTANCE)); + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(FoldConstantRuleOnFE.VISITOR_INSTANCE) + )); assertRewriteAfterTypeCoercion("'1' = 2", "false"); assertRewriteAfterTypeCoercion("1 = 2", "false"); assertRewriteAfterTypeCoercion("1 != 2", "true"); @@ -173,7 +187,9 @@ void testCompareFold() { @Test void testArithmeticFold() { - executor = new ExpressionRuleExecutor(ImmutableList.of(FoldConstantRuleOnFE.INSTANCE)); + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(FoldConstantRuleOnFE.VISITOR_INSTANCE) + )); assertRewrite("1 + 1", Literal.of((short) 2)); assertRewrite("1 - 1", Literal.of((short) 0)); assertRewrite("100 + 100", Literal.of((short) 200)); @@ -206,7 +222,9 @@ void testArithmeticFold() { @Test void testTimestampFold() { - executor = new ExpressionRuleExecutor(ImmutableList.of(FoldConstantRuleOnFE.INSTANCE)); + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(FoldConstantRuleOnFE.VISITOR_INSTANCE) + )); String interval = "'1991-05-01' - interval 1 day"; Expression e7 = process((TimestampArithmetic) PARSER.parseExpression(interval)); Expression e8 = Config.enable_date_conversion diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/PredicatesSplitterTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/PredicatesSplitterTest.java index cab2c2f5a64274..a83ac620164806 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/PredicatesSplitterTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/PredicatesSplitterTest.java @@ -48,7 +48,7 @@ public void testSplitPredicates() { "c = d or a = 10"); assetEquals("a = b and c + d = e and a > 7 and 10 > d", "a = b", - "10 > d and a > 7", + "a > 7 and 10 > d", "c + d = e"); assetEquals("a = b and c + d = e or a > 7 and 10 > d", "", diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyArithmeticRuleTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyArithmeticRuleTest.java index 4ea50bf1f8817c..6af8682365519f 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyArithmeticRuleTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyArithmeticRuleTest.java @@ -29,9 +29,11 @@ class SimplifyArithmeticRuleTest extends ExpressionRewriteTestHelper { @Test void testSimplifyArithmetic() { executor = new ExpressionRuleExecutor(ImmutableList.of( - SimplifyArithmeticRule.INSTANCE, + bottomUp(SimplifyArithmeticRule.INSTANCE), FunctionBinder.INSTANCE, - FoldConstantRule.INSTANCE + bottomUp( + FoldConstantRule.INSTANCE + ) )); assertRewriteAfterTypeCoercion("IA", "IA"); assertRewriteAfterTypeCoercion("IA + 1", "IA + 1"); @@ -55,7 +57,7 @@ void testSimplifyArithmetic() { @Test void testSimplifyArithmeticRuleOnly() { executor = new ExpressionRuleExecutor(ImmutableList.of( - SimplifyArithmeticRule.INSTANCE + bottomUp(SimplifyArithmeticRule.INSTANCE) )); // add and subtract @@ -67,39 +69,43 @@ void testSimplifyArithmeticRuleOnly() { assertRewriteAfterTypeCoercion("IA - 2 - ((-IB - 1) - (3 + (IC + 4)))", "(((IA + IB) + IC) - ((((2 + 0) - 1) - 3) - 4))"); // multiply and divide - assertRewriteAfterTypeCoercion("2 / IA / ((1 / IB) / (3 * IC))", "((((cast(2 as DOUBLE) / cast(1 as DOUBLE)) / cast(IA as DOUBLE)) * cast(IB as DOUBLE)) * cast((3 * IC) as DOUBLE))"); + assertRewriteAfterTypeCoercion("2 / IA / ((1 / IB) / (3 * IC))", "((((cast(2 as DOUBLE) / cast(1 as DOUBLE)) / cast(IA as DOUBLE)) * cast(IB as DOUBLE)) * cast((IC * 3) as DOUBLE))"); assertRewriteAfterTypeCoercion("IA / 2 / ((IB * 1) / (3 / (IC / 4)))", "(((cast(IA as DOUBLE) / cast((IB * 1) as DOUBLE)) / cast(IC as DOUBLE)) / ((cast(2 as DOUBLE) / cast(3 as DOUBLE)) / cast(4 as DOUBLE)))"); assertRewriteAfterTypeCoercion("IA / 2 / ((IB / 1) / (3 / (IC * 4)))", "(((cast(IA as DOUBLE) / cast(IB as DOUBLE)) / cast((IC * 4) as DOUBLE)) / ((cast(2 as DOUBLE) / cast(1 as DOUBLE)) / cast(3 as DOUBLE)))"); - assertRewriteAfterTypeCoercion("IA / 2 / ((IB / 1) / (3 * (IC * 4)))", "(((cast(IA as DOUBLE) / cast(IB as DOUBLE)) * cast((3 * (IC * 4)) as DOUBLE)) / (cast(2 as DOUBLE) / cast(1 as DOUBLE)))"); + assertRewriteAfterTypeCoercion("IA / 2 / ((IB / 1) / (3 * (IC * 4)))", "(((cast(IA as DOUBLE) / cast(IB as DOUBLE)) * cast((IC * (3 * 4)) as DOUBLE)) / (cast(2 as DOUBLE) / cast(1 as DOUBLE)))"); // hybrid // root is subtract assertRewriteAfterTypeCoercion("-2 - IA * ((1 - IB) - (3 / IC))", "(cast(-2 as DOUBLE) - (cast(IA as DOUBLE) * (cast((1 - IB) as DOUBLE) - (cast(3 as DOUBLE) / cast(IC as DOUBLE)))))"); - assertRewriteAfterTypeCoercion("-IA - 2 - ((IB * 1) - (3 * (IC / 4)))", "((cast(((0 - IA) - 2) as DOUBLE) - cast((IB * 1) as DOUBLE)) + (cast(3 as DOUBLE) * (cast(IC as DOUBLE) / cast(4 as DOUBLE))))"); + assertRewriteAfterTypeCoercion("-IA - 2 - ((IB * 1) - (3 * (IC / 4)))", "((cast(((0 - 2) - IA) as DOUBLE) - cast((IB * 1) as DOUBLE)) + (cast(3 as DOUBLE) * (cast(IC as DOUBLE) / cast(4 as DOUBLE))))"); // root is add - assertRewriteAfterTypeCoercion("-IA * 2 + ((IB - 1) / (3 - (IC + 4)))", "(cast(((0 - IA) * 2) as DOUBLE) + (cast((IB - 1) as DOUBLE) / cast((3 - (IC + 4)) as DOUBLE)))"); + assertRewriteAfterTypeCoercion("-IA * 2 + ((IB - 1) / (3 - (IC + 4)))", "(cast(((0 - IA) * 2) as DOUBLE) + (cast((IB - 1) as DOUBLE) / cast(((3 - 4) - IC) as DOUBLE)))"); assertRewriteAfterTypeCoercion("-IA + 2 + ((IB - 1) - (3 * (IC + 4)))", "(((((0 + 2) - 1) - IA) + IB) - (3 * (IC + 4)))"); // root is multiply - assertRewriteAfterTypeCoercion("-IA / 2 * ((-IB - 1) - (3 + (IC + 4)))", "((cast((0 - IA) as DOUBLE) * cast((((0 - IB) - 1) - (3 + (IC + 4))) as DOUBLE)) / cast(2 as DOUBLE))"); - assertRewriteAfterTypeCoercion("-IA / 2 * ((-IB - 1) * (3 / (IC + 4)))", "(((cast((0 - IA) as DOUBLE) * cast(((0 - IB) - 1) as DOUBLE)) / cast((IC + 4) as DOUBLE)) / (cast(2 as DOUBLE) / cast(3 as DOUBLE)))"); + assertRewriteAfterTypeCoercion("-IA / 2 * ((-IB - 1) - (3 + (IC + 4)))", "((cast((0 - IA) as DOUBLE) * cast((((((0 - 1) - 3) - 4) - IB) - IC) as DOUBLE)) / cast(2 as DOUBLE))"); + assertRewriteAfterTypeCoercion("-IA / 2 * ((-IB - 1) * (3 / (IC + 4)))", "(((cast((0 - IA) as DOUBLE) * cast(((0 - 1) - IB) as DOUBLE)) / cast((IC + 4) as DOUBLE)) / (cast(2 as DOUBLE) / cast(3 as DOUBLE)))"); // root is divide - assertRewriteAfterTypeCoercion("(-IA / 2) / ((-IB - 1) - (3 + (IC + 4)))", "((cast((0 - IA) as DOUBLE) / cast((((0 - IB) - 1) - (3 + (IC + 4))) as DOUBLE)) / cast(2 as DOUBLE))"); - assertRewriteAfterTypeCoercion("(-IA / 2) / ((-IB - 1) / (3 + (IC * 4)))", "(((cast((0 - IA) as DOUBLE) / cast(((0 - IB) - 1) as DOUBLE)) * cast((3 + (IC * 4)) as DOUBLE)) / cast(2 as DOUBLE))"); + assertRewriteAfterTypeCoercion("(-IA / 2) / ((-IB - 1) - (3 + (IC + 4)))", "((cast((0 - IA) as DOUBLE) / cast((((((0 - 1) - 3) - 4) - IB) - IC) as DOUBLE)) / cast(2 as DOUBLE))"); + assertRewriteAfterTypeCoercion("(-IA / 2) / ((-IB - 1) / (3 + (IC * 4)))", "(((cast((0 - IA) as DOUBLE) / cast(((0 - 1) - IB) as DOUBLE)) * cast(((IC * 4) + 3) as DOUBLE)) / cast(2 as DOUBLE))"); // unsupported decimal - assertRewriteAfterTypeCoercion("-2 - MA - ((1 - IB) - (3 + IC))", "((cast(-2 as DECIMALV3(38, 0)) - MA) - cast(((1 - IB) - (3 + IC)) as DECIMALV3(38, 0)))"); - assertRewriteAfterTypeCoercion("-IA / 2.0 * ((-IB - 1) - (3 + (IC + 4)))", "((cast((0 - IA) as DECIMALV3(25, 5)) / 2.0) * cast((((0 - IB) - 1) - (3 + (IC + 4))) as DECIMALV3(20, 0)))"); + assertRewriteAfterTypeCoercion("-2 - MA - ((1 - IB) - (3 + IC))", "((cast(-2 as DECIMALV3(38, 0)) - MA) - cast((((1 - 3) - IB) - IC) as DECIMALV3(38, 0)))"); + assertRewriteAfterTypeCoercion("-IA / 2.0 * ((-IB - 1) - (3 + (IC + 4)))", "((cast((0 - IA) as DECIMALV3(25, 5)) / 2.0) * cast((((((0 - 1) - 3) - 4) - IB) - IC) as DECIMALV3(20, 0)))"); } @Test void testSimplifyArithmeticComparison() { executor = new ExpressionRuleExecutor(ImmutableList.of( - SimplifyArithmeticRule.INSTANCE, - FoldConstantRule.INSTANCE, - SimplifyArithmeticComparisonRule.INSTANCE, - SimplifyArithmeticRule.INSTANCE, + bottomUp( + SimplifyArithmeticRule.INSTANCE, + FoldConstantRule.INSTANCE, + SimplifyArithmeticComparisonRule.INSTANCE, + SimplifyArithmeticRule.INSTANCE + ), FunctionBinder.INSTANCE, - FoldConstantRule.INSTANCE + bottomUp( + FoldConstantRule.INSTANCE + ) )); assertRewriteAfterTypeCoercion("IA", "IA"); assertRewriteAfterTypeCoercion("IA > IB", "IA > IB"); @@ -134,12 +140,16 @@ void testSimplifyArithmeticComparison() { @Test void testSimplifyDateTimeComparison() { executor = new ExpressionRuleExecutor(ImmutableList.of( - SimplifyArithmeticRule.INSTANCE, - FoldConstantRule.INSTANCE, - SimplifyArithmeticComparisonRule.INSTANCE, - SimplifyArithmeticRule.INSTANCE, + bottomUp( + SimplifyArithmeticRule.INSTANCE, + FoldConstantRule.INSTANCE, + SimplifyArithmeticComparisonRule.INSTANCE, + SimplifyArithmeticRule.INSTANCE + ), FunctionBinder.INSTANCE, - FoldConstantRule.INSTANCE + bottomUp( + FoldConstantRule.INSTANCE + ) )); assertRewriteAfterTypeCoercion("years_add(IA, 1) > '2021-01-01 00:00:00'", "(cast(IA as DATETIMEV2(0)) > '2020-01-01 00:00:00')"); assertRewriteAfterTypeCoercion("years_sub(IA, 1) > '2021-01-01 00:00:00'", "(cast(IA as DATETIMEV2(0)) > '2022-01-01 00:00:00')"); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyInPredicateTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyInPredicateTest.java index 87c57889b2f6fc..09fc7346f56659 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyInPredicateTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyInPredicateTest.java @@ -34,8 +34,10 @@ public class SimplifyInPredicateTest extends ExpressionRewriteTestHelper { @Test public void test() { executor = new ExpressionRuleExecutor(ImmutableList.of( - FoldConstantRule.INSTANCE, - SimplifyInPredicate.INSTANCE + bottomUp( + FoldConstantRule.INSTANCE, + SimplifyInPredicate.INSTANCE + ) )); Map mem = Maps.newHashMap(); Expression rewrittenExpression = PARSER.parseExpression("cast(CA as DATETIME) in ('1992-01-31 00:00:00', '1992-02-01 00:00:00')"); @@ -48,7 +50,9 @@ public void test() { Expression expectedExpression = PARSER.parseExpression("CA in (cast('1992-01-31' as date), cast('1992-02-01' as date))"); expectedExpression = replaceUnboundSlot(expectedExpression, mem); executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp( FoldConstantRule.INSTANCE + ) )); expectedExpression = executor.rewrite(expectedExpression, context); Assertions.assertEquals(expectedExpression, rewrittenExpression); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyRangeTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyRangeTest.java index 16476e371464fd..81accb2964716d 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyRangeTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyRangeTest.java @@ -44,7 +44,7 @@ import java.util.List; import java.util.Map; -public class SimplifyRangeTest { +public class SimplifyRangeTest extends ExpressionRewrite { private static final NereidsParser PARSER = new NereidsParser(); private ExpressionRuleExecutor executor; @@ -58,7 +58,9 @@ public SimplifyRangeTest() { @Test public void testSimplify() { - executor = new ExpressionRuleExecutor(ImmutableList.of(SimplifyRange.INSTANCE)); + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(SimplifyRange.INSTANCE) + )); assertRewrite("TA", "TA"); assertRewrite("(TA >= 1 and TA <=3 ) or (TA > 5 and TA < 7)", "(TA >= 1 and TA <=3 ) or (TA > 5 and TA < 7)"); assertRewrite("(TA > 3 and TA < 1) or (TA > 7 and TA < 5)", "FALSE"); @@ -160,8 +162,10 @@ public void testSimplify() { @Test public void testSimplifyDate() { - executor = new ExpressionRuleExecutor(ImmutableList.of(SimplifyRange.INSTANCE)); - // assertRewrite("TA", "TA"); + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(SimplifyRange.INSTANCE) + )); + assertRewrite("TA", "TA"); assertRewrite( "(TA >= date '2024-01-01' and TA <= date '2024-01-03') or (TA > date '2024-01-05' and TA < date '2024-01-07')", "(TA >= date '2024-01-01' and TA <= date '2024-01-03') or (TA > date '2024-01-05' and TA < date '2024-01-07')"); @@ -226,8 +230,10 @@ public void testSimplifyDate() { @Test public void testSimplifyDateTime() { - executor = new ExpressionRuleExecutor(ImmutableList.of(SimplifyRange.INSTANCE)); - // assertRewrite("TA", "TA"); + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(SimplifyRange.INSTANCE) + )); + assertRewrite("TA", "TA"); assertRewrite( "(TA >= timestamp '2024-01-01 00:00:00' and TA <= timestamp '2024-01-03 00:00:00') or (TA > timestamp '2024-01-05 00:00:00' and TA < timestamp '2024-01-07 00:00:00')", "(TA >= timestamp '2024-01-01 00:00:00' and TA <= timestamp '2024-01-03 00:00:00') or (TA > timestamp '2024-01-05 00:00:00' and TA < timestamp '2024-01-07 00:00:00')"); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/NullSafeEqualToEqualTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/NullSafeEqualToEqualTest.java index 140f72c57f4a4c..db1186738da713 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/NullSafeEqualToEqualTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/NullSafeEqualToEqualTest.java @@ -35,7 +35,9 @@ class NullSafeEqualToEqualTest extends ExpressionRewriteTestHelper { // "A<=> Null" to "A is null" @Test void testNullSafeEqualToIsNull() { - executor = new ExpressionRuleExecutor(ImmutableList.of(NullSafeEqualToEqual.INSTANCE)); + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(NullSafeEqualToEqual.INSTANCE) + )); SlotReference slot = new SlotReference("a", StringType.INSTANCE, true); assertRewrite(new NullSafeEqual(slot, NullLiteral.INSTANCE), new IsNull(slot)); } @@ -43,7 +45,9 @@ void testNullSafeEqualToIsNull() { // "A<=> Null" to "False", when A is not nullable @Test void testNullSafeEqualToFalse() { - executor = new ExpressionRuleExecutor(ImmutableList.of(NullSafeEqualToEqual.INSTANCE)); + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(NullSafeEqualToEqual.INSTANCE) + )); SlotReference slot = new SlotReference("a", StringType.INSTANCE, false); assertRewrite(new NullSafeEqual(slot, NullLiteral.INSTANCE), BooleanLiteral.FALSE); } @@ -51,7 +55,9 @@ void testNullSafeEqualToFalse() { // "A(nullable)<=>B" not changed @Test void testNullSafeEqualNotChangedLeft() { - executor = new ExpressionRuleExecutor(ImmutableList.of(NullSafeEqualToEqual.INSTANCE)); + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(NullSafeEqualToEqual.INSTANCE) + )); SlotReference a = new SlotReference("a", StringType.INSTANCE, true); SlotReference b = new SlotReference("b", StringType.INSTANCE, false); assertRewrite(new NullSafeEqual(a, b), new NullSafeEqual(a, b)); @@ -60,7 +66,9 @@ void testNullSafeEqualNotChangedLeft() { // "A<=>B(nullable)" not changed @Test void testNullSafeEqualNotChangedRight() { - executor = new ExpressionRuleExecutor(ImmutableList.of(NullSafeEqualToEqual.INSTANCE)); + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(NullSafeEqualToEqual.INSTANCE) + )); SlotReference a = new SlotReference("a", StringType.INSTANCE, false); SlotReference b = new SlotReference("b", StringType.INSTANCE, true); assertRewrite(new NullSafeEqual(a, b), new NullSafeEqual(a, b)); @@ -69,7 +77,9 @@ void testNullSafeEqualNotChangedRight() { // "A<=>B" changed @Test void testNullSafeEqualToEqual() { - executor = new ExpressionRuleExecutor(ImmutableList.of(NullSafeEqualToEqual.INSTANCE)); + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(NullSafeEqualToEqual.INSTANCE) + )); SlotReference a = new SlotReference("a", StringType.INSTANCE, false); SlotReference b = new SlotReference("b", StringType.INSTANCE, false); assertRewrite(new NullSafeEqual(a, b), new EqualTo(a, b)); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyArithmeticComparisonRuleTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyArithmeticComparisonRuleTest.java index fc31daaa9414d7..4d932187611136 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyArithmeticComparisonRuleTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyArithmeticComparisonRuleTest.java @@ -17,6 +17,7 @@ package org.apache.doris.nereids.rules.expression.rules; +import org.apache.doris.nereids.rules.expression.ExpressionRewrite; import org.apache.doris.nereids.rules.expression.ExpressionRewriteTestHelper; import org.apache.doris.nereids.rules.expression.ExpressionRuleExecutor; import org.apache.doris.nereids.trees.expressions.Expression; @@ -37,7 +38,9 @@ class SimplifyArithmeticComparisonRuleTest extends ExpressionRewriteTestHelper { public void testProcess() { Map nameToSlot = new HashMap<>(); nameToSlot.put("a", new SlotReference("a", IntegerType.INSTANCE)); - executor = new ExpressionRuleExecutor(ImmutableList.of(SimplifyArithmeticComparisonRule.INSTANCE)); + executor = new ExpressionRuleExecutor(ImmutableList.of( + ExpressionRewrite.bottomUp(SimplifyArithmeticComparisonRule.INSTANCE) + )); assertRewriteAfterSimplify("a + 1 > 1", "a > cast((1 - 1) as INT)", nameToSlot); assertRewriteAfterSimplify("a - 1 > 1", "a > cast((1 + 1) as INT)", nameToSlot); assertRewriteAfterSimplify("a / -2 > 1", "cast((1 * -2) as INT) > a", nameToSlot); @@ -82,7 +85,7 @@ private void assertRewriteAfterSimplify(String expr, String expected, Map 2021-01-01 00:00:00.001) Expression expression = new GreaterThan(left, right); Expression rewrittenExpression = executor.rewrite(typeCoercion(expression), context); - Assertions.assertEquals(left.getDataType(), rewrittenExpression.child(0).getDataType()); + Assertions.assertEquals(dt.getDataType(), rewrittenExpression.child(0).getDataType()); // (cast(0001-01-01 01:01:01 as DATETIMEV2(0)) < 2021-01-01 00:00:00.001) expression = new GreaterThan(left, right); rewrittenExpression = executor.rewrite(typeCoercion(expression), context); - Assertions.assertEquals(left.getDataType(), rewrittenExpression.child(0).getDataType()); + Assertions.assertEquals(dt.getDataType(), rewrittenExpression.child(0).getDataType()); } @Test void testRound() { - executor = new ExpressionRuleExecutor( - ImmutableList.of(SimplifyCastRule.INSTANCE, SimplifyComparisonPredicate.INSTANCE)); + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp( + SimplifyCastRule.INSTANCE, + SimplifyComparisonPredicate.INSTANCE + ) + )); Expression left = new Cast(new DateTimeLiteral("2021-01-02 00:00:00.00"), DateTimeV2Type.of(1)); Expression right = new DateTimeV2Literal("2021-01-01 23:59:59.99"); @@ -120,13 +132,14 @@ void testRound() { Expression rewrittenExpression = executor.rewrite(typeCoercion(expression), context); // right should round to be 2021-01-02 00:00:00.00 - Assertions.assertEquals(new DateTimeV2Literal("2021-01-02 00:00:00"), rewrittenExpression.child(1)); + Assertions.assertEquals(new DateTimeLiteral("2021-01-02 00:00:00"), rewrittenExpression.child(1)); } @Test void testDoubleLiteral() { - executor = new ExpressionRuleExecutor( - ImmutableList.of(SimplifyComparisonPredicate.INSTANCE)); + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(SimplifyComparisonPredicate.INSTANCE) + )); Expression leftChild = new BigIntLiteral(999); Expression left = new Cast(leftChild, DoubleType.INSTANCE); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyDecimalV3ComparisonTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyDecimalV3ComparisonTest.java index ff424e4971145b..ee089a82f88079 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyDecimalV3ComparisonTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyDecimalV3ComparisonTest.java @@ -39,7 +39,9 @@ public void testSimplifyDecimalV3Comparison() { Config.enable_decimal_conversion = false; Map nameToSlot = new HashMap<>(); nameToSlot.put("col1", new SlotReference("col1", DecimalV3Type.createDecimalV3Type(15, 2))); - executor = new ExpressionRuleExecutor(ImmutableList.of(SimplifyDecimalV3Comparison.INSTANCE)); + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(SimplifyDecimalV3Comparison.INSTANCE) + )); assertRewriteAfterSimplify("cast(col1 as decimalv3(27, 9)) > 0.6", "cast(col1 as decimalv3(27, 9)) > 0.6", nameToSlot); } @@ -48,7 +50,7 @@ private void assertRewriteAfterSimplify(String expr, String expected, Map inPredicates = rewritten.collect(e -> e instanceof InPredicate); Assertions.assertEquals(1, inPredicates.size()); InPredicate inPredicate = inPredicates.iterator().next(); @@ -62,7 +61,7 @@ void test1() { void test2() { String expr = "col1 = 1 and col1 = 3 and col2 = 3 or col2 = 4"; Expression expression = PARSER.parseExpression(expr); - Expression rewritten = new OrToIn().rewrite(expression, new ExpressionRewriteContext(null)); + Expression rewritten = OrToIn.INSTANCE.rewriteTree(expression, context); Assertions.assertEquals("((((col1 = 1) AND (col1 = 3)) AND (col2 = 3)) OR (col2 = 4))", rewritten.toSql()); } @@ -71,7 +70,7 @@ void test2() { void test3() { String expr = "(col1 = 1 or col1 = 2) and (col2 = 3 or col2 = 4)"; Expression expression = PARSER.parseExpression(expr); - Expression rewritten = new OrToIn().rewrite(expression, new ExpressionRewriteContext(null)); + Expression rewritten = OrToIn.INSTANCE.rewriteTree(expression, context); List inPredicates = rewritten.collectToList(e -> e instanceof InPredicate); Assertions.assertEquals(2, inPredicates.size()); InPredicate in1 = inPredicates.get(0); @@ -95,7 +94,7 @@ void test4() { String expr = "case when col = 1 or col = 2 or col = 3 then 1" + " when col = 4 or col = 5 or col = 6 then 1 else 0 end"; Expression expression = PARSER.parseExpression(expr); - Expression rewritten = new OrToIn().rewrite(expression, new ExpressionRewriteContext(null)); + Expression rewritten = OrToIn.INSTANCE.rewriteTree(expression, context); Assertions.assertEquals("CASE WHEN col IN (1, 2, 3) THEN 1 WHEN col IN (4, 5, 6) THEN 1 ELSE 0 END", rewritten.toSql()); } @@ -104,7 +103,7 @@ void test4() { void test5() { String expr = "col = 1 or (col = 2 and (col = 3 or col = 4 or col = 5))"; Expression expression = PARSER.parseExpression(expr); - Expression rewritten = new OrToIn().rewrite(expression, new ExpressionRewriteContext(null)); + Expression rewritten = OrToIn.INSTANCE.rewriteTree(expression, context); Assertions.assertEquals("((col = 1) OR ((col = 2) AND col IN (3, 4, 5)))", rewritten.toSql()); } @@ -113,7 +112,7 @@ void test5() { void test6() { String expr = "col = 1 or col = 2 or col in (1, 2, 3)"; Expression expression = PARSER.parseExpression(expr); - Expression rewritten = new OrToIn().rewrite(expression, new ExpressionRewriteContext(null)); + Expression rewritten = OrToIn.INSTANCE.rewriteTree(expression, context); Assertions.assertEquals("col IN (1, 2, 3)", rewritten.toSql()); } @@ -121,7 +120,7 @@ void test6() { void test7() { String expr = "A = 1 or A = 2 or abs(A)=5 or A in (1, 2, 3) or B = 1 or B = 2 or B in (1, 2, 3) or B+1 in (4, 5, 7)"; Expression expression = PARSER.parseExpression(expr); - Expression rewritten = new OrToIn().rewrite(expression, new ExpressionRewriteContext(null)); + Expression rewritten = OrToIn.INSTANCE.rewriteTree(expression, context); Assertions.assertEquals("(((A IN (1, 2, 3) OR B IN (1, 2, 3)) OR (abs(A) = 5)) OR (B + 1) IN (4, 5, 7))", rewritten.toSql()); } @@ -129,7 +128,7 @@ void test7() { void test8() { String expr = "col = 1 or (col = 2 and (col = 3 or col = '4' or col = 5.0))"; Expression expression = PARSER.parseExpression(expr); - Expression rewritten = new OrToIn().rewrite(expression, new ExpressionRewriteContext(null)); + Expression rewritten = OrToIn.INSTANCE.rewriteTree(expression, context); Assertions.assertEquals("((col = 1) OR ((col = 2) AND col IN ('4', 3, 5.0)))", rewritten.toSql()); } @@ -139,7 +138,7 @@ void testEnsureOrder() { // ensure not rewrite to col2 in (1, 2) or cor 1 in (1, 2) String expr = "col1 IN (1, 2) OR col2 IN (1, 2)"; Expression expression = PARSER.parseExpression(expr); - Expression rewritten = new OrToIn().rewrite(expression, new ExpressionRewriteContext(null)); + Expression rewritten = OrToIn.INSTANCE.rewriteTree(expression, context); Assertions.assertEquals("(col1 IN (1, 2) OR col2 IN (1, 2))", rewritten.toSql()); } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownFilterThroughAggregationTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownFilterThroughAggregationTest.java index 36cb8cee8d41ec..8bae1713fe1b51 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownFilterThroughAggregationTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownFilterThroughAggregationTest.java @@ -47,8 +47,8 @@ import java.util.Optional; public class PushDownFilterThroughAggregationTest implements MemoPatternMatchSupported { - private final LogicalOlapScan scan = new LogicalOlapScan(StatementScopeIdGenerator.newRelationId(), PlanConstructor.student, - ImmutableList.of("")); + private final LogicalOlapScan scan = new LogicalOlapScan( + StatementScopeIdGenerator.newRelationId(), PlanConstructor.student, ImmutableList.of("")); /*- * origin plan: diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/functions/ComputeSignatureHelperTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/functions/ComputeSignatureHelperTest.java index 3fc00ee4bad2ba..ff518fb9d1fa87 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/functions/ComputeSignatureHelperTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/functions/ComputeSignatureHelperTest.java @@ -49,6 +49,7 @@ import java.math.BigDecimal; import java.util.Collections; import java.util.List; +import java.util.Optional; public class ComputeSignatureHelperTest { @@ -419,6 +420,16 @@ public int arity() { return 0; } + @Override + public Optional getMutableState(String key) { + return Optional.empty(); + } + + @Override + public void setMutableState(String key, Object value) { + + } + @Override public Expression withChildren(List children) { return null; diff --git a/regression-test/suites/schema_change_p0/test_alter_table_replace.groovy b/regression-test/suites/schema_change_p0/test_alter_table_replace.groovy index b07a54c528865a..f97768e4f3741a 100644 --- a/regression-test/suites/schema_change_p0/test_alter_table_replace.groovy +++ b/regression-test/suites/schema_change_p0/test_alter_table_replace.groovy @@ -96,7 +96,7 @@ suite("test_alter_table_replace") { test { sql """ select * from ${tbNameB} order by user_id""" // check exception message contains - exception "Unknown table '${tbNameB}'" + exception "Table [${tbNameB}] does not exist in database" } sql "DROP TABLE IF EXISTS ${tbNameA} FORCE;"