Skip to content

Commit

Permalink
[Enhancement] remove partition predicate from cardinality estimation …
Browse files Browse the repository at this point in the history
…(backport #50631) (#50688)

Co-authored-by: Murphy <[email protected]>
  • Loading branch information
mergify[bot] and murphyatwork committed Sep 6, 2024
1 parent 5d84a78 commit 6db0586
Show file tree
Hide file tree
Showing 5 changed files with 150 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import com.starrocks.sql.optimizer.rewrite.ScalarOperatorRewriter;
import com.starrocks.sql.optimizer.transformer.SqlToScalarOperatorTranslator;
import com.starrocks.sql.plan.ScalarOperatorToExpr;
import org.apache.commons.collections.CollectionUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

Expand Down Expand Up @@ -220,6 +221,62 @@ public List<Long> prune() throws AnalysisException {
}
}

/**
* TODO: support more cases
* Only some simple conjuncts can be pruned
*/
public static boolean canPruneWithConjunct(ScalarOperator conjunct) {
if (conjunct instanceof BinaryPredicateOperator) {
BinaryPredicateOperator bop = conjunct.cast();
return bop.getBinaryType().isEqualOrRange() && evaluateConstant(bop.getChild(1)) != null;
} else if (conjunct instanceof InPredicateOperator) {
InPredicateOperator inOp = conjunct.cast();
return !inOp.isNotIn() && inOp.getChildren().stream().skip(1).allMatch(ScalarOperator::isConstant);
}
return false;
}

/**
* Can we use this conjunct to deduce extract pruneable-conjuncts
* Example:
* - conjunct: dt >= '2024-01-01'
* - generate-expr: month=date_trunc('MONTH', dt)
*/
public static List<String> deduceGenerateColumns(LogicalScanOperator scanOperator) {
List<String> partitionColumnNames = scanOperator.getTable().getPartitionColumnNames();
if (CollectionUtils.isEmpty(partitionColumnNames)) {
return Lists.newArrayList();
}
List<String> result = Lists.newArrayList(partitionColumnNames);

java.util.function.Function<SlotRef, ColumnRefOperator> slotRefResolver = (slot) -> {
return scanOperator.getColumnNameToColRefMap().get(slot.getColumnName());
};
Consumer<SlotRef> slotRefConsumer = (slot) -> {
ColumnRefOperator ref = scanOperator.getColumnNameToColRefMap().get(slot.getColumnName());
slot.setType(ref.getType());
};
for (String partitionColumn : partitionColumnNames) {
Column column = scanOperator.getTable().getColumn(partitionColumn);
if (column != null && column.isGeneratedColumn()) {
Expr generatedExpr = column.getGeneratedColumnExpr(scanOperator.getTable().getBaseSchema());
ExpressionAnalyzer.analyzeExpressionResolveSlot(generatedExpr, ConnectContext.get(), slotRefConsumer);
ScalarOperator call =
SqlToScalarOperatorTranslator.translateWithSlotRef(generatedExpr, slotRefResolver);

if (call instanceof CallOperator &&
ScalarOperatorEvaluator.INSTANCE.isMonotonicFunction((CallOperator) call)) {
List<ColumnRefOperator> columnRefOperatorList = Utils.extractColumnRef(call);
for (ColumnRefOperator ref : columnRefOperatorList) {
result.add(ref.getName());
}
}
}
}

return result;
}

public void prepareDeduceExtraConjuncts(LogicalScanOperator scanOperator) {
this.deduceExtraConjuncts = true;
this.scanOperator = scanOperator;
Expand Down Expand Up @@ -393,6 +450,11 @@ private Pair<Set<Long>, Boolean> evalPartitionPruneFilter(ScalarOperator operato
}

private boolean isSinglePartitionColumn(ScalarOperator predicate) {
return isSinglePartitionColumn(predicate, partitionColumnRefs);
}

private static boolean isSinglePartitionColumn(ScalarOperator predicate,
List<ColumnRefOperator> partitionColumnRefs) {
List<ColumnRefOperator> columnRefOperatorList = Utils.extractColumnRef(predicate);
if (columnRefOperatorList.size() == 1 && partitionColumnRefs.contains(columnRefOperatorList.get(0))) {
// such int_part_column + 1 = 11 can't prune partition
Expand All @@ -405,7 +467,7 @@ private boolean isSinglePartitionColumn(ScalarOperator predicate) {
return false;
}

private LiteralExpr castLiteralExpr(LiteralExpr literalExpr, Type type) {
private static LiteralExpr castLiteralExpr(LiteralExpr literalExpr, Type type) {
LiteralExpr result = null;
String value = literalExpr.getStringValue();
if (literalExpr.getType() == Type.DATE && type.isNumericType()) {
Expand Down Expand Up @@ -441,7 +503,7 @@ private ConcurrentNavigableMap<LiteralExpr, Set<Long>> getCastPartitionValueMap(
return newPartitionValueMap;
}

private ConstantOperator evaluateConstant(ScalarOperator operator) {
private static ConstantOperator evaluateConstant(ScalarOperator operator) {
if (operator.isConstantRef()) {
return (ConstantOperator) operator;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

package com.starrocks.sql.optimizer.statistics;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
Expand Down Expand Up @@ -139,12 +140,12 @@
import com.starrocks.sql.optimizer.operator.scalar.CastOperator;
import com.starrocks.sql.optimizer.operator.scalar.ColumnRefOperator;
import com.starrocks.sql.optimizer.operator.scalar.ConstantOperator;
import com.starrocks.sql.optimizer.operator.scalar.InPredicateOperator;
import com.starrocks.sql.optimizer.operator.scalar.PredicateOperator;
import com.starrocks.sql.optimizer.operator.scalar.ScalarOperator;
import com.starrocks.sql.optimizer.operator.scalar.SubfieldOperator;
import com.starrocks.sql.optimizer.operator.stream.LogicalBinlogScanOperator;
import com.starrocks.sql.optimizer.operator.stream.PhysicalStreamScanOperator;
import com.starrocks.sql.optimizer.rule.transformation.ListPartitionPruner;
import com.starrocks.statistic.StatisticUtils;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
Expand Down Expand Up @@ -190,6 +191,13 @@ public StatisticsCalculator(ExpressionContext expressionContext,
this.optimizerContext = optimizerContext;
}

@VisibleForTesting
public StatisticsCalculator() {
this.expressionContext = null;
this.columnRefFactory = null;
this.optimizerContext = null;
}

public void estimatorStats() {
expressionContext.getOp().accept(this, expressionContext);
}
Expand Down Expand Up @@ -1670,39 +1678,43 @@ public Void visitPhysicalNoCTE(PhysicalNoCTEOperator node, ExpressionContext con
}

// avoid use partition cols filter rows twice
private ScalarOperator removePartitionPredicate(ScalarOperator predicate, Operator operator,
@VisibleForTesting
public ScalarOperator removePartitionPredicate(ScalarOperator predicate, Operator operator,
OptimizerContext optimizerContext) {
if (operator instanceof LogicalIcebergScanOperator && !optimizerContext.isObtainedFromInternalStatistics()) {
LogicalIcebergScanOperator icebergScanOperator = operator.cast();
List<String> partitionColNames = icebergScanOperator.getTable().getPartitionColumnNames();
boolean isTableTypeSupported = operator instanceof LogicalIcebergScanOperator ||
isOlapScanListPartitionTable(operator);
if (isTableTypeSupported && !optimizerContext.isObtainedFromInternalStatistics()) {
LogicalScanOperator scanOperator = operator.cast();
List<String> partitionColNames = scanOperator.getTable().getPartitionColumnNames();
partitionColNames.addAll(ListPartitionPruner.deduceGenerateColumns(scanOperator));

List<ScalarOperator> conjuncts = Utils.extractConjuncts(predicate);
List<ScalarOperator> newPredicates = Lists.newArrayList();
for (ScalarOperator scalarOperator : conjuncts) {
if (scalarOperator instanceof BinaryPredicateOperator) {
BinaryPredicateOperator bop = scalarOperator.cast();
if (bop.getBinaryType().isEqualOrRange()
&& bop.getChild(1).isConstantRef()
&& isPartitionCol(bop.getChild(0), partitionColNames)) {
// do nothing
} else {
newPredicates.add(scalarOperator);
}
} else if (scalarOperator instanceof InPredicateOperator) {
InPredicateOperator inOp = scalarOperator.cast();
if (!inOp.isNotIn()
&& inOp.getChildren().stream().skip(1).allMatch(ScalarOperator::isConstant)
&& isPartitionCol(inOp.getChild(0), partitionColNames)) {
// do nothing
} else {
newPredicates.add(scalarOperator);
}
boolean isPartitionCol = isPartitionCol(scalarOperator.getChild(0), partitionColNames);
if (isPartitionCol && ListPartitionPruner.canPruneWithConjunct(scalarOperator)) {
// drop this predicate
} else {
newPredicates.add(scalarOperator);
}
}
return newPredicates.size() < 1 ? ConstantOperator.TRUE : Utils.compoundAnd(newPredicates);
return newPredicates.isEmpty() ? ConstantOperator.TRUE : Utils.compoundAnd(newPredicates);
}
return predicate;
}

// NOTE: Why list partition ?
// The list partition only have one unique value for each partition, but range partition doesn't.
// So only the partition-predicate of list partition can be removed without affect the cardinality estimation
private boolean isOlapScanListPartitionTable(Operator operator) {
if (!(operator instanceof LogicalOlapScanOperator)) {
return false;
}
LogicalOlapScanOperator scan = operator.cast();
OlapTable table = (OlapTable) scan.getTable();
return table.getPartitionInfo().isListPartition();
}

private boolean isPartitionCol(ScalarOperator scalarOperator, Collection<String> partitionColumns) {
if (scalarOperator.isColumnRef()) {
String colName = ((ColumnRefOperator) scalarOperator).getName();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,15 @@
package com.starrocks.sql.plan;

import com.starrocks.common.FeConstants;
import com.starrocks.common.Pair;
import com.starrocks.sql.optimizer.Memo;
import com.starrocks.sql.optimizer.OptimizerContext;
import com.starrocks.sql.optimizer.base.ColumnRefFactory;
import com.starrocks.sql.optimizer.operator.logical.LogicalScanOperator;
import com.starrocks.sql.optimizer.operator.scalar.ScalarOperator;
import com.starrocks.sql.optimizer.statistics.StatisticsCalculator;
import com.starrocks.utframe.UtFrameUtils;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;

Expand Down Expand Up @@ -202,6 +210,39 @@ public void testNullException() throws Exception {
assertCContains(plan, "partitions=0/4");
}

private static Pair<ScalarOperator, LogicalScanOperator> buildConjunctAndScan(String sql) throws Exception {
Pair<String, ExecPlan> pair = UtFrameUtils.getPlanAndFragment(connectContext, sql);
ExecPlan execPlan = pair.second;
LogicalScanOperator scanOperator =
(LogicalScanOperator) execPlan.getLogicalPlan().getRoot().inputAt(0).inputAt(0).inputAt(0).getOp();
ScalarOperator predicate = execPlan.getPhysicalPlan().getOp().getPredicate();
return Pair.create(predicate, scanOperator);
}

private void testRemovePredicate(String sql, String expected) throws Exception {
Pair<ScalarOperator, LogicalScanOperator> pair = buildConjunctAndScan(sql);
StatisticsCalculator calculator = new StatisticsCalculator();
OptimizerContext context = new OptimizerContext(new Memo(), new ColumnRefFactory());
ScalarOperator newPredicate = calculator.removePartitionPredicate(pair.first, pair.second, context);
Assert.assertEquals(expected, newPredicate.toString());
}

@Test
public void testGeneratedColumnPrune_RemovePredicate() throws Exception {
testRemovePredicate("select * from t_gen_col where c1 = '2024-01-01' ", "true");
testRemovePredicate("select * from t_gen_col where c1 = '2024-01-01' and c2 > 100", "true");
testRemovePredicate("select * from t_gen_col where c1 >= '2024-01-01' and c1 <= '2024-01-03' " +
"and c2 > 100", "true");
testRemovePredicate("select * from t_gen_col where c2 in (1, 2,3)", "true");
testRemovePredicate("select * from t_gen_col where c2 = cast('123' as int)", "true");

// can not be removed
testRemovePredicate("select * from t_gen_col where c1 = random() and c2 > 100",
"cast(1: c1 as double) = random(1)");
testRemovePredicate("select * from t_gen_col where c2 + 100 > c1 + 1",
"cast(add(2: c2, 100) as double) > add(cast(1: c1 as double), 1)");
}

@Test
public void testGeneratedColumnPrune() throws Exception {
// c2
Expand Down
10 changes: 5 additions & 5 deletions test/sql/test_list_partition/R/test_list_partition_cardinality
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi
-- result:
None
-- !result
function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi_column_1 WHERE c1=1', 'cardinality: 2')
function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi_column_1 WHERE c1=1', 'cardinality: 4')
-- result:
None
-- !result
Expand All @@ -55,7 +55,7 @@ function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi
-- result:
None
-- !result
function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi_column_1 WHERE c1=4', 'cardinality: 500')
function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi_column_1 WHERE c1=4', 'cardinality: 1000')
-- result:
None
-- !result
Expand All @@ -67,15 +67,15 @@ function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi
-- result:
None
-- !result
function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi_column_1 WHERE c2=2', 'cardinality: 2')
function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi_column_1 WHERE c2=2', 'cardinality: 3')
-- result:
None
-- !result
function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi_column_1 WHERE c2=3', 'cardinality: 1')
-- result:
None
-- !result
function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi_column_1 WHERE c2=7', 'cardinality: 500')
function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi_column_1 WHERE c2=7', 'cardinality: 1000')
-- result:
None
-- !result
-- !result
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@ ANALYZE FULL TABLE partitions_multi_column_1 WITH SYNC MODE;
SELECT count(*) FROM partitions_multi_column_1;

function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi_column_1 WHERE c1=0', 'EMPTYSET')
function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi_column_1 WHERE c1=1', 'cardinality: 2')
function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi_column_1 WHERE c1=1', 'cardinality: 4')
function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi_column_1 WHERE c1=2', 'cardinality: 1')
function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi_column_1 WHERE c1=3', 'cardinality: 1')
function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi_column_1 WHERE c1=4', 'cardinality: 500')
function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi_column_1 WHERE c1=4', 'cardinality: 1000')

function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi_column_1 WHERE c2=0', 'EMPTYSET')
function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi_column_1 WHERE c2=1', 'cardinality: 1')
function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi_column_1 WHERE c2=2', 'cardinality: 2')
function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi_column_1 WHERE c2=2', 'cardinality: 3')
function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi_column_1 WHERE c2=3', 'cardinality: 1')
function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi_column_1 WHERE c2=7', 'cardinality: 500')
function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi_column_1 WHERE c2=7', 'cardinality: 1000')

0 comments on commit 6db0586

Please sign in to comment.