Skip to content

Commit

Permalink
Support push down aggregate functions in mv rewrite
Browse files Browse the repository at this point in the history
Signed-off-by: shuming.li <[email protected]>
  • Loading branch information
LiShuMing committed Sep 19, 2024
1 parent 20ed111 commit b6f7885
Show file tree
Hide file tree
Showing 15 changed files with 986 additions and 69 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,6 @@
import com.starrocks.common.Pair;
import com.starrocks.common.profile.Timer;
import com.starrocks.common.profile.Tracers;
import com.starrocks.common.util.concurrent.lock.LockType;
import com.starrocks.common.util.concurrent.lock.Locker;
import com.starrocks.privilege.SecurityPolicyRewriteRule;
import com.starrocks.qe.ConnectContext;
import com.starrocks.server.GlobalStateMgr;
Expand Down Expand Up @@ -1359,7 +1357,6 @@ public Table resolveTable(TableRelation tableRelation) {
}

MetaUtils.checkDbNullAndReport(db, dbName);
Locker locker = new Locker();

Table table = null;
if (tableRelation.isSyncMVQuery()) {
Expand All @@ -1370,17 +1367,12 @@ public Table resolveTable(TableRelation tableRelation) {
Table mvTable = materializedIndex.first;
Preconditions.checkState(mvTable != null);
Preconditions.checkState(mvTable instanceof OlapTable);
try {
// Add read lock to avoid concurrent problems.
locker.lockDatabase(db.getId(), LockType.READ);
OlapTable mvOlapTable = new OlapTable();
((OlapTable) mvTable).copyOnlyForQuery(mvOlapTable);
// Copy the necessary olap table meta to avoid changing original meta;
mvOlapTable.setBaseIndexId(materializedIndex.second.getIndexId());
table = mvOlapTable;
} finally {
locker.unLockDatabase(db.getId(), LockType.READ);
}
// Add read lock to avoid concurrent problems.
OlapTable mvOlapTable = new OlapTable();
((OlapTable) mvTable).copyOnlyForQuery(mvOlapTable);
// Copy the necessary olap table meta to avoid changing original meta;
mvOlapTable.setBaseIndexId(materializedIndex.second.getIndexId());
table = mvOlapTable;
}
}
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,19 @@ public class AggregateFunctionRollupUtils {
.put(FunctionSet.ARRAY_AGG_DISTINCT, FunctionSet.ARRAY_UNIQUE_AGG)
.build();

// Functions that can be pushed down to mv union rewrite.
// eg:
// sum(fn(col)) = fn(sum(col))
// min(fn(col)) = fn(min(col))
// max(fn(col)) = fn(max(col))
// if fn is a scalar function, it can be pushed down to mv union rewrite.
public static final Map<String, String> MV_REWRITE_PUSH_DOWN_FUNCTION_MAP = ImmutableMap.<String, String>builder()
// Functions and rollup functions are the same.
.put(FunctionSet.SUM, FunctionSet.SUM)
.put(FunctionSet.MAX, FunctionSet.MAX)
.put(FunctionSet.MIN, FunctionSet.MIN)
.build();

public static final Set<String> NON_CUMULATIVE_ROLLUP_FUNCTION_MAP = ImmutableSet.<String>builder()
.add(FunctionSet.MAX)
.add(FunctionSet.MIN)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,30 @@
import com.starrocks.analysis.Expr;
import com.starrocks.catalog.Function;
import com.starrocks.catalog.FunctionSet;
import com.starrocks.catalog.Type;
import com.starrocks.common.Pair;
import com.starrocks.qe.ConnectContext;
import com.starrocks.sql.analyzer.FunctionAnalyzer;
import com.starrocks.sql.optimizer.operator.scalar.BinaryPredicateOperator;
import com.starrocks.sql.optimizer.operator.scalar.CallOperator;
import com.starrocks.sql.optimizer.operator.scalar.CaseWhenOperator;
import com.starrocks.sql.optimizer.operator.scalar.ColumnRefOperator;
import com.starrocks.sql.optimizer.operator.scalar.ScalarOperator;
import com.starrocks.sql.optimizer.operator.scalar.ScalarOperatorUtil;
import com.starrocks.sql.optimizer.operator.scalar.ScalarOperatorVisitor;
import com.starrocks.sql.optimizer.rewrite.BaseScalarOperatorShuttle;
import com.starrocks.sql.optimizer.rewrite.ReplaceColumnRefRewriter;
import com.starrocks.sql.optimizer.rewrite.ScalarOperatorRewriter;
import com.starrocks.sql.optimizer.rewrite.scalar.ImplicitCastRule;
import com.starrocks.sql.optimizer.rule.transformation.materialization.equivalent.EquivalentShuttleContext;
import com.starrocks.sql.optimizer.rule.transformation.materialization.equivalent.IRewriteEquivalent;
import com.starrocks.sql.optimizer.rule.transformation.materialization.equivalent.RewriteEquivalent;
import com.starrocks.sql.parser.NodePosition;

import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;

import static com.starrocks.sql.optimizer.rule.transformation.materialization.equivalent.RewriteEquivalent.EQUIVALENTS;

Expand All @@ -48,6 +58,9 @@ public class EquationRewriter {
private AggregateFunctionRewriter aggregateFunctionRewriter;
boolean underAggFunctionRewriteContext;

// Replace the corresponding ColumnRef with ScalarOperator if this call operator can be pushed down.
private Map<String, Map<ColumnRefOperator, CallOperator>> aggPushDownOperatorMap = Maps.newHashMap();

public EquationRewriter() {
this.equationMap = ArrayListMultimap.create();
this.rewriteEquivalents = Maps.newHashMap();
Expand Down Expand Up @@ -161,6 +174,13 @@ public ScalarOperator visitCall(CallOperator call, Void context) {
return rewritten;
}
}

// rewrite by pushing down aggregate
rewritten = rewriteByPushDownAggregation(call);
if (rewritten != null) {
return rewritten;
}

// If count(1)/sum(1) cannot be rewritten by mv's defined equivalents, return null directly,
// otherwise it may cause a wrong plan.
// mv : SELECT 1, count(distinct k1) from tbl1;
Expand All @@ -169,9 +189,167 @@ public ScalarOperator visitCall(CallOperator call, Void context) {
if (call.isAggregate() && call.isConstant()) {
return null;
}

return super.visitCall(call, context);
}

private ScalarOperator rewriteByPushDownAggregation(CallOperator call) {
if (AggregateFunctionRollupUtils.MV_REWRITE_PUSH_DOWN_FUNCTION_MAP.containsKey(call.getFnName()) &&
aggPushDownOperatorMap.containsKey(call.getFnName())) {
Map<ColumnRefOperator, CallOperator> operatorMap = aggPushDownOperatorMap.get(call.getFnName());
ScalarOperator arg0 = call.getChild(0);
if (call.getChildren().size() != 1) {
return null;
}
// push down aggregate now only supports one child
// it's fine since rewrite will clone argo
ScalarOperator pdCall = pushDownAggregationToArg0(call, arg0, operatorMap);
if (pdCall == null || pdCall.equals(arg0)) {
return null;
}
ScalarOperator rewritten = pdCall.accept(this, null);
// only can be used if pdCall is rewritten
if (rewritten != null && !rewritten.equals(pdCall)) {
shuttleContext.setRewrittenByEquivalent(true);
if (FunctionSet.SUM.equalsIgnoreCase(call.getFnName())) {
Function newFn = ScalarOperatorUtil.findSumFn(new Type[] {rewritten.getType()});
CallOperator newCall = new CallOperator(call.getFnName(), call.getType(), Lists.newArrayList(rewritten),
newFn);
ScalarOperatorRewriter scalarRewriter = new ScalarOperatorRewriter();
CallOperator result = (CallOperator) scalarRewriter.rewrite(newCall,
Lists.newArrayList(new ImplicitCastRule()));
return result;
} else {
return new CallOperator(call.getFnName(), call.getType(), Lists.newArrayList(rewritten),
call.getFunction());
}
}
}
return null;
}

/**
* Whether the call operator can be pushed down:
* - Only supports to push down min/max/sum aggregate function
* - If the argument is a column ref, and operator map contains the column ref, return true
* - If the argument is a call operator and contains multi-column refs(Ony IF/CaseWhen is supported),
* ensure the aggregate column does not appear in the condition clause.
*/
private ScalarOperator pushDownAggregationToArg0(ScalarOperator call,
ScalarOperator arg0,
Map<ColumnRefOperator, CallOperator> operatorMap) {
if (arg0 == null || !(arg0 instanceof CallOperator)) {
return null;
}
CallOperator arg0Call = (CallOperator) arg0;
List<ColumnRefOperator> columnRefs = arg0.getColumnRefs();
if (columnRefs.size() == 1) {
ColumnRefOperator child0 = columnRefs.get(0);
if (!operatorMap.containsKey(child0)) {
return null;
}
CallOperator aggFunc = operatorMap.get(child0);
// strict mode, only supports col's type and agg(col)'s type are strict equal; otherwise we should
// refresh the call operator's argument/result type recursively.
if (!aggFunc.getType().equals(child0.getType())) {
return null;
}
ReplaceColumnRefRewriter rewriter = new ReplaceColumnRefRewriter(operatorMap);
return rewriter.rewrite(arg0);
} else {
// if there are many column refs in arg0, the agg column must be the same.
if (arg0Call.getFnName().equalsIgnoreCase(FunctionSet.IF)) {
if (arg0Call.getChildren().size() != 3) {
return null;
}
// if the first column ref is in the operatorMap, means the agg column maybe as condition which
// cannot be rewritten
if (isContainAggregateColumn(arg0Call.getChild(0), operatorMap)) {
return null;
}
ScalarOperator child1 = arg0Call.getChild(1);
ScalarOperator rewritten1 = rewriteIfOnlyContianAggregateColumn(child1, operatorMap);
if (rewritten1 == null) {
return null;
}
ScalarOperator child2 = arg0Call.getChild(2);
ScalarOperator rewritten2 = rewriteIfOnlyContianAggregateColumn(child2, operatorMap);
if (rewritten2 == null) {
return null;
}
ConnectContext ctx = ConnectContext.get() == null ? new ConnectContext() : ConnectContext.get();
List<ScalarOperator> args = Lists.newArrayList(arg0Call.getChild(0), rewritten1, rewritten2);
Type[] argTypes = args.stream().map(x -> x.getType()).collect(Collectors.toList()).toArray(new Type[0]);
Function newFn = FunctionAnalyzer.getAnalyzedBuiltInFunction(ctx, FunctionSet.IF, null, argTypes,
NodePosition.ZERO);
if (newFn == null) {
return null;
}
return new CallOperator(FunctionSet.IF, newFn.getReturnType(), args, newFn);
} else if (arg0Call instanceof CaseWhenOperator) {
CaseWhenOperator caseClause = (CaseWhenOperator) arg0Call;
// if case condition contains any agg column ref, return false
ScalarOperator caseExpr = caseClause.hasCase() ? caseClause.getCaseClause() : null;
if (caseExpr != null && isContainAggregateColumn(caseExpr, operatorMap)) {
return null;
}
List<ScalarOperator> newCaseWhens = Lists.newArrayList();
for (int i = 0; i < caseClause.getWhenClauseSize(); i++) {
ScalarOperator when = caseClause.getWhenClause(i);
if (isContainAggregateColumn(when, operatorMap)) {
return null;
}
newCaseWhens.add(when);

// when clause or else clause can only contain aggregate column ref
ScalarOperator then = caseClause.getThenClause(i);
ScalarOperator newThen = rewriteIfOnlyContianAggregateColumn(then, operatorMap);
if (newThen == null) {
return null;
}
newCaseWhens.add(newThen);
}
ScalarOperator elseClause = caseClause.hasElse() ? caseClause.getElseClause() : null;
ScalarOperator newElseClause = elseClause;
if (elseClause != null) {
newElseClause = rewriteIfOnlyContianAggregateColumn(elseClause, operatorMap);
if (newElseClause == null) {
return null;
}
}
// NOTE: use call's result type as its input.
return new CaseWhenOperator(call.getType(), caseExpr, newElseClause, newCaseWhens);
}
}
return null;
}

private boolean isContainAggregateColumn(ScalarOperator child,
Map<ColumnRefOperator, CallOperator> operatorMap) {
return child.getColumnRefs().stream().anyMatch(x -> operatorMap.containsKey(x));
}

private ScalarOperator rewriteIfOnlyContianAggregateColumn(ScalarOperator child,
Map<ColumnRefOperator, CallOperator> operatorMap) {
List<ColumnRefOperator> colRefs = child.getColumnRefs();
if (colRefs.size() > 1) {
return null;
}
// constant operator
if (colRefs.size() == 0) {
return child;
}
// TODO: only supports column ref now, support common expression later.
if (!(child instanceof ColumnRefOperator)) {
return null;

}
if (!operatorMap.containsKey(colRefs.get(0))) {
return null;
}
return operatorMap.get(colRefs.get(0));
}

Optional<ScalarOperator> replace(ScalarOperator scalarOperator) {
if (equationMap.containsKey(scalarOperator)) {
Optional<Pair<ColumnRefOperator, ScalarOperator>> mappedColumnAndExprRef =
Expand Down Expand Up @@ -240,6 +418,7 @@ public void addMapping(ScalarOperator expr, ColumnRefOperator col) {
equationMap.put(extendedEntry.first, Pair.create(col, extendedEntry.second));
}

// add into equivalents
for (IRewriteEquivalent equivalent : EQUIVALENTS) {
IRewriteEquivalent.RewriteEquivalentContext eqContext = equivalent.prepare(expr);
if (eqContext != null) {
Expand All @@ -248,6 +427,25 @@ public void addMapping(ScalarOperator expr, ColumnRefOperator col) {
.add(eq);
}
}

// add into a push-down operator map
if (expr instanceof CallOperator) {
CallOperator call = (CallOperator) expr;
String fnName = call.getFnName();
if (AggregateFunctionRollupUtils.REWRITE_ROLLUP_FUNCTION_MAP.containsKey(fnName) && call.getChildren().size() == 1) {
ScalarOperator arg0 = call.getChild(0);
// NOTE: only support push down when the argument is a column ref.
// eg:
// mv: sum(cast(col as tinyint))
// query: 2 * sum(col)
// query cannot be used to push down, because the argument is not a column ref.
if (arg0 != null && arg0.isColumnRef()) {
aggPushDownOperatorMap
.computeIfAbsent(fnName, x -> Maps.newHashMap())
.put((ColumnRefOperator) arg0, call);
}
}
}
}

private static class EquationTransformer extends ScalarOperatorVisitor<Void, Void> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ public void testAggPushDown_RollupFunctions_QueryMV_NoMatch() {
// query and mv's agg function is not the same, cannot rewrite.
String mvAggArg = "LO_REVENUE";
String queryAggArg = "(LO_REVENUE + 1) * 2";
Set<String> supportedPushDownAggregateFunctions = Sets.newHashSet("min", "max");
for (Map.Entry<String, String> e : SAFE_REWRITE_ROLLUP_FUNCTION_MAP.entrySet()) {
String funcName = e.getKey();
String mvAggFunc = getAggFunction(funcName, mvAggArg);
Expand All @@ -162,7 +163,11 @@ public void testAggPushDown_RollupFunctions_QueryMV_NoMatch() {
String query = String.format("select LO_ORDERDATE, %s as revenue_sum\n" +
" from lineorder l join dates d on l.LO_ORDERDATE = d.d_datekey\n" +
" group by LO_ORDERDATE", queryAggFunc);
sql(query).nonMatch("mv0");
if (supportedPushDownAggregateFunctions.contains(funcName)) {
sql(query).match("mv0");
} else {
sql(query).nonMatch("mv0");
}
});
}
}
Expand Down
Loading

0 comments on commit b6f7885

Please sign in to comment.