diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/transformation/ListPartitionPruner.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/transformation/ListPartitionPruner.java index 698bc97db9773..03571678c12a7 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/transformation/ListPartitionPruner.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/transformation/ListPartitionPruner.java @@ -49,6 +49,7 @@ import com.starrocks.sql.optimizer.rewrite.ScalarOperatorRewriter; import com.starrocks.sql.optimizer.transformer.SqlToScalarOperatorTranslator; import com.starrocks.sql.plan.ScalarOperatorToExpr; +import org.apache.commons.collections.CollectionUtils; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -220,6 +221,62 @@ public List prune() throws AnalysisException { } } + /** + * TODO: support more cases + * Only some simple conjuncts can be pruned + */ + public static boolean canPruneWithConjunct(ScalarOperator conjunct) { + if (conjunct instanceof BinaryPredicateOperator) { + BinaryPredicateOperator bop = conjunct.cast(); + return bop.getBinaryType().isEqualOrRange() && evaluateConstant(bop.getChild(1)) != null; + } else if (conjunct instanceof InPredicateOperator) { + InPredicateOperator inOp = conjunct.cast(); + return !inOp.isNotIn() && inOp.getChildren().stream().skip(1).allMatch(ScalarOperator::isConstant); + } + return false; + } + + /** + * Can we use this conjunct to deduce extract pruneable-conjuncts + * Example: + * - conjunct: dt >= '2024-01-01' + * - generate-expr: month=date_trunc('MONTH', dt) + */ + public static List deduceGenerateColumns(LogicalScanOperator scanOperator) { + List partitionColumnNames = scanOperator.getTable().getPartitionColumnNames(); + if (CollectionUtils.isEmpty(partitionColumnNames)) { + return Lists.newArrayList(); + } + List result = Lists.newArrayList(partitionColumnNames); + + java.util.function.Function slotRefResolver = (slot) -> { + return scanOperator.getColumnNameToColRefMap().get(slot.getColumnName()); + }; + Consumer slotRefConsumer = (slot) -> { + ColumnRefOperator ref = scanOperator.getColumnNameToColRefMap().get(slot.getColumnName()); + slot.setType(ref.getType()); + }; + for (String partitionColumn : partitionColumnNames) { + Column column = scanOperator.getTable().getColumn(partitionColumn); + if (column != null && column.isGeneratedColumn()) { + Expr generatedExpr = column.getGeneratedColumnExpr(scanOperator.getTable().getBaseSchema()); + ExpressionAnalyzer.analyzeExpressionResolveSlot(generatedExpr, ConnectContext.get(), slotRefConsumer); + ScalarOperator call = + SqlToScalarOperatorTranslator.translateWithSlotRef(generatedExpr, slotRefResolver); + + if (call instanceof CallOperator && + ScalarOperatorEvaluator.INSTANCE.isMonotonicFunction((CallOperator) call)) { + List columnRefOperatorList = Utils.extractColumnRef(call); + for (ColumnRefOperator ref : columnRefOperatorList) { + result.add(ref.getName()); + } + } + } + } + + return result; + } + public void prepareDeduceExtraConjuncts(LogicalScanOperator scanOperator) { this.deduceExtraConjuncts = true; this.scanOperator = scanOperator; @@ -393,6 +450,11 @@ private Pair, Boolean> evalPartitionPruneFilter(ScalarOperator operato } private boolean isSinglePartitionColumn(ScalarOperator predicate) { + return isSinglePartitionColumn(predicate, partitionColumnRefs); + } + + private static boolean isSinglePartitionColumn(ScalarOperator predicate, + List partitionColumnRefs) { List columnRefOperatorList = Utils.extractColumnRef(predicate); if (columnRefOperatorList.size() == 1 && partitionColumnRefs.contains(columnRefOperatorList.get(0))) { // such int_part_column + 1 = 11 can't prune partition @@ -405,7 +467,7 @@ private boolean isSinglePartitionColumn(ScalarOperator predicate) { return false; } - private LiteralExpr castLiteralExpr(LiteralExpr literalExpr, Type type) { + private static LiteralExpr castLiteralExpr(LiteralExpr literalExpr, Type type) { LiteralExpr result = null; String value = literalExpr.getStringValue(); if (literalExpr.getType() == Type.DATE && type.isNumericType()) { @@ -441,7 +503,7 @@ private ConcurrentNavigableMap> getCastPartitionValueMap( return newPartitionValueMap; } - private ConstantOperator evaluateConstant(ScalarOperator operator) { + private static ConstantOperator evaluateConstant(ScalarOperator operator) { if (operator.isConstantRef()) { return (ConstantOperator) operator; } diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/statistics/StatisticsCalculator.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/statistics/StatisticsCalculator.java index e0dc88b6a9978..a0c47266eccaf 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/statistics/StatisticsCalculator.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/statistics/StatisticsCalculator.java @@ -14,6 +14,7 @@ package com.starrocks.sql.optimizer.statistics; +import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -139,12 +140,12 @@ import com.starrocks.sql.optimizer.operator.scalar.CastOperator; import com.starrocks.sql.optimizer.operator.scalar.ColumnRefOperator; import com.starrocks.sql.optimizer.operator.scalar.ConstantOperator; -import com.starrocks.sql.optimizer.operator.scalar.InPredicateOperator; import com.starrocks.sql.optimizer.operator.scalar.PredicateOperator; import com.starrocks.sql.optimizer.operator.scalar.ScalarOperator; import com.starrocks.sql.optimizer.operator.scalar.SubfieldOperator; import com.starrocks.sql.optimizer.operator.stream.LogicalBinlogScanOperator; import com.starrocks.sql.optimizer.operator.stream.PhysicalStreamScanOperator; +import com.starrocks.sql.optimizer.rule.transformation.ListPartitionPruner; import com.starrocks.statistic.StatisticUtils; import org.apache.commons.collections4.CollectionUtils; import org.apache.commons.lang3.StringUtils; @@ -190,6 +191,13 @@ public StatisticsCalculator(ExpressionContext expressionContext, this.optimizerContext = optimizerContext; } + @VisibleForTesting + public StatisticsCalculator() { + this.expressionContext = null; + this.columnRefFactory = null; + this.optimizerContext = null; + } + public void estimatorStats() { expressionContext.getOp().accept(this, expressionContext); } @@ -1670,39 +1678,43 @@ public Void visitPhysicalNoCTE(PhysicalNoCTEOperator node, ExpressionContext con } // avoid use partition cols filter rows twice - private ScalarOperator removePartitionPredicate(ScalarOperator predicate, Operator operator, + @VisibleForTesting + public ScalarOperator removePartitionPredicate(ScalarOperator predicate, Operator operator, OptimizerContext optimizerContext) { - if (operator instanceof LogicalIcebergScanOperator && !optimizerContext.isObtainedFromInternalStatistics()) { - LogicalIcebergScanOperator icebergScanOperator = operator.cast(); - List partitionColNames = icebergScanOperator.getTable().getPartitionColumnNames(); + boolean isTableTypeSupported = operator instanceof LogicalIcebergScanOperator || + isOlapScanListPartitionTable(operator); + if (isTableTypeSupported && !optimizerContext.isObtainedFromInternalStatistics()) { + LogicalScanOperator scanOperator = operator.cast(); + List partitionColNames = scanOperator.getTable().getPartitionColumnNames(); + partitionColNames.addAll(ListPartitionPruner.deduceGenerateColumns(scanOperator)); + List conjuncts = Utils.extractConjuncts(predicate); List newPredicates = Lists.newArrayList(); for (ScalarOperator scalarOperator : conjuncts) { - if (scalarOperator instanceof BinaryPredicateOperator) { - BinaryPredicateOperator bop = scalarOperator.cast(); - if (bop.getBinaryType().isEqualOrRange() - && bop.getChild(1).isConstantRef() - && isPartitionCol(bop.getChild(0), partitionColNames)) { - // do nothing - } else { - newPredicates.add(scalarOperator); - } - } else if (scalarOperator instanceof InPredicateOperator) { - InPredicateOperator inOp = scalarOperator.cast(); - if (!inOp.isNotIn() - && inOp.getChildren().stream().skip(1).allMatch(ScalarOperator::isConstant) - && isPartitionCol(inOp.getChild(0), partitionColNames)) { - // do nothing - } else { - newPredicates.add(scalarOperator); - } + boolean isPartitionCol = isPartitionCol(scalarOperator.getChild(0), partitionColNames); + if (isPartitionCol && ListPartitionPruner.canPruneWithConjunct(scalarOperator)) { + // drop this predicate + } else { + newPredicates.add(scalarOperator); } } - return newPredicates.size() < 1 ? ConstantOperator.TRUE : Utils.compoundAnd(newPredicates); + return newPredicates.isEmpty() ? ConstantOperator.TRUE : Utils.compoundAnd(newPredicates); } return predicate; } + // NOTE: Why list partition ? + // The list partition only have one unique value for each partition, but range partition doesn't. + // So only the partition-predicate of list partition can be removed without affect the cardinality estimation + private boolean isOlapScanListPartitionTable(Operator operator) { + if (!(operator instanceof LogicalOlapScanOperator)) { + return false; + } + LogicalOlapScanOperator scan = operator.cast(); + OlapTable table = (OlapTable) scan.getTable(); + return table.getPartitionInfo().isListPartition(); + } + private boolean isPartitionCol(ScalarOperator scalarOperator, Collection partitionColumns) { if (scalarOperator.isColumnRef()) { String colName = ((ColumnRefOperator) scalarOperator).getName(); diff --git a/fe/fe-core/src/test/java/com/starrocks/sql/plan/PartitionPruneTest.java b/fe/fe-core/src/test/java/com/starrocks/sql/plan/PartitionPruneTest.java index 9507d5f400d28..407712ced37c2 100644 --- a/fe/fe-core/src/test/java/com/starrocks/sql/plan/PartitionPruneTest.java +++ b/fe/fe-core/src/test/java/com/starrocks/sql/plan/PartitionPruneTest.java @@ -16,7 +16,15 @@ package com.starrocks.sql.plan; import com.starrocks.common.FeConstants; +import com.starrocks.common.Pair; +import com.starrocks.sql.optimizer.Memo; +import com.starrocks.sql.optimizer.OptimizerContext; +import com.starrocks.sql.optimizer.base.ColumnRefFactory; +import com.starrocks.sql.optimizer.operator.logical.LogicalScanOperator; +import com.starrocks.sql.optimizer.operator.scalar.ScalarOperator; +import com.starrocks.sql.optimizer.statistics.StatisticsCalculator; import com.starrocks.utframe.UtFrameUtils; +import org.junit.Assert; import org.junit.BeforeClass; import org.junit.Test; @@ -202,6 +210,39 @@ public void testNullException() throws Exception { assertCContains(plan, "partitions=0/4"); } + private static Pair buildConjunctAndScan(String sql) throws Exception { + Pair pair = UtFrameUtils.getPlanAndFragment(connectContext, sql); + ExecPlan execPlan = pair.second; + LogicalScanOperator scanOperator = + (LogicalScanOperator) execPlan.getLogicalPlan().getRoot().inputAt(0).inputAt(0).inputAt(0).getOp(); + ScalarOperator predicate = execPlan.getPhysicalPlan().getOp().getPredicate(); + return Pair.create(predicate, scanOperator); + } + + private void testRemovePredicate(String sql, String expected) throws Exception { + Pair pair = buildConjunctAndScan(sql); + StatisticsCalculator calculator = new StatisticsCalculator(); + OptimizerContext context = new OptimizerContext(new Memo(), new ColumnRefFactory()); + ScalarOperator newPredicate = calculator.removePartitionPredicate(pair.first, pair.second, context); + Assert.assertEquals(expected, newPredicate.toString()); + } + + @Test + public void testGeneratedColumnPrune_RemovePredicate() throws Exception { + testRemovePredicate("select * from t_gen_col where c1 = '2024-01-01' ", "true"); + testRemovePredicate("select * from t_gen_col where c1 = '2024-01-01' and c2 > 100", "true"); + testRemovePredicate("select * from t_gen_col where c1 >= '2024-01-01' and c1 <= '2024-01-03' " + + "and c2 > 100", "true"); + testRemovePredicate("select * from t_gen_col where c2 in (1, 2,3)", "true"); + testRemovePredicate("select * from t_gen_col where c2 = cast('123' as int)", "true"); + + // can not be removed + testRemovePredicate("select * from t_gen_col where c1 = random() and c2 > 100", + "cast(1: c1 as double) = random(1)"); + testRemovePredicate("select * from t_gen_col where c2 + 100 > c1 + 1", + "cast(add(2: c2, 100) as double) > add(cast(1: c1 as double), 1)"); + } + @Test public void testGeneratedColumnPrune() throws Exception { // c2 diff --git a/test/sql/test_list_partition/R/test_list_partition_cardinality b/test/sql/test_list_partition/R/test_list_partition_cardinality index 3331bc007ba63..7964664fbeab5 100644 --- a/test/sql/test_list_partition/R/test_list_partition_cardinality +++ b/test/sql/test_list_partition/R/test_list_partition_cardinality @@ -43,7 +43,7 @@ function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi -- result: None -- !result -function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi_column_1 WHERE c1=1', 'cardinality: 2') +function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi_column_1 WHERE c1=1', 'cardinality: 4') -- result: None -- !result @@ -55,7 +55,7 @@ function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi -- result: None -- !result -function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi_column_1 WHERE c1=4', 'cardinality: 500') +function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi_column_1 WHERE c1=4', 'cardinality: 1000') -- result: None -- !result @@ -67,7 +67,7 @@ function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi -- result: None -- !result -function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi_column_1 WHERE c2=2', 'cardinality: 2') +function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi_column_1 WHERE c2=2', 'cardinality: 3') -- result: None -- !result @@ -75,7 +75,7 @@ function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi -- result: None -- !result -function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi_column_1 WHERE c2=7', 'cardinality: 500') +function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi_column_1 WHERE c2=7', 'cardinality: 1000') -- result: None --- !result \ No newline at end of file +-- !result diff --git a/test/sql/test_list_partition/T/test_list_partition_cardinality b/test/sql/test_list_partition/T/test_list_partition_cardinality index c2d13a17b5e90..a1f8c361ebf7c 100644 --- a/test/sql/test_list_partition/T/test_list_partition_cardinality +++ b/test/sql/test_list_partition/T/test_list_partition_cardinality @@ -29,13 +29,13 @@ ANALYZE FULL TABLE partitions_multi_column_1 WITH SYNC MODE; SELECT count(*) FROM partitions_multi_column_1; function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi_column_1 WHERE c1=0', 'EMPTYSET') -function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi_column_1 WHERE c1=1', 'cardinality: 2') +function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi_column_1 WHERE c1=1', 'cardinality: 4') function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi_column_1 WHERE c1=2', 'cardinality: 1') function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi_column_1 WHERE c1=3', 'cardinality: 1') -function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi_column_1 WHERE c1=4', 'cardinality: 500') +function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi_column_1 WHERE c1=4', 'cardinality: 1000') function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi_column_1 WHERE c2=0', 'EMPTYSET') function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi_column_1 WHERE c2=1', 'cardinality: 1') -function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi_column_1 WHERE c2=2', 'cardinality: 2') +function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi_column_1 WHERE c2=2', 'cardinality: 3') function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi_column_1 WHERE c2=3', 'cardinality: 1') -function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi_column_1 WHERE c2=7', 'cardinality: 500') \ No newline at end of file +function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi_column_1 WHERE c2=7', 'cardinality: 1000')