diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionOptimization.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionOptimization.java index 49e2643c90a5ac7..7bd37d6df2dec15 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionOptimization.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionOptimization.java @@ -49,9 +49,9 @@ public class ExpressionOptimization extends ExpressionRewrite { new SimplifyRange(), new DateFunctionRewrite(), ArrayContainToArrayOverlap.INSTANCE, - CaseWhenToIf.INSTANCE + CaseWhenToIf.INSTANCE, + TopnToMax.INSTANCE ), - TopnToMax.INSTANCE, NullSafeEqualToEqual.INSTANCE ); private static final ExpressionRuleExecutor EXECUTOR = new ExpressionRuleExecutor(OPTIMIZE_REWRITE_RULES); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/OrToIn.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/OrToIn.java index 2f834484e8622d0..c6c425f78be71a9 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/OrToIn.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/OrToIn.java @@ -20,6 +20,7 @@ import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; import org.apache.doris.nereids.rules.expression.ExpressionRewrite; +import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.InPredicate; @@ -67,10 +68,10 @@ public List> buildRules() { ); } - public Expression rewriteTree(Expression expr) { + public Expression rewriteTree(Expression expr, ExpressionRewriteContext context) { return new ExpressionRewrite( ExpressionRewrite.bottomUp(this) - ).rewrite(expr, null); + ).rewrite(expr, context); } public static Expression rewrite(Or or) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/PartitionPruner.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/PartitionPruner.java index 8626008342b041d..147f82913a0a464 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/PartitionPruner.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/PartitionPruner.java @@ -21,6 +21,7 @@ import org.apache.doris.catalog.PartitionItem; import org.apache.doris.catalog.RangePartitionItem; import org.apache.doris.nereids.CascadesContext; +import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.ComparisonPredicate; import org.apache.doris.nereids.trees.expressions.Expression; @@ -121,7 +122,8 @@ public static List prune(List partitionSlots, Expression partitionPr kv.getKey(), kv.getValue(), partitionSlots, cascadesContext, partitionTableType)); } - partitionPredicate = OrToIn.INSTANCE.rewriteTree(partitionPredicate); + partitionPredicate = OrToIn.INSTANCE.rewriteTree( + partitionPredicate, new ExpressionRewriteContext(cascadesContext)); PartitionPruner partitionPruner = new PartitionPruner(evaluators, partitionPredicate); //TODO: we keep default partition because it's too hard to prune it, we return false in canPrune(). return partitionPruner.prune(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/TopnToMax.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/TopnToMax.java index 30e76cfe226f5b4..c46cbc831e61b2e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/TopnToMax.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/TopnToMax.java @@ -17,39 +17,37 @@ package org.apache.doris.nereids.rules.expression.rules; -import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; -import org.apache.doris.nereids.rules.expression.ExpressionRewriteRule; +import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; +import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; import org.apache.doris.nereids.trees.expressions.functions.agg.Max; import org.apache.doris.nereids.trees.expressions.functions.agg.TopN; import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral; -import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter; + +import com.google.common.collect.ImmutableList; + +import java.util.List; /** * Convert topn(x, 1) to max(x) */ -public class TopnToMax extends DefaultExpressionRewriter implements - ExpressionRewriteRule { +public class TopnToMax implements ExpressionPatternRuleFactory { public static final TopnToMax INSTANCE = new TopnToMax(); @Override - public Expression rewrite(Expression expr, ExpressionRewriteContext ctx) { - return expr.accept(this, null); + public List> buildRules() { + return ImmutableList.of( + matchesTopType(TopN.class).then(TopnToMax::rewrite) + ); } - @Override - public Expression visitAggregateFunction(AggregateFunction aggregateFunction, ExpressionRewriteContext context) { - if (!(aggregateFunction instanceof TopN)) { - return aggregateFunction; - } - TopN topN = (TopN) aggregateFunction; + public static Expression rewrite(TopN topN) { if (topN.arity() == 2 && topN.child(1) instanceof IntegerLikeLiteral && ((IntegerLikeLiteral) topN.child(1)).getIntValue() == 1) { return new Max(topN.child(0)); } else { - return aggregateFunction; + return topN; } } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/TopnToMaxTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/TopnToMaxTest.java index 8f3c682de5b8c6d..c0595136614cf2a 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/TopnToMaxTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/TopnToMaxTest.java @@ -32,7 +32,9 @@ class TopnToMaxTest extends ExpressionRewriteTestHelper { @Test void testSimplifyComparisonPredicateRule() { - executor = new ExpressionRuleExecutor(ImmutableList.of(TopnToMax.INSTANCE)); + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(TopnToMax.INSTANCE) + )); Slot slot = new SlotReference("a", StringType.INSTANCE); assertRewrite(new TopN(slot, Literal.of(1)), new Max(slot)); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/OrToInTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/OrToInTest.java index f92e679411307ff..4127390c6764bed 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/OrToInTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/OrToInTest.java @@ -39,7 +39,7 @@ class OrToInTest extends ExpressionRewriteTestHelper { void test1() { String expr = "col1 = 1 or col1 = 2 or col1 = 3 and (col2 = 4)"; Expression expression = PARSER.parseExpression(expr); - Expression rewritten = new OrToIn().rewrite(expression, new ExpressionRewriteContext(null)); + Expression rewritten = new OrToIn().rewriteTree(expression, new ExpressionRewriteContext(null)); Set inPredicates = rewritten.collect(e -> e instanceof InPredicate); Assertions.assertEquals(1, inPredicates.size()); InPredicate inPredicate = inPredicates.iterator().next(); @@ -62,7 +62,7 @@ void test1() { void test2() { String expr = "col1 = 1 and col1 = 3 and col2 = 3 or col2 = 4"; Expression expression = PARSER.parseExpression(expr); - Expression rewritten = new OrToIn().rewrite(expression, new ExpressionRewriteContext(null)); + Expression rewritten = new OrToIn().rewriteTree(expression, new ExpressionRewriteContext(null)); Assertions.assertEquals("((((col1 = 1) AND (col1 = 3)) AND (col2 = 3)) OR (col2 = 4))", rewritten.toSql()); } @@ -71,7 +71,7 @@ void test2() { void test3() { String expr = "(col1 = 1 or col1 = 2) and (col2 = 3 or col2 = 4)"; Expression expression = PARSER.parseExpression(expr); - Expression rewritten = new OrToIn().rewrite(expression, new ExpressionRewriteContext(null)); + Expression rewritten = new OrToIn().rewriteTree(expression, new ExpressionRewriteContext(null)); List inPredicates = rewritten.collectToList(e -> e instanceof InPredicate); Assertions.assertEquals(2, inPredicates.size()); InPredicate in1 = inPredicates.get(0); @@ -95,7 +95,7 @@ void test4() { String expr = "case when col = 1 or col = 2 or col = 3 then 1" + " when col = 4 or col = 5 or col = 6 then 1 else 0 end"; Expression expression = PARSER.parseExpression(expr); - Expression rewritten = new OrToIn().rewrite(expression, new ExpressionRewriteContext(null)); + Expression rewritten = new OrToIn().rewriteTree(expression, new ExpressionRewriteContext(null)); Assertions.assertEquals("CASE WHEN col IN (1, 2, 3) THEN 1 WHEN col IN (4, 5, 6) THEN 1 ELSE 0 END", rewritten.toSql()); } @@ -104,7 +104,7 @@ void test4() { void test5() { String expr = "col = 1 or (col = 2 and (col = 3 or col = 4 or col = 5))"; Expression expression = PARSER.parseExpression(expr); - Expression rewritten = new OrToIn().rewrite(expression, new ExpressionRewriteContext(null)); + Expression rewritten = new OrToIn().rewriteTree(expression, new ExpressionRewriteContext(null)); Assertions.assertEquals("((col = 1) OR ((col = 2) AND col IN (3, 4, 5)))", rewritten.toSql()); } @@ -113,7 +113,7 @@ void test5() { void test6() { String expr = "col = 1 or col = 2 or col in (1, 2, 3)"; Expression expression = PARSER.parseExpression(expr); - Expression rewritten = new OrToIn().rewrite(expression, new ExpressionRewriteContext(null)); + Expression rewritten = new OrToIn().rewriteTree(expression, new ExpressionRewriteContext(null)); Assertions.assertEquals("col IN (1, 2, 3)", rewritten.toSql()); } @@ -121,7 +121,7 @@ void test6() { void test7() { String expr = "A = 1 or A = 2 or abs(A)=5 or A in (1, 2, 3) or B = 1 or B = 2 or B in (1, 2, 3) or B+1 in (4, 5, 7)"; Expression expression = PARSER.parseExpression(expr); - Expression rewritten = new OrToIn().rewrite(expression, new ExpressionRewriteContext(null)); + Expression rewritten = new OrToIn().rewriteTree(expression, new ExpressionRewriteContext(null)); Assertions.assertEquals("(((A IN (1, 2, 3) OR B IN (1, 2, 3)) OR (abs(A) = 5)) OR (B + 1) IN (4, 5, 7))", rewritten.toSql()); } @@ -129,7 +129,7 @@ void test7() { void test8() { String expr = "col = 1 or (col = 2 and (col = 3 or col = '4' or col = 5.0))"; Expression expression = PARSER.parseExpression(expr); - Expression rewritten = new OrToIn().rewrite(expression, new ExpressionRewriteContext(null)); + Expression rewritten = new OrToIn().rewriteTree(expression, new ExpressionRewriteContext(null)); Assertions.assertEquals("((col = 1) OR ((col = 2) AND col IN ('4', 3, 5.0)))", rewritten.toSql()); } @@ -139,7 +139,7 @@ void testEnsureOrder() { // ensure not rewrite to col2 in (1, 2) or cor 1 in (1, 2) String expr = "col1 IN (1, 2) OR col2 IN (1, 2)"; Expression expression = PARSER.parseExpression(expr); - Expression rewritten = new OrToIn().rewrite(expression, new ExpressionRewriteContext(null)); + Expression rewritten = new OrToIn().rewriteTree(expression, new ExpressionRewriteContext(null)); Assertions.assertEquals("(col1 IN (1, 2) OR col2 IN (1, 2))", rewritten.toSql()); }