Skip to content

Commit

Permalink
save
Browse files Browse the repository at this point in the history
  • Loading branch information
924060929 committed Mar 16, 2024
1 parent 9b8a06d commit 9505ebd
Show file tree
Hide file tree
Showing 17 changed files with 70 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,19 @@
public class ExpressionOptimization extends ExpressionRewrite {
public static final List<ExpressionRewriteRule> OPTIMIZE_REWRITE_RULES = ImmutableList.of(
bottomUp(
new ExtractCommonFactorRule(),
new DistinctPredicatesRule(),
new SimplifyComparisonPredicate(),
new SimplifyInPredicate(),
new SimplifyDecimalV3Comparison(),
ExtractCommonFactorRule.INSTANCE,
DistinctPredicatesRule.INSTANCE,
SimplifyComparisonPredicate.INSTANCE,
SimplifyInPredicate.INSTANCE,
SimplifyDecimalV3Comparison.INSTANCE,
OrToIn.INSTANCE,
new SimplifyRange(),
new DateFunctionRewrite(),
SimplifyRange.INSTANCE,
DateFunctionRewrite.INSTANCE,
ArrayContainToArrayOverlap.INSTANCE,
CaseWhenToIf.INSTANCE,
TopnToMax.INSTANCE
),
NullSafeEqualToEqual.INSTANCE
TopnToMax.INSTANCE,
NullSafeEqualToEqual.INSTANCE
)
);
private static final ExpressionRuleExecutor EXECUTOR = new ExpressionRuleExecutor(OPTIMIZE_REWRITE_RULES);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
*
*/
public class DateFunctionRewrite implements ExpressionPatternRuleFactory {
public static DateFunctionRewrite INSTANCE = new DateFunctionRewrite();

@Override
public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
return ImmutableList.of(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
* transform (a = 1) or (a = 1) to (a = 1)
*/
public class DistinctPredicatesRule implements ExpressionPatternRuleFactory {
public static final DistinctPredicatesRule INSTANCE = new DistinctPredicatesRule();

@Override
public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
return ImmutableList.of(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
*/
@Developing
public class ExtractCommonFactorRule implements ExpressionPatternRuleFactory {
public static final ExtractCommonFactorRule INSTANCE = new ExtractCommonFactorRule();

@Override
public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,31 +17,34 @@

package org.apache.doris.nereids.rules.expression.rules;

import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
import org.apache.doris.nereids.rules.expression.ExpressionRewriteRule;
import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher;
import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.IsNull;
import org.apache.doris.nereids.trees.expressions.NullSafeEqual;
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;

import com.google.common.collect.ImmutableList;

import java.util.List;

/**
* convert "<=>" to "=", if any side is not nullable
* convert "A <=> null" to "A is null"
*/
public class NullSafeEqualToEqual extends DefaultExpressionRewriter<ExpressionRewriteContext> implements
ExpressionRewriteRule<ExpressionRewriteContext> {
public class NullSafeEqualToEqual implements ExpressionPatternRuleFactory {
public static final NullSafeEqualToEqual INSTANCE = new NullSafeEqualToEqual();

@Override
public Expression rewrite(Expression expr, ExpressionRewriteContext ctx) {
return expr.accept(this, null);
public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
return ImmutableList.of(
matchesType(NullSafeEqual.class).then(NullSafeEqualToEqual::rewrite)
);
}

@Override
public Expression visitNullSafeEqual(NullSafeEqual nullSafeEqual, ExpressionRewriteContext ctx) {
public static Expression rewrite(NullSafeEqual nullSafeEqual) {
if (nullSafeEqual.left() instanceof NullLiteral) {
if (nullSafeEqual.right().nullable()) {
return new IsNull(nullSafeEqual.right());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@
* - cast(cast(1 as bigint) as string) -> cast(1 as string).
*/
public class SimplifyCastRule implements ExpressionPatternRuleFactory {
public static SimplifyCastRule INSTANCE = new SimplifyCastRule();

@Override
public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
return ImmutableList.of(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
* cast(c1 AS double) > 2.0 --> c1 >= 2 (c1 is integer like type)
*/
public class SimplifyComparisonPredicate extends AbstractExpressionRewriteRule implements ExpressionPatternRuleFactory {
public static SimplifyComparisonPredicate INSTANCE = new SimplifyComparisonPredicate();

enum AdjustType {
LOWER,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
* this rule simplify it from cast(col1 as decimalv3(27, 9)) > 0.6 to col1 > 0.6
*/
public class SimplifyDecimalV3Comparison implements ExpressionPatternRuleFactory {
public static SimplifyDecimalV3Comparison INSTANCE = new SimplifyDecimalV3Comparison();

@Override
public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
return ImmutableList.of(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
* SimplifyInPredicate
*/
public class SimplifyInPredicate implements ExpressionPatternRuleFactory {
public static final SimplifyInPredicate INSTANCE = new SimplifyInPredicate();

@Override
public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
return ImmutableList.of(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@
* todo: support a > 10 and (a < 10 or a > 20 ) => a > 20
*/
public class SimplifyRange implements ExpressionPatternRuleFactory {
public static final SimplifyRange INSTANCE = new SimplifyRange();

@Override
public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
return ImmutableList.of(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ void testNormalizeExpressionRewrite() {
@Test
void testDistinctPredicatesRewrite() {
executor = new ExpressionRuleExecutor(ImmutableList.of(
bottomUp(new DistinctPredicatesRule())
bottomUp(DistinctPredicatesRule.INSTANCE)
));

assertRewrite("a = 1", "a = 1");
Expand All @@ -110,7 +110,7 @@ void testDistinctPredicatesRewrite() {
@Test
void testExtractCommonFactorRewrite() {
executor = new ExpressionRuleExecutor(ImmutableList.of(
bottomUp(new ExtractCommonFactorRule())
bottomUp(ExtractCommonFactorRule.INSTANCE)
));

assertRewrite("a", "a");
Expand Down Expand Up @@ -226,7 +226,7 @@ void testSimplifyCastRule() {
@Test
void testSimplifyDecimalV3Comparison() {
executor = new ExpressionRuleExecutor(ImmutableList.of(
bottomUp(new SimplifyDecimalV3Comparison())
bottomUp(SimplifyDecimalV3Comparison.INSTANCE)
));

// do rewrite
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public void test() {
executor = new ExpressionRuleExecutor(ImmutableList.of(
FoldConstantRule.INSTANCE,
bottomUp(
new SimplifyInPredicate()
SimplifyInPredicate.INSTANCE
)
));
Map<String, Slot> mem = Maps.newHashMap();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ public SimplifyRangeTest() {
@Test
public void testSimplify() {
executor = new ExpressionRuleExecutor(ImmutableList.of(
bottomUp(new SimplifyRange())
bottomUp(SimplifyRange.INSTANCE)
));
assertRewrite("TA", "TA");
assertRewrite("(TA >= 1 and TA <=3 ) or (TA > 5 and TA < 7)", "(TA >= 1 and TA <=3 ) or (TA > 5 and TA < 7)");
Expand Down Expand Up @@ -163,7 +163,7 @@ public void testSimplify() {
@Test
public void testSimplifyDate() {
executor = new ExpressionRuleExecutor(ImmutableList.of(
bottomUp(new SimplifyRange())
bottomUp(SimplifyRange.INSTANCE)
));
assertRewrite("TA", "TA");
assertRewrite(
Expand Down Expand Up @@ -231,7 +231,7 @@ public void testSimplifyDate() {
@Test
public void testSimplifyDateTime() {
executor = new ExpressionRuleExecutor(ImmutableList.of(
bottomUp(new SimplifyRange())
bottomUp(SimplifyRange.INSTANCE)
));
assertRewrite("TA", "TA");
assertRewrite(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,23 +36,29 @@ class NullSafeEqualToEqualTest extends ExpressionRewriteTestHelper {
// "A<=> Null" to "A is null"
@Test
void testNullSafeEqualToIsNull() {
executor = new ExpressionRuleExecutor(ImmutableList.of(NullSafeEqualToEqual.INSTANCE));
executor = new ExpressionRuleExecutor(ImmutableList.of(
bottomUp(NullSafeEqualToEqual.INSTANCE)
));
SlotReference slot = new SlotReference("a", StringType.INSTANCE, true);
assertRewrite(new NullSafeEqual(slot, NullLiteral.INSTANCE), new IsNull(slot));
}

// "A<=> Null" to "False", when A is not nullable
@Test
void testNullSafeEqualToFalse() {
executor = new ExpressionRuleExecutor(ImmutableList.of(NullSafeEqualToEqual.INSTANCE));
executor = new ExpressionRuleExecutor(ImmutableList.of(
bottomUp(NullSafeEqualToEqual.INSTANCE)
));
SlotReference slot = new SlotReference("a", StringType.INSTANCE, false);
assertRewrite(new NullSafeEqual(slot, NullLiteral.INSTANCE), BooleanLiteral.FALSE);
}

// "A<=> "abc" to "A = "abc"
@Test
void testNullSafeEqualToEqual() {
executor = new ExpressionRuleExecutor(ImmutableList.of(NullSafeEqualToEqual.INSTANCE));
executor = new ExpressionRuleExecutor(ImmutableList.of(
bottomUp(NullSafeEqualToEqual.INSTANCE)
));
SlotReference slot = new SlotReference("a", StringType.INSTANCE, true);
StringLiteral str = new StringLiteral("abc");
assertRewrite(new NullSafeEqual(slot, str), new EqualTo(slot, str));
Expand All @@ -61,7 +67,9 @@ void testNullSafeEqualToEqual() {
// "A<=>B" not changed
@Test
void testNullSafeEqualNotChanged() {
executor = new ExpressionRuleExecutor(ImmutableList.of(NullSafeEqualToEqual.INSTANCE));
executor = new ExpressionRuleExecutor(ImmutableList.of(
bottomUp(NullSafeEqualToEqual.INSTANCE)
));
SlotReference a = new SlotReference("a", StringType.INSTANCE, true);
SlotReference b = new SlotReference("b", StringType.INSTANCE, true);
assertRewrite(new NullSafeEqual(a, b), new NullSafeEqual(a, b));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ void testSimplifyComparisonPredicateRule() {
executor = new ExpressionRuleExecutor(ImmutableList.of(
bottomUp(
new SimplifyCastRule(),
new SimplifyComparisonPredicate()
SimplifyComparisonPredicate.INSTANCE
)
));

Expand Down Expand Up @@ -94,7 +94,7 @@ void testDateTimeV2CmpDateTimeV2() {
executor = new ExpressionRuleExecutor(ImmutableList.of(
bottomUp(
new SimplifyCastRule(),
new SimplifyComparisonPredicate()
SimplifyComparisonPredicate.INSTANCE
)
));

Expand All @@ -120,10 +120,9 @@ void testDateTimeV2CmpDateTimeV2() {
void testRound() {
executor = new ExpressionRuleExecutor(ImmutableList.of(
bottomUp(
new SimplifyCastRule(),
new SimplifyComparisonPredicate()
SimplifyCastRule.INSTANCE,
SimplifyComparisonPredicate.INSTANCE
)
// new SimplifyComparisonPredicate()
));

Expression left = new Cast(new DateTimeLiteral("2021-01-02 00:00:00.00"), DateTimeV2Type.of(1));
Expand All @@ -139,7 +138,7 @@ void testRound() {
@Test
void testDoubleLiteral() {
executor = new ExpressionRuleExecutor(ImmutableList.of(
bottomUp(new SimplifyComparisonPredicate())
bottomUp(SimplifyComparisonPredicate.INSTANCE)
));

Expression leftChild = new BigIntLiteral(999);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public void testSimplifyDecimalV3Comparison() {
Map<String, Slot> nameToSlot = new HashMap<>();
nameToSlot.put("col1", new SlotReference("col1", DecimalV3Type.createDecimalV3Type(15, 2)));
executor = new ExpressionRuleExecutor(ImmutableList.of(
bottomUp(new SimplifyDecimalV3Comparison())
bottomUp(SimplifyDecimalV3Comparison.INSTANCE)
));
assertRewriteAfterSimplify("cast(col1 as decimalv3(27, 9)) > 0.6", "cast(col1 as decimalv3(27, 9)) > 0.6", nameToSlot);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class OrToInTest extends ExpressionRewriteTestHelper {
void test1() {
String expr = "col1 = 1 or col1 = 2 or col1 = 3 and (col2 = 4)";
Expression expression = PARSER.parseExpression(expr);
Expression rewritten = new OrToIn().rewriteTree(expression, new ExpressionRewriteContext(null));
Expression rewritten = OrToIn.INSTANCE.rewriteTree(expression, new ExpressionRewriteContext(null));
Set<InPredicate> inPredicates = rewritten.collect(e -> e instanceof InPredicate);
Assertions.assertEquals(1, inPredicates.size());
InPredicate inPredicate = inPredicates.iterator().next();
Expand All @@ -62,7 +62,7 @@ void test1() {
void test2() {
String expr = "col1 = 1 and col1 = 3 and col2 = 3 or col2 = 4";
Expression expression = PARSER.parseExpression(expr);
Expression rewritten = new OrToIn().rewriteTree(expression, new ExpressionRewriteContext(null));
Expression rewritten = OrToIn.INSTANCE.rewriteTree(expression, new ExpressionRewriteContext(null));
Assertions.assertEquals("((((col1 = 1) AND (col1 = 3)) AND (col2 = 3)) OR (col2 = 4))",
rewritten.toSql());
}
Expand All @@ -71,7 +71,7 @@ void test2() {
void test3() {
String expr = "(col1 = 1 or col1 = 2) and (col2 = 3 or col2 = 4)";
Expression expression = PARSER.parseExpression(expr);
Expression rewritten = new OrToIn().rewriteTree(expression, new ExpressionRewriteContext(null));
Expression rewritten = OrToIn.INSTANCE.rewriteTree(expression, new ExpressionRewriteContext(null));
List<InPredicate> inPredicates = rewritten.collectToList(e -> e instanceof InPredicate);
Assertions.assertEquals(2, inPredicates.size());
InPredicate in1 = inPredicates.get(0);
Expand All @@ -95,7 +95,7 @@ void test4() {
String expr = "case when col = 1 or col = 2 or col = 3 then 1"
+ " when col = 4 or col = 5 or col = 6 then 1 else 0 end";
Expression expression = PARSER.parseExpression(expr);
Expression rewritten = new OrToIn().rewriteTree(expression, new ExpressionRewriteContext(null));
Expression rewritten = OrToIn.INSTANCE.rewriteTree(expression, new ExpressionRewriteContext(null));
Assertions.assertEquals("CASE WHEN col IN (1, 2, 3) THEN 1 WHEN col IN (4, 5, 6) THEN 1 ELSE 0 END",
rewritten.toSql());
}
Expand All @@ -104,7 +104,7 @@ void test4() {
void test5() {
String expr = "col = 1 or (col = 2 and (col = 3 or col = 4 or col = 5))";
Expression expression = PARSER.parseExpression(expr);
Expression rewritten = new OrToIn().rewriteTree(expression, new ExpressionRewriteContext(null));
Expression rewritten = OrToIn.INSTANCE.rewriteTree(expression, new ExpressionRewriteContext(null));
Assertions.assertEquals("((col = 1) OR ((col = 2) AND col IN (3, 4, 5)))",
rewritten.toSql());
}
Expand All @@ -113,23 +113,23 @@ void test5() {
void test6() {
String expr = "col = 1 or col = 2 or col in (1, 2, 3)";
Expression expression = PARSER.parseExpression(expr);
Expression rewritten = new OrToIn().rewriteTree(expression, new ExpressionRewriteContext(null));
Expression rewritten = OrToIn.INSTANCE.rewriteTree(expression, new ExpressionRewriteContext(null));
Assertions.assertEquals("col IN (1, 2, 3)", rewritten.toSql());
}

@Test
void test7() {
String expr = "A = 1 or A = 2 or abs(A)=5 or A in (1, 2, 3) or B = 1 or B = 2 or B in (1, 2, 3) or B+1 in (4, 5, 7)";
Expression expression = PARSER.parseExpression(expr);
Expression rewritten = new OrToIn().rewriteTree(expression, new ExpressionRewriteContext(null));
Expression rewritten = OrToIn.INSTANCE.rewriteTree(expression, new ExpressionRewriteContext(null));
Assertions.assertEquals("(((A IN (1, 2, 3) OR B IN (1, 2, 3)) OR (abs(A) = 5)) OR (B + 1) IN (4, 5, 7))", rewritten.toSql());
}

@Test
void test8() {
String expr = "col = 1 or (col = 2 and (col = 3 or col = '4' or col = 5.0))";
Expression expression = PARSER.parseExpression(expr);
Expression rewritten = new OrToIn().rewriteTree(expression, new ExpressionRewriteContext(null));
Expression rewritten = OrToIn.INSTANCE.rewriteTree(expression, new ExpressionRewriteContext(null));
Assertions.assertEquals("((col = 1) OR ((col = 2) AND col IN ('4', 3, 5.0)))",
rewritten.toSql());
}
Expand All @@ -139,7 +139,7 @@ void testEnsureOrder() {
// ensure not rewrite to col2 in (1, 2) or cor 1 in (1, 2)
String expr = "col1 IN (1, 2) OR col2 IN (1, 2)";
Expression expression = PARSER.parseExpression(expr);
Expression rewritten = new OrToIn().rewriteTree(expression, new ExpressionRewriteContext(null));
Expression rewritten = OrToIn.INSTANCE.rewriteTree(expression, new ExpressionRewriteContext(null));
Assertions.assertEquals("(col1 IN (1, 2) OR col2 IN (1, 2))",
rewritten.toSql());
}
Expand Down

0 comments on commit 9505ebd

Please sign in to comment.