Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement] remove partition predicate from cardinality estimation (backport #50631) #50688

Merged
merged 1 commit into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import com.starrocks.sql.optimizer.rewrite.ScalarOperatorRewriter;
import com.starrocks.sql.optimizer.transformer.SqlToScalarOperatorTranslator;
import com.starrocks.sql.plan.ScalarOperatorToExpr;
import org.apache.commons.collections.CollectionUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

Expand Down Expand Up @@ -220,6 +221,62 @@ public List<Long> prune() throws AnalysisException {
}
}

/**
* TODO: support more cases
* Only some simple conjuncts can be pruned
*/
public static boolean canPruneWithConjunct(ScalarOperator conjunct) {
if (conjunct instanceof BinaryPredicateOperator) {
BinaryPredicateOperator bop = conjunct.cast();
return bop.getBinaryType().isEqualOrRange() && evaluateConstant(bop.getChild(1)) != null;
} else if (conjunct instanceof InPredicateOperator) {
InPredicateOperator inOp = conjunct.cast();
return !inOp.isNotIn() && inOp.getChildren().stream().skip(1).allMatch(ScalarOperator::isConstant);
}
return false;
}

/**
* Can we use this conjunct to deduce extract pruneable-conjuncts
* Example:
* - conjunct: dt >= '2024-01-01'
* - generate-expr: month=date_trunc('MONTH', dt)
*/
public static List<String> deduceGenerateColumns(LogicalScanOperator scanOperator) {
List<String> partitionColumnNames = scanOperator.getTable().getPartitionColumnNames();
if (CollectionUtils.isEmpty(partitionColumnNames)) {
return Lists.newArrayList();
}
List<String> result = Lists.newArrayList(partitionColumnNames);

java.util.function.Function<SlotRef, ColumnRefOperator> slotRefResolver = (slot) -> {
return scanOperator.getColumnNameToColRefMap().get(slot.getColumnName());
};
Consumer<SlotRef> slotRefConsumer = (slot) -> {
ColumnRefOperator ref = scanOperator.getColumnNameToColRefMap().get(slot.getColumnName());
slot.setType(ref.getType());
};
for (String partitionColumn : partitionColumnNames) {
Column column = scanOperator.getTable().getColumn(partitionColumn);
if (column != null && column.isGeneratedColumn()) {
Expr generatedExpr = column.getGeneratedColumnExpr(scanOperator.getTable().getBaseSchema());
ExpressionAnalyzer.analyzeExpressionResolveSlot(generatedExpr, ConnectContext.get(), slotRefConsumer);
ScalarOperator call =
SqlToScalarOperatorTranslator.translateWithSlotRef(generatedExpr, slotRefResolver);

if (call instanceof CallOperator &&
ScalarOperatorEvaluator.INSTANCE.isMonotonicFunction((CallOperator) call)) {
List<ColumnRefOperator> columnRefOperatorList = Utils.extractColumnRef(call);
for (ColumnRefOperator ref : columnRefOperatorList) {
result.add(ref.getName());
}
}
}
}

return result;
}

public void prepareDeduceExtraConjuncts(LogicalScanOperator scanOperator) {
this.deduceExtraConjuncts = true;
this.scanOperator = scanOperator;
Expand Down Expand Up @@ -393,6 +450,11 @@ private Pair<Set<Long>, Boolean> evalPartitionPruneFilter(ScalarOperator operato
}

private boolean isSinglePartitionColumn(ScalarOperator predicate) {
return isSinglePartitionColumn(predicate, partitionColumnRefs);
}

private static boolean isSinglePartitionColumn(ScalarOperator predicate,
List<ColumnRefOperator> partitionColumnRefs) {
List<ColumnRefOperator> columnRefOperatorList = Utils.extractColumnRef(predicate);
if (columnRefOperatorList.size() == 1 && partitionColumnRefs.contains(columnRefOperatorList.get(0))) {
// such int_part_column + 1 = 11 can't prune partition
Expand All @@ -405,7 +467,7 @@ private boolean isSinglePartitionColumn(ScalarOperator predicate) {
return false;
}

private LiteralExpr castLiteralExpr(LiteralExpr literalExpr, Type type) {
private static LiteralExpr castLiteralExpr(LiteralExpr literalExpr, Type type) {
LiteralExpr result = null;
String value = literalExpr.getStringValue();
if (literalExpr.getType() == Type.DATE && type.isNumericType()) {
Expand Down Expand Up @@ -441,7 +503,7 @@ private ConcurrentNavigableMap<LiteralExpr, Set<Long>> getCastPartitionValueMap(
return newPartitionValueMap;
}

private ConstantOperator evaluateConstant(ScalarOperator operator) {
private static ConstantOperator evaluateConstant(ScalarOperator operator) {
if (operator.isConstantRef()) {
return (ConstantOperator) operator;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

package com.starrocks.sql.optimizer.statistics;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
Expand Down Expand Up @@ -139,12 +140,12 @@
import com.starrocks.sql.optimizer.operator.scalar.CastOperator;
import com.starrocks.sql.optimizer.operator.scalar.ColumnRefOperator;
import com.starrocks.sql.optimizer.operator.scalar.ConstantOperator;
import com.starrocks.sql.optimizer.operator.scalar.InPredicateOperator;
import com.starrocks.sql.optimizer.operator.scalar.PredicateOperator;
import com.starrocks.sql.optimizer.operator.scalar.ScalarOperator;
import com.starrocks.sql.optimizer.operator.scalar.SubfieldOperator;
import com.starrocks.sql.optimizer.operator.stream.LogicalBinlogScanOperator;
import com.starrocks.sql.optimizer.operator.stream.PhysicalStreamScanOperator;
import com.starrocks.sql.optimizer.rule.transformation.ListPartitionPruner;
import com.starrocks.statistic.StatisticUtils;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
Expand Down Expand Up @@ -190,6 +191,13 @@ public StatisticsCalculator(ExpressionContext expressionContext,
this.optimizerContext = optimizerContext;
}

@VisibleForTesting
public StatisticsCalculator() {
this.expressionContext = null;
this.columnRefFactory = null;
this.optimizerContext = null;
}

public void estimatorStats() {
expressionContext.getOp().accept(this, expressionContext);
}
Expand Down Expand Up @@ -1670,39 +1678,43 @@ public Void visitPhysicalNoCTE(PhysicalNoCTEOperator node, ExpressionContext con
}

// avoid use partition cols filter rows twice
private ScalarOperator removePartitionPredicate(ScalarOperator predicate, Operator operator,
@VisibleForTesting
public ScalarOperator removePartitionPredicate(ScalarOperator predicate, Operator operator,
OptimizerContext optimizerContext) {
if (operator instanceof LogicalIcebergScanOperator && !optimizerContext.isObtainedFromInternalStatistics()) {
LogicalIcebergScanOperator icebergScanOperator = operator.cast();
List<String> partitionColNames = icebergScanOperator.getTable().getPartitionColumnNames();
boolean isTableTypeSupported = operator instanceof LogicalIcebergScanOperator ||
isOlapScanListPartitionTable(operator);
if (isTableTypeSupported && !optimizerContext.isObtainedFromInternalStatistics()) {
LogicalScanOperator scanOperator = operator.cast();
List<String> partitionColNames = scanOperator.getTable().getPartitionColumnNames();
partitionColNames.addAll(ListPartitionPruner.deduceGenerateColumns(scanOperator));

List<ScalarOperator> conjuncts = Utils.extractConjuncts(predicate);
List<ScalarOperator> newPredicates = Lists.newArrayList();
for (ScalarOperator scalarOperator : conjuncts) {
if (scalarOperator instanceof BinaryPredicateOperator) {
BinaryPredicateOperator bop = scalarOperator.cast();
if (bop.getBinaryType().isEqualOrRange()
&& bop.getChild(1).isConstantRef()
&& isPartitionCol(bop.getChild(0), partitionColNames)) {
// do nothing
} else {
newPredicates.add(scalarOperator);
}
} else if (scalarOperator instanceof InPredicateOperator) {
InPredicateOperator inOp = scalarOperator.cast();
if (!inOp.isNotIn()
&& inOp.getChildren().stream().skip(1).allMatch(ScalarOperator::isConstant)
&& isPartitionCol(inOp.getChild(0), partitionColNames)) {
// do nothing
} else {
newPredicates.add(scalarOperator);
}
boolean isPartitionCol = isPartitionCol(scalarOperator.getChild(0), partitionColNames);
if (isPartitionCol && ListPartitionPruner.canPruneWithConjunct(scalarOperator)) {
// drop this predicate
} else {
newPredicates.add(scalarOperator);
}
}
return newPredicates.size() < 1 ? ConstantOperator.TRUE : Utils.compoundAnd(newPredicates);
return newPredicates.isEmpty() ? ConstantOperator.TRUE : Utils.compoundAnd(newPredicates);
}
return predicate;
}

// NOTE: Why list partition ?
// The list partition only have one unique value for each partition, but range partition doesn't.
// So only the partition-predicate of list partition can be removed without affect the cardinality estimation
private boolean isOlapScanListPartitionTable(Operator operator) {
if (!(operator instanceof LogicalOlapScanOperator)) {
return false;
}
LogicalOlapScanOperator scan = operator.cast();
OlapTable table = (OlapTable) scan.getTable();
return table.getPartitionInfo().isListPartition();
}

private boolean isPartitionCol(ScalarOperator scalarOperator, Collection<String> partitionColumns) {
if (scalarOperator.isColumnRef()) {
String colName = ((ColumnRefOperator) scalarOperator).getName();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,15 @@
package com.starrocks.sql.plan;

import com.starrocks.common.FeConstants;
import com.starrocks.common.Pair;
import com.starrocks.sql.optimizer.Memo;
import com.starrocks.sql.optimizer.OptimizerContext;
import com.starrocks.sql.optimizer.base.ColumnRefFactory;
import com.starrocks.sql.optimizer.operator.logical.LogicalScanOperator;
import com.starrocks.sql.optimizer.operator.scalar.ScalarOperator;
import com.starrocks.sql.optimizer.statistics.StatisticsCalculator;
import com.starrocks.utframe.UtFrameUtils;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;

Expand Down Expand Up @@ -202,6 +210,39 @@ public void testNullException() throws Exception {
assertCContains(plan, "partitions=0/4");
}

private static Pair<ScalarOperator, LogicalScanOperator> buildConjunctAndScan(String sql) throws Exception {
Pair<String, ExecPlan> pair = UtFrameUtils.getPlanAndFragment(connectContext, sql);
ExecPlan execPlan = pair.second;
LogicalScanOperator scanOperator =
(LogicalScanOperator) execPlan.getLogicalPlan().getRoot().inputAt(0).inputAt(0).inputAt(0).getOp();
ScalarOperator predicate = execPlan.getPhysicalPlan().getOp().getPredicate();
return Pair.create(predicate, scanOperator);
}

private void testRemovePredicate(String sql, String expected) throws Exception {
Pair<ScalarOperator, LogicalScanOperator> pair = buildConjunctAndScan(sql);
StatisticsCalculator calculator = new StatisticsCalculator();
OptimizerContext context = new OptimizerContext(new Memo(), new ColumnRefFactory());
ScalarOperator newPredicate = calculator.removePartitionPredicate(pair.first, pair.second, context);
Assert.assertEquals(expected, newPredicate.toString());
}

@Test
public void testGeneratedColumnPrune_RemovePredicate() throws Exception {
testRemovePredicate("select * from t_gen_col where c1 = '2024-01-01' ", "true");
testRemovePredicate("select * from t_gen_col where c1 = '2024-01-01' and c2 > 100", "true");
testRemovePredicate("select * from t_gen_col where c1 >= '2024-01-01' and c1 <= '2024-01-03' " +
"and c2 > 100", "true");
testRemovePredicate("select * from t_gen_col where c2 in (1, 2,3)", "true");
testRemovePredicate("select * from t_gen_col where c2 = cast('123' as int)", "true");

// can not be removed
testRemovePredicate("select * from t_gen_col where c1 = random() and c2 > 100",
"cast(1: c1 as double) = random(1)");
testRemovePredicate("select * from t_gen_col where c2 + 100 > c1 + 1",
"cast(add(2: c2, 100) as double) > add(cast(1: c1 as double), 1)");
}

@Test
public void testGeneratedColumnPrune() throws Exception {
// c2
Expand Down
10 changes: 5 additions & 5 deletions test/sql/test_list_partition/R/test_list_partition_cardinality
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi
-- result:
None
-- !result
function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi_column_1 WHERE c1=1', 'cardinality: 2')
function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi_column_1 WHERE c1=1', 'cardinality: 4')
-- result:
None
-- !result
Expand All @@ -55,7 +55,7 @@ function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi
-- result:
None
-- !result
function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi_column_1 WHERE c1=4', 'cardinality: 500')
function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi_column_1 WHERE c1=4', 'cardinality: 1000')
-- result:
None
-- !result
Expand All @@ -67,15 +67,15 @@ function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi
-- result:
None
-- !result
function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi_column_1 WHERE c2=2', 'cardinality: 2')
function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi_column_1 WHERE c2=2', 'cardinality: 3')
-- result:
None
-- !result
function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi_column_1 WHERE c2=3', 'cardinality: 1')
-- result:
None
-- !result
function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi_column_1 WHERE c2=7', 'cardinality: 500')
function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi_column_1 WHERE c2=7', 'cardinality: 1000')
-- result:
None
-- !result
-- !result
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@ ANALYZE FULL TABLE partitions_multi_column_1 WITH SYNC MODE;
SELECT count(*) FROM partitions_multi_column_1;

function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi_column_1 WHERE c1=0', 'EMPTYSET')
function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi_column_1 WHERE c1=1', 'cardinality: 2')
function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi_column_1 WHERE c1=1', 'cardinality: 4')
function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi_column_1 WHERE c1=2', 'cardinality: 1')
function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi_column_1 WHERE c1=3', 'cardinality: 1')
function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi_column_1 WHERE c1=4', 'cardinality: 500')
function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi_column_1 WHERE c1=4', 'cardinality: 1000')

function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi_column_1 WHERE c2=0', 'EMPTYSET')
function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi_column_1 WHERE c2=1', 'cardinality: 1')
function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi_column_1 WHERE c2=2', 'cardinality: 2')
function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi_column_1 WHERE c2=2', 'cardinality: 3')
function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi_column_1 WHERE c2=3', 'cardinality: 1')
function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi_column_1 WHERE c2=7', 'cardinality: 500')
function: assert_explain_verbose_contains('SELECT COUNT(*) FROM partitions_multi_column_1 WHERE c2=7', 'cardinality: 1000')
Loading