Skip to content

Commit

Permalink
branch-2.1: [test](nereids) add arthmetic comparison ut #45690 (#45893)
Browse files Browse the repository at this point in the history
cherry pick from #45690
  • Loading branch information
yujun777 authored Dec 27, 2024
1 parent fcc4d0d commit ae68580
Show file tree
Hide file tree
Showing 2 changed files with 150 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.BooleanType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.DateTimeV2Type;
import org.apache.doris.nereids.types.DateV2Type;
import org.apache.doris.nereids.types.DecimalV3Type;
import org.apache.doris.nereids.types.DoubleType;
Expand Down Expand Up @@ -131,6 +132,8 @@ protected DataType getType(char t) {
return BooleanType.INSTANCE;
case 'C':
return DateV2Type.INSTANCE;
case 'A':
return DateTimeV2Type.SYSTEM_DEFAULT;
case 'M':
return DecimalV3Type.SYSTEM_DEFAULT;
default:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,70 +21,171 @@
import org.apache.doris.nereids.rules.expression.ExpressionRewriteTestHelper;
import org.apache.doris.nereids.rules.expression.ExpressionRuleExecutor;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.types.IntegerType;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Maps;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

import java.util.HashMap;
import java.util.Map;

class SimplifyArithmeticComparisonRuleTest extends ExpressionRewriteTestHelper {

@Test
public void testProcess() {
Map<String, Slot> nameToSlot = new HashMap<>();
nameToSlot.put("a", new SlotReference("a", IntegerType.INSTANCE));
public void testNumeric() {
executor = new ExpressionRuleExecutor(ImmutableList.of(
ExpressionRewrite.bottomUp(SimplifyArithmeticComparisonRule.INSTANCE)
));
assertRewriteAfterSimplify("a + 1 > 1", "a > cast((1 - 1) as INT)", nameToSlot);
assertRewriteAfterSimplify("a - 1 > 1", "a > cast((1 + 1) as INT)", nameToSlot);
assertRewriteAfterSimplify("a / -2 > 1", "cast((1 * -2) as INT) > a", nameToSlot);

// test tinyint type
assertRewriteAfterSimplify("TA + 2 > 1", "cast(TA as SMALLINT) > (1 - 2)");
assertRewriteAfterSimplify("TA - 2 > 1", "cast(TA as SMALLINT) > (1 + 2)");
assertRewriteAfterSimplify("1 + TA > 2", "cast(TA as SMALLINT) > (2 - 1)");
assertRewriteAfterSimplify("-1 + TA > 2", "cast(TA as SMALLINT) > (2 - (-1))");
assertRewriteAfterSimplify("1 - TA > 2", "cast(TA as SMALLINT) < (1 - 2))");
assertRewriteAfterSimplify("-1 - TA > 2", "cast(TA as SMALLINT) < ((-1) - 2)");
assertRewriteAfterSimplify("2 * TA > 1", "((2 * TA) > 1)");
assertRewriteAfterSimplify("-2 * TA > 1", "((-2 * TA) > 1)");
assertRewriteAfterSimplify("2 / TA > 1", "((2 / TA) > 1)");
assertRewriteAfterSimplify("-2 / TA > 1", "((-2 / TA) > 1)");
assertRewriteAfterSimplify("TA * 2 > 1", "((TA * 2) > 1)");
assertRewriteAfterSimplify("TA * (-2) > 1", "((TA * (-2)) > 1)");
assertRewriteAfterSimplify("TA / 2 > 1", "cast(TA as SMALLINT) > (1 * 2)");
assertRewriteAfterSimplify("TA / -2 > 1", "(1 * -2) > cast(TA as SMALLINT)");

// test integer type
assertRewriteAfterSimplify("IA + 2 > 1", "IA > cast((1 - 2) as INT)");
assertRewriteAfterSimplify("IA - 2 > 1", "IA > cast((1 + 2) as INT)");
assertRewriteAfterSimplify("1 + IA > 2", "IA > cast((2 - 1) as INT)");
assertRewriteAfterSimplify("-1 + IA > 2", "IA > cast((2 - (-1)) as INT)");
assertRewriteAfterSimplify("1 - IA > 2", "IA < cast((1 - 2) as INT)");
assertRewriteAfterSimplify("-1 - IA > 2", "IA < cast(((-1) - 2) as INT)");
assertRewriteAfterSimplify("2 * IA > 1", "((2 * IA) > 1)");
assertRewriteAfterSimplify("-2 * IA > 1", "((-2 * IA) > 1)");
assertRewriteAfterSimplify("2 / IA > 1", "((2 / IA) > 1)");
assertRewriteAfterSimplify("-2 / IA > 1", "((-2 / IA) > 1)");
assertRewriteAfterSimplify("IA * 2 > 1", "((IA * 2) > 1)");
assertRewriteAfterSimplify("IA * (-2) > 1", "((IA * (-2)) > 1)");
assertRewriteAfterSimplify("IA / 2 > 1", "(IA > cast((1 * 2) as INT))");
assertRewriteAfterSimplify("IA / -2 > 1", "cast((1 * -2) as INT) > IA");

// test integer type
assertRewriteAfterSimplify("1 + a > 2", "a > cast((2 - 1) as INT)", nameToSlot);
assertRewriteAfterSimplify("-1 + a > 2", "a > cast((2 - (-1)) as INT)", nameToSlot);
assertRewriteAfterSimplify("1 - a > 2", "a < cast((1 - 2) as INT)", nameToSlot);
assertRewriteAfterSimplify("-1 - a > 2", "a < cast(((-1) - 2) as INT)", nameToSlot);
assertRewriteAfterSimplify("2 * a > 1", "((2 * a) > 1)", nameToSlot);
assertRewriteAfterSimplify("-2 * a > 1", "((-2 * a) > 1)", nameToSlot);
assertRewriteAfterSimplify("2 / a > 1", "((2 / a) > 1)", nameToSlot);
assertRewriteAfterSimplify("-2 / a > 1", "((-2 / a) > 1)", nameToSlot);
assertRewriteAfterSimplify("a * 2 > 1", "((a * 2) > 1)", nameToSlot);
assertRewriteAfterSimplify("a * (-2) > 1", "((a * (-2)) > 1)", nameToSlot);
assertRewriteAfterSimplify("a / 2 > 1", "(a > cast((1 * 2) as INT))", nameToSlot);
assertRewriteAfterSimplify("TA + 2 > 200", "cast(TA as INT) > (200 - 2)");
assertRewriteAfterSimplify("TA - 2 > 200", "cast(TA as INT) > (200 + 2)");
assertRewriteAfterSimplify("1 + TA > 200", "cast(TA as INT) > (200 - 1)");
assertRewriteAfterSimplify("-1 + TA > 200", "cast(TA as INT) > (200 - (-1))");
assertRewriteAfterSimplify("1 - TA > 200", "cast(TA as INT) < (1 - 200))");
assertRewriteAfterSimplify("-1 - TA > 200", "cast(TA as INT) < ((-1) - 200)");
assertRewriteAfterSimplify("2 * TA > 200", "((2 * TA) > 200)");
assertRewriteAfterSimplify("-2 * TA > 200", "((-2 * TA) > 200)");
assertRewriteAfterSimplify("2 / TA > 200", "((2 / TA) > 200)");
assertRewriteAfterSimplify("-2 / TA > 200", "((-2 / TA) > 200)");
assertRewriteAfterSimplify("TA * 2 > 200", "((TA * 2) > 200)");
assertRewriteAfterSimplify("TA * (-2) > 200", "((TA * (-2)) > 200)");
assertRewriteAfterSimplify("TA / 2 > 200", "cast(TA as INT) > (200 * 2)");
assertRewriteAfterSimplify("TA / -2 > 200", "(200 * -2) > cast(TA as INT)");

// test decimal type
assertRewriteAfterSimplify("1.1 + a > 2.22", "(cast(a as DECIMALV3(12, 2)) > cast((2.22 - 1.1) as DECIMALV3(12, 2)))", nameToSlot);
assertRewriteAfterSimplify("-1.1 + a > 2.22", "(cast(a as DECIMALV3(12, 2)) > cast((2.22 - (-1.1)) as DECIMALV3(12, 2)))", nameToSlot);
assertRewriteAfterSimplify("1.1 - a > 2.22", "(cast(a as DECIMALV3(11, 1)) < cast((1.1 - 2.22) as DECIMALV3(11, 1)))", nameToSlot);
assertRewriteAfterSimplify("-1.1 - a > 2.22", "(cast(a as DECIMALV3(11, 1)) < cast((-1.1 - 2.22) as DECIMALV3(11, 1)))", nameToSlot);
assertRewriteAfterSimplify("2.22 * a > 1.1", "((2.22 * a) > 1.1)", nameToSlot);
assertRewriteAfterSimplify("-2.22 * a > 1.1", "-2.22 * a > 1.1", nameToSlot);
assertRewriteAfterSimplify("2.22 / a > 1.1", "((2.22 / a) > 1.1)", nameToSlot);
assertRewriteAfterSimplify("-2.22 / a > 1.1", "((-2.22 / a) > 1.1)", nameToSlot);
assertRewriteAfterSimplify("a * 2.22 > 1.1", "a * 2.22 > 1.1", nameToSlot);
assertRewriteAfterSimplify("a * (-2.22) > 1.1", "a * (-2.22) > 1.1", nameToSlot);
assertRewriteAfterSimplify("a / 2.22 > 1.1", "(cast(a as DECIMALV3(13, 3)) > cast((1.1 * 2.22) as DECIMALV3(13, 3)))", nameToSlot);
assertRewriteAfterSimplify("a / (-2.22) > 1.1", "(cast((1.1 * -2.22) as DECIMALV3(13, 3)) > cast(a as DECIMALV3(13, 3)))", nameToSlot);

// test (1 + a) can be processed
assertRewriteAfterSimplify("2 - (1 + a) > 3", "(a < ((2 - 3) - 1))", nameToSlot);
assertRewriteAfterSimplify("(1 - a) / 2 > 3", "(a < (1 - 6))", nameToSlot);
assertRewriteAfterSimplify("1 - a / 2 > 3", "(a < ((1 - 3) * 2))", nameToSlot);
assertRewriteAfterSimplify("(1 - (a + 4)) / 2 > 3", "(cast(a as BIGINT) < ((1 - 6) - 4))", nameToSlot);
assertRewriteAfterSimplify("2 * (1 + a) > 1", "(2 * (1 + a)) > 1", nameToSlot);
assertRewriteAfterSimplify("1.1 + IA > 2.22", "(cast(IA as DECIMALV3(12, 2)) > cast((2.22 - 1.1) as DECIMALV3(12, 2)))");
assertRewriteAfterSimplify("-1.1 + IA > 2.22", "(cast(IA as DECIMALV3(12, 2)) > cast((2.22 - (-1.1)) as DECIMALV3(12, 2)))");
assertRewriteAfterSimplify("1.1 - IA > 2.22", "(cast(IA as DECIMALV3(11, 1)) < cast((1.1 - 2.22) as DECIMALV3(11, 1)))");
assertRewriteAfterSimplify("-1.1 - IA > 2.22", "(cast(IA as DECIMALV3(11, 1)) < cast((-1.1 - 2.22) as DECIMALV3(11, 1)))");
assertRewriteAfterSimplify("2.22 * IA > 1.1", "((2.22 * IA) > 1.1)");
assertRewriteAfterSimplify("-2.22 * IA > 1.1", "-2.22 * IA > 1.1");
assertRewriteAfterSimplify("2.22 / IA > 1.1", "((2.22 / IA) > 1.1)");
assertRewriteAfterSimplify("-2.22 / IA > 1.1", "((-2.22 / IA) > 1.1)");
assertRewriteAfterSimplify("IA * 2.22 > 1.1", "IA * 2.22 > 1.1");
assertRewriteAfterSimplify("IA * (-2.22) > 1.1", "IA * (-2.22) > 1.1");
assertRewriteAfterSimplify("IA / 2.22 > 1.1", "(cast(IA as DECIMALV3(13, 3)) > cast((1.1 * 2.22) as DECIMALV3(13, 3)))");
assertRewriteAfterSimplify("IA / (-2.22) > 1.1", "(cast((1.1 * -2.22) as DECIMALV3(13, 3)) > cast(IA as DECIMALV3(13, 3)))");

// test (1 + IA) can be processed
assertRewriteAfterSimplify("2 - (1 + IA) > 3", "(IA < ((2 - 3) - 1))");
assertRewriteAfterSimplify("(1 - IA) / 2 > 3", "(IA < (1 - 6))");
assertRewriteAfterSimplify("1 - IA / 2 > 3", "(IA < ((1 - 3) * 2))");
assertRewriteAfterSimplify("(1 - (IA + 4)) / 2 > 3", "(cast(IA as BIGINT) < ((1 - 6) - 4))");
assertRewriteAfterSimplify("2 * (1 + IA) > 1", "(2 * (1 + IA)) > 1");

// test (IA + IB) can be processed
assertRewriteAfterSimplify("2 - (1 + (IA + IB)) > 3", "(IA + IB) < cast(((2 - 3) - 1) as BIGINT)");
assertRewriteAfterSimplify("(1 - (IA + IB)) / 2 > 3", "(IA + IB) < cast((1 - 6) as BIGINT)");
assertRewriteAfterSimplify("1 - (IA + IB) / 2 > 3", "(IA + IB) < cast(((1 - 3) * 2) as BIGINT)");
assertRewriteAfterSimplify("2 * (1 + (IA + IB)) > 1", "(2 * (1 + (IA + IB))) > 1");
}

@Test
public void testDateLike() {
executor = new ExpressionRuleExecutor(ImmutableList.of(
bottomUp(
SimplifyArithmeticRule.INSTANCE,
SimplifyArithmeticComparisonRule.INSTANCE
)
));

// test datetimev2 type
assertRewriteAfterTypeCoercion("years_add(AA, 1) > '2021-01-01 00:00:00'", "(years_add(AA, 1) > '2021-01-01 00:00:00')");
assertRewriteAfterTypeCoercion("years_sub(AA, 1) > '2021-01-01 00:00:00'", "(years_sub(AA, 1) > '2021-01-01 00:00:00')");
assertRewriteAfterTypeCoercion("months_add(AA, 1) > '2021-01-01 00:00:00'", "(months_add(AA, 1) > '2021-01-01 00:00:00')");
assertRewriteAfterTypeCoercion("months_sub(AA, 1) > '2021-01-01 00:00:00'", "(months_sub(AA, 1) > '2021-01-01 00:00:00')");
assertRewriteAfterTypeCoercion("weeks_add(AA, 1) > '2021-01-01 00:00:00'", "AA > weeks_sub('2021-01-01 00:00:00', 1)");
assertRewriteAfterTypeCoercion("weeks_sub(AA, 1) > '2021-01-01 00:00:00'", "AA > weeks_add('2021-01-01 00:00:00', 1)");
assertRewriteAfterTypeCoercion("days_add(AA, 1) > '2021-01-01 00:00:00'", "AA > days_sub('2021-01-01 00:00:00', 1)");
assertRewriteAfterTypeCoercion("days_sub(AA, 1) > '2021-01-01 00:00:00'", "AA > days_add('2021-01-01 00:00:00', 1)");
assertRewriteAfterTypeCoercion("hours_add(AA, 1) > '2021-01-01 00:00:00'", "AA > hours_sub('2021-01-01 00:00:00', 1)");
assertRewriteAfterTypeCoercion("hours_sub(AA, 1) > '2021-01-01 00:00:00'", "AA > hours_add('2021-01-01 00:00:00', 1)");
assertRewriteAfterTypeCoercion("minutes_add(AA, 1) > '2021-01-01 00:00:00'", "AA > minutes_sub('2021-01-01 00:00:00', 1)");
assertRewriteAfterTypeCoercion("minutes_sub(AA, 1) > '2021-01-01 00:00:00'", "AA > minutes_add('2021-01-01 00:00:00', 1)");
assertRewriteAfterTypeCoercion("seconds_add(AA, 1) > '2021-01-01 00:00:00'", "AA > seconds_sub('2021-01-01 00:00:00', 1)");
assertRewriteAfterTypeCoercion("seconds_sub(AA, 1) > '2021-01-01 00:00:00'", "AA > seconds_add('2021-01-01 00:00:00', 1)");

assertRewriteAfterTypeCoercion("years_add(AA, 1) > '2021-01-01'", "(years_add(AA, 1) > '2021-01-01 00:00:00')");
assertRewriteAfterTypeCoercion("years_sub(AA, 1) > '2021-01-01'", "(years_sub(AA, 1) > '2021-01-01 00:00:00')");
assertRewriteAfterTypeCoercion("months_add(AA, 1) > '2021-01-01'", "(months_add(AA, 1) > '2021-01-01 00:00:00')");
assertRewriteAfterTypeCoercion("months_sub(AA, 1) > '2021-01-01'", "(months_sub(AA, 1) > '2021-01-01 00:00:00')");
assertRewriteAfterTypeCoercion("weeks_add(AA, 1) > '2021-01-01'", "AA > weeks_sub('2021-01-01 00:00:00', 1)");
assertRewriteAfterTypeCoercion("weeks_sub(AA, 1) > '2021-01-01'", "AA > weeks_add('2021-01-01 00:00:00', 1)");
assertRewriteAfterTypeCoercion("days_add(AA, 1) > '2021-01-01'", "AA > days_sub('2021-01-01 00:00:00', 1)");
assertRewriteAfterTypeCoercion("days_sub(AA, 1) > '2021-01-01'", "AA > days_add('2021-01-01 00:00:00', 1)");
assertRewriteAfterTypeCoercion("hours_add(AA, 1) > '2021-01-01'", "AA > hours_sub('2021-01-01 00:00:00', 1)");
assertRewriteAfterTypeCoercion("hours_sub(AA, 1) > '2021-01-01'", "AA > hours_add('2021-01-01 00:00:00', 1)");
assertRewriteAfterTypeCoercion("minutes_add(AA, 1) > '2021-01-01'", "AA > minutes_sub('2021-01-01 00:00:00', 1)");
assertRewriteAfterTypeCoercion("minutes_sub(AA, 1) > '2021-01-01'", "AA > minutes_add('2021-01-01 00:00:00', 1)");
assertRewriteAfterTypeCoercion("seconds_add(AA, 1) > '2021-01-01'", "AA > seconds_sub('2021-01-01 00:00:00', 1)");
assertRewriteAfterTypeCoercion("seconds_sub(AA, 1) > '2021-01-01'", "AA > seconds_add('2021-01-01 00:00:00', 1)");

// test date type
assertRewriteAfterTypeCoercion("years_add(CA, 1) > '2021-01-01'", "years_add(CA, 1) > cast('2021-01-01' as date)");
assertRewriteAfterTypeCoercion("years_sub(CA, 1) > '2021-01-01'", "years_sub(CA, 1) > cast('2021-01-01' as date)");
assertRewriteAfterTypeCoercion("months_add(CA, 1) > '2021-01-01'", "months_add(CA, 1) > cast('2021-01-01' as date)");
assertRewriteAfterTypeCoercion("months_sub(CA, 1) > '2021-01-01'", "months_sub(CA, 1) > cast('2021-01-01' as date)");
assertRewriteAfterTypeCoercion("weeks_add(CA, 1) > '2021-01-01'", "CA > weeks_sub(cast('2021-01-01' as date), 1)");
assertRewriteAfterTypeCoercion("weeks_sub(CA, 1) > '2021-01-01'", "CA > weeks_add(cast('2021-01-01' as date), 1)");
assertRewriteAfterTypeCoercion("days_add(CA, 1) > '2021-01-01'", "CA > days_sub(cast('2021-01-01' as date), 1)");
assertRewriteAfterTypeCoercion("days_sub(CA, 1) > '2021-01-01'", "CA > days_add(cast('2021-01-01' as date), 1)");
assertRewriteAfterTypeCoercion("hours_add(CA, 1) > '2021-01-01'", "cast(CA as datetime) > hours_sub('2021-01-01 00:00:00', 1)");
assertRewriteAfterTypeCoercion("hours_sub(CA, 1) > '2021-01-01'", "cast(CA as datetime) > hours_add('2021-01-01 00:00:00', 1)");
assertRewriteAfterTypeCoercion("minutes_add(CA, 1) > '2021-01-01'", "cast(CA as datetime) > minutes_sub('2021-01-01 00:00:00', 1)");
assertRewriteAfterTypeCoercion("minutes_sub(CA, 1) > '2021-01-01'", "cast(CA as datetime) > minutes_add('2021-01-01 00:00:00', 1)");
assertRewriteAfterTypeCoercion("seconds_add(CA, 1) > '2021-01-01'", "cast(CA as datetime) > seconds_sub('2021-01-01 00:00:00', 1)");
assertRewriteAfterTypeCoercion("seconds_sub(CA, 1) > '2021-01-01'", "cast(CA as datetime) > seconds_add('2021-01-01 00:00:00', 1)");

assertRewriteAfterTypeCoercion("years_add(CA, 1) > '2021-01-01 00:00:00'", "years_add(CA, 1) > cast('2021-01-01' as date)");
assertRewriteAfterTypeCoercion("years_sub(CA, 1) > '2021-01-01 00:00:00'", "years_sub(CA, 1) > cast('2021-01-01' as date)");
assertRewriteAfterTypeCoercion("months_add(CA, 1) > '2021-01-01 00:00:00'", "months_add(CA, 1) > cast('2021-01-01' as date)");
assertRewriteAfterTypeCoercion("months_sub(CA, 1) > '2021-01-01 00:00:00'", "months_sub(CA, 1) > cast('2021-01-01' as date)");
assertRewriteAfterTypeCoercion("weeks_add(CA, 1) > '2021-01-01 00:00:00'", "CA > weeks_sub(cast('2021-01-01' as date), 1)");
assertRewriteAfterTypeCoercion("weeks_sub(CA, 1) > '2021-01-01 00:00:00'", "CA > weeks_add(cast('2021-01-01' as date), 1)");
assertRewriteAfterTypeCoercion("days_add(CA, 1) > '2021-01-01 00:00:00'", "CA > days_sub(cast('2021-01-01' as date), 1)");
assertRewriteAfterTypeCoercion("days_sub(CA, 1) > '2021-01-01 00:00:00'", "CA > days_add(cast('2021-01-01' as date), 1)");
assertRewriteAfterTypeCoercion("hours_add(CA, 1) > '2021-01-01 00:00:00'", "cast(CA as datetime) > hours_sub('2021-01-01 00:00:00', 1)");
assertRewriteAfterTypeCoercion("hours_sub(CA, 1) > '2021-01-01 00:00:00'", "cast(CA as datetime) > hours_add('2021-01-01 00:00:00', 1)");
assertRewriteAfterTypeCoercion("minutes_add(CA, 1) > '2021-01-01 00:00:00'", "cast(CA as datetime) > minutes_sub('2021-01-01 00:00:00', 1)");
assertRewriteAfterTypeCoercion("minutes_sub(CA, 1) > '2021-01-01 00:00:00'", "cast(CA as datetime) > minutes_add('2021-01-01 00:00:00', 1)");
assertRewriteAfterTypeCoercion("seconds_add(CA, 1) > '2021-01-01 00:00:00'", "cast(CA as datetime) > seconds_sub('2021-01-01 00:00:00', 1)");
assertRewriteAfterTypeCoercion("seconds_sub(CA, 1) > '2021-01-01 00:00:00'", "cast(CA as datetime) > seconds_add('2021-01-01 00:00:00', 1)");
}

private void assertRewriteAfterSimplify(String expr, String expected, Map<String, Slot> slotNameToSlot) {
private void assertRewriteAfterSimplify(String expr, String expected) {
Expression needRewriteExpression = PARSER.parseExpression(expr);
if (slotNameToSlot != null) {
needRewriteExpression = replaceUnboundSlot(needRewriteExpression, slotNameToSlot);
}
needRewriteExpression = replaceUnboundSlot(needRewriteExpression, Maps.newHashMap());
Expression rewritten = executor.rewrite(needRewriteExpression, context);
Expression expectedExpression = PARSER.parseExpression(expected);
Assertions.assertEquals(expectedExpression.toSql(), rewritten.toSql());
Expand Down

0 comments on commit ae68580

Please sign in to comment.