Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement] Prune group bys and eliminate aggs by UK and FK #52201

Merged
merged 8 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions fe/fe-core/src/main/java/com/starrocks/qe/SessionVariable.java
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ public class SessionVariable implements Serializable, Writable, Cloneable {
public static final String ENABLE_UKFK_JOIN_REORDER = "enable_ukfk_join_reorder";
public static final String MAX_UKFK_JOIN_REORDER_SCALE_RATIO = "max_ukfk_join_reorder_scale_ratio";
public static final String MAX_UKFK_JOIN_REORDER_FK_ROWS = "max_ukfk_join_reorder_fk_rows";
public static final String ENABLE_ELIMINATE_AGG = "enable_eliminate_agg";

// if set to true, some of stmt will be forwarded to leader FE to get result

Expand Down Expand Up @@ -1242,6 +1243,9 @@ public static MaterializedViewRewriteMode parse(String str) {
@VarAttr(name = MAX_UKFK_JOIN_REORDER_FK_ROWS, flag = VariableMgr.INVISIBLE)
private int maxUKFKJoinReorderFKRows = 100000000;

@VarAttr(name = ENABLE_ELIMINATE_AGG)
private boolean enableEliminateAgg = true;

@VariableMgr.VarAttr(name = FORWARD_TO_LEADER, alias = FORWARD_TO_MASTER)
private boolean forwardToLeader = false;

Expand Down Expand Up @@ -2865,6 +2869,14 @@ public boolean isEnableTablePruneOnUpdate() {
return enableTablePruneOnUpdate;
}

public boolean isEnableEliminateAgg() {
return enableEliminateAgg;
}

public void setEnableEliminateAgg(boolean enableEliminateAgg) {
this.enableEliminateAgg = enableEliminateAgg;
}

public boolean isEnableUKFKOpt() {
return enableUKFKOpt;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ public MvPlanContext optimize(MaterializedView mv,
optimizerConfig.disableRule(RuleType.TF_PRUNE_EMPTY_SCAN);
optimizerConfig.disableRule(RuleType.TF_MV_TEXT_MATCH_REWRITE_RULE);
optimizerConfig.disableRule(RuleType.TF_MV_TRANSPARENT_REWRITE_RULE);
optimizerConfig.disableRule(RuleType.TF_ELIMINATE_AGG);
// For sync mv, no rewrite query by original sync mv rule to avoid useless rewrite.
if (mv.getRefreshScheme().isSync()) {
optimizerConfig.disableRule(RuleType.TF_MATERIALIZED_VIEW);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -530,11 +530,12 @@ private OptExpression logicalRuleRewrite(

ruleRewriteIterative(tree, rootTaskContext, new MergeTwoProjectRule());
ruleRewriteOnlyOnce(tree, rootTaskContext, RuleSetType.ELIMINATE_OP_WITH_CONSTANT);
ruleRewriteOnlyOnce(tree, rootTaskContext, EliminateAggRule.getInstance());
ruleRewriteOnlyOnce(tree, rootTaskContext, new PushDownPredicateRankingWindowRule());

ruleRewriteOnlyOnce(tree, rootTaskContext, new ConvertToEqualForNullRule());
ruleRewriteOnlyOnce(tree, rootTaskContext, RuleSetType.PRUNE_COLUMNS);
// Put EliminateAggRule after PRUNE_COLUMNS to give a chance to prune group bys before eliminate aggregations.
ruleRewriteOnlyOnce(tree, rootTaskContext, EliminateAggRule.getInstance());
ruleRewriteIterative(tree, rootTaskContext, RuleSetType.PRUNE_UKFK_JOIN);
deriveLogicalProperty(tree);

Expand Down Expand Up @@ -782,6 +783,7 @@ private OptExpression pushDownAggregation(OptExpression tree, TaskContext rootTa
deriveLogicalProperty(tree);
rootTaskContext.setRequiredColumns(requiredColumns.clone());
ruleRewriteOnlyOnce(tree, rootTaskContext, RuleSetType.PRUNE_COLUMNS);
ruleRewriteOnlyOnce(tree, rootTaskContext, EliminateAggRule.getInstance());
}

return tree;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.
package com.starrocks.sql.optimizer;

import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import com.starrocks.analysis.JoinOperator;
import com.starrocks.catalog.BaseTableInfo;
Expand All @@ -25,6 +26,7 @@
import com.starrocks.sql.optimizer.base.ColumnRefSet;
import com.starrocks.sql.optimizer.operator.Operator;
import com.starrocks.sql.optimizer.operator.UKFKConstraints;
import com.starrocks.sql.optimizer.operator.logical.LogicalAggregationOperator;
import com.starrocks.sql.optimizer.operator.logical.LogicalJoinOperator;
import com.starrocks.sql.optimizer.operator.logical.LogicalOlapScanOperator;
import com.starrocks.sql.optimizer.operator.physical.PhysicalJoinOperator;
Expand All @@ -51,29 +53,67 @@ public static void collectColumnConstraints(OptExpression root) {
if (!ConnectContext.get().getSessionVariable().isEnableUKFKOpt()) {
return;
}
collectColumnConstraintsForce(root);
}

public static void collectColumnConstraintsForce(OptExpression root) {
UKFKConstraintsCollector collector = new UKFKConstraintsCollector();
root.getOp().accept(collector, root, null);
}

@Override
public Void visit(OptExpression optExpression, Void context) {
visitChildren(optExpression, context);
if (optExpression.getConstraints() != null) {
if (!visitChildren(optExpression, context)) {
return null;
}
optExpression.setConstraints(new UKFKConstraints());
return null;
}

private void visitChildren(OptExpression optExpression, Void context) {
@Override
public Void visitLogicalAggregate(OptExpression optExpression, Void context) {
if (!visitChildren(optExpression, context)) {
return null;
}

LogicalAggregationOperator aggOp = optExpression.getOp().cast();

ColumnRefSet outputColumns = optExpression.getRowOutputInfo().getOutputColumnRefSet();
UKFKConstraints childConstraints = optExpression.inputAt(0).getConstraints();

UKFKConstraints constraints = new UKFKConstraints();
constraints.inheritForeignKey(childConstraints, outputColumns);
constraints.inheritRelaxedUniqueKey(childConstraints, outputColumns);

if (!aggOp.isOnlyLocalAggregate()) {
ColumnRefSet groupBys = new ColumnRefSet(aggOp.getGroupingKeys());
if (!groupBys.isEmpty() && outputColumns.containsAll(groupBys)) {
constraints.addAggUniqueKey(
new UKFKConstraints.UniqueConstraintWrapper(null, new ColumnRefSet(), false, groupBys));
}
}

optExpression.setConstraints(constraints);
return null;
}

/**
* Visit each child.
*
* @return true,if the current node needs to recollect constraints.
*/
private boolean visitChildren(OptExpression optExpression, Void context) {
boolean childConstraintsChanged = false;
for (OptExpression child : optExpression.getInputs()) {
UKFKConstraints prevConstraints = child.getConstraints();
child.getOp().accept(this, child, context);
childConstraintsChanged |= !Objects.equals(prevConstraints, child.getConstraints());
}
return childConstraintsChanged || optExpression.getConstraints() == null;
}

private void inheritFromSingleChild(OptExpression optExpression, Void context) {
visitChildren(optExpression, context);
if (optExpression.getConstraints() != null) {
if (!visitChildren(optExpression, context)) {
return;
}
UKFKConstraints childConstraints = optExpression.inputAt(0).getConstraints();
Expand All @@ -85,8 +125,7 @@ private void inheritFromSingleChild(OptExpression optExpression, Void context) {

@Override
public Void visitLogicalTableScan(OptExpression optExpression, Void context) {
visitChildren(optExpression, context);
if (optExpression.getConstraints() != null) {
if (!visitChildren(optExpression, context)) {
return null;
}
if (!(optExpression.getOp() instanceof LogicalOlapScanOperator)) {
Expand All @@ -108,8 +147,7 @@ public Void visitLogicalTableScan(OptExpression optExpression, Void context) {

@Override
public Void visitPhysicalOlapScan(OptExpression optExpression, Void context) {
visitChildren(optExpression, context);
if (optExpression.getConstraints() != null) {
if (!visitChildren(optExpression, context)) {
return null;
}
if (!(optExpression.getOp() instanceof PhysicalOlapScanOperator)) {
Expand All @@ -134,53 +172,38 @@ private void visitOlapTable(OptExpression optExpression, OlapTable table,
if (table.hasUniqueConstraints()) {
List<UniqueConstraint> ukConstraints = table.getUniqueConstraints();
for (UniqueConstraint ukConstraint : ukConstraints) {
// For now, we only handle one column primary key or foreign key
if (ukConstraint.getUniqueColumnNames(table).size() == 1) {
String ukColumn = ukConstraint.getUniqueColumnNames(table).get(0);
ColumnRefSet nonUkColumnRefs = new ColumnRefSet(table.getColumns().stream()
.map(Column::getName)
.filter(columnNameToColRefMap::containsKey)
.filter(name -> !Objects.equals(ukColumn, name))
.map(columnNameToColRefMap::get)
.collect(Collectors.toList()));

ColumnRefOperator columnRefOperator = columnNameToColRefMap.get(ukColumn);
if (columnRefOperator != null && outputColumns.contains(columnRefOperator)) {
constraint.addUniqueKey(columnRefOperator.getId(),
new UKFKConstraints.UniqueConstraintWrapper(ukConstraint,
nonUkColumnRefs, usedColumns.isEmpty()));
constraint.addTableUniqueKey(columnRefOperator.getId(),
new UKFKConstraints.UniqueConstraintWrapper(ukConstraint,
nonUkColumnRefs, usedColumns.isEmpty()));
}
} else {
List<String> ukColNames = ukConstraint.getUniqueColumnNames(table);
boolean containsAllUk = true;
for (String colName : ukColNames) {
ColumnRefOperator columnRefOperator = columnNameToColRefMap.get(colName);
if (columnRefOperator == null || !outputColumns.contains(columnRefOperator)) {
containsAllUk = false;
break;
}
}
List<String> ukColNames = ukConstraint.getUniqueColumnNames(table);
boolean containsAllUk = ukColNames.stream().allMatch(colName ->
columnNameToColRefMap.containsKey(colName) && outputColumns.contains(columnNameToColRefMap.get(colName)));
if (!containsAllUk) {
continue;
}

if (containsAllUk) {
ColumnRefSet nonUkColumnRefs = new ColumnRefSet(table.getColumns().stream()
.map(Column::getName)
.filter(columnNameToColRefMap::containsKey)
.filter(name -> !ukColNames.contains(name))
.map(columnNameToColRefMap::get)
.collect(Collectors.toList()));
for (String colName : ukColNames) {
ColumnRefOperator columnRefOperator = columnNameToColRefMap.get(colName);
constraint.addTableUniqueKey(columnRefOperator.getId(),
new UKFKConstraints.UniqueConstraintWrapper(ukConstraint,
nonUkColumnRefs, usedColumns.isEmpty()));
}
}
ColumnRefSet ukColumnRefs = new ColumnRefSet();
ukColNames.stream()
.map(columnNameToColRefMap::get)
.forEach(ukColumnRefs::union);
ColumnRefSet nonUkColumnRefs = new ColumnRefSet();
table.getColumns().stream()
.map(Column::getName)
.filter(columnNameToColRefMap::containsKey)
.filter(name -> !ukColNames.contains(name))
.map(columnNameToColRefMap::get)
.forEach(nonUkColumnRefs::union);

UKFKConstraints.UniqueConstraintWrapper uk = new UKFKConstraints.UniqueConstraintWrapper(ukConstraint,
nonUkColumnRefs, usedColumns.isEmpty(), ukColumnRefs);
constraint.addAggUniqueKey(uk);

// For now, we only handle one column primary key or foreign key
if (ukColNames.size() == 1) {
Preconditions.checkState(ukColumnRefs.size() == 1,
"the size of ukColumnRefs MUST be the same as that of ukColNames");
constraint.addUniqueKey(ukColumnRefs.getFirstId(), uk);
}
}
}

if (table.hasForeignKeyConstraints()) {
Column firstKeyColumn = table.getKeyColumns().get(0);
ColumnRefOperator firstKeyColumnRef = columnNameToColRefMap.get(firstKeyColumn.getName());
Expand Down Expand Up @@ -235,9 +258,7 @@ public Void visitPhysicalJoin(OptExpression optExpression, Void context) {

private void visitJoinOperator(OptExpression optExpression, Void context, JoinOperator joinType,
ScalarOperator onPredicates) {
visitChildren(optExpression, context);

if (optExpression.getConstraints() != null) {
if (!visitChildren(optExpression, context)) {
return;
}

Expand Down Expand Up @@ -268,15 +289,68 @@ public static UKFKConstraints buildJoinColumnConstraint(Operator operator, JoinO
if (property != null) {
constraint.setJoinProperty(property);

if ((joinType.isLeftSemiJoin() && property.isLeftUK) ||
(joinType.isRightSemiJoin() && !property.isLeftUK)) {
UKFKConstraints ukConstraints = property.isLeftUK ? leftConstraints : rightConstraints;
UKFKConstraints fkConstraints = property.isLeftUK ? rightConstraints : leftConstraints;
List<UKFKConstraints.UniqueConstraintWrapper> ukChildMultiUKs = ukConstraints.getAggUniqueKeys();
List<UKFKConstraints.UniqueConstraintWrapper> fkChildMultiUKs = fkConstraints.getAggUniqueKeys();
ColumnRefSet fkColumnRef = new ColumnRefSet(property.fkColumnRef.getId());

// 1. Inherit unique key constraints.
// If it is a left semi join on the UK child side, all rows of the UK child will be preserved.
boolean inheritUKChildUK = (joinType.isLeftSemiJoin() && property.isLeftUK) ||
(joinType.isRightSemiJoin() && !property.isLeftUK);
if (inheritUKChildUK) {
// The unique property is preserved
if (outputColumns.contains(property.ukColumnRef)) {
constraint.addUniqueKey(property.ukColumnRef.getId(), property.ukConstraint);
}
constraint.inheritAggUniqueKey(ukConstraints, outputColumns);
}

// 2. Inherit aggregate unique key constraints.
// 2.1 from FK child side.
// If it is not an outer join on the UK child side, the FK child side will not produce duplicate rows (NULL rows).
boolean inheritFKChildAggUK = (property.isLeftUK && !joinType.isLeftOuterJoin() && !joinType.isFullOuterJoin()) ||
(!property.isLeftUK && !joinType.isRightOuterJoin() && !joinType.isFullOuterJoin());
if (inheritFKChildAggUK) {
constraint.inheritAggUniqueKey(fkConstraints, outputColumns);
}

// 2.2 form UK child side.
// If it is not an outer join on the FK child side, the UK child side will not produce duplicate rows (NULL rows).
//
// Assumed that fk_table has a unique key (c11, c12) and a foreign key c11 referencing to c22 of uk_table,
// uk_table has two unique keys c21 and c22.
// Then, after INNER Join(fk_table.c11=uk_table.c21), (c12, c21), (c12, c22) are all unique.
boolean inheritUKChildAggUK = (property.isLeftUK && !joinType.isRightOuterJoin() && !joinType.isFullOuterJoin()) ||
(!property.isLeftUK && !joinType.isLeftOuterJoin() && !joinType.isFullOuterJoin());
if (inheritUKChildAggUK) {
for (UKFKConstraints.UniqueConstraintWrapper fkChildMultiUK : fkChildMultiUKs) {
if (!fkChildMultiUK.ukColumnRefs.containsAll(fkColumnRef)) {
continue;
}

ColumnRefSet ukScopedColumnRefs = fkChildMultiUK.ukColumnRefs.clone();
ukScopedColumnRefs.except(fkColumnRef);
if (!outputColumns.containsAll(ukScopedColumnRefs)) {
continue;
}

ukChildMultiUKs.stream()
.filter(uk -> outputColumns.containsAll(uk.ukColumnRefs))
.forEach(uk -> {
ColumnRefSet newUKColumnRefs = uk.ukColumnRefs.clone();
newUKColumnRefs.union(ukScopedColumnRefs);
constraint.addAggUniqueKey(new UKFKConstraints.UniqueConstraintWrapper(null,
uk.nonUKColumnRefs, false, newUKColumnRefs));
});
}
}
}

constraint.inheritRelaxedUniqueKey(leftConstraints, outputColumns);
constraint.inheritRelaxedUniqueKey(rightConstraints, outputColumns);

// All foreign properties can be preserved
constraint.inheritForeignKey(leftConstraints, outputColumns);
constraint.inheritForeignKey(rightConstraints, outputColumns);
Expand Down
Loading
Loading