diff --git a/fe/fe-core/src/main/java/com/starrocks/qe/SessionVariable.java b/fe/fe-core/src/main/java/com/starrocks/qe/SessionVariable.java index 7199d8cdadc2c..b65e7bd8d74b7 100644 --- a/fe/fe-core/src/main/java/com/starrocks/qe/SessionVariable.java +++ b/fe/fe-core/src/main/java/com/starrocks/qe/SessionVariable.java @@ -213,6 +213,7 @@ public class SessionVariable implements Serializable, Writable, Cloneable { public static final String ENABLE_UKFK_JOIN_REORDER = "enable_ukfk_join_reorder"; public static final String MAX_UKFK_JOIN_REORDER_SCALE_RATIO = "max_ukfk_join_reorder_scale_ratio"; public static final String MAX_UKFK_JOIN_REORDER_FK_ROWS = "max_ukfk_join_reorder_fk_rows"; + public static final String ENABLE_ELIMINATE_AGG = "enable_eliminate_agg"; // if set to true, some of stmt will be forwarded to leader FE to get result @@ -1242,6 +1243,9 @@ public static MaterializedViewRewriteMode parse(String str) { @VarAttr(name = MAX_UKFK_JOIN_REORDER_FK_ROWS, flag = VariableMgr.INVISIBLE) private int maxUKFKJoinReorderFKRows = 100000000; + @VarAttr(name = ENABLE_ELIMINATE_AGG) + private boolean enableEliminateAgg = true; + @VariableMgr.VarAttr(name = FORWARD_TO_LEADER, alias = FORWARD_TO_MASTER) private boolean forwardToLeader = false; @@ -2865,6 +2869,14 @@ public boolean isEnableTablePruneOnUpdate() { return enableTablePruneOnUpdate; } + public boolean isEnableEliminateAgg() { + return enableEliminateAgg; + } + + public void setEnableEliminateAgg(boolean enableEliminateAgg) { + this.enableEliminateAgg = enableEliminateAgg; + } + public boolean isEnableUKFKOpt() { return enableUKFKOpt; } diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/MaterializedViewOptimizer.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/MaterializedViewOptimizer.java index c677763d6f6bc..8b5758c6dd834 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/MaterializedViewOptimizer.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/MaterializedViewOptimizer.java @@ -50,6 +50,7 @@ public MvPlanContext optimize(MaterializedView mv, optimizerConfig.disableRule(RuleType.TF_PRUNE_EMPTY_SCAN); optimizerConfig.disableRule(RuleType.TF_MV_TEXT_MATCH_REWRITE_RULE); optimizerConfig.disableRule(RuleType.TF_MV_TRANSPARENT_REWRITE_RULE); + optimizerConfig.disableRule(RuleType.TF_ELIMINATE_AGG); // For sync mv, no rewrite query by original sync mv rule to avoid useless rewrite. if (mv.getRefreshScheme().isSync()) { optimizerConfig.disableRule(RuleType.TF_MATERIALIZED_VIEW); diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/Optimizer.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/Optimizer.java index 9f3b06a392912..0956a0fb38ef5 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/Optimizer.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/Optimizer.java @@ -530,11 +530,12 @@ private OptExpression logicalRuleRewrite( ruleRewriteIterative(tree, rootTaskContext, new MergeTwoProjectRule()); ruleRewriteOnlyOnce(tree, rootTaskContext, RuleSetType.ELIMINATE_OP_WITH_CONSTANT); - ruleRewriteOnlyOnce(tree, rootTaskContext, EliminateAggRule.getInstance()); ruleRewriteOnlyOnce(tree, rootTaskContext, new PushDownPredicateRankingWindowRule()); ruleRewriteOnlyOnce(tree, rootTaskContext, new ConvertToEqualForNullRule()); ruleRewriteOnlyOnce(tree, rootTaskContext, RuleSetType.PRUNE_COLUMNS); + // Put EliminateAggRule after PRUNE_COLUMNS to give a chance to prune group bys before eliminate aggregations. + ruleRewriteOnlyOnce(tree, rootTaskContext, EliminateAggRule.getInstance()); ruleRewriteIterative(tree, rootTaskContext, RuleSetType.PRUNE_UKFK_JOIN); deriveLogicalProperty(tree); @@ -782,6 +783,7 @@ private OptExpression pushDownAggregation(OptExpression tree, TaskContext rootTa deriveLogicalProperty(tree); rootTaskContext.setRequiredColumns(requiredColumns.clone()); ruleRewriteOnlyOnce(tree, rootTaskContext, RuleSetType.PRUNE_COLUMNS); + ruleRewriteOnlyOnce(tree, rootTaskContext, EliminateAggRule.getInstance()); } return tree; diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/UKFKConstraintsCollector.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/UKFKConstraintsCollector.java index 6a5aef0115fe2..eb34216133fd0 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/UKFKConstraintsCollector.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/UKFKConstraintsCollector.java @@ -13,6 +13,7 @@ // limitations under the License. package com.starrocks.sql.optimizer; +import com.google.common.base.Preconditions; import com.google.common.collect.Lists; import com.starrocks.analysis.JoinOperator; import com.starrocks.catalog.BaseTableInfo; @@ -25,6 +26,7 @@ import com.starrocks.sql.optimizer.base.ColumnRefSet; import com.starrocks.sql.optimizer.operator.Operator; import com.starrocks.sql.optimizer.operator.UKFKConstraints; +import com.starrocks.sql.optimizer.operator.logical.LogicalAggregationOperator; import com.starrocks.sql.optimizer.operator.logical.LogicalJoinOperator; import com.starrocks.sql.optimizer.operator.logical.LogicalOlapScanOperator; import com.starrocks.sql.optimizer.operator.physical.PhysicalJoinOperator; @@ -51,29 +53,67 @@ public static void collectColumnConstraints(OptExpression root) { if (!ConnectContext.get().getSessionVariable().isEnableUKFKOpt()) { return; } + collectColumnConstraintsForce(root); + } + + public static void collectColumnConstraintsForce(OptExpression root) { UKFKConstraintsCollector collector = new UKFKConstraintsCollector(); root.getOp().accept(collector, root, null); } @Override public Void visit(OptExpression optExpression, Void context) { - visitChildren(optExpression, context); - if (optExpression.getConstraints() != null) { + if (!visitChildren(optExpression, context)) { return null; } optExpression.setConstraints(new UKFKConstraints()); return null; } - private void visitChildren(OptExpression optExpression, Void context) { + @Override + public Void visitLogicalAggregate(OptExpression optExpression, Void context) { + if (!visitChildren(optExpression, context)) { + return null; + } + + LogicalAggregationOperator aggOp = optExpression.getOp().cast(); + + ColumnRefSet outputColumns = optExpression.getRowOutputInfo().getOutputColumnRefSet(); + UKFKConstraints childConstraints = optExpression.inputAt(0).getConstraints(); + + UKFKConstraints constraints = new UKFKConstraints(); + constraints.inheritForeignKey(childConstraints, outputColumns); + constraints.inheritRelaxedUniqueKey(childConstraints, outputColumns); + + if (!aggOp.isOnlyLocalAggregate()) { + ColumnRefSet groupBys = new ColumnRefSet(aggOp.getGroupingKeys()); + if (!groupBys.isEmpty() && outputColumns.containsAll(groupBys)) { + constraints.addAggUniqueKey( + new UKFKConstraints.UniqueConstraintWrapper(null, new ColumnRefSet(), false, groupBys)); + } + } + + optExpression.setConstraints(constraints); + return null; + } + + /** + * Visit each child. + * + * @return true,if the current node needs to recollect constraints. + */ + private boolean visitChildren(OptExpression optExpression, Void context) { + boolean childConstraintsChanged = false; for (OptExpression child : optExpression.getInputs()) { + UKFKConstraints prevConstraints = child.getConstraints(); child.getOp().accept(this, child, context); + childConstraintsChanged |= !Objects.equals(prevConstraints, child.getConstraints()); } + return childConstraintsChanged || optExpression.getConstraints() == null; } private void inheritFromSingleChild(OptExpression optExpression, Void context) { - visitChildren(optExpression, context); - if (optExpression.getConstraints() != null) { + if (!visitChildren(optExpression, context)) { return; } UKFKConstraints childConstraints = optExpression.inputAt(0).getConstraints(); @@ -85,8 +125,7 @@ private void inheritFromSingleChild(OptExpression optExpression, Void context) { @Override public Void visitLogicalTableScan(OptExpression optExpression, Void context) { - visitChildren(optExpression, context); - if (optExpression.getConstraints() != null) { + if (!visitChildren(optExpression, context)) { return null; } if (!(optExpression.getOp() instanceof LogicalOlapScanOperator)) { @@ -108,8 +147,7 @@ public Void visitLogicalTableScan(OptExpression optExpression, Void context) { @Override public Void visitPhysicalOlapScan(OptExpression optExpression, Void context) { - visitChildren(optExpression, context); - if (optExpression.getConstraints() != null) { + if (!visitChildren(optExpression, context)) { return null; } if (!(optExpression.getOp() instanceof PhysicalOlapScanOperator)) { @@ -134,53 +172,38 @@ private void visitOlapTable(OptExpression optExpression, OlapTable table, if (table.hasUniqueConstraints()) { List ukConstraints = table.getUniqueConstraints(); for (UniqueConstraint ukConstraint : ukConstraints) { - // For now, we only handle one column primary key or foreign key - if (ukConstraint.getUniqueColumnNames(table).size() == 1) { - String ukColumn = ukConstraint.getUniqueColumnNames(table).get(0); - ColumnRefSet nonUkColumnRefs = new ColumnRefSet(table.getColumns().stream() - .map(Column::getName) - .filter(columnNameToColRefMap::containsKey) - .filter(name -> !Objects.equals(ukColumn, name)) - .map(columnNameToColRefMap::get) - .collect(Collectors.toList())); - - ColumnRefOperator columnRefOperator = columnNameToColRefMap.get(ukColumn); - if (columnRefOperator != null && outputColumns.contains(columnRefOperator)) { - constraint.addUniqueKey(columnRefOperator.getId(), - new UKFKConstraints.UniqueConstraintWrapper(ukConstraint, - nonUkColumnRefs, usedColumns.isEmpty())); - constraint.addTableUniqueKey(columnRefOperator.getId(), - new UKFKConstraints.UniqueConstraintWrapper(ukConstraint, - nonUkColumnRefs, usedColumns.isEmpty())); - } - } else { - List ukColNames = ukConstraint.getUniqueColumnNames(table); - boolean containsAllUk = true; - for (String colName : ukColNames) { - ColumnRefOperator columnRefOperator = columnNameToColRefMap.get(colName); - if (columnRefOperator == null || !outputColumns.contains(columnRefOperator)) { - containsAllUk = false; - break; - } - } + List ukColNames = ukConstraint.getUniqueColumnNames(table); + boolean containsAllUk = ukColNames.stream().allMatch(colName -> + columnNameToColRefMap.containsKey(colName) && outputColumns.contains(columnNameToColRefMap.get(colName))); + if (!containsAllUk) { + continue; + } - if (containsAllUk) { - ColumnRefSet nonUkColumnRefs = new ColumnRefSet(table.getColumns().stream() - .map(Column::getName) - .filter(columnNameToColRefMap::containsKey) - .filter(name -> !ukColNames.contains(name)) - .map(columnNameToColRefMap::get) - .collect(Collectors.toList())); - for (String colName : ukColNames) { - ColumnRefOperator columnRefOperator = columnNameToColRefMap.get(colName); - constraint.addTableUniqueKey(columnRefOperator.getId(), - new UKFKConstraints.UniqueConstraintWrapper(ukConstraint, - nonUkColumnRefs, usedColumns.isEmpty())); - } - } + ColumnRefSet ukColumnRefs = new ColumnRefSet(); + ukColNames.stream() + .map(columnNameToColRefMap::get) + .forEach(ukColumnRefs::union); + ColumnRefSet nonUkColumnRefs = new ColumnRefSet(); + table.getColumns().stream() + .map(Column::getName) + .filter(columnNameToColRefMap::containsKey) + .filter(name -> !ukColNames.contains(name)) + .map(columnNameToColRefMap::get) + .forEach(nonUkColumnRefs::union); + + UKFKConstraints.UniqueConstraintWrapper uk = new UKFKConstraints.UniqueConstraintWrapper(ukConstraint, + nonUkColumnRefs, usedColumns.isEmpty(), ukColumnRefs); + constraint.addAggUniqueKey(uk); + + // For now, we only handle one column primary key or foreign key + if (ukColNames.size() == 1) { + Preconditions.checkState(ukColumnRefs.size() == 1, + "the size of ukColumnRefs MUST be the same as that of ukColNames"); + constraint.addUniqueKey(ukColumnRefs.getFirstId(), uk); } } } + if (table.hasForeignKeyConstraints()) { Column firstKeyColumn = table.getKeyColumns().get(0); ColumnRefOperator firstKeyColumnRef = columnNameToColRefMap.get(firstKeyColumn.getName()); @@ -235,9 +258,7 @@ public Void visitPhysicalJoin(OptExpression optExpression, Void context) { private void visitJoinOperator(OptExpression optExpression, Void context, JoinOperator joinType, ScalarOperator onPredicates) { - visitChildren(optExpression, context); - - if (optExpression.getConstraints() != null) { + if (!visitChildren(optExpression, context)) { return; } @@ -268,15 +289,68 @@ public static UKFKConstraints buildJoinColumnConstraint(Operator operator, JoinO if (property != null) { constraint.setJoinProperty(property); - if ((joinType.isLeftSemiJoin() && property.isLeftUK) || - (joinType.isRightSemiJoin() && !property.isLeftUK)) { + UKFKConstraints ukConstraints = property.isLeftUK ? leftConstraints : rightConstraints; + UKFKConstraints fkConstraints = property.isLeftUK ? rightConstraints : leftConstraints; + List ukChildMultiUKs = ukConstraints.getAggUniqueKeys(); + List fkChildMultiUKs = fkConstraints.getAggUniqueKeys(); + ColumnRefSet fkColumnRef = new ColumnRefSet(property.fkColumnRef.getId()); + + // 1. Inherit unique key constraints. + // If it is a left semi join on the UK child side, all rows of the UK child will be preserved. + boolean inheritUKChildUK = (joinType.isLeftSemiJoin() && property.isLeftUK) || + (joinType.isRightSemiJoin() && !property.isLeftUK); + if (inheritUKChildUK) { // The unique property is preserved if (outputColumns.contains(property.ukColumnRef)) { constraint.addUniqueKey(property.ukColumnRef.getId(), property.ukConstraint); } + constraint.inheritAggUniqueKey(ukConstraints, outputColumns); + } + + // 2. Inherit aggregate unique key constraints. + // 2.1 from FK child side. + // If it is not an outer join on the UK child side, the FK child side will not produce duplicate rows (NULL rows). + boolean inheritFKChildAggUK = (property.isLeftUK && !joinType.isLeftOuterJoin() && !joinType.isFullOuterJoin()) || + (!property.isLeftUK && !joinType.isRightOuterJoin() && !joinType.isFullOuterJoin()); + if (inheritFKChildAggUK) { + constraint.inheritAggUniqueKey(fkConstraints, outputColumns); + } + + // 2.2 form UK child side. + // If it is not an outer join on the FK child side, the UK child side will not produce duplicate rows (NULL rows). + // + // Assumed that fk_table has a unique key (c11, c12) and a foreign key c11 referencing to c22 of uk_table, + // uk_table has two unique keys c21 and c22. + // Then, after INNER Join(fk_table.c11=uk_table.c21), (c12, c21), (c12, c22) are all unique. + boolean inheritUKChildAggUK = (property.isLeftUK && !joinType.isRightOuterJoin() && !joinType.isFullOuterJoin()) || + (!property.isLeftUK && !joinType.isLeftOuterJoin() && !joinType.isFullOuterJoin()); + if (inheritUKChildAggUK) { + for (UKFKConstraints.UniqueConstraintWrapper fkChildMultiUK : fkChildMultiUKs) { + if (!fkChildMultiUK.ukColumnRefs.containsAll(fkColumnRef)) { + continue; + } + + ColumnRefSet ukScopedColumnRefs = fkChildMultiUK.ukColumnRefs.clone(); + ukScopedColumnRefs.except(fkColumnRef); + if (!outputColumns.containsAll(ukScopedColumnRefs)) { + continue; + } + + ukChildMultiUKs.stream() + .filter(uk -> outputColumns.containsAll(uk.ukColumnRefs)) + .forEach(uk -> { + ColumnRefSet newUKColumnRefs = uk.ukColumnRefs.clone(); + newUKColumnRefs.union(ukScopedColumnRefs); + constraint.addAggUniqueKey(new UKFKConstraints.UniqueConstraintWrapper(null, + uk.nonUKColumnRefs, false, newUKColumnRefs)); + }); + } } } + constraint.inheritRelaxedUniqueKey(leftConstraints, outputColumns); + constraint.inheritRelaxedUniqueKey(rightConstraints, outputColumns); + // All foreign properties can be preserved constraint.inheritForeignKey(leftConstraints, outputColumns); constraint.inheritForeignKey(rightConstraints, outputColumns); diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/operator/UKFKConstraints.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/operator/UKFKConstraints.java index f5b49ccf65f5e..22a2eea0d137f 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/operator/UKFKConstraints.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/operator/UKFKConstraints.java @@ -24,23 +24,36 @@ import com.starrocks.sql.plan.ScalarOperatorToExpr; import org.apache.hadoop.shaded.com.google.common.collect.Maps; +import java.util.List; import java.util.Map; +import java.util.stream.Stream; public class UKFKConstraints { // ColumnRefOperator::id -> UniqueConstraint + // When propagating constraints from the child node to the parent node, it needs to ensure that the parent node includes + // all outputs of the child node. private final Map uniqueKeys = Maps.newHashMap(); - // The unique key of the data table needs to be collected when eliminating aggregation - private Map tableUniqueKeys = Maps.newHashMap(); + // ColumnRefOperator::id -> UniqueConstraint + // ukColumnRefs no longer satisfies uniqueness, but the relationship between ukColumnRefs and nonUKColumnRefs still + // satisfies, that is, the rows with the same ukColumnRefs must be the same nonUKColumnRefs. + private final Map relaxedUniqueKeys = Maps.newHashMap(); + // aggUniqueKeys contains all unique keys, including multi-column unique keys, while uniqueKeys only contains single-column + // unique keys. It is only used to eliminate aggregation and not to eliminate joins based on unique keys and foreign keys. + // Therefore, when propagating constraints from the child node to the parent node, it is not necessary to ensure that the + // parent node includes all outputs of the child node; only needs to ensure that the child node’s output has no duplicate rows. + // TODO: unify aggUniqueKeys and uniqueKeys, and apply them to eliminate Join. + private final List aggUniqueKeys = Lists.newArrayList(); // ColumnRefOperator::id -> ForeignKeyConstraint private final Map foreignKeys = Maps.newHashMap(); private JoinProperty joinProperty; public void addUniqueKey(int id, UniqueConstraintWrapper uniqueKey) { uniqueKeys.put(id, uniqueKey); + relaxedUniqueKeys.put(id, uniqueKey); } - public void addTableUniqueKey(int id, UniqueConstraintWrapper uniqueKey) { - tableUniqueKeys.put(id, uniqueKey); + public void addAggUniqueKey(UniqueConstraintWrapper uniqueKey) { + aggUniqueKeys.add(uniqueKey); } public void addForeignKey(int id, ForeignKeyConstraintWrapper foreignKey) { @@ -51,24 +64,23 @@ public UniqueConstraintWrapper getUniqueConstraint(Integer id) { return uniqueKeys.get(id); } - public void setTableUniqueKeys(Map tableUniqueKeys) { - this.tableUniqueKeys = tableUniqueKeys; - } - - public Map getTableUniqueKeys() { - return tableUniqueKeys; + public List getAggUniqueKeys() { + return aggUniqueKeys; } public ForeignKeyConstraintWrapper getForeignKeyConstraint(Integer id) { return foreignKeys.get(id); } + public UniqueConstraintWrapper getRelaxedUniqueConstraint(Integer id) { + return relaxedUniqueKeys.get(id); + } + public JoinProperty getJoinProperty() { return joinProperty; } - public void setJoinProperty( - JoinProperty joinProperty) { + public void setJoinProperty(JoinProperty joinProperty) { this.joinProperty = joinProperty; } @@ -77,13 +89,9 @@ public static UKFKConstraints inheritFrom(UKFKConstraints from, ColumnRefSet toO from.uniqueKeys.entrySet().stream() .filter(entry -> toOutputColumns.contains(entry.getKey())) .forEach(entry -> clone.uniqueKeys.put(entry.getKey(), entry.getValue())); - from.foreignKeys.entrySet().stream() - .filter(entry -> toOutputColumns.contains(entry.getKey())) - .forEach(entry -> clone.foreignKeys.put(entry.getKey(), entry.getValue())); - if (!(from.getTableUniqueKeys().isEmpty()) && - toOutputColumns.containsAll(Lists.newArrayList((from.getTableUniqueKeys().keySet())))) { - clone.setTableUniqueKeys(from.getTableUniqueKeys()); - } + clone.inheritForeignKey(from, toOutputColumns); + clone.inheritRelaxedUniqueKey(from, toOutputColumns); + clone.inheritAggUniqueKey(from, toOutputColumns); return clone; } @@ -94,16 +102,31 @@ public void inheritForeignKey(UKFKConstraints other, ColumnRefSet outputColumns) .forEach(entry -> foreignKeys.put(entry.getKey(), entry.getValue())); } + public void inheritRelaxedUniqueKey(UKFKConstraints other, ColumnRefSet outputColumns) { + Stream.concat(other.uniqueKeys.entrySet().stream(), other.relaxedUniqueKeys.entrySet().stream()) + .filter(entry -> outputColumns.contains(entry.getKey())) + .forEach(entry -> relaxedUniqueKeys.put(entry.getKey(), entry.getValue())); + } + + public void inheritAggUniqueKey(UKFKConstraints other, ColumnRefSet outputColumns) { + other.aggUniqueKeys.stream() + .filter(uk -> outputColumns.containsAll(uk.ukColumnRefs)) + .forEach(aggUniqueKeys::add); + } + public static final class UniqueConstraintWrapper { public final UniqueConstraint constraint; public final ColumnRefSet nonUKColumnRefs; public final boolean isIntact; - public UniqueConstraintWrapper(UniqueConstraint constraint, - ColumnRefSet nonUKColumnRefs, boolean isIntact) { + public final ColumnRefSet ukColumnRefs; + + public UniqueConstraintWrapper(UniqueConstraint constraint, ColumnRefSet nonUKColumnRefs, boolean isIntact, + ColumnRefSet ukColumnRefs) { this.constraint = constraint; this.nonUKColumnRefs = nonUKColumnRefs; this.isIntact = isIntact; + this.ukColumnRefs = ukColumnRefs; } } diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/operator/logical/LogicalAggregationOperator.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/operator/logical/LogicalAggregationOperator.java index 40797a3b9698a..8d547013238ee 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/operator/logical/LogicalAggregationOperator.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/operator/logical/LogicalAggregationOperator.java @@ -114,6 +114,10 @@ public void setOnlyLocalAggregate() { isSplit = false; } + public boolean isOnlyLocalAggregate() { + return type.isLocal() && !isSplit; + } + public List getPartitionByColumns() { return partitionByColumns; } diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/RuleSet.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/RuleSet.java index 2889142e28f63..ccd4f8dc7db0e 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/RuleSet.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/RuleSet.java @@ -110,6 +110,7 @@ import com.starrocks.sql.optimizer.rule.transformation.PruneTableFunctionColumnRule; import com.starrocks.sql.optimizer.rule.transformation.PruneTopNColumnsRule; import com.starrocks.sql.optimizer.rule.transformation.PruneTrueFilterRule; +import com.starrocks.sql.optimizer.rule.transformation.PruneUKFKGroupByKeysRule; import com.starrocks.sql.optimizer.rule.transformation.PruneUKFKJoinRule; import com.starrocks.sql.optimizer.rule.transformation.PruneUnionColumnsRule; import com.starrocks.sql.optimizer.rule.transformation.PruneValuesColumnsRule; @@ -299,6 +300,7 @@ public class RuleSet { PruneScanColumnRule.BINLOG_SCAN, new PruneProjectColumnsRule(), new PruneFilterColumnsRule(), + new PruneUKFKGroupByKeysRule(), // Put this before PruneAggregateColumnsRule new PruneAggregateColumnsRule(), new PruneGroupByKeysRule(), new PruneTopNColumnsRule(), diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/RuleType.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/RuleType.java index 5b6eead20c664..ee3430cf4fbee 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/RuleType.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/RuleType.java @@ -97,6 +97,7 @@ public enum RuleType { TF_PRUNE_GROUP_BY_KEYS, TF_PRUNE_SUBFIELD, TF_PRUNE_UKFK_JOIN, + TF_PRUNE_UKFK_GROUP_BY_KEYS, TF_SUBFILED_NOCOPY, TF_PARTITION_COLUMN_MINMAX, diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/join/JoinOrder.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/join/JoinOrder.java index 03f9c5cb696d4..81b614ef03fb1 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/join/JoinOrder.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/join/JoinOrder.java @@ -370,7 +370,7 @@ protected Optional buildJoinExpr(GroupInfo leftGroup, GroupInfo UKFKConstraints.JoinProperty joinProperty = null; SessionVariable sessionVariable = ConnectContext.get().getSessionVariable(); - if (sessionVariable.isEnableUKFKOpt()) { + if (sessionVariable.isEnableUKFKJoinReorder()) { UKFKConstraintsCollector.collectColumnConstraints(leftExprInfo.expr); UKFKConstraintsCollector.collectColumnConstraints(rightExprInfo.expr); UKFKConstraints constraint = UKFKConstraintsCollector.buildJoinColumnConstraint(newJoin, diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/transformation/EliminateAggRule.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/transformation/EliminateAggRule.java index b196d27317e7b..3457663cbb058 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/transformation/EliminateAggRule.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/transformation/EliminateAggRule.java @@ -14,6 +14,7 @@ package com.starrocks.sql.optimizer.rule.transformation; +import com.google.common.collect.ImmutableSet; import com.google.common.collect.Lists; import com.starrocks.analysis.Expr; import com.starrocks.catalog.Function; @@ -24,6 +25,7 @@ import com.starrocks.sql.optimizer.OptExpression; import com.starrocks.sql.optimizer.OptimizerContext; import com.starrocks.sql.optimizer.UKFKConstraintsCollector; +import com.starrocks.sql.optimizer.base.ColumnRefSet; import com.starrocks.sql.optimizer.operator.OperatorType; import com.starrocks.sql.optimizer.operator.UKFKConstraints; import com.starrocks.sql.optimizer.operator.logical.LogicalAggregationOperator; @@ -40,11 +42,9 @@ import java.util.ArrayList; import java.util.HashMap; -import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; -import java.util.stream.Collectors; // When a column used in a SQL query's Group By statement has a unique attribute, aggregation can be eliminated, // and the LogicalAggregationOperator can be replaced with a LogicalProjectOperator. @@ -80,63 +80,54 @@ public static EliminateAggRule getInstance() { private static final EliminateAggRule INSTANCE = new EliminateAggRule(); + private static final Set SUPPORTED_AGG_FUNCTIONS = ImmutableSet.of( + FunctionSet.SUM, FunctionSet.COUNT, FunctionSet.AVG, FunctionSet.FIRST_VALUE, + FunctionSet.MAX, FunctionSet.MIN, FunctionSet.GROUP_CONCAT + ); + @Override public boolean check(OptExpression input, OptimizerContext context) { - LogicalAggregationOperator aggOp = input.getOp().cast(); - List groupKeys = aggOp.getGroupingKeys(); - - for (Map.Entry entry : aggOp.getAggregations().entrySet()) { - if (entry.getValue().isDistinct()) { - return false; - } - String fnName = entry.getValue().getFnName(); - if (!(fnName.equals(FunctionSet.SUM) || fnName.equals(FunctionSet.COUNT) || - fnName.equals(FunctionSet.AVG) || - fnName.equals(FunctionSet.FIRST_VALUE) || - fnName.equals(FunctionSet.MAX) || fnName.equals(FunctionSet.MIN) || - fnName.equals(FunctionSet.GROUP_CONCAT))) { - return false; - } + if (!context.getSessionVariable().isEnableEliminateAgg()) { + return false; } - // collect uk pk key - UKFKConstraintsCollector collector = new UKFKConstraintsCollector(); - input.getOp().accept(collector, input, null); + LogicalAggregationOperator aggOp = input.getOp().cast(); + OptExpression childOpt = input.inputAt(0); - OptExpression childOptExpression = input.inputAt(0); - Map uniqueKeys = - childOptExpression.getConstraints().getTableUniqueKeys(); - if (uniqueKeys.isEmpty()) { + List groupBys = aggOp.getGroupingKeys(); + if (groupBys.isEmpty()) { return false; } - if (uniqueKeys.size() != groupKeys.size()) { + + boolean supportedAllAggFunctions = aggOp.getAggregations().values().stream() + .allMatch(call -> !call.isDistinct() && SUPPORTED_AGG_FUNCTIONS.contains(call.getFnName())); + if (!supportedAllAggFunctions) { return false; } - Set groupColumnRefIds = groupKeys.stream() - .map(ColumnRefOperator::getId) - .collect(Collectors.toSet()); + UKFKConstraintsCollector.collectColumnConstraintsForce(input); - Set uniqueColumnRefIds = new HashSet<>(uniqueKeys.keySet()); - if (!groupColumnRefIds.equals(uniqueColumnRefIds)) { + List uniqueKeys = childOpt.getConstraints().getAggUniqueKeys(); + if (uniqueKeys.isEmpty()) { return false; } - return true; + ColumnRefSet groupByIds = new ColumnRefSet(); + groupBys.stream().map(ColumnRefOperator::getId).forEach(groupByIds::union); + return uniqueKeys.stream().anyMatch(constraint -> groupByIds.containsAll(constraint.ukColumnRefs)); } @Override public List transform(OptExpression input, OptimizerContext context) { LogicalAggregationOperator aggOp = input.getOp().cast(); - Map newProjectMap = new HashMap<>(); + Map newProjectMap = new HashMap<>(); for (Map.Entry entry : aggOp.getAggregations().entrySet()) { ColumnRefOperator aggColumnRef = entry.getKey(); CallOperator callOperator = entry.getValue(); ScalarOperator newOperator = handleAggregationFunction(callOperator.getFnName(), callOperator); newProjectMap.put(aggColumnRef, newOperator); } - aggOp.getGroupingKeys().forEach(ref -> newProjectMap.put(ref, ref)); LogicalProjectOperator newProjectOp = LogicalProjectOperator.builder().setColumnRefMap(newProjectMap).build(); @@ -184,8 +175,7 @@ private ScalarOperator rewriteCastFunction(CallOperator callOperator) { if (callOperator.getType().equals(argument.getType())) { return argument; } - ScalarOperator scalarOperator = new CastOperator(callOperator.getType(), argument); - return scalarOperator; + return new CastOperator(callOperator.getType(), argument); } } diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/transformation/PruneUKFKGroupByKeysRule.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/transformation/PruneUKFKGroupByKeysRule.java new file mode 100644 index 0000000000000..a5ad3b3cb24e5 --- /dev/null +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/transformation/PruneUKFKGroupByKeysRule.java @@ -0,0 +1,139 @@ +// Copyright 2021-present StarRocks, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.starrocks.sql.optimizer.rule.transformation; + +import com.google.api.client.util.Sets; +import com.google.common.collect.Lists; +import com.starrocks.sql.optimizer.OptExpression; +import com.starrocks.sql.optimizer.OptimizerContext; +import com.starrocks.sql.optimizer.UKFKConstraintsCollector; +import com.starrocks.sql.optimizer.Utils; +import com.starrocks.sql.optimizer.base.ColumnRefSet; +import com.starrocks.sql.optimizer.operator.AggType; +import com.starrocks.sql.optimizer.operator.OperatorType; +import com.starrocks.sql.optimizer.operator.UKFKConstraints; +import com.starrocks.sql.optimizer.operator.logical.LogicalAggregationOperator; +import com.starrocks.sql.optimizer.operator.logical.LogicalProjectOperator; +import com.starrocks.sql.optimizer.operator.pattern.Pattern; +import com.starrocks.sql.optimizer.operator.scalar.ColumnRefOperator; +import com.starrocks.sql.optimizer.operator.scalar.ScalarOperator; +import com.starrocks.sql.optimizer.rule.RuleType; + +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; + +public class PruneUKFKGroupByKeysRule extends TransformationRule { + public PruneUKFKGroupByKeysRule() { + super(RuleType.TF_PRUNE_UKFK_GROUP_BY_KEYS, Pattern.create(OperatorType.LOGICAL_AGGR, OperatorType.LOGICAL_PROJECT)); + } + + @Override + public boolean check(OptExpression input, OptimizerContext context) { + if (!context.getSessionVariable().isEnableUKFKOpt()) { + return false; + } + UKFKConstraintsCollector.collectColumnConstraints(input); + return input.getConstraints() != null; + } + + @Override + public List transform(OptExpression aggOpt, OptimizerContext context) { + LogicalAggregationOperator aggOp = aggOpt.getOp().cast(); + OptExpression projectOpt = aggOpt.getInputs().get(0); + LogicalProjectOperator projectOp = projectOpt.getOp().cast(); + + UKFKConstraints constraints = aggOpt.getConstraints(); + ColumnRefSet requiredOutputColumns = context.getTaskContext().getRequiredColumns(); + List groupBys = aggOp.getGroupingKeys(); + + // Retrieve non-UK columns from constraints that contains the UK column used in the GROUP BY clause. + Set groupBysToRemove = Sets.newHashSet(); + Set ukGroupBys = Sets.newHashSet(); + for (ColumnRefOperator groupBy : groupBys) { + if (groupBysToRemove.contains(groupBy)) { + continue; + } + + UKFKConstraints.UniqueConstraintWrapper constraint = constraints.getRelaxedUniqueConstraint(groupBy.getId()); + + if (constraint == null) { + continue; + } + + int prevSize = groupBysToRemove.size(); + getGroupBysToRemoveByConstraint(constraint, requiredOutputColumns, aggOp, projectOp, ukGroupBys, groupBysToRemove); + if (groupBysToRemove.size() > prevSize) { + ukGroupBys.add(groupBy); + } + } + + if (groupBysToRemove.isEmpty()) { + return Lists.newArrayList(); + } + + List newPartitionColumns = aggOp.getPartitionByColumns().stream() + .filter(columnRefOperator -> !groupBysToRemove.contains(columnRefOperator)) + .collect(Collectors.toList()); + List newGroupBys = aggOp.getGroupingKeys().stream() + .filter(columnRefOperator -> !groupBysToRemove.contains(columnRefOperator)) + .collect(Collectors.toList()); + + LogicalAggregationOperator newAggOperator = new LogicalAggregationOperator.Builder().withOperator(aggOp) + .setType(AggType.GLOBAL) + .setGroupingKeys(newGroupBys) + .setPartitionByColumns(newPartitionColumns) + .build(); + OptExpression result = OptExpression.create(newAggOperator, aggOpt.getInputs()); + + return Lists.newArrayList(result); + } + + /** + * Remove group by columns that are non-UK columns and not used in the parent project operator. + */ + private void getGroupBysToRemoveByConstraint(UKFKConstraints.UniqueConstraintWrapper constraint, + ColumnRefSet requiredOutputColumns, + LogicalAggregationOperator aggOp, + LogicalProjectOperator projectOp, + Set ukGroupBys, + Set groupBysToRemove) { + ColumnRefSet nonUKColumnRefs = constraint.nonUKColumnRefs; + if (nonUKColumnRefs.isEmpty()) { + return; + } + + for (ColumnRefOperator groupBy : aggOp.getGroupingKeys()) { + if (requiredOutputColumns.contains(groupBy) || ukGroupBys.contains(groupBy)) { + continue; + } + + if (nonUKColumnRefs.contains(groupBy)) { + groupBysToRemove.add(groupBy); + } else { + ScalarOperator inputOp = projectOp.getColumnRefMap().get(groupBy); + ColumnRefSet usedColumns = inputOp.getUsedColumns(); + // If the expression that the group by column comes from only uses one column, and the expression will always + // produce the same output when the input is the same, it can also be pruned. + // such as group by substr(col, 1, 2) + if (usedColumns.size() == 1 + && nonUKColumnRefs.contains(usedColumns.getFirstId()) + && !Utils.hasNonDeterministicFunc(inputOp)) { + groupBysToRemove.add(groupBy); + } + } + } + } +} diff --git a/fe/fe-core/src/test/java/com/starrocks/planner/MaterializedViewRewriteWithSSBTest.java b/fe/fe-core/src/test/java/com/starrocks/planner/MaterializedViewRewriteWithSSBTest.java index 785e5bc490daf..cf20fec3e115c 100644 --- a/fe/fe-core/src/test/java/com/starrocks/planner/MaterializedViewRewriteWithSSBTest.java +++ b/fe/fe-core/src/test/java/com/starrocks/planner/MaterializedViewRewriteWithSSBTest.java @@ -35,6 +35,7 @@ public static void beforeClass() throws Exception { createTables("sql/ssb/", Lists.newArrayList("customer", "dates", "supplier", "part", "lineorder")); connectContext.getSessionVariable().setMaterializedViewRewriteMode("force"); connectContext.getSessionVariable().setEnableMaterializedViewPushDownRewrite(true); + connectContext.getSessionVariable().setEnableEliminateAgg(false); } @Test diff --git a/fe/fe-core/src/test/java/com/starrocks/planner/MaterializedViewTest.java b/fe/fe-core/src/test/java/com/starrocks/planner/MaterializedViewTest.java index 3cb6ff3250f57..168b8ed2cd472 100644 --- a/fe/fe-core/src/test/java/com/starrocks/planner/MaterializedViewTest.java +++ b/fe/fe-core/src/test/java/com/starrocks/planner/MaterializedViewTest.java @@ -403,6 +403,8 @@ public static void beforeClass() throws Exception { GlobalStateMgr globalStateMgr = connectContext.getGlobalStateMgr(); OlapTable t7 = (OlapTable) globalStateMgr.getLocalMetastore().getDb(MATERIALIZED_DB_NAME).getTable("emps"); setTableStatistics(t7, 6000000); + + connectContext.getSessionVariable().setEnableEliminateAgg(false); } @Test diff --git a/fe/fe-core/src/test/java/com/starrocks/sql/plan/AggregateWithUKFKTest.java b/fe/fe-core/src/test/java/com/starrocks/sql/plan/AggregateWithUKFKTest.java new file mode 100644 index 0000000000000..412e007679610 --- /dev/null +++ b/fe/fe-core/src/test/java/com/starrocks/sql/plan/AggregateWithUKFKTest.java @@ -0,0 +1,499 @@ +// Copyright 2021-present StarRocks, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.starrocks.sql.plan; + +import org.junit.BeforeClass; +import org.junit.Test; + +public class AggregateWithUKFKTest extends PlanTestBase { + @BeforeClass + public static void beforeClass() throws Exception { + PlanTestBase.beforeClass(); + + starRocksAssert.withTable("CREATE TABLE `tt2` (\n" + + " `c21` int NULL,\n" + + " `c22` int NULL,\n" + + " `c23` int NULL,\n" + + " `c24` int NULL,\n" + + " `c25` int NULL,\n" + + " `c26` int NULL\n" + + ") ENGINE=OLAP\n" + + "DUPLICATE KEY(`c21`)\n" + + "DISTRIBUTED BY HASH(`c21`) BUCKETS 48\n" + + "PROPERTIES (\n" + + "\"replication_num\" = \"1\",\n" + + "\"unique_constraints\" = \"c21;c22;c23;c24;c25\"\n" + + ");"); + starRocksAssert.withTable("CREATE TABLE `tt1` (\n" + + " `c11` int NULL,\n" + + " `c12` int NULL,\n" + + " `c13` int NULL,\n" + + " `c14` int NULL,\n" + + " `c15` int NULL,\n" + + " `c16` int NULL\n" + + ") ENGINE=OLAP\n" + + "DUPLICATE KEY(`c11`)\n" + + "DISTRIBUTED BY HASH(`c11`) BUCKETS 48\n" + + "PROPERTIES (\n" + + "\"replication_num\" = \"1\",\n" + + "\"unique_constraints\" = \"c11;c12;c13,c14;c11,c12,c13,c14\",\n" + + "\"foreign_key_constraints\" = \"(c13) REFERENCES tt2(c23);(c15) REFERENCES tt2(c25);\"\n" + + ");"); + + connectContext.getSessionVariable().setEnableUKFKOpt(true); + } + + @Test + public void testEliminateAgg1() throws Exception { + String sql = "SELECT \n" + + " id, \n" + + " SUM(big_value) AS sum_big_value\n" + + "FROM \n" + + " test_agg_group_single_unique_key\n" + + "GROUP BY \n" + + " id\n" + + "ORDER BY \n" + + " id;"; + String plan = getVerboseExplain(sql); + assertContains(plan, " 1:Project\n" + + " | output columns:\n" + + " | 1 <-> [1: id, INT, false]\n" + + " | 6 <-> [2: big_value, BIGINT, true]\n" + + " | cardinality: 1\n" + + " | \n" + + " 0:OlapScanNode\n" + + " table: test_agg_group_single_unique_key, rollup: test_agg_group_single_unique_key\n" + + " preAggregation: off. Reason: None aggregate function\n"); + + sql = "SELECT \n" + + " id, \n" + + " COUNT(varchar_value) AS count_varchar_value\n" + + "FROM \n" + + " test_agg_group_single_unique_key\n" + + "GROUP BY \n" + + " id\n" + + "ORDER BY \n" + + " id;"; + plan = getVerboseExplain(sql); + assertContains(plan, " 1:Project\n" + + " | output columns:\n" + + " | 1 <-> [1: id, INT, false]\n" + + " | 6 <-> if[(5: varchar_value IS NULL, 0, 1); " + + "args: BOOLEAN,INT,INT; result: TINYINT; args nullable: false; result nullable: true]\n" + + " | cardinality: 1\n" + + " | \n" + + " 0:OlapScanNode\n" + + " table: test_agg_group_single_unique_key, rollup: test_agg_group_single_unique_key\n" + + " preAggregation: off. Reason: None aggregate function\n"); + + sql = "SELECT\n" + + " id,\n" + + " big_value,\n" + + " AVG(decimal_value) AS avg_decimal_value\n" + + "FROM\n" + + " test_agg_group_multi_unique_key\n" + + "GROUP BY\n" + + " id, big_value\n" + + "ORDER BY\n" + + " id;"; + plan = getVerboseExplain(sql); + assertContains(plan, " 1:Project\n" + + " | output columns:\n" + + " | 1 <-> [1: id, INT, false]\n" + + " | 2 <-> [2: big_value, BIGINT, true]\n" + + " | 6 <-> cast([4: decimal_value, DECIMAL64(10,5), true] as DECIMAL128(38,11))\n" + + " | cardinality: 1\n" + + " | \n" + + " 0:OlapScanNode\n" + + " table: test_agg_group_multi_unique_key, rollup: test_agg_group_multi_unique_key\n" + + " preAggregation: off. Reason: None aggregate function\n"); + + sql = "SELECT\n" + + " id,\n" + + " big_value,\n" + + " COUNT(varchar_value) AS count_varchar_value\n" + + "FROM\n" + + " test_agg_group_multi_unique_key\n" + + "GROUP BY\n" + + " id, big_value\n" + + "ORDER BY\n" + + " id;"; + plan = getVerboseExplain(sql); + assertContains(plan, " 1:Project\n" + + " | output columns:\n" + + " | 1 <-> [1: id, INT, false]\n" + + " | 2 <-> [2: big_value, BIGINT, true]\n" + + " | 6 <-> if[(5: varchar_value IS NULL, 0, 1); args: BOOLEAN,INT,INT;" + + " result: TINYINT; args nullable: false; result nullable: true]\n" + + " | cardinality: 1\n" + + " | \n" + + " 0:OlapScanNode\n" + + " table: test_agg_group_multi_unique_key, rollup: test_agg_group_multi_unique_key\n" + + " preAggregation: off. Reason: None aggregate function\n"); + sql = "SELECT\n" + + " id,\n" + + " big_value,\n" + + " COUNT(varchar_value) AS count_varchar_value\n" + + "FROM test_agg_group_multi_unique_key\n" + + "GROUP BY id, big_value\n" + + "HAVING COUNT(varchar_value) > 0;"; + plan = getFragmentPlan(sql); + assertContains(plan, " PREDICATES: if(5: varchar_value IS NULL, 0, 1) > 0"); + sql = "SELECT\n" + + " id,\n" + + " big_value,\n" + + " COUNT(varchar_value) AS count_varchar_value\n" + + "FROM test_agg_group_multi_unique_key\n" + + "GROUP BY id, big_value\n" + + "HAVING SUM(varchar_value) > 0 and id + 2 > 5;"; + plan = getFragmentPlan(sql); + assertContains(plan, "CAST(5: varchar_value AS DOUBLE) > 0.0, 1: id > 3"); + } + + @Test + public void testEliminateAgg2() throws Exception { + String sql; + String plan; + + sql = "select c11, c12, c13, c14, c15, c16, sum(c11) from tt1 group by c11, c12, c13, c14, c15, c16"; + plan = getFragmentPlan(sql); + assertContains(plan, "\n" + + " 1:Project\n" + + " | : 1: c11\n" + + " | : 2: c12\n" + + " | : 3: c13\n" + + " | : 4: c14\n" + + " | : 5: c15\n" + + " | : 6: c16\n" + + " | : CAST(1: c11 AS BIGINT)\n" + + " | \n" + + " 0:OlapScanNode"); + + sql = "select c13, c14, c15, c16, sum(c11) from tt1 group by c13, c14, c15, c16"; + plan = getFragmentPlan(sql); + assertContains(plan, "\n" + + " 1:Project\n" + + " | : 3: c13\n" + + " | : 4: c14\n" + + " | : 5: c15\n" + + " | : 6: c16\n" + + " | : CAST(1: c11 AS BIGINT)\n" + + " | \n" + + " 0:OlapScanNode"); + + // UK (c13, c14) is not satisified. + sql = "select c14, c15, c16, sum(c11) from tt1 group by c14, c15, c16"; + plan = getFragmentPlan(sql); + assertContains(plan, " 1:AGGREGATE (update finalize)\n" + + " | output: sum(1: c11)\n" + + " | group by: 4: c14, 5: c15, 6: c16\n" + + " | \n" + + " 0:OlapScanNode"); + } + + @Test + public void testEliminateAggAfterPruneGroupBys() throws Exception { + String sql; + String plan; + + sql = "select sum(c11) from tt1 group by c11, c12, c13, c14, c15, c16"; + plan = getFragmentPlan(sql); + assertContains(plan, "\n" + + " 1:Project\n" + + " | : CAST(1: c11 AS BIGINT)\n" + + " | \n" + + " 0:OlapScanNode"); + + sql = "select c11, c12, c13, sum(c11) from tt1 group by c11, c12, c13, c14, c15, c16"; + plan = getFragmentPlan(sql); + assertContains(plan, "\n" + + " 1:Project\n" + + " | : 1: c11\n" + + " | : 2: c12\n" + + " | : 3: c13\n" + + " | : CAST(1: c11 AS BIGINT)\n" + + " | \n" + + " 0:OlapScanNode"); + + sql = "select c11, c12, c13, c14, sum(c11) from tt1 group by c11, c12, c13, c14, c15, c16"; + plan = getFragmentPlan(sql); + assertContains(plan, "\n" + + " 1:Project\n" + + " | : 1: c11\n" + + " | : 2: c12\n" + + " | : 3: c13\n" + + " | : 4: c14\n" + + " | : CAST(1: c11 AS BIGINT)\n" + + " | \n" + + " 0:OlapScanNode"); + } + + @Test + public void testEliminateAggAfterJoinFromUKChild() throws Exception { + String sql; + String plan; + + sql = "select c14, c21, sum(c11) from tt1 join tt2 on c13=c23 group by c14, c21"; + plan = getFragmentPlan(sql); + assertContains(plan, " 4:Project\n" + + " | : 4: c14\n" + + " | : 7: c21\n" + + " | : CAST(1: c11 AS BIGINT)\n" + + " | \n" + + " 3:HASH JOIN\n" + + " | join op: INNER JOIN (BROADCAST)\n" + + " | colocate: false, reason: \n" + + " | equal join conjunct: 3: c13 = 9: c23\n" + + " | \n" + + " |----2:EXCHANGE\n" + + " | \n" + + " 0:OlapScanNode"); + + sql = "select sum(c11) from tt1 join tt2 on c13=c23 group by c14, c21"; + plan = getFragmentPlan(sql); + assertContains(plan, " 4:Project\n" + + " | : CAST(1: c11 AS BIGINT)\n" + + " | \n" + + " 3:HASH JOIN\n" + + " | join op: INNER JOIN (BROADCAST)\n" + + " | colocate: false, reason: \n" + + " | equal join conjunct: 3: c13 = 9: c23\n" + + " | \n" + + " |----2:EXCHANGE\n" + + " | \n" + + " 0:OlapScanNode"); + + sql = "with w1 as (select c15, c14, sum(c11) as c11 from tt1 group by c15, c14)" + + " select c14, c21, sum(c11) from w1 join tt2 on c15=c25 group by c14, c21"; + plan = getFragmentPlan(sql); + assertContains(plan, " 5:Project\n" + + " | : 4: c14\n" + + " | : 8: c21\n" + + " | : 7: sum\n" + + " | \n" + + " 4:HASH JOIN"); + + sql = "with w1 as (select c15, c14, sum(c11) as c11 from tt1 group by c15, c14)" + + " select sum(c11) from w1 join tt2 on c15=c25 group by c14, c21"; + plan = getFragmentPlan(sql); + assertContains(plan, " 6:Project\n" + + " | : 7: sum\n" + + " | \n" + + " 5:HASH JOIN"); + } + + + @Test + public void testCannotEliminateAggAfterJoinFromUKChild() throws Exception { + String sql; + String plan; + + sql = "with w1 as (select c15, c14, sum(c11) as c11 from tt1 group by c15, c14)" + + " select c14, c21, sum(c11) from w1 left join tt2 on c15=c25 group by c14, c21"; + plan = getFragmentPlan(sql); + assertContains(plan, " 6:AGGREGATE (update serialize)\n" + + " | STREAMING\n" + + " | output: sum(7: sum)\n" + + " | group by: 4: c14, 8: c21\n" + + " | \n" + + " 5:Project\n" + + " | : 4: c14\n" + + " | : 7: sum\n" + + " | : 8: c21\n" + + " | \n" + + " 4:HASH JOIN"); + + sql = "with w1 as (select c15, c14, sum(c11) as c11 from tt1 group by c15, c14)" + + " select c14, c21, sum(c11) from w1 full join tt2 on c15=c25 group by c14, c21"; + plan = getFragmentPlan(sql); + assertContains(plan, " 7:AGGREGATE (update serialize)\n" + + " | STREAMING\n" + + " | output: sum(7: sum)\n" + + " | group by: 4: c14, 8: c21\n" + + " | \n" + + " 6:Project\n" + + " | : 4: c14\n" + + " | : 7: sum\n" + + " | : 8: c21\n" + + " | \n" + + " 5:HASH JOIN"); + } + + + @Test + public void testEliminateAggAfterJoinFromFKChild() throws Exception { + String sql; + String plan; + + sql = "select c11, c12, c14, c15, c16, sum(c11) from tt1 join tt2 on c13=c23 group by c11, c12, c14, c15, c16"; + plan = getFragmentPlan(sql); + assertContains(plan, " 1:Project\n" + + " | : 1: c11\n" + + " | : 2: c12\n" + + " | : 4: c14\n" + + " | : 5: c15\n" + + " | : 6: c16\n" + + " | : CAST(1: c11 AS BIGINT)\n" + + " | \n" + + " 0:OlapScanNode"); + + sql = "select sum(c11) from tt1 join tt2 on c13=c23 group by c11, c12, c13, c14, c15, c16"; + plan = getFragmentPlan(sql); + assertContains(plan, " 1:Project\n" + + " | : CAST(1: c11 AS BIGINT)\n" + + " | \n" + + " 0:OlapScanNode"); + } + + + @Test + public void testCannotEliminateAggAfterJoinFromFKChild() throws Exception { + String sql; + String plan; + + sql = "select c11, c12, c14, c15, c16, sum(c11) from tt1 right join tt2 on c13=c23 group by c11, c12, c14, c15, c16"; + plan = getFragmentPlan(sql); + assertContains(plan, " 6:AGGREGATE (update serialize)\n" + + " | STREAMING\n" + + " | output: sum(1: c11)\n" + + " | group by: 1: c11, 2: c12, 4: c14, 5: c15, 6: c16\n" + + " | \n" + + " 5:Project\n" + + " | : 1: c11\n" + + " | : 2: c12\n" + + " | : 4: c14\n" + + " | : 5: c15\n" + + " | : 6: c16\n" + + " | \n" + + " 4:HASH JOIN"); + + + sql = "select c11, c12, c14, c15, c16, sum(c11) from tt1 full join tt2 on c13=c23 group by c11, c12, c14, c15, c16"; + plan = getFragmentPlan(sql); + assertContains(plan, "\n" + + " 6:AGGREGATE (update serialize)\n" + + " | STREAMING\n" + + " | output: sum(1: c11)\n" + + " | group by: 1: c11, 2: c12, 4: c14, 5: c15, 6: c16\n" + + " | \n" + + " 5:Project\n" + + " | : 1: c11\n" + + " | : 2: c12\n" + + " | : 4: c14\n" + + " | : 5: c15\n" + + " | : 6: c16\n" + + " | \n" + + " 4:HASH JOIN"); + } + + @Test + public void testEliminateAggAfterAgg() throws Exception { + String sql; + String plan; + + sql = "select c15, c16 from (" + + "select c15, c16 from tt1 group by c15, c16" + + ")t group by c15, c16"; + plan = getFragmentPlan(sql); + assertContains(plan, " RESULT SINK\n" + + "\n" + + " 1:AGGREGATE (update finalize)\n" + + " | group by: 5: c15, 6: c16\n" + + " | \n" + + " 0:OlapScanNode"); + + sql = "select c15, c15+1 from (" + + "select c15 from tt1 group by c15" + + ")t group by 1, 2"; + plan = getFragmentPlan(sql); + assertContains(plan, " RESULT SINK\n" + + "\n" + + " 2:Project\n" + + " | : 5: c15\n" + + " | : CAST(5: c15 AS BIGINT) + 1\n" + + " | \n" + + " 1:AGGREGATE (update finalize)\n" + + " | group by: 5: c15\n" + + " | \n" + + " 0:OlapScanNode"); + + sql = "select c16, sum_c11, count(sum_c11) from (" + + "select c16, sum(c11) as sum_c11 from tt1 group by c16" + + ")t group by c16, sum_c11"; + plan = getFragmentPlan(sql); + assertContains(plan, " RESULT SINK\n" + + "\n" + + " 2:Project\n" + + " | : 6: c16\n" + + " | : 7: sum\n" + + " | : if(7: sum IS NULL, 0, 1)\n" + + " | \n" + + " 1:AGGREGATE (update finalize)\n" + + " | output: sum(1: c11)\n" + + " | group by: 6: c16\n" + + " | \n" + + " 0:OlapScanNode"); + + sql = "select c16, count(sum_c11) from (" + + "select c15, c16, sum(c11) as sum_c11 from tt1 group by c15, c16 " + + ")t group by c16"; + plan = getFragmentPlan(sql); + assertContains(plan, " 3:AGGREGATE (update serialize)\n" + + " | STREAMING\n" + + " | output: count(7: sum)\n" + + " | group by: 6: c16\n" + + " | \n" + + " 2:Project\n" + + " | : 6: c16\n" + + " | : 7: sum\n" + + " | \n" + + " 1:AGGREGATE (update finalize)\n" + + " | output: sum(1: c11)\n" + + " | group by: 5: c15, 6: c16\n" + + " | \n" + + " 0:OlapScanNode"); + + sql = "select c15, c16, count(sum_c11) from (" + + "select c11, c15, c16, sum(c11) as sum_c11 from tt1 group by c11, c15, c16 " + + ")t group by c15, c16"; + plan = getFragmentPlan(sql); + assertContains(plan, " RESULT SINK\n" + + "\n" + + " 2:AGGREGATE (update finalize)\n" + + " | output: count(7: sum)\n" + + " | group by: 5: c15, 6: c16\n" + + " | \n" + + " 1:Project\n" + + " | : 5: c15\n" + + " | : 6: c16\n" + + " | : CAST(1: c11 AS BIGINT)\n" + + " | \n" + + " 0:OlapScanNode"); + + sql = "select c11, c16, count(sum_c11) from (" + + "select c11, c15, c16, sum(c11) as sum_c11 from tt1 group by c11, c15, c16 " + + ")t group by c11, c16"; + plan = getFragmentPlan(sql); + assertContains(plan, " RESULT SINK\n" + + "\n" + + " 1:Project\n" + + " | : 1: c11\n" + + " | : 6: c16\n" + + " | : if(CAST(1: c11 AS BIGINT) IS NULL, 0, 1)\n" + + " | \n" + + " 0:OlapScanNode"); + } + +} diff --git a/fe/fe-core/src/test/java/com/starrocks/sql/plan/InsertPlanTest.java b/fe/fe-core/src/test/java/com/starrocks/sql/plan/InsertPlanTest.java index 47eec58427390..ba7abf6bcd7b9 100644 --- a/fe/fe-core/src/test/java/com/starrocks/sql/plan/InsertPlanTest.java +++ b/fe/fe-core/src/test/java/com/starrocks/sql/plan/InsertPlanTest.java @@ -675,7 +675,9 @@ public void testInsertExchange() throws Exception { " 1:Project\n" + " | : 1: pk\n" + " | : CAST(2: v1 AS VARCHAR)\n" + - " | : 3: v2\n"); + " | : 3: v2\n" + + " | \n" + + " 0:OlapScanNode"); } { // KesType is AGG_KEYS diff --git a/fe/fe-core/src/test/java/com/starrocks/sql/plan/LowCardinalityTest2.java b/fe/fe-core/src/test/java/com/starrocks/sql/plan/LowCardinalityTest2.java index 845e3ad300def..04c912f25dfec 100644 --- a/fe/fe-core/src/test/java/com/starrocks/sql/plan/LowCardinalityTest2.java +++ b/fe/fe-core/src/test/java/com/starrocks/sql/plan/LowCardinalityTest2.java @@ -1931,7 +1931,7 @@ public void testLowCardForLimit() throws Exception { @Test public void testProjectAggregate() throws Exception { - String sql = "SELECT DISTINCT x1, x2 from (" + + String sql = "SELECT /*+SET_VAR(enable_eliminate_agg=false)*/ DISTINCT x1, x2 from (" + " SELECT lower(t_a_0.`c`) AS x1, t_a_0.`c` AS x2 " + " FROM (select distinct upper(S_ADDRESS) c from supplier) t_a_0) t_a_1;"; diff --git a/fe/fe-core/src/test/java/com/starrocks/sql/plan/PlanFragmentWithCostTest.java b/fe/fe-core/src/test/java/com/starrocks/sql/plan/PlanFragmentWithCostTest.java index 4cf1d66731486..be3efdff83f46 100644 --- a/fe/fe-core/src/test/java/com/starrocks/sql/plan/PlanFragmentWithCostTest.java +++ b/fe/fe-core/src/test/java/com/starrocks/sql/plan/PlanFragmentWithCostTest.java @@ -1549,7 +1549,9 @@ public List getColumnStatistics(Table table, List colum }; boolean prevEnableLocalShuffleAgg = connectContext.getSessionVariable().isEnableLocalShuffleAgg(); + boolean prevEliminateAgg = connectContext.getSessionVariable().isEnableEliminateAgg(); connectContext.getSessionVariable().setEnableLocalShuffleAgg(true); + connectContext.getSessionVariable().setEnableEliminateAgg(false); String sql; String plan; @@ -1728,6 +1730,7 @@ public List getColumnStatistics(Table table, List colum " 2:EXCHANGE"); } finally { connectContext.getSessionVariable().setEnableLocalShuffleAgg(prevEnableLocalShuffleAgg); + connectContext.getSessionVariable().setEnableEliminateAgg(prevEliminateAgg); } } diff --git a/fe/fe-core/src/test/java/com/starrocks/sql/plan/PruneUKFKJoinRuleTest.java b/fe/fe-core/src/test/java/com/starrocks/sql/plan/PruneUKFKJoinRuleTest.java index 27ee16afb27bf..cd7cd1e172667 100644 --- a/fe/fe-core/src/test/java/com/starrocks/sql/plan/PruneUKFKJoinRuleTest.java +++ b/fe/fe-core/src/test/java/com/starrocks/sql/plan/PruneUKFKJoinRuleTest.java @@ -81,8 +81,11 @@ public void tearDown() { private void assertPlans(String query, boolean equals, String... patterns) throws Exception { connectContext.getSessionVariable().setEnableUKFKOpt(false); String planDisabled = getFragmentPlan(query); + System.out.println(planDisabled); + System.out.println("==================="); connectContext.getSessionVariable().setEnableUKFKOpt(true); String planEnabled = getFragmentPlan(query); + System.out.println(planEnabled); if (equals) { Assert.assertEquals(planDisabled, planEnabled); } else { @@ -173,7 +176,8 @@ public void testQ3() throws Exception { @Test public void testQ4() throws Exception { - assertPlans(Q04, true); + assertPlans(Q04, false, "group by: \\d+: c_customer_id, \\d+: c_first_name, \\d+: c_last_name, " + + "\\d+: c_preferred_cust_flag, \\d+: c_birth_country, \\d+: c_login, \\d+: c_email_address, \\d+: d_year"); } @Test @@ -208,7 +212,8 @@ public void testQ10() throws Exception { @Test public void testQ11() throws Exception { - assertPlans(Q11, true); + assertPlans(Q11, false, "group by: \\d+: c_customer_id, \\d+: c_first_name, \\d+: c_last_name, " + + "\\d+: c_preferred_cust_flag, \\d+: c_birth_country, \\d+: c_login, \\d+: c_email_address, \\d+: d_year"); } @Test @@ -270,8 +275,10 @@ public void testQ22() throws Exception { @Test public void testQ23() throws Exception { - assertPlans(Q23_1, false, "[0-9]+: ss_customer_sk = [0-9]+: c_customer_sk"); - assertPlans(Q23_2, false, "[0-9]+: ss_customer_sk = [0-9]+: c_customer_sk"); + assertPlans(Q23_1, false, "[0-9]+: ss_customer_sk = [0-9]+: c_customer_sk", + "[0-9]+: i_item_sk = [0-9]+: ss_item_sk"); + assertPlans(Q23_2, false, "[0-9]+: ss_customer_sk = [0-9]+: c_customer_sk", + "[0-9]+: i_item_sk = [0-9]+: ss_item_sk"); } @Test diff --git a/fe/fe-core/src/test/java/com/starrocks/sql/plan/SetTest.java b/fe/fe-core/src/test/java/com/starrocks/sql/plan/SetTest.java index 6206b23b67efc..356d4ed5f844b 100644 --- a/fe/fe-core/src/test/java/com/starrocks/sql/plan/SetTest.java +++ b/fe/fe-core/src/test/java/com/starrocks/sql/plan/SetTest.java @@ -706,112 +706,4 @@ public void testUnionToValues() throws Exception { " 3 | 4\n" + " 5 | 6\n"); } - - @Test - public void testEliminateAgg() throws Exception { - connectContext.getSessionVariable().setOptimizerExecuteTimeout(-1); - String sql = "SELECT \n" + - " id, \n" + - " SUM(big_value) AS sum_big_value\n" + - "FROM \n" + - " test_agg_group_single_unique_key\n" + - "GROUP BY \n" + - " id\n" + - "ORDER BY \n" + - " id;"; - String plan = getVerboseExplain(sql); - assertContains(plan, " 1:Project\n" + - " | output columns:\n" + - " | 1 <-> [1: id, INT, false]\n" + - " | 6 <-> [2: big_value, BIGINT, true]\n" + - " | cardinality: 1\n" + - " | \n" + - " 0:OlapScanNode\n" + - " table: test_agg_group_single_unique_key, rollup: test_agg_group_single_unique_key\n" + - " preAggregation: off. Reason: None aggregate function\n"); - - sql = "SELECT \n" + - " id, \n" + - " COUNT(varchar_value) AS count_varchar_value\n" + - "FROM \n" + - " test_agg_group_single_unique_key\n" + - "GROUP BY \n" + - " id\n" + - "ORDER BY \n" + - " id;"; - plan = getVerboseExplain(sql); - assertContains(plan, " 1:Project\n" + - " | output columns:\n" + - " | 1 <-> [1: id, INT, false]\n" + - " | 6 <-> if[(5: varchar_value IS NULL, 0, 1); " + - "args: BOOLEAN,INT,INT; result: TINYINT; args nullable: false; result nullable: true]\n" + - " | cardinality: 1\n" + - " | \n" + - " 0:OlapScanNode\n" + - " table: test_agg_group_single_unique_key, rollup: test_agg_group_single_unique_key\n" + - " preAggregation: off. Reason: None aggregate function\n"); - - sql = "SELECT\n" + - " id,\n" + - " big_value,\n" + - " AVG(decimal_value) AS avg_decimal_value\n" + - "FROM\n" + - " test_agg_group_multi_unique_key\n" + - "GROUP BY\n" + - " id, big_value\n" + - "ORDER BY\n" + - " id;"; - plan = getVerboseExplain(sql); - assertContains(plan, " 1:Project\n" + - " | output columns:\n" + - " | 1 <-> [1: id, INT, false]\n" + - " | 2 <-> [2: big_value, BIGINT, true]\n" + - " | 6 <-> cast([4: decimal_value, DECIMAL64(10,5), true] as DECIMAL128(38,11))\n" + - " | cardinality: 1\n" + - " | \n" + - " 0:OlapScanNode\n" + - " table: test_agg_group_multi_unique_key, rollup: test_agg_group_multi_unique_key\n" + - " preAggregation: off. Reason: None aggregate function\n"); - - sql = "SELECT\n" + - " id,\n" + - " big_value,\n" + - " COUNT(varchar_value) AS count_varchar_value\n" + - "FROM\n" + - " test_agg_group_multi_unique_key\n" + - "GROUP BY\n" + - " id, big_value\n" + - "ORDER BY\n" + - " id;"; - plan = getVerboseExplain(sql); - assertContains(plan, " 1:Project\n" + - " | output columns:\n" + - " | 1 <-> [1: id, INT, false]\n" + - " | 2 <-> [2: big_value, BIGINT, true]\n" + - " | 6 <-> if[(5: varchar_value IS NULL, 0, 1); args: BOOLEAN,INT,INT;" + - " result: TINYINT; args nullable: false; result nullable: true]\n" + - " | cardinality: 1\n" + - " | \n" + - " 0:OlapScanNode\n" + - " table: test_agg_group_multi_unique_key, rollup: test_agg_group_multi_unique_key\n" + - " preAggregation: off. Reason: None aggregate function\n"); - sql = "SELECT\n" + - " id,\n" + - " big_value,\n" + - " COUNT(varchar_value) AS count_varchar_value\n" + - "FROM test_agg_group_multi_unique_key\n" + - "GROUP BY id, big_value\n" + - "HAVING COUNT(varchar_value) > 0;"; - plan = getFragmentPlan(sql); - assertContains(plan, " PREDICATES: if(5: varchar_value IS NULL, 0, 1) > 0"); - sql = "SELECT\n" + - " id,\n" + - " big_value,\n" + - " COUNT(varchar_value) AS count_varchar_value\n" + - "FROM test_agg_group_multi_unique_key\n" + - "GROUP BY id, big_value\n" + - "HAVING SUM(varchar_value) > 0 and id + 2 > 5;"; - plan = getFragmentPlan(sql); - assertContains(plan, "CAST(5: varchar_value AS DOUBLE) > 0.0, 1: id > 3"); - } } diff --git a/fe/fe-core/src/test/java/com/starrocks/sql/plan/TPCDSAggregateWithUKFKTest.java b/fe/fe-core/src/test/java/com/starrocks/sql/plan/TPCDSAggregateWithUKFKTest.java new file mode 100644 index 0000000000000..c187f1fd97ae9 --- /dev/null +++ b/fe/fe-core/src/test/java/com/starrocks/sql/plan/TPCDSAggregateWithUKFKTest.java @@ -0,0 +1,227 @@ +// Copyright 2021-present StarRocks, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.starrocks.sql.plan; + +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +public class TPCDSAggregateWithUKFKTest extends TPCDS1TTestBase { + + @BeforeAll + public static void beforeClass() throws Exception { + TPCDS1TTestBase.beforeClass(); + prepareUniqueKeys(); + prepareForeignKeys(); + + connectContext.getSessionVariable().setEnableUKFKOpt(true); + } + + @Test + public void testPruneGroupBys() throws Exception { + String sql; + String plan; + + // Prune itemdesc and i_class in Aggregation, and then eliminate item table. + sql = "select item_sk, cnt from (\n" + + " select substr(i_item_desc, 1, 30) itemdesc, i_item_sk item_sk, i_class, d_date solddate, count(*) cnt\n" + + " from store_sales\n" + + " , date_dim\n" + + " , item\n" + + " where ss_sold_date_sk = d_date_sk\n" + + " and ss_item_sk = i_item_sk\n" + + " and d_year in (2000, 2000 + 1, 2000 + 2, 2000 + 3)\n" + + " group by substr(i_item_desc, 1, 30), i_item_sk, i_class, d_date\n" + + " having count(*) > 4\n" + + ")t;\n"; + plan = getFragmentPlan(sql); + assertNotContains(plan, "TABLE: item"); + assertContains(plan, "group by: 52: i_item_sk, 26: d_date"); + + // Cannot prune itemdesc in Aggregation, since it is used as an output column. + // Prune i_class in Aggregation. + sql = "select item_sk, itemdesc, cnt from (\n" + + " select substr(i_item_desc, 1, 30) itemdesc, i_item_sk item_sk, i_class, d_date solddate, count(*) cnt\n" + + " from store_sales\n" + + " , date_dim\n" + + " , item\n" + + " where ss_sold_date_sk = d_date_sk\n" + + " and ss_item_sk = i_item_sk\n" + + " and d_year in (2000, 2000 + 1, 2000 + 2, 2000 + 3)\n" + + " group by substr(i_item_desc, 1, 30), i_item_sk, i_class, d_date\n" + + " having count(*) > 4\n" + + ")t;\n"; + plan = getFragmentPlan(sql); + assertContains(plan, "TABLE: item"); + assertContains(plan, "group by: 74: substr, 52: i_item_sk, 26: d_date"); + + // Cannot prune itemdesc and i_class in Aggregation, since they are used as output columns. + sql = "select item_sk, itemdesc, i_class, cnt from (\n" + + " select substr(i_item_desc, 1, 30) itemdesc, i_item_sk item_sk, i_class, d_date solddate, count(*) cnt\n" + + " from store_sales\n" + + " , date_dim\n" + + " , item\n" + + " where ss_sold_date_sk = d_date_sk\n" + + " and ss_item_sk = i_item_sk\n" + + " and d_year in (2000, 2000 + 1, 2000 + 2, 2000 + 3)\n" + + " group by substr(i_item_desc, 1, 30), i_item_sk, i_class, d_date\n" + + " having count(*) > 4\n" + + ")t;\n"; + plan = getFragmentPlan(sql); + assertContains(plan, "TABLE: item"); + assertContains(plan, "74: substr, 52: i_item_sk, 62: i_class, 26: d_date"); + + // Prune itemdesc and i_class in Aggregation, and then eliminate item table. + sql = "select item_sk, cnt from (\n" + + " select substr(i_item_desc, 1, 30) itemdesc, i_item_sk item_sk, i_class, d_date solddate, count(*) cnt\n" + + " from store_sales\n" + + " , date_dim\n" + + " , item\n" + + " where ss_sold_date_sk = d_date_sk\n" + + " and ss_item_sk = i_item_sk\n" + + " and d_year in (2000, 2000 + 1, 2000 + 2, 2000 + 3)\n" + + " group by substr(i_item_desc, 1, 30), i_item_sk, i_class, d_date_sk, d_date, d_current_month\n" + + " having count(*) > 4\n" + + ")t;\n"; + plan = getFragmentPlan(sql); + assertNotContains(plan, "TABLE: item"); + assertContains(plan, "group by: 52: i_item_sk, 24: d_date_sk"); + } + + @Test + public void testEliminateAggRule() throws Exception { + String sql; + String plan; + + // Agg (Group by ss_customer_sk, d_year) -> Join(ss_customer_sk=c_customer_sk) -> Agg(Group by c_customer_id, d_year, ...) + // Agg(Group by c_customer_id, d_year, ...) could be eliminated. + sql = "\n" + + "with w1 as (\n" + + " select \n" + + " sum(((ss_ext_list_price-ss_ext_wholesale_cost-ss_ext_discount_amt)+ss_ext_sales_price)/2) year_total,\n" + + " ss_customer_sk,\n" + + " d_year\n" + + " from store_sales, date_dim\n" + + " where ss_sold_date_sk = d_date_sk and d_year between 2001 and 2002\n" + + " group by ss_customer_sk, d_year\n" + + ")\n" + + "select c_customer_id customer_id\n" + + " ,c_first_name customer_first_name\n" + + " ,c_last_name customer_last_name\n" + + " ,c_preferred_cust_flag customer_preferred_cust_flag\n" + + " ,c_birth_country customer_birth_country\n" + + " ,c_login customer_login\n" + + " ,c_email_address customer_email_address\n" + + " ,d_year dyear\n" + + " ,sum(year_total) year_total\n" + + " ,'s' sale_type\n" + + "from customer\n" + + " , w1\n" + + "where c_customer_sk = ss_customer_sk\n" + + "group by c_customer_id\n" + + " ,c_first_name\n" + + " ,c_last_name\n" + + " ,c_preferred_cust_flag\n" + + " ,c_birth_country\n" + + " ,c_login\n" + + " ,c_email_address\n" + + " ,d_year"; + plan = getFragmentPlan(sql); + assertContains(plan, " 11:Project\n" + + " | : 2: c_customer_id\n" + + " | : 9: c_first_name\n" + + " | : 10: c_last_name\n" + + " | : 11: c_preferred_cust_flag\n" + + " | : 15: c_birth_country\n" + + " | : 16: c_login\n" + + " | : 17: c_email_address\n" + + " | : 48: d_year\n" + + " | : 71: sum\n" + + " | : 's'\n" + + " | \n" + + " 10:HASH JOIN\n" + + " | join op: INNER JOIN (BROADCAST)\n" + + " | colocate: false, reason: \n" + + " | equal join conjunct: 23: ss_customer_sk = 1: c_customer_sk\n" + + " | \n" + + " |----9:EXCHANGE\n" + + " | \n" + + " 7:AGGREGATE (merge finalize)\n" + + " | output: sum(71: sum)\n" + + " | group by: 23: ss_customer_sk, 48: d_year"); + + // Agg (Group by ss_customer_sk, d_year) -> Join(ss_customer_sk=c_customer_sk) -> Agg(Group by c_customer_id, ...) + // Agg(Group by c_customer_id, ...) could not be eliminated. + sql = "\n" + + "with w1 as (\n" + + " select \n" + + " sum(((ss_ext_list_price-ss_ext_wholesale_cost-ss_ext_discount_amt)+ss_ext_sales_price)/2) year_total,\n" + + " ss_customer_sk,\n" + + " d_year\n" + + " from store_sales, date_dim\n" + + " where ss_sold_date_sk = d_date_sk and d_year between 2001 and 2002\n" + + " group by ss_customer_sk, d_year\n" + + ")\n" + + "select c_customer_id customer_id\n" + + " ,c_first_name customer_first_name\n" + + " ,c_last_name customer_last_name\n" + + " ,c_preferred_cust_flag customer_preferred_cust_flag\n" + + " ,c_birth_country customer_birth_country\n" + + " ,c_login customer_login\n" + + " ,c_email_address customer_email_address\n" + + " ,sum(year_total) year_total\n" + + " ,'s' sale_type\n" + + "from customer\n" + + " , w1\n" + + "where c_customer_sk = ss_customer_sk\n" + + "group by c_customer_id\n" + + " ,c_first_name\n" + + " ,c_last_name\n" + + " ,c_preferred_cust_flag\n" + + " ,c_birth_country\n" + + " ,c_login\n" + + " ,c_email_address\n"; + plan = getFragmentPlan(sql); + assertContains(plan, " 13:AGGREGATE (update serialize)\n" + + " | STREAMING\n" + + " | output: sum(71: sum)\n" + + " | group by: 2: c_customer_id, 9: c_first_name, 10: c_last_name, " + + "11: c_preferred_cust_flag, 15: c_birth_country, 16: c_login, 17: c_email_address\n" + + " | \n" + + " 12:Project\n" + + " | : 2: c_customer_id\n" + + " | : 9: c_first_name\n" + + " | : 10: c_last_name\n" + + " | : 11: c_preferred_cust_flag\n" + + " | : 15: c_birth_country\n" + + " | : 16: c_login\n" + + " | : 17: c_email_address\n" + + " | : 71: sum\n" + + " | \n" + + " 11:HASH JOIN\n" + + " | join op: INNER JOIN (BROADCAST)\n" + + " | colocate: false, reason: \n" + + " | equal join conjunct: 23: ss_customer_sk = 1: c_customer_sk\n" + + " | \n" + + " |----10:EXCHANGE\n" + + " | \n" + + " 8:Project\n" + + " | : 23: ss_customer_sk\n" + + " | : 71: sum\n" + + " | \n" + + " 7:AGGREGATE (merge finalize)\n" + + " | output: sum(71: sum)\n" + + " | group by: 23: ss_customer_sk, 48: d_year"); + } +} diff --git a/fe/fe-core/src/test/resources/sql/tpcds_constraints/AddUniqueKeys.sql b/fe/fe-core/src/test/resources/sql/tpcds_constraints/AddUniqueKeys.sql index d90fcd9415bf7..27ce6ac6ec9b9 100644 --- a/fe/fe-core/src/test/resources/sql/tpcds_constraints/AddUniqueKeys.sql +++ b/fe/fe-core/src/test/resources/sql/tpcds_constraints/AddUniqueKeys.sql @@ -9,7 +9,7 @@ alter table income_band set ("unique_constraints" = "ib_income_band_sk"); alter table item set ("unique_constraints" = "i_item_sk"); alter table store set ("unique_constraints" = "s_store_sk"); alter table call_center set ("unique_constraints" = "cc_call_center_sk"); -alter table customer set ("unique_constraints" = "c_customer_sk"); +alter table customer set ("unique_constraints" = "c_customer_sk;c_customer_id"); alter table web_site set ("unique_constraints" = "web_site_sk"); alter table store_returns set ("unique_constraints" = "sr_item_sk, sr_ticket_number"); alter table household_demographics set ("unique_constraints" = "hd_demo_sk");