Skip to content

Commit

Permalink
TopNToMax
Browse files Browse the repository at this point in the history
  • Loading branch information
924060929 committed Mar 16, 2024
1 parent f8dea99 commit 9b8a06d
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ public class ExpressionOptimization extends ExpressionRewrite {
new SimplifyRange(),
new DateFunctionRewrite(),
ArrayContainToArrayOverlap.INSTANCE,
CaseWhenToIf.INSTANCE
CaseWhenToIf.INSTANCE,
TopnToMax.INSTANCE
),
TopnToMax.INSTANCE,
NullSafeEqualToEqual.INSTANCE
);
private static final ExpressionRuleExecutor EXECUTOR = new ExpressionRuleExecutor(OPTIMIZE_REWRITE_RULES);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher;
import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory;
import org.apache.doris.nereids.rules.expression.ExpressionRewrite;
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.InPredicate;
Expand Down Expand Up @@ -67,10 +68,10 @@ public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
);
}

public Expression rewriteTree(Expression expr) {
public Expression rewriteTree(Expression expr, ExpressionRewriteContext context) {
return new ExpressionRewrite(
ExpressionRewrite.bottomUp(this)
).rewrite(expr, null);
).rewrite(expr, context);
}

public static Expression rewrite(Or or) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.apache.doris.catalog.PartitionItem;
import org.apache.doris.catalog.RangePartitionItem;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.Expression;
Expand Down Expand Up @@ -121,7 +122,8 @@ public static List<Long> prune(List<Slot> partitionSlots, Expression partitionPr
kv.getKey(), kv.getValue(), partitionSlots, cascadesContext, partitionTableType));
}

partitionPredicate = OrToIn.INSTANCE.rewriteTree(partitionPredicate);
partitionPredicate = OrToIn.INSTANCE.rewriteTree(
partitionPredicate, new ExpressionRewriteContext(cascadesContext));
PartitionPruner partitionPruner = new PartitionPruner(evaluators, partitionPredicate);
//TODO: we keep default partition because it's too hard to prune it, we return false in canPrune().
return partitionPruner.prune();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,39 +17,37 @@

package org.apache.doris.nereids.rules.expression.rules;

import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
import org.apache.doris.nereids.rules.expression.ExpressionRewriteRule;
import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher;
import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
import org.apache.doris.nereids.trees.expressions.functions.agg.TopN;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;

import com.google.common.collect.ImmutableList;

import java.util.List;

/**
* Convert topn(x, 1) to max(x)
*/
public class TopnToMax extends DefaultExpressionRewriter<ExpressionRewriteContext> implements
ExpressionRewriteRule<ExpressionRewriteContext> {
public class TopnToMax implements ExpressionPatternRuleFactory {

public static final TopnToMax INSTANCE = new TopnToMax();

@Override
public Expression rewrite(Expression expr, ExpressionRewriteContext ctx) {
return expr.accept(this, null);
public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
return ImmutableList.of(
matchesTopType(TopN.class).then(TopnToMax::rewrite)
);
}

@Override
public Expression visitAggregateFunction(AggregateFunction aggregateFunction, ExpressionRewriteContext context) {
if (!(aggregateFunction instanceof TopN)) {
return aggregateFunction;
}
TopN topN = (TopN) aggregateFunction;
public static Expression rewrite(TopN topN) {
if (topN.arity() == 2 && topN.child(1) instanceof IntegerLikeLiteral
&& ((IntegerLikeLiteral) topN.child(1)).getIntValue() == 1) {
return new Max(topN.child(0));
} else {
return aggregateFunction;
return topN;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@
class TopnToMaxTest extends ExpressionRewriteTestHelper {
@Test
void testSimplifyComparisonPredicateRule() {
executor = new ExpressionRuleExecutor(ImmutableList.of(TopnToMax.INSTANCE));
executor = new ExpressionRuleExecutor(ImmutableList.of(
bottomUp(TopnToMax.INSTANCE)
));

Slot slot = new SlotReference("a", StringType.INSTANCE);
assertRewrite(new TopN(slot, Literal.of(1)), new Max(slot));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class OrToInTest extends ExpressionRewriteTestHelper {
void test1() {
String expr = "col1 = 1 or col1 = 2 or col1 = 3 and (col2 = 4)";
Expression expression = PARSER.parseExpression(expr);
Expression rewritten = new OrToIn().rewrite(expression, new ExpressionRewriteContext(null));
Expression rewritten = new OrToIn().rewriteTree(expression, new ExpressionRewriteContext(null));
Set<InPredicate> inPredicates = rewritten.collect(e -> e instanceof InPredicate);
Assertions.assertEquals(1, inPredicates.size());
InPredicate inPredicate = inPredicates.iterator().next();
Expand All @@ -62,7 +62,7 @@ void test1() {
void test2() {
String expr = "col1 = 1 and col1 = 3 and col2 = 3 or col2 = 4";
Expression expression = PARSER.parseExpression(expr);
Expression rewritten = new OrToIn().rewrite(expression, new ExpressionRewriteContext(null));
Expression rewritten = new OrToIn().rewriteTree(expression, new ExpressionRewriteContext(null));
Assertions.assertEquals("((((col1 = 1) AND (col1 = 3)) AND (col2 = 3)) OR (col2 = 4))",
rewritten.toSql());
}
Expand All @@ -71,7 +71,7 @@ void test2() {
void test3() {
String expr = "(col1 = 1 or col1 = 2) and (col2 = 3 or col2 = 4)";
Expression expression = PARSER.parseExpression(expr);
Expression rewritten = new OrToIn().rewrite(expression, new ExpressionRewriteContext(null));
Expression rewritten = new OrToIn().rewriteTree(expression, new ExpressionRewriteContext(null));
List<InPredicate> inPredicates = rewritten.collectToList(e -> e instanceof InPredicate);
Assertions.assertEquals(2, inPredicates.size());
InPredicate in1 = inPredicates.get(0);
Expand All @@ -95,7 +95,7 @@ void test4() {
String expr = "case when col = 1 or col = 2 or col = 3 then 1"
+ " when col = 4 or col = 5 or col = 6 then 1 else 0 end";
Expression expression = PARSER.parseExpression(expr);
Expression rewritten = new OrToIn().rewrite(expression, new ExpressionRewriteContext(null));
Expression rewritten = new OrToIn().rewriteTree(expression, new ExpressionRewriteContext(null));
Assertions.assertEquals("CASE WHEN col IN (1, 2, 3) THEN 1 WHEN col IN (4, 5, 6) THEN 1 ELSE 0 END",
rewritten.toSql());
}
Expand All @@ -104,7 +104,7 @@ void test4() {
void test5() {
String expr = "col = 1 or (col = 2 and (col = 3 or col = 4 or col = 5))";
Expression expression = PARSER.parseExpression(expr);
Expression rewritten = new OrToIn().rewrite(expression, new ExpressionRewriteContext(null));
Expression rewritten = new OrToIn().rewriteTree(expression, new ExpressionRewriteContext(null));
Assertions.assertEquals("((col = 1) OR ((col = 2) AND col IN (3, 4, 5)))",
rewritten.toSql());
}
Expand All @@ -113,23 +113,23 @@ void test5() {
void test6() {
String expr = "col = 1 or col = 2 or col in (1, 2, 3)";
Expression expression = PARSER.parseExpression(expr);
Expression rewritten = new OrToIn().rewrite(expression, new ExpressionRewriteContext(null));
Expression rewritten = new OrToIn().rewriteTree(expression, new ExpressionRewriteContext(null));
Assertions.assertEquals("col IN (1, 2, 3)", rewritten.toSql());
}

@Test
void test7() {
String expr = "A = 1 or A = 2 or abs(A)=5 or A in (1, 2, 3) or B = 1 or B = 2 or B in (1, 2, 3) or B+1 in (4, 5, 7)";
Expression expression = PARSER.parseExpression(expr);
Expression rewritten = new OrToIn().rewrite(expression, new ExpressionRewriteContext(null));
Expression rewritten = new OrToIn().rewriteTree(expression, new ExpressionRewriteContext(null));
Assertions.assertEquals("(((A IN (1, 2, 3) OR B IN (1, 2, 3)) OR (abs(A) = 5)) OR (B + 1) IN (4, 5, 7))", rewritten.toSql());
}

@Test
void test8() {
String expr = "col = 1 or (col = 2 and (col = 3 or col = '4' or col = 5.0))";
Expression expression = PARSER.parseExpression(expr);
Expression rewritten = new OrToIn().rewrite(expression, new ExpressionRewriteContext(null));
Expression rewritten = new OrToIn().rewriteTree(expression, new ExpressionRewriteContext(null));
Assertions.assertEquals("((col = 1) OR ((col = 2) AND col IN ('4', 3, 5.0)))",
rewritten.toSql());
}
Expand All @@ -139,7 +139,7 @@ void testEnsureOrder() {
// ensure not rewrite to col2 in (1, 2) or cor 1 in (1, 2)
String expr = "col1 IN (1, 2) OR col2 IN (1, 2)";
Expression expression = PARSER.parseExpression(expr);
Expression rewritten = new OrToIn().rewrite(expression, new ExpressionRewriteContext(null));
Expression rewritten = new OrToIn().rewriteTree(expression, new ExpressionRewriteContext(null));
Assertions.assertEquals("(col1 IN (1, 2) OR col2 IN (1, 2))",
rewritten.toSql());
}
Expand Down

0 comments on commit 9b8a06d

Please sign in to comment.