Skip to content

Commit

Permalink
branch-2.1: [test](nereids) add test simplify comparison predicate #4…
Browse files Browse the repository at this point in the history
…4886 (#45804)

cherry pick from #44886
  • Loading branch information
yujun777 authored Dec 24, 2024
1 parent f0031d9 commit 1314a2b
Show file tree
Hide file tree
Showing 2 changed files with 333 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,23 @@
import org.apache.doris.nereids.trees.expressions.literal.DateV2Literal;
import org.apache.doris.nereids.trees.expressions.literal.DecimalV3Literal;
import org.apache.doris.nereids.trees.expressions.literal.DoubleLiteral;
import org.apache.doris.nereids.trees.expressions.literal.FloatLiteral;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.trees.expressions.literal.SmallIntLiteral;
import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.BooleanType;
import org.apache.doris.nereids.types.DateTimeType;
import org.apache.doris.nereids.types.DateTimeV2Type;
import org.apache.doris.nereids.types.DateType;
import org.apache.doris.nereids.types.DateV2Type;
import org.apache.doris.nereids.types.DecimalV3Type;
import org.apache.doris.nereids.types.DoubleType;
import org.apache.doris.nereids.types.FloatType;
import org.apache.doris.nereids.types.IntegerType;
import org.apache.doris.nereids.types.SmallIntType;
import org.apache.doris.nereids.types.TinyIntType;
import org.apache.doris.nereids.util.ExpressionUtils;

import com.google.common.collect.ImmutableList;
Expand Down Expand Up @@ -283,10 +292,163 @@ void testDoubleLiteral() {
Expression rewrittenExpression = executor.rewrite(expression, context);
Assertions.assertEquals(left.child(0).getDataType(), rewrittenExpression.child(1).getDataType());
Assertions.assertEquals(rewrittenExpression.child(0).getDataType(), rewrittenExpression.child(1).getDataType());

Expression tinyIntSlot = new SlotReference("a", TinyIntType.INSTANCE);
Expression smallIntSlot = new SlotReference("a", SmallIntType.INSTANCE);
Expression intSlot = new SlotReference("a", IntegerType.INSTANCE);
Expression bigIntSlot = new SlotReference("a", BigIntType.INSTANCE);

// tiny int, literal not exceeds data type limit
assertRewrite(new EqualTo(new Cast(tinyIntSlot, FloatType.INSTANCE), new FloatLiteral(12.0f)),
new EqualTo(tinyIntSlot, new TinyIntLiteral((byte) 12)));
assertRewrite(new EqualTo(new Cast(tinyIntSlot, DoubleType.INSTANCE), new DoubleLiteral(12.0f)),
new EqualTo(tinyIntSlot, new TinyIntLiteral((byte) 12)));
assertRewrite(new EqualTo(new Cast(tinyIntSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)),
ExpressionUtils.falseOrNull(tinyIntSlot));
assertRewrite(new NullSafeEqual(new Cast(tinyIntSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)),
BooleanLiteral.FALSE);
assertRewrite(new GreaterThan(new Cast(tinyIntSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)),
new GreaterThan(tinyIntSlot, new TinyIntLiteral((byte) 12)));
assertRewrite(new GreaterThanEqual(new Cast(tinyIntSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)),
new GreaterThanEqual(tinyIntSlot, new TinyIntLiteral((byte) 13)));
assertRewrite(new LessThan(new Cast(tinyIntSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)),
new LessThan(tinyIntSlot, new TinyIntLiteral((byte) 13)));
assertRewrite(new LessThanEqual(new Cast(tinyIntSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)),
new LessThanEqual(tinyIntSlot, new TinyIntLiteral((byte) 12)));

// small int
assertRewrite(new EqualTo(new Cast(smallIntSlot, FloatType.INSTANCE), new FloatLiteral(12.0f)),
new EqualTo(smallIntSlot, new SmallIntLiteral((short) 12)));
assertRewrite(new EqualTo(new Cast(smallIntSlot, DoubleType.INSTANCE), new DoubleLiteral(12.0f)),
new EqualTo(smallIntSlot, new SmallIntLiteral((short) 12)));
assertRewrite(new EqualTo(new Cast(smallIntSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)),
ExpressionUtils.falseOrNull(smallIntSlot));
assertRewrite(new NullSafeEqual(new Cast(smallIntSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)),
BooleanLiteral.FALSE);
assertRewrite(new GreaterThan(new Cast(smallIntSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)),
new GreaterThan(smallIntSlot, new SmallIntLiteral((short) 12)));
assertRewrite(new GreaterThanEqual(new Cast(smallIntSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)),
new GreaterThanEqual(smallIntSlot, new SmallIntLiteral((short) 13)));
assertRewrite(new LessThan(new Cast(smallIntSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)),
new LessThan(smallIntSlot, new SmallIntLiteral((short) 13)));
assertRewrite(new LessThanEqual(new Cast(smallIntSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)),
new LessThanEqual(smallIntSlot, new SmallIntLiteral((short) 12)));

// int
assertRewrite(new EqualTo(new Cast(intSlot, FloatType.INSTANCE), new FloatLiteral(12.0f)),
new EqualTo(intSlot, new IntegerLiteral(12)));
assertRewrite(new EqualTo(new Cast(intSlot, DoubleType.INSTANCE), new DoubleLiteral(12.0f)),
new EqualTo(intSlot, new IntegerLiteral(12)));
assertRewrite(new EqualTo(new Cast(intSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)),
ExpressionUtils.falseOrNull(intSlot));
assertRewrite(new NullSafeEqual(new Cast(intSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)),
BooleanLiteral.FALSE);
assertRewrite(new GreaterThan(new Cast(intSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)),
new GreaterThan(intSlot, new IntegerLiteral(12)));
assertRewrite(new GreaterThanEqual(new Cast(intSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)),
new GreaterThanEqual(intSlot, new IntegerLiteral(13)));
assertRewrite(new LessThan(new Cast(intSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)),
new LessThan(intSlot, new IntegerLiteral(13)));
assertRewrite(new LessThanEqual(new Cast(intSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)),
new LessThanEqual(intSlot, new IntegerLiteral(12)));

// big int
assertRewrite(new EqualTo(new Cast(bigIntSlot, FloatType.INSTANCE), new FloatLiteral(12.0f)),
new EqualTo(bigIntSlot, new BigIntLiteral(12L)));
assertRewrite(new EqualTo(new Cast(bigIntSlot, DoubleType.INSTANCE), new DoubleLiteral(12.0f)),
new EqualTo(bigIntSlot, new BigIntLiteral(12L)));
assertRewrite(new EqualTo(new Cast(bigIntSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)),
ExpressionUtils.falseOrNull(bigIntSlot));
assertRewrite(new NullSafeEqual(new Cast(bigIntSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)),
BooleanLiteral.FALSE);
assertRewrite(new GreaterThan(new Cast(bigIntSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)),
new GreaterThan(bigIntSlot, new BigIntLiteral(12L)));
assertRewrite(new GreaterThanEqual(new Cast(bigIntSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)),
new GreaterThanEqual(bigIntSlot, new BigIntLiteral(13L)));
assertRewrite(new LessThan(new Cast(bigIntSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)),
new LessThan(bigIntSlot, new BigIntLiteral(13L)));
assertRewrite(new LessThanEqual(new Cast(bigIntSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)),
new LessThanEqual(bigIntSlot, new BigIntLiteral(12L)));
}

@Test
void testIntCmpDecimalV3Literal() {
executor = new ExpressionRuleExecutor(ImmutableList.of(
bottomUp(SimplifyComparisonPredicate.INSTANCE)
));

Expression tinyIntSlot = new SlotReference("a", TinyIntType.INSTANCE);
Expression smallIntSlot = new SlotReference("a", SmallIntType.INSTANCE);
Expression intSlot = new SlotReference("a", IntegerType.INSTANCE);
Expression bigIntSlot = new SlotReference("a", BigIntType.INSTANCE);

// tiny int, literal not exceeds data type limit
assertRewrite(new EqualTo(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.0"))),
new EqualTo(tinyIntSlot, new TinyIntLiteral((byte) 12)));
assertRewrite(new EqualTo(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))),
ExpressionUtils.falseOrNull(tinyIntSlot));
assertRewrite(new NullSafeEqual(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))),
BooleanLiteral.FALSE);
assertRewrite(new GreaterThan(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))),
new GreaterThan(tinyIntSlot, new TinyIntLiteral((byte) 12)));
assertRewrite(new GreaterThanEqual(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))),
new GreaterThanEqual(tinyIntSlot, new TinyIntLiteral((byte) 13)));
assertRewrite(new LessThan(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))),
new LessThan(tinyIntSlot, new TinyIntLiteral((byte) 13)));
assertRewrite(new LessThanEqual(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))),
new LessThanEqual(tinyIntSlot, new TinyIntLiteral((byte) 12)));

// small int
assertRewrite(new EqualTo(new Cast(smallIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.0"))),
new EqualTo(smallIntSlot, new SmallIntLiteral((short) 12)));
assertRewrite(new EqualTo(new Cast(smallIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))),
ExpressionUtils.falseOrNull(smallIntSlot));
assertRewrite(new NullSafeEqual(new Cast(smallIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))),
BooleanLiteral.FALSE);
assertRewrite(new GreaterThan(new Cast(smallIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))),
new GreaterThan(smallIntSlot, new SmallIntLiteral((short) 12)));
assertRewrite(new GreaterThanEqual(new Cast(smallIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))),
new GreaterThanEqual(smallIntSlot, new SmallIntLiteral((short) 13)));
assertRewrite(new LessThan(new Cast(smallIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))),
new LessThan(smallIntSlot, new SmallIntLiteral((short) 13)));
assertRewrite(new LessThanEqual(new Cast(smallIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))),
new LessThanEqual(smallIntSlot, new SmallIntLiteral((short) 12)));

// int
assertRewrite(new EqualTo(new Cast(intSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.0"))),
new EqualTo(intSlot, new IntegerLiteral(12)));
assertRewrite(new EqualTo(new Cast(intSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))),
ExpressionUtils.falseOrNull(intSlot));
assertRewrite(new NullSafeEqual(new Cast(intSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))),
BooleanLiteral.FALSE);
assertRewrite(new GreaterThan(new Cast(intSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))),
new GreaterThan(intSlot, new IntegerLiteral(12)));
assertRewrite(new GreaterThanEqual(new Cast(intSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))),
new GreaterThanEqual(intSlot, new IntegerLiteral(13)));
assertRewrite(new LessThan(new Cast(intSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))),
new LessThan(intSlot, new IntegerLiteral(13)));
assertRewrite(new LessThanEqual(new Cast(intSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))),
new LessThanEqual(intSlot, new IntegerLiteral(12)));

// big int
assertRewrite(new EqualTo(new Cast(bigIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.0"))),
new EqualTo(bigIntSlot, new BigIntLiteral(12L)));
assertRewrite(new EqualTo(new Cast(bigIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))),
ExpressionUtils.falseOrNull(bigIntSlot));
assertRewrite(new NullSafeEqual(new Cast(bigIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))),
BooleanLiteral.FALSE);
assertRewrite(new GreaterThan(new Cast(bigIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))),
new GreaterThan(bigIntSlot, new BigIntLiteral(12L)));
assertRewrite(new GreaterThanEqual(new Cast(bigIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))),
new GreaterThanEqual(bigIntSlot, new BigIntLiteral(13L)));
assertRewrite(new LessThan(new Cast(bigIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))),
new LessThan(bigIntSlot, new BigIntLiteral(13L)));
assertRewrite(new LessThanEqual(new Cast(bigIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))),
new LessThanEqual(bigIntSlot, new BigIntLiteral(12L)));
}

@Test
void testDecimalV3Literal() {
void testDecimalCmpDecimalV3Literal() {
executor = new ExpressionRuleExecutor(ImmutableList.of(
bottomUp(SimplifyComparisonPredicate.INSTANCE)
));
Expand Down
Loading

0 comments on commit 1314a2b

Please sign in to comment.