From 2a33bb368adfa56b556247d45f5316221790fd0e Mon Sep 17 00:00:00 2001 From: yujun Date: Tue, 3 Dec 2024 09:41:14 +0800 Subject: [PATCH 1/5] add test simplify comparison predicate --- .../test_simplify_comparison_predicate.groovy | 170 ++++++++++++++++++ 1 file changed, 170 insertions(+) create mode 100644 regression-test/suites/nereids_rules_p0/expression/test_simplify_comparison_predicate.groovy diff --git a/regression-test/suites/nereids_rules_p0/expression/test_simplify_comparison_predicate.groovy b/regression-test/suites/nereids_rules_p0/expression/test_simplify_comparison_predicate.groovy new file mode 100644 index 00000000000000..7909226d99a7ee --- /dev/null +++ b/regression-test/suites/nereids_rules_p0/expression/test_simplify_comparison_predicate.groovy @@ -0,0 +1,170 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// TODO: date datetime comparison still has bug, need fix. +suite('test_simplify_comparison_predicate', 'nonConcurrent') { + def tbl = 'test_simplify_comparison_predicate_tbl' + def checkExplain = { expression, resExpression -> + def checker = { explainString, exception, startTime, endTime -> + assertNull(exception) + def foundOutputExprs = false + def succ = false + for (def line : explainString.split('\n')) { + if (foundOutputExprs) { + assertTrue(line.contains(resExpression), "'${line}' no contains '${resExpression}'") + succ = true + break + } + if (line.contains('OUTPUT EXPRS:')) { + foundOutputExprs = true + } + } + assertTrue(foundOutputExprs) + assertTrue(succ) + } + + explain { + sql "SELECT ${expression} FROM ${tbl}" + check checker + } + } + def testSimplify = { checkNullColumn, checkNotNullColumn, expression, resExpression -> + def types = [''] + def column = '' + if (expression.contains('{int_like_column}')) { + column = '{int_like_column}' + types = ['tinyint', 'smallint', 'int', 'bigint'] + } else if (expression.contains('{decimal_column}')) { + column = '{decimal_column}' + types = ['decimal_3_0', 'decimal_5_2'] + } else if (expression.contains('{date_column}')) { + column = '{date_column}' + types = ['date', 'datev1'] + } else if (expression.contains('{datetime_column}')) { + column = '{datetime_column}' + types = ['datetime_0', 'datetime_3', 'datetimev1'] + } + for (def type : types) { + if (type == '') { + checkExplain expression, resExpression + } else { + if (checkNullColumn) { + checkExplain expression.replace(column, "c_${type}_null"), resExpression.replace(column, "c_${type}_null") + } + if (checkNotNullColumn) { + checkExplain expression.replace(column, "c_${type}"), resExpression.replace(column, "c_${type}") + } + } + } + } + + setFeConfigTemporary([disable_datev1:false, disable_decimalv2:false]) { + sql """ + DROP TABLE IF EXISTS ${tbl} FORCE; + + CREATE TABLE ${tbl} ( + c_tinyint tinyint not null default 1, + c_tinyint_null tinyint, + c_smallint smallint not null default 1, + c_smallint_null smallint, + c_int int not null default 1, + c_int_null int, + c_bigint bigint not null default 1, + c_bigint_null bigint, + c_decimal_3_0 decimal(3, 0) not null default 1, + c_decimal_3_0_null decimal(3, 0), + c_decimal_5_2 decimal(5, 2) not null default 1, + c_decimal_5_2_null decimal(5, 2), + c_date date not null default '2000-01-01', + c_date_null date, + c_datev1 datev1 not null default '2000-01-01', + c_datev1_null datev1 null, + c_datetime_0 datetime(0) not null default '2000-01-01 00:00:00', + c_datetime_0_null datetime(0), + c_datetime_3 datetime(3) not null default '2000-01-01 00:00:00', + c_datetime_3_null datetime(3), + c_datetimev1 datetimev1 not null default '2000-01-01 00:00:00', + c_datetimev1_null datetimev1 + ) + PROPERTIES ('replication_num' = '1'); + + INSERT INTO ${tbl} VALUES(); + """ + + testSimplify true, true, '{int_like_column} = CAST(1.00 as DOUBLE)', '({int_like_column} = 1)' + testSimplify true, false, '{int_like_column} = CAST(1.01 as DOUBLE)', 'AND[{int_like_column} IS NULL,NULL]' + testSimplify false, true, '{int_like_column} = CAST(1.01 as DOUBLE)', 'FALSE' + testSimplify true, true, '{int_like_column} <=> CAST(1.01 as DOUBLE)', 'FALSE' + testSimplify true, true, '{int_like_column} > CAST(1.00 as DOUBLE)', '({int_like_column} > 1)' + testSimplify true, true, '{int_like_column} < CAST(1.00 as DOUBLE)', '({int_like_column} < 1)' + testSimplify true, true, '{int_like_column} > CAST(1.01 as DOUBLE)', '({int_like_column} > 1)' + testSimplify true, true, '{int_like_column} >= CAST(1.01 as DOUBLE)', '({int_like_column} >= 2)' + testSimplify true, true, '{int_like_column} <= CAST(1.01 as DOUBLE)', '({int_like_column} <= 1)' + testSimplify true, true, '{int_like_column} < CAST(1.01 as DOUBLE)', '({int_like_column} < 2)' + testSimplify true, true, '{int_like_column} = 1.00', '({int_like_column} = 1)' + testSimplify true, true, '{int_like_column} > 1.00', '({int_like_column} > 1)' + testSimplify true, true, '{int_like_column} < 1.00', '({int_like_column} < 1)' + testSimplify true, false, '{int_like_column} = 1.01', 'AND[{int_like_column} IS NULL,NULL]' + testSimplify false, true, '{int_like_column} = 1.01', 'FALSE' + testSimplify true, true, '{int_like_column} <=> 1.01', 'FALSE' + testSimplify true, true, '{int_like_column} > 1.01', '({int_like_column} > 1)' + testSimplify true, true, '{int_like_column} >= 1.01', '({int_like_column} >= 2)' + testSimplify true, true, '{int_like_column} <= 1.01', '({int_like_column} <= 1)' + testSimplify true, true, '{int_like_column} < 1.01', '({int_like_column} < 2)' + testSimplify false, false, 'CAST(c_decimal_3_0_null as DECIMAL(10, 5)) = CAST(1.00 as DECIMAL(10, 5))', '(c_decimal_3_0_null = 1)' + testSimplify false, false, 'CAST(c_decimal_3_0_null as DECIMAL(10, 5)) = CAST(1.1 as DECIMAL(10, 5))', 'AND[c_decimal_3_0_null IS NULL,NULL]' + testSimplify false, false, 'CAST(c_decimal_3_0_null as DECIMAL(10, 5)) > CAST(1.1 as DECIMAL(10, 5))', '(c_decimal_3_0_null > 1)' + testSimplify false, false, 'CAST(c_decimal_3_0_null as DECIMAL(10, 5)) >= CAST(1.1 as DECIMAL(10, 5))', '(c_decimal_3_0_null >= 2)' + testSimplify false, false, 'CAST(c_decimal_3_0_null as DECIMAL(10, 5)) < CAST(1.1 as DECIMAL(10, 5))', '(c_decimal_3_0_null < 2)' + testSimplify false, false, 'CAST(c_decimal_3_0_null as DECIMAL(10, 5)) <= CAST(1.1 as DECIMAL(10, 5))', '(c_decimal_3_0_null <= 1)' + testSimplify false, false, 'c_decimal_5_2_null = CAST(1.0 as DECIMAL(10, 5))', '(c_decimal_5_2_null = 1.00)' + testSimplify false, false, 'c_decimal_5_2_null = CAST(1.1 as DECIMAL(10, 5))', '(c_decimal_5_2_null = 1.10)' + testSimplify false, false, 'c_decimal_5_2_null = CAST(1.12 as DECIMAL(10, 5))', '(c_decimal_5_2_null = 1.12)' + testSimplify false, false, 'c_decimal_5_2_null = CAST(1.123 as DECIMAL(10, 5))', 'AND[c_decimal_5_2_null IS NULL,NULL]' + testSimplify false, false, 'c_decimal_5_2 = CAST(1.123 as DECIMAL(10, 5))', 'FALSE' + testSimplify false, false, 'c_decimal_5_2_null > CAST(1.123 as DECIMAL(10, 5))', 'c_decimal_5_2_null > 1.12' + testSimplify false, false, 'c_decimal_5_2_null >= CAST(1.123 as DECIMAL(10, 5))', 'c_decimal_5_2_null >= 1.13' + testSimplify false, false, 'c_decimal_5_2_null <= CAST(1.123 as DECIMAL(10, 5))', 'c_decimal_5_2_null <= 1.12' + testSimplify false, false, 'c_decimal_5_2_null < CAST(1.123 as DECIMAL(10, 5))', 'c_decimal_5_2_null < 1.13' + testSimplify false, false, "CAST(c_datetime_0 AS DATETIME(5)) = '2000-01-01'", "(c_datetime_0 = '2000-01-01 00:00:00')" + testSimplify false, false, "CAST(c_datetime_0 AS DATETIME(5)) = '2000-01-01 00:00:00.1'", 'FALSE' + testSimplify false, false, "CAST(c_datetime_0_null AS DATETIME(5)) = '2000-01-01 00:00:00.1'", 'AND[c_datetime_0_null IS NULL,NULL]' + testSimplify false, false, "CAST(c_datetime_0_null AS DATETIME(5)) <=> '2000-01-01 00:00:00.1'", 'FALSE' + testSimplify false, false, "CAST(c_datetime_0 AS DATETIME(5)) >= '2000-01-01 00:00:00.1'", "(c_datetime_0 >= '2000-01-01 00:00:01')" + testSimplify false, false, "CAST(c_datetime_0 AS DATETIME(5)) > '2000-01-01 00:00:00.1'", "(c_datetime_0 > '2000-01-01 00:00:00')" + testSimplify false, false, "CAST(c_datetime_0 AS DATETIME(5)) <= '2000-01-01 00:00:00.1'", "(c_datetime_0 <= '2000-01-01 00:00:00')" + testSimplify false, false, "CAST(c_datetime_0 AS DATETIME(5)) < '2000-01-01 00:00:00.1'", "(c_datetime_0 < '2000-01-01 00:00:01')" + testSimplify false, false, "CAST(c_datetime_3 AS DATETIME(5)) = '2000-01-01'", "(c_datetime_3 = '2000-01-01 00:00:00.000')" + testSimplify false, false, "CAST(c_datetime_3 AS DATETIME(5)) = '2000-01-01 00:00:00.1234'", 'FALSE' + testSimplify false, false, "CAST(c_datetime_3_null AS DATETIME(5)) = '2000-01-01 00:00:00.1234'", 'AND[c_datetime_3_null IS NULL,NULL]' + testSimplify false, false, "CAST(c_datetime_3_null AS DATETIME(5)) <=> '2000-01-01 00:00:00.1234'", 'FALSE' + testSimplify false, false, "CAST(c_datetime_3 AS DATETIME(5)) >= '2000-01-01 00:00:00.1234'", "(c_datetime_3 >= '2000-01-01 00:00:00.124')" + testSimplify false, false, "CAST(c_datetime_3 AS DATETIME(5)) > '2000-01-01 00:00:00.1234'", "(c_datetime_3 > '2000-01-01 00:00:00.123')" + testSimplify false, false, "CAST(c_datetime_3 AS DATETIME(5)) <= '2000-01-01 00:00:00.1234'", "(c_datetime_3 <= '2000-01-01 00:00:00.123')" + testSimplify false, false, "CAST(c_datetime_3 AS DATETIME(5)) < '2000-01-01 00:00:00.1234'", "(c_datetime_3 < '2000-01-01 00:00:00.124')" + // testSimplify false, false, "c_date = '2000-01-01 00:00:01'", 'FALSE' + // testSimplify false, false, "CAST(c_date_null AS DATETIME(5)) = '2000-01-01 00:00:01'", 'AND[c_date_null IS NULL,NULL]' + // testSimplify false, false, "CAST(c_date_null AS DATETIME(5)) <=> '2000-01-01 00:00:01'", 'FALSE' + testSimplify false, false, "CAST(c_date AS DATETIME(5)) > '2000-01-01 00:00:01'", "c_date > '2000-01-01'" + testSimplify false, false, "CAST(c_date AS DATETIME(5)) >= '2000-01-01 00:00:01'", "c_date >= '2000-01-02'" + testSimplify false, false, "CAST(c_date AS DATETIME(5)) <= '2000-01-01 00:00:01'", "c_date <= '2000-01-01'" + testSimplify false, false, "CAST(c_date AS DATETIME(5)) < '2000-01-01 00:00:01'", "c_date < '2000-01-02'" + + sql "DROP TABLE IF EXISTS ${tbl} FORCE" + } +} From 849c25f8475ffe1760b3b348eb1d9510d97809be Mon Sep 17 00:00:00 2001 From: yujun Date: Mon, 23 Dec 2024 10:39:49 +0800 Subject: [PATCH 2/5] add double literal compare --- .../SimplifyComparisonPredicateTest.java | 98 +++++++++++++++++++ 1 file changed, 98 insertions(+) diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicateTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicateTest.java index 028f1c4864f099..652f77733662a3 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicateTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicateTest.java @@ -40,6 +40,7 @@ import org.apache.doris.nereids.trees.expressions.literal.DateV2Literal; import org.apache.doris.nereids.trees.expressions.literal.DecimalV3Literal; import org.apache.doris.nereids.trees.expressions.literal.DoubleLiteral; +import org.apache.doris.nereids.trees.expressions.literal.FloatLiteral; import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral; import org.apache.doris.nereids.trees.expressions.literal.LargeIntLiteral; import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; @@ -54,6 +55,7 @@ import org.apache.doris.nereids.types.DateV2Type; import org.apache.doris.nereids.types.DecimalV3Type; import org.apache.doris.nereids.types.DoubleType; +import org.apache.doris.nereids.types.FloatType; import org.apache.doris.nereids.types.IntegerType; import org.apache.doris.nereids.types.SmallIntType; import org.apache.doris.nereids.types.TinyIntType; @@ -272,6 +274,102 @@ void testDoubleLiteral() { Expression rewrittenExpression = executor.rewrite(expression, context); Assertions.assertEquals(left.child(0).getDataType(), rewrittenExpression.child(1).getDataType()); Assertions.assertEquals(rewrittenExpression.child(0).getDataType(), rewrittenExpression.child(1).getDataType()); + + Expression tinyIntSlot = new SlotReference("a", TinyIntType.INSTANCE); + Expression smallIntSlot = new SlotReference("a", SmallIntType.INSTANCE); + Expression intSlot = new SlotReference("a", IntegerType.INSTANCE); + Expression bigIntSlot = new SlotReference("a", BigIntType.INSTANCE); + + // tiny int, literal not exceeds data type limit + assertRewrite(new EqualTo(new Cast(tinyIntSlot, FloatType.INSTANCE), new FloatLiteral(12.0f)), + new EqualTo(tinyIntSlot, new TinyIntLiteral((byte) 12))); + assertRewrite(new EqualTo(new Cast(tinyIntSlot, DoubleType.INSTANCE), new DoubleLiteral(12.0f)), + new EqualTo(tinyIntSlot, new TinyIntLiteral((byte) 12))); + assertRewrite(new EqualTo(new Cast(tinyIntSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)), + ExpressionUtils.falseOrNull(tinyIntSlot)); + assertRewrite(new NullSafeEqual(new Cast(tinyIntSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)), + BooleanLiteral.FALSE); + assertRewrite(new GreaterThan(new Cast(tinyIntSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)), + new GreaterThan(tinyIntSlot, new TinyIntLiteral((byte) 12))); + assertRewrite(new GreaterThanEqual(new Cast(tinyIntSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)), + new GreaterThanEqual(tinyIntSlot, new TinyIntLiteral((byte) 13))); + assertRewrite(new LessThan(new Cast(tinyIntSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)), + new LessThan(tinyIntSlot, new TinyIntLiteral((byte) 13))); + assertRewrite(new LessThanEqual(new Cast(tinyIntSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)), + new LessThanEqual(tinyIntSlot, new TinyIntLiteral((byte) 12))); + + + // tiny int, literal exceeds data type limit + assertRewrite(new EqualTo(new Cast(tinyIntSlot, FloatType.INSTANCE), new FloatLiteral(200.0f)), + ExpressionUtils.falseOrNull(tinyIntSlot)); + assertRewrite(new EqualTo(new Cast(tinyIntSlot, DoubleType.INSTANCE), new DoubleLiteral(200.0f)), + ExpressionUtils.falseOrNull(tinyIntSlot)); + assertRewrite(new EqualTo(new Cast(tinyIntSlot, DoubleType.INSTANCE), new DoubleLiteral(200.3f)), + ExpressionUtils.falseOrNull(tinyIntSlot)); + assertRewrite(new NullSafeEqual(new Cast(tinyIntSlot, DoubleType.INSTANCE), new DoubleLiteral(200.3f)), + BooleanLiteral.FALSE); + assertRewrite(new GreaterThan(new Cast(tinyIntSlot, DoubleType.INSTANCE), new DoubleLiteral(200.3f)), + ExpressionUtils.falseOrNull(tinyIntSlot)); + assertRewrite(new GreaterThanEqual(new Cast(tinyIntSlot, DoubleType.INSTANCE), new DoubleLiteral(200.3f)), + ExpressionUtils.falseOrNull(tinyIntSlot)); + assertRewrite(new LessThan(new Cast(tinyIntSlot, DoubleType.INSTANCE), new DoubleLiteral(200.3f)), + ExpressionUtils.trueOrNull(tinyIntSlot)); + assertRewrite(new LessThanEqual(new Cast(tinyIntSlot, DoubleType.INSTANCE), new DoubleLiteral(200.3f)), + ExpressionUtils.trueOrNull(tinyIntSlot)); + + // small int + assertRewrite(new EqualTo(new Cast(smallIntSlot, FloatType.INSTANCE), new FloatLiteral(12.0f)), + new EqualTo(smallIntSlot, new SmallIntLiteral((short) 12))); + assertRewrite(new EqualTo(new Cast(smallIntSlot, DoubleType.INSTANCE), new DoubleLiteral(12.0f)), + new EqualTo(smallIntSlot, new SmallIntLiteral((short) 12))); + assertRewrite(new EqualTo(new Cast(smallIntSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)), + ExpressionUtils.falseOrNull(smallIntSlot)); + assertRewrite(new NullSafeEqual(new Cast(smallIntSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)), + BooleanLiteral.FALSE); + assertRewrite(new GreaterThan(new Cast(smallIntSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)), + new GreaterThan(smallIntSlot, new SmallIntLiteral((short) 12))); + assertRewrite(new GreaterThanEqual(new Cast(smallIntSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)), + new GreaterThanEqual(smallIntSlot, new SmallIntLiteral((short) 13))); + assertRewrite(new LessThan(new Cast(smallIntSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)), + new LessThan(smallIntSlot, new SmallIntLiteral((short) 13))); + assertRewrite(new LessThanEqual(new Cast(smallIntSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)), + new LessThanEqual(smallIntSlot, new SmallIntLiteral((short) 12))); + + // int + assertRewrite(new EqualTo(new Cast(intSlot, FloatType.INSTANCE), new FloatLiteral(12.0f)), + new EqualTo(intSlot, new IntegerLiteral(12))); + assertRewrite(new EqualTo(new Cast(intSlot, DoubleType.INSTANCE), new DoubleLiteral(12.0f)), + new EqualTo(intSlot, new IntegerLiteral(12))); + assertRewrite(new EqualTo(new Cast(intSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)), + ExpressionUtils.falseOrNull(intSlot)); + assertRewrite(new NullSafeEqual(new Cast(intSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)), + BooleanLiteral.FALSE); + assertRewrite(new GreaterThan(new Cast(intSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)), + new GreaterThan(intSlot, new IntegerLiteral(12))); + assertRewrite(new GreaterThanEqual(new Cast(intSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)), + new GreaterThanEqual(intSlot, new IntegerLiteral(13))); + assertRewrite(new LessThan(new Cast(intSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)), + new LessThan(intSlot, new IntegerLiteral(13))); + assertRewrite(new LessThanEqual(new Cast(intSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)), + new LessThanEqual(intSlot, new IntegerLiteral(12))); + + // big int + assertRewrite(new EqualTo(new Cast(bigIntSlot, FloatType.INSTANCE), new FloatLiteral(12.0f)), + new EqualTo(bigIntSlot, new BigIntLiteral(12L))); + assertRewrite(new EqualTo(new Cast(bigIntSlot, DoubleType.INSTANCE), new DoubleLiteral(12.0f)), + new EqualTo(bigIntSlot, new BigIntLiteral(12L))); + assertRewrite(new EqualTo(new Cast(bigIntSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)), + ExpressionUtils.falseOrNull(bigIntSlot)); + assertRewrite(new NullSafeEqual(new Cast(bigIntSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)), + BooleanLiteral.FALSE); + assertRewrite(new GreaterThan(new Cast(bigIntSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)), + new GreaterThan(bigIntSlot, new BigIntLiteral(12L))); + assertRewrite(new GreaterThanEqual(new Cast(bigIntSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)), + new GreaterThanEqual(bigIntSlot, new BigIntLiteral(13L))); + assertRewrite(new LessThan(new Cast(bigIntSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)), + new LessThan(bigIntSlot, new BigIntLiteral(13L))); + assertRewrite(new LessThanEqual(new Cast(bigIntSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)), + new LessThanEqual(bigIntSlot, new BigIntLiteral(12L))); } @Test From 1b296c211db54c148f7a85804c35016ea0253dd5 Mon Sep 17 00:00:00 2001 From: yujun Date: Mon, 23 Dec 2024 11:03:20 +0800 Subject: [PATCH 3/5] update test --- .../SimplifyComparisonPredicateTest.java | 95 ++++++++++++++++++- .../test_simplify_comparison_predicate.groovy | 6 +- 2 files changed, 96 insertions(+), 5 deletions(-) diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicateTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicateTest.java index 652f77733662a3..f3bd798b50aa6e 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicateTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicateTest.java @@ -298,7 +298,6 @@ void testDoubleLiteral() { assertRewrite(new LessThanEqual(new Cast(tinyIntSlot, DoubleType.INSTANCE), new DoubleLiteral(12.3f)), new LessThanEqual(tinyIntSlot, new TinyIntLiteral((byte) 12))); - // tiny int, literal exceeds data type limit assertRewrite(new EqualTo(new Cast(tinyIntSlot, FloatType.INSTANCE), new FloatLiteral(200.0f)), ExpressionUtils.falseOrNull(tinyIntSlot)); @@ -373,7 +372,99 @@ void testDoubleLiteral() { } @Test - void testDecimalV3Literal() { + void testIntCmpDecimalV3Literal() { + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(SimplifyComparisonPredicate.INSTANCE) + )); + + Expression tinyIntSlot = new SlotReference("a", TinyIntType.INSTANCE); + Expression smallIntSlot = new SlotReference("a", SmallIntType.INSTANCE); + Expression intSlot = new SlotReference("a", IntegerType.INSTANCE); + Expression bigIntSlot = new SlotReference("a", BigIntType.INSTANCE); + + // tiny int, literal not exceeds data type limit + assertRewrite(new EqualTo(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(3,1)), new DecimalV3Literal(new BigDecimal("12.0"))), + new EqualTo(tinyIntSlot, new TinyIntLiteral((byte) 12))); + assertRewrite(new EqualTo(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(3,1)), new DecimalV3Literal(new BigDecimal("12.3"))), + ExpressionUtils.falseOrNull(tinyIntSlot)); + assertRewrite(new NullSafeEqual(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(3,1)), new DecimalV3Literal(new BigDecimal("12.3"))), + BooleanLiteral.FALSE); + assertRewrite(new GreaterThan(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(3,1)), new DecimalV3Literal(new BigDecimal("12.3"))), + new GreaterThan(tinyIntSlot, new TinyIntLiteral((byte) 12))); + assertRewrite(new GreaterThanEqual(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(3,1)), new DecimalV3Literal(new BigDecimal("12.3"))), + new GreaterThanEqual(tinyIntSlot, new TinyIntLiteral((byte) 13))); + assertRewrite(new LessThan(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(3,1)), new DecimalV3Literal(new BigDecimal("12.3"))), + new LessThan(tinyIntSlot, new TinyIntLiteral((byte) 13))); + assertRewrite(new LessThanEqual(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(3,1)), new DecimalV3Literal(new BigDecimal("12.3"))), + new LessThanEqual(tinyIntSlot, new TinyIntLiteral((byte) 12))); + + // tiny int, literal exceeds data type limit + assertRewrite(new EqualTo(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(4,1)), new DecimalV3Literal(new BigDecimal("200.0"))), + ExpressionUtils.falseOrNull(tinyIntSlot)); + assertRewrite(new EqualTo(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(4,1)), new DecimalV3Literal(new BigDecimal("200.3"))), + ExpressionUtils.falseOrNull(tinyIntSlot)); + assertRewrite(new NullSafeEqual(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(4,1)), new DecimalV3Literal(new BigDecimal("200.3"))), + BooleanLiteral.FALSE); + assertRewrite(new GreaterThan(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(4,1)), new DecimalV3Literal(new BigDecimal("200.3"))), + ExpressionUtils.falseOrNull(tinyIntSlot)); + assertRewrite(new GreaterThanEqual(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(4,1)), new DecimalV3Literal(new BigDecimal("200.3"))), + ExpressionUtils.falseOrNull(tinyIntSlot)); + assertRewrite(new LessThan(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(4,1)), new DecimalV3Literal(new BigDecimal("200.3"))), + ExpressionUtils.trueOrNull(tinyIntSlot)); + assertRewrite(new LessThanEqual(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(4,1)), new DecimalV3Literal(new BigDecimal("200.3"))), + ExpressionUtils.trueOrNull(tinyIntSlot)); + + // small int + assertRewrite(new EqualTo(new Cast(smallIntSlot, DecimalV3Type.createDecimalV3Type(3,1)), new DecimalV3Literal(new BigDecimal("12.0"))), + new EqualTo(smallIntSlot, new SmallIntLiteral((short) 12))); + assertRewrite(new EqualTo(new Cast(smallIntSlot, DecimalV3Type.createDecimalV3Type(3,1)), new DecimalV3Literal(new BigDecimal("12.3"))), + ExpressionUtils.falseOrNull(smallIntSlot)); + assertRewrite(new NullSafeEqual(new Cast(smallIntSlot, DecimalV3Type.createDecimalV3Type(3,1)), new DecimalV3Literal(new BigDecimal("12.3"))), + BooleanLiteral.FALSE); + assertRewrite(new GreaterThan(new Cast(smallIntSlot, DecimalV3Type.createDecimalV3Type(3,1)), new DecimalV3Literal(new BigDecimal("12.3"))), + new GreaterThan(smallIntSlot, new SmallIntLiteral((short) 12))); + assertRewrite(new GreaterThanEqual(new Cast(smallIntSlot, DecimalV3Type.createDecimalV3Type(3,1)), new DecimalV3Literal(new BigDecimal("12.3"))), + new GreaterThanEqual(smallIntSlot, new SmallIntLiteral((short) 13))); + assertRewrite(new LessThan(new Cast(smallIntSlot, DecimalV3Type.createDecimalV3Type(3,1)), new DecimalV3Literal(new BigDecimal("12.3"))), + new LessThan(smallIntSlot, new SmallIntLiteral((short) 13))); + assertRewrite(new LessThanEqual(new Cast(smallIntSlot, DecimalV3Type.createDecimalV3Type(3,1)), new DecimalV3Literal(new BigDecimal("12.3"))), + new LessThanEqual(smallIntSlot, new SmallIntLiteral((short) 12))); + + // int + assertRewrite(new EqualTo(new Cast(intSlot, DecimalV3Type.createDecimalV3Type(3,1)), new DecimalV3Literal(new BigDecimal("12.0"))), + new EqualTo(intSlot, new IntegerLiteral(12))); + assertRewrite(new EqualTo(new Cast(intSlot, DecimalV3Type.createDecimalV3Type(3,1)), new DecimalV3Literal(new BigDecimal("12.3"))), + ExpressionUtils.falseOrNull(intSlot)); + assertRewrite(new NullSafeEqual(new Cast(intSlot, DecimalV3Type.createDecimalV3Type(3,1)), new DecimalV3Literal(new BigDecimal("12.3"))), + BooleanLiteral.FALSE); + assertRewrite(new GreaterThan(new Cast(intSlot, DecimalV3Type.createDecimalV3Type(3,1)), new DecimalV3Literal(new BigDecimal("12.3"))), + new GreaterThan(intSlot, new IntegerLiteral(12))); + assertRewrite(new GreaterThanEqual(new Cast(intSlot, DecimalV3Type.createDecimalV3Type(3,1)), new DecimalV3Literal(new BigDecimal("12.3"))), + new GreaterThanEqual(intSlot, new IntegerLiteral(13))); + assertRewrite(new LessThan(new Cast(intSlot, DecimalV3Type.createDecimalV3Type(3,1)), new DecimalV3Literal(new BigDecimal("12.3"))), + new LessThan(intSlot, new IntegerLiteral(13))); + assertRewrite(new LessThanEqual(new Cast(intSlot, DecimalV3Type.createDecimalV3Type(3,1)), new DecimalV3Literal(new BigDecimal("12.3"))), + new LessThanEqual(intSlot, new IntegerLiteral(12))); + + // big int + assertRewrite(new EqualTo(new Cast(bigIntSlot, DecimalV3Type.createDecimalV3Type(3,1)), new DecimalV3Literal(new BigDecimal("12.0"))), + new EqualTo(bigIntSlot, new BigIntLiteral(12L))); + assertRewrite(new EqualTo(new Cast(bigIntSlot, DecimalV3Type.createDecimalV3Type(3,1)), new DecimalV3Literal(new BigDecimal("12.3"))), + ExpressionUtils.falseOrNull(bigIntSlot)); + assertRewrite(new NullSafeEqual(new Cast(bigIntSlot, DecimalV3Type.createDecimalV3Type(3,1)), new DecimalV3Literal(new BigDecimal("12.3"))), + BooleanLiteral.FALSE); + assertRewrite(new GreaterThan(new Cast(bigIntSlot, DecimalV3Type.createDecimalV3Type(3,1)), new DecimalV3Literal(new BigDecimal("12.3"))), + new GreaterThan(bigIntSlot, new BigIntLiteral(12L))); + assertRewrite(new GreaterThanEqual(new Cast(bigIntSlot, DecimalV3Type.createDecimalV3Type(3,1)), new DecimalV3Literal(new BigDecimal("12.3"))), + new GreaterThanEqual(bigIntSlot, new BigIntLiteral(13L))); + assertRewrite(new LessThan(new Cast(bigIntSlot, DecimalV3Type.createDecimalV3Type(3,1)), new DecimalV3Literal(new BigDecimal("12.3"))), + new LessThan(bigIntSlot, new BigIntLiteral(13L))); + assertRewrite(new LessThanEqual(new Cast(bigIntSlot, DecimalV3Type.createDecimalV3Type(3,1)), new DecimalV3Literal(new BigDecimal("12.3"))), + new LessThanEqual(bigIntSlot, new BigIntLiteral(12L))); + } + + @Test + void testDecimalCmpDecimalV3Literal() { executor = new ExpressionRuleExecutor(ImmutableList.of( bottomUp(SimplifyComparisonPredicate.INSTANCE) )); diff --git a/regression-test/suites/nereids_rules_p0/expression/test_simplify_comparison_predicate.groovy b/regression-test/suites/nereids_rules_p0/expression/test_simplify_comparison_predicate.groovy index 7909226d99a7ee..af975aeeaa22e7 100644 --- a/regression-test/suites/nereids_rules_p0/expression/test_simplify_comparison_predicate.groovy +++ b/regression-test/suites/nereids_rules_p0/expression/test_simplify_comparison_predicate.groovy @@ -157,9 +157,9 @@ suite('test_simplify_comparison_predicate', 'nonConcurrent') { testSimplify false, false, "CAST(c_datetime_3 AS DATETIME(5)) > '2000-01-01 00:00:00.1234'", "(c_datetime_3 > '2000-01-01 00:00:00.123')" testSimplify false, false, "CAST(c_datetime_3 AS DATETIME(5)) <= '2000-01-01 00:00:00.1234'", "(c_datetime_3 <= '2000-01-01 00:00:00.123')" testSimplify false, false, "CAST(c_datetime_3 AS DATETIME(5)) < '2000-01-01 00:00:00.1234'", "(c_datetime_3 < '2000-01-01 00:00:00.124')" - // testSimplify false, false, "c_date = '2000-01-01 00:00:01'", 'FALSE' - // testSimplify false, false, "CAST(c_date_null AS DATETIME(5)) = '2000-01-01 00:00:01'", 'AND[c_date_null IS NULL,NULL]' - // testSimplify false, false, "CAST(c_date_null AS DATETIME(5)) <=> '2000-01-01 00:00:01'", 'FALSE' + testSimplify false, false, "c_date = '2000-01-01 00:00:01'", 'FALSE' + testSimplify false, false, "CAST(c_date_null AS DATETIME(5)) = '2000-01-01 00:00:01'", 'AND[c_date_null IS NULL,NULL]' + testSimplify false, false, "CAST(c_date_null AS DATETIME(5)) <=> '2000-01-01 00:00:01'", 'FALSE' testSimplify false, false, "CAST(c_date AS DATETIME(5)) > '2000-01-01 00:00:01'", "c_date > '2000-01-01'" testSimplify false, false, "CAST(c_date AS DATETIME(5)) >= '2000-01-01 00:00:01'", "c_date >= '2000-01-02'" testSimplify false, false, "CAST(c_date AS DATETIME(5)) <= '2000-01-01 00:00:01'", "c_date <= '2000-01-01'" From 0073c714cec71128d95ccb312d4aaa435ab54010 Mon Sep 17 00:00:00 2001 From: yujun Date: Mon, 23 Dec 2024 12:00:05 +0800 Subject: [PATCH 4/5] fix compile --- .../SimplifyComparisonPredicateTest.java | 56 +++++++++---------- 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicateTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicateTest.java index f3bd798b50aa6e..c51c6d43dbaed7 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicateTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicateTest.java @@ -383,19 +383,19 @@ void testIntCmpDecimalV3Literal() { Expression bigIntSlot = new SlotReference("a", BigIntType.INSTANCE); // tiny int, literal not exceeds data type limit - assertRewrite(new EqualTo(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(3,1)), new DecimalV3Literal(new BigDecimal("12.0"))), + assertRewrite(new EqualTo(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.0"))), new EqualTo(tinyIntSlot, new TinyIntLiteral((byte) 12))); - assertRewrite(new EqualTo(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(3,1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new EqualTo(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), ExpressionUtils.falseOrNull(tinyIntSlot)); - assertRewrite(new NullSafeEqual(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(3,1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new NullSafeEqual(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), BooleanLiteral.FALSE); - assertRewrite(new GreaterThan(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(3,1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new GreaterThan(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), new GreaterThan(tinyIntSlot, new TinyIntLiteral((byte) 12))); - assertRewrite(new GreaterThanEqual(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(3,1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new GreaterThanEqual(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), new GreaterThanEqual(tinyIntSlot, new TinyIntLiteral((byte) 13))); - assertRewrite(new LessThan(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(3,1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new LessThan(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), new LessThan(tinyIntSlot, new TinyIntLiteral((byte) 13))); - assertRewrite(new LessThanEqual(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(3,1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new LessThanEqual(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), new LessThanEqual(tinyIntSlot, new TinyIntLiteral((byte) 12))); // tiny int, literal exceeds data type limit @@ -415,51 +415,51 @@ void testIntCmpDecimalV3Literal() { ExpressionUtils.trueOrNull(tinyIntSlot)); // small int - assertRewrite(new EqualTo(new Cast(smallIntSlot, DecimalV3Type.createDecimalV3Type(3,1)), new DecimalV3Literal(new BigDecimal("12.0"))), + assertRewrite(new EqualTo(new Cast(smallIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.0"))), new EqualTo(smallIntSlot, new SmallIntLiteral((short) 12))); - assertRewrite(new EqualTo(new Cast(smallIntSlot, DecimalV3Type.createDecimalV3Type(3,1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new EqualTo(new Cast(smallIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), ExpressionUtils.falseOrNull(smallIntSlot)); - assertRewrite(new NullSafeEqual(new Cast(smallIntSlot, DecimalV3Type.createDecimalV3Type(3,1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new NullSafeEqual(new Cast(smallIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), BooleanLiteral.FALSE); - assertRewrite(new GreaterThan(new Cast(smallIntSlot, DecimalV3Type.createDecimalV3Type(3,1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new GreaterThan(new Cast(smallIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), new GreaterThan(smallIntSlot, new SmallIntLiteral((short) 12))); - assertRewrite(new GreaterThanEqual(new Cast(smallIntSlot, DecimalV3Type.createDecimalV3Type(3,1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new GreaterThanEqual(new Cast(smallIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), new GreaterThanEqual(smallIntSlot, new SmallIntLiteral((short) 13))); - assertRewrite(new LessThan(new Cast(smallIntSlot, DecimalV3Type.createDecimalV3Type(3,1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new LessThan(new Cast(smallIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), new LessThan(smallIntSlot, new SmallIntLiteral((short) 13))); - assertRewrite(new LessThanEqual(new Cast(smallIntSlot, DecimalV3Type.createDecimalV3Type(3,1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new LessThanEqual(new Cast(smallIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), new LessThanEqual(smallIntSlot, new SmallIntLiteral((short) 12))); // int - assertRewrite(new EqualTo(new Cast(intSlot, DecimalV3Type.createDecimalV3Type(3,1)), new DecimalV3Literal(new BigDecimal("12.0"))), + assertRewrite(new EqualTo(new Cast(intSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.0"))), new EqualTo(intSlot, new IntegerLiteral(12))); - assertRewrite(new EqualTo(new Cast(intSlot, DecimalV3Type.createDecimalV3Type(3,1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new EqualTo(new Cast(intSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), ExpressionUtils.falseOrNull(intSlot)); - assertRewrite(new NullSafeEqual(new Cast(intSlot, DecimalV3Type.createDecimalV3Type(3,1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new NullSafeEqual(new Cast(intSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), BooleanLiteral.FALSE); - assertRewrite(new GreaterThan(new Cast(intSlot, DecimalV3Type.createDecimalV3Type(3,1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new GreaterThan(new Cast(intSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), new GreaterThan(intSlot, new IntegerLiteral(12))); - assertRewrite(new GreaterThanEqual(new Cast(intSlot, DecimalV3Type.createDecimalV3Type(3,1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new GreaterThanEqual(new Cast(intSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), new GreaterThanEqual(intSlot, new IntegerLiteral(13))); - assertRewrite(new LessThan(new Cast(intSlot, DecimalV3Type.createDecimalV3Type(3,1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new LessThan(new Cast(intSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), new LessThan(intSlot, new IntegerLiteral(13))); - assertRewrite(new LessThanEqual(new Cast(intSlot, DecimalV3Type.createDecimalV3Type(3,1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new LessThanEqual(new Cast(intSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), new LessThanEqual(intSlot, new IntegerLiteral(12))); // big int - assertRewrite(new EqualTo(new Cast(bigIntSlot, DecimalV3Type.createDecimalV3Type(3,1)), new DecimalV3Literal(new BigDecimal("12.0"))), + assertRewrite(new EqualTo(new Cast(bigIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.0"))), new EqualTo(bigIntSlot, new BigIntLiteral(12L))); - assertRewrite(new EqualTo(new Cast(bigIntSlot, DecimalV3Type.createDecimalV3Type(3,1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new EqualTo(new Cast(bigIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), ExpressionUtils.falseOrNull(bigIntSlot)); - assertRewrite(new NullSafeEqual(new Cast(bigIntSlot, DecimalV3Type.createDecimalV3Type(3,1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new NullSafeEqual(new Cast(bigIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), BooleanLiteral.FALSE); - assertRewrite(new GreaterThan(new Cast(bigIntSlot, DecimalV3Type.createDecimalV3Type(3,1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new GreaterThan(new Cast(bigIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), new GreaterThan(bigIntSlot, new BigIntLiteral(12L))); - assertRewrite(new GreaterThanEqual(new Cast(bigIntSlot, DecimalV3Type.createDecimalV3Type(3,1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new GreaterThanEqual(new Cast(bigIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), new GreaterThanEqual(bigIntSlot, new BigIntLiteral(13L))); - assertRewrite(new LessThan(new Cast(bigIntSlot, DecimalV3Type.createDecimalV3Type(3,1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new LessThan(new Cast(bigIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), new LessThan(bigIntSlot, new BigIntLiteral(13L))); - assertRewrite(new LessThanEqual(new Cast(bigIntSlot, DecimalV3Type.createDecimalV3Type(3,1)), new DecimalV3Literal(new BigDecimal("12.3"))), + assertRewrite(new LessThanEqual(new Cast(bigIntSlot, DecimalV3Type.createDecimalV3Type(3, 1)), new DecimalV3Literal(new BigDecimal("12.3"))), new LessThanEqual(bigIntSlot, new BigIntLiteral(12L))); } From 654edea96c91108d83befda8859ae0828df44522 Mon Sep 17 00:00:00 2001 From: yujun Date: Mon, 23 Dec 2024 12:08:10 +0800 Subject: [PATCH 5/5] fix compile --- .../rules/SimplifyComparisonPredicateTest.java | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicateTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicateTest.java index c51c6d43dbaed7..9df4b1bc4737f8 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicateTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicateTest.java @@ -399,19 +399,19 @@ void testIntCmpDecimalV3Literal() { new LessThanEqual(tinyIntSlot, new TinyIntLiteral((byte) 12))); // tiny int, literal exceeds data type limit - assertRewrite(new EqualTo(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(4,1)), new DecimalV3Literal(new BigDecimal("200.0"))), + assertRewrite(new EqualTo(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(4, 1)), new DecimalV3Literal(new BigDecimal("200.0"))), ExpressionUtils.falseOrNull(tinyIntSlot)); - assertRewrite(new EqualTo(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(4,1)), new DecimalV3Literal(new BigDecimal("200.3"))), + assertRewrite(new EqualTo(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(4, 1)), new DecimalV3Literal(new BigDecimal("200.3"))), ExpressionUtils.falseOrNull(tinyIntSlot)); - assertRewrite(new NullSafeEqual(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(4,1)), new DecimalV3Literal(new BigDecimal("200.3"))), + assertRewrite(new NullSafeEqual(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(4, 1)), new DecimalV3Literal(new BigDecimal("200.3"))), BooleanLiteral.FALSE); - assertRewrite(new GreaterThan(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(4,1)), new DecimalV3Literal(new BigDecimal("200.3"))), + assertRewrite(new GreaterThan(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(4, 1)), new DecimalV3Literal(new BigDecimal("200.3"))), ExpressionUtils.falseOrNull(tinyIntSlot)); - assertRewrite(new GreaterThanEqual(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(4,1)), new DecimalV3Literal(new BigDecimal("200.3"))), + assertRewrite(new GreaterThanEqual(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(4, 1)), new DecimalV3Literal(new BigDecimal("200.3"))), ExpressionUtils.falseOrNull(tinyIntSlot)); - assertRewrite(new LessThan(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(4,1)), new DecimalV3Literal(new BigDecimal("200.3"))), + assertRewrite(new LessThan(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(4, 1)), new DecimalV3Literal(new BigDecimal("200.3"))), ExpressionUtils.trueOrNull(tinyIntSlot)); - assertRewrite(new LessThanEqual(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(4,1)), new DecimalV3Literal(new BigDecimal("200.3"))), + assertRewrite(new LessThanEqual(new Cast(tinyIntSlot, DecimalV3Type.createDecimalV3Type(4, 1)), new DecimalV3Literal(new BigDecimal("200.3"))), ExpressionUtils.trueOrNull(tinyIntSlot)); // small int