Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement] Support push down aggregate functions in mv rewrite #49979

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,6 @@
import com.starrocks.common.Pair;
import com.starrocks.common.profile.Timer;
import com.starrocks.common.profile.Tracers;
import com.starrocks.common.util.concurrent.lock.LockType;
import com.starrocks.common.util.concurrent.lock.Locker;
import com.starrocks.privilege.SecurityPolicyRewriteRule;
import com.starrocks.qe.ConnectContext;
import com.starrocks.server.GlobalStateMgr;
Expand Down Expand Up @@ -1359,7 +1357,6 @@ public Table resolveTable(TableRelation tableRelation) {
}

MetaUtils.checkDbNullAndReport(db, dbName);
Locker locker = new Locker();

Table table = null;
if (tableRelation.isSyncMVQuery()) {
Expand All @@ -1370,17 +1367,12 @@ public Table resolveTable(TableRelation tableRelation) {
Table mvTable = materializedIndex.first;
Preconditions.checkState(mvTable != null);
Preconditions.checkState(mvTable instanceof OlapTable);
try {
// Add read lock to avoid concurrent problems.
locker.lockDatabase(db.getId(), LockType.READ);
OlapTable mvOlapTable = new OlapTable();
((OlapTable) mvTable).copyOnlyForQuery(mvOlapTable);
// Copy the necessary olap table meta to avoid changing original meta;
mvOlapTable.setBaseIndexId(materializedIndex.second.getIndexId());
table = mvOlapTable;
} finally {
locker.unLockDatabase(db.getId(), LockType.READ);
}
// Add read lock to avoid concurrent problems.
OlapTable mvOlapTable = new OlapTable();
((OlapTable) mvTable).copyOnlyForQuery(mvOlapTable);
// Copy the necessary olap table meta to avoid changing original meta;
mvOlapTable.setBaseIndexId(materializedIndex.second.getIndexId());
table = mvOlapTable;
}
}
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,19 @@ public class AggregateFunctionRollupUtils {
.put(FunctionSet.ARRAY_AGG_DISTINCT, FunctionSet.ARRAY_UNIQUE_AGG)
.build();

// Functions that can be pushed down to mv union rewrite.
// eg:
// sum(fn(col)) = fn(sum(col))
// min(fn(col)) = fn(min(col))
// max(fn(col)) = fn(max(col))
// if fn is a scalar function, it can be pushed down to mv union rewrite.
public static final Map<String, String> MV_REWRITE_PUSH_DOWN_FUNCTION_MAP = ImmutableMap.<String, String>builder()
// Functions and rollup functions are the same.
.put(FunctionSet.SUM, FunctionSet.SUM)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not support COUNT?

.put(FunctionSet.MAX, FunctionSet.MAX)
.put(FunctionSet.MIN, FunctionSet.MIN)
.build();

public static final Set<String> NON_CUMULATIVE_ROLLUP_FUNCTION_MAP = ImmutableSet.<String>builder()
.add(FunctionSet.MAX)
.add(FunctionSet.MIN)
LiShuMing marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,30 @@
import com.starrocks.analysis.Expr;
import com.starrocks.catalog.Function;
import com.starrocks.catalog.FunctionSet;
import com.starrocks.catalog.Type;
import com.starrocks.common.Pair;
import com.starrocks.qe.ConnectContext;
import com.starrocks.sql.analyzer.FunctionAnalyzer;
import com.starrocks.sql.optimizer.operator.scalar.BinaryPredicateOperator;
import com.starrocks.sql.optimizer.operator.scalar.CallOperator;
import com.starrocks.sql.optimizer.operator.scalar.CaseWhenOperator;
import com.starrocks.sql.optimizer.operator.scalar.ColumnRefOperator;
import com.starrocks.sql.optimizer.operator.scalar.ScalarOperator;
import com.starrocks.sql.optimizer.operator.scalar.ScalarOperatorUtil;
import com.starrocks.sql.optimizer.operator.scalar.ScalarOperatorVisitor;
import com.starrocks.sql.optimizer.rewrite.BaseScalarOperatorShuttle;
import com.starrocks.sql.optimizer.rewrite.ReplaceColumnRefRewriter;
import com.starrocks.sql.optimizer.rewrite.ScalarOperatorRewriter;
import com.starrocks.sql.optimizer.rewrite.scalar.ImplicitCastRule;
import com.starrocks.sql.optimizer.rule.transformation.materialization.equivalent.EquivalentShuttleContext;
import com.starrocks.sql.optimizer.rule.transformation.materialization.equivalent.IRewriteEquivalent;
import com.starrocks.sql.optimizer.rule.transformation.materialization.equivalent.RewriteEquivalent;
import com.starrocks.sql.parser.NodePosition;

import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;

import static com.starrocks.sql.optimizer.rule.transformation.materialization.equivalent.RewriteEquivalent.EQUIVALENTS;

Expand All @@ -48,6 +58,9 @@ public class EquationRewriter {
private AggregateFunctionRewriter aggregateFunctionRewriter;
boolean underAggFunctionRewriteContext;

// Replace the corresponding ColumnRef with ScalarOperator if this call operator can be pushed down.
private Map<String, Map<ColumnRefOperator, CallOperator>> aggPushDownOperatorMap = Maps.newHashMap();

public EquationRewriter() {
this.equationMap = ArrayListMultimap.create();
this.rewriteEquivalents = Maps.newHashMap();
Expand Down Expand Up @@ -161,6 +174,13 @@ public ScalarOperator visitCall(CallOperator call, Void context) {
return rewritten;
}
}

// rewrite by pushing down aggregate
rewritten = rewriteByPushDownAggregation(call);
if (rewritten != null) {
return rewritten;
}

// If count(1)/sum(1) cannot be rewritten by mv's defined equivalents, return null directly,
// otherwise it may cause a wrong plan.
// mv : SELECT 1, count(distinct k1) from tbl1;
Expand All @@ -169,9 +189,167 @@ public ScalarOperator visitCall(CallOperator call, Void context) {
if (call.isAggregate() && call.isConstant()) {
return null;
}

return super.visitCall(call, context);
}

private ScalarOperator rewriteByPushDownAggregation(CallOperator call) {
if (AggregateFunctionRollupUtils.MV_REWRITE_PUSH_DOWN_FUNCTION_MAP.containsKey(call.getFnName()) &&
aggPushDownOperatorMap.containsKey(call.getFnName())) {
Map<ColumnRefOperator, CallOperator> operatorMap = aggPushDownOperatorMap.get(call.getFnName());
ScalarOperator arg0 = call.getChild(0);
if (call.getChildren().size() != 1) {
return null;
}
// push down aggregate now only supports one child
// it's fine since rewrite will clone argo
ScalarOperator pdCall = pushDownAggregationToArg0(call, arg0, operatorMap);
if (pdCall == null || pdCall.equals(arg0)) {
return null;
}
ScalarOperator rewritten = pdCall.accept(this, null);
// only can be used if pdCall is rewritten
if (rewritten != null && !rewritten.equals(pdCall)) {
shuttleContext.setRewrittenByEquivalent(true);
if (FunctionSet.SUM.equalsIgnoreCase(call.getFnName())) {
Function newFn = ScalarOperatorUtil.findSumFn(new Type[] {rewritten.getType()});
CallOperator newCall = new CallOperator(call.getFnName(), call.getType(), Lists.newArrayList(rewritten),
newFn);
ScalarOperatorRewriter scalarRewriter = new ScalarOperatorRewriter();
CallOperator result = (CallOperator) scalarRewriter.rewrite(newCall,
Lists.newArrayList(new ImplicitCastRule()));
return result;
} else {
return new CallOperator(call.getFnName(), call.getType(), Lists.newArrayList(rewritten),
call.getFunction());
}
}
}
return null;
}

/**
* Whether the call operator can be pushed down:
* - Only supports to push down min/max/sum aggregate function
* - If the argument is a column ref, and operator map contains the column ref, return true
* - If the argument is a call operator and contains multi-column refs(Ony IF/CaseWhen is supported),
* ensure the aggregate column does not appear in the condition clause.
*/
private ScalarOperator pushDownAggregationToArg0(ScalarOperator call,
ScalarOperator arg0,
Map<ColumnRefOperator, CallOperator> operatorMap) {
if (arg0 == null || !(arg0 instanceof CallOperator)) {
return null;
}
CallOperator arg0Call = (CallOperator) arg0;
List<ColumnRefOperator> columnRefs = arg0.getColumnRefs();
if (columnRefs.size() == 1) {
ColumnRefOperator child0 = columnRefs.get(0);
if (!operatorMap.containsKey(child0)) {
return null;
}
CallOperator aggFunc = operatorMap.get(child0);
// strict mode, only supports col's type and agg(col)'s type are strict equal; otherwise we should
// refresh the call operator's argument/result type recursively.
if (!aggFunc.getType().equals(child0.getType())) {
return null;
}
ReplaceColumnRefRewriter rewriter = new ReplaceColumnRefRewriter(operatorMap);
return rewriter.rewrite(arg0);
} else {
// if there are many column refs in arg0, the agg column must be the same.
if (arg0Call.getFnName().equalsIgnoreCase(FunctionSet.IF)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about other functions: ifnull, nullif, coalesce?

if (arg0Call.getChildren().size() != 3) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IF functions's children num must be 3, no need check!

return null;
}
// if the first column ref is in the operatorMap, means the agg column maybe as condition which
// cannot be rewritten
if (isContainAggregateColumn(arg0Call.getChild(0), operatorMap)) {
return null;
}
ScalarOperator child1 = arg0Call.getChild(1);
ScalarOperator rewritten1 = rewriteIfOnlyContianAggregateColumn(child1, operatorMap);
if (rewritten1 == null) {
return null;
}
ScalarOperator child2 = arg0Call.getChild(2);
ScalarOperator rewritten2 = rewriteIfOnlyContianAggregateColumn(child2, operatorMap);
if (rewritten2 == null) {
return null;
}
ConnectContext ctx = ConnectContext.get() == null ? new ConnectContext() : ConnectContext.get();
List<ScalarOperator> args = Lists.newArrayList(arg0Call.getChild(0), rewritten1, rewritten2);
Type[] argTypes = args.stream().map(x -> x.getType()).collect(Collectors.toList()).toArray(new Type[0]);
Function newFn = FunctionAnalyzer.getAnalyzedBuiltInFunction(ctx, FunctionSet.IF, null, argTypes,
NodePosition.ZERO);
if (newFn == null) {
return null;
}
return new CallOperator(FunctionSet.IF, newFn.getReturnType(), args, newFn);
} else if (arg0Call instanceof CaseWhenOperator) {
CaseWhenOperator caseClause = (CaseWhenOperator) arg0Call;
// if case condition contains any agg column ref, return false
ScalarOperator caseExpr = caseClause.hasCase() ? caseClause.getCaseClause() : null;
if (caseExpr != null && isContainAggregateColumn(caseExpr, operatorMap)) {
return null;
}
List<ScalarOperator> newCaseWhens = Lists.newArrayList();
for (int i = 0; i < caseClause.getWhenClauseSize(); i++) {
ScalarOperator when = caseClause.getWhenClause(i);
if (isContainAggregateColumn(when, operatorMap)) {
return null;
}
newCaseWhens.add(when);

// when clause or else clause can only contain aggregate column ref
ScalarOperator then = caseClause.getThenClause(i);
ScalarOperator newThen = rewriteIfOnlyContianAggregateColumn(then, operatorMap);
if (newThen == null) {
return null;
}
newCaseWhens.add(newThen);
}
ScalarOperator elseClause = caseClause.hasElse() ? caseClause.getElseClause() : null;
ScalarOperator newElseClause = elseClause;
if (elseClause != null) {
newElseClause = rewriteIfOnlyContianAggregateColumn(elseClause, operatorMap);
if (newElseClause == null) {
return null;
}
}
// NOTE: use call's result type as its input.
return new CaseWhenOperator(call.getType(), caseExpr, newElseClause, newCaseWhens);
}
}
return null;
}

private boolean isContainAggregateColumn(ScalarOperator child,
Map<ColumnRefOperator, CallOperator> operatorMap) {
return child.getColumnRefs().stream().anyMatch(x -> operatorMap.containsKey(x));
}

private ScalarOperator rewriteIfOnlyContianAggregateColumn(ScalarOperator child,
Map<ColumnRefOperator, CallOperator> operatorMap) {
List<ColumnRefOperator> colRefs = child.getColumnRefs();
if (colRefs.size() > 1) {
return null;
}
// constant operator
if (colRefs.size() == 0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about const null operator?
sum(null) seems should be rewritten into 0;
max(null) seems should be rewritten into null;

return child;
}
// TODO: only supports column ref now, support common expression later.
if (!(child instanceof ColumnRefOperator)) {
return null;

}
if (!operatorMap.containsKey(colRefs.get(0))) {
return null;
}
return operatorMap.get(colRefs.get(0));
}

Optional<ScalarOperator> replace(ScalarOperator scalarOperator) {
if (equationMap.containsKey(scalarOperator)) {
Optional<Pair<ColumnRefOperator, ScalarOperator>> mappedColumnAndExprRef =
Expand Down Expand Up @@ -240,6 +418,7 @@ public void addMapping(ScalarOperator expr, ColumnRefOperator col) {
equationMap.put(extendedEntry.first, Pair.create(col, extendedEntry.second));
}

// add into equivalents
for (IRewriteEquivalent equivalent : EQUIVALENTS) {
IRewriteEquivalent.RewriteEquivalentContext eqContext = equivalent.prepare(expr);
if (eqContext != null) {
Expand All @@ -248,6 +427,25 @@ public void addMapping(ScalarOperator expr, ColumnRefOperator col) {
.add(eq);
}
}

// add into a push-down operator map
if (expr instanceof CallOperator) {
CallOperator call = (CallOperator) expr;
String fnName = call.getFnName();
if (AggregateFunctionRollupUtils.REWRITE_ROLLUP_FUNCTION_MAP.containsKey(fnName) && call.getChildren().size() == 1) {
ScalarOperator arg0 = call.getChild(0);
// NOTE: only support push down when the argument is a column ref.
// eg:
// mv: sum(cast(col as tinyint))
// query: 2 * sum(col)
// query cannot be used to push down, because the argument is not a column ref.
if (arg0 != null && arg0.isColumnRef()) {
aggPushDownOperatorMap
.computeIfAbsent(fnName, x -> Maps.newHashMap())
.put((ColumnRefOperator) arg0, call);
}
}
}
}

private static class EquationTransformer extends ScalarOperatorVisitor<Void, Void> {
LiShuMing marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ public void testAggPushDown_RollupFunctions_QueryMV_NoMatch() {
// query and mv's agg function is not the same, cannot rewrite.
String mvAggArg = "LO_REVENUE";
String queryAggArg = "(LO_REVENUE + 1) * 2";
Set<String> supportedPushDownAggregateFunctions = Sets.newHashSet("min", "max");
for (Map.Entry<String, String> e : SAFE_REWRITE_ROLLUP_FUNCTION_MAP.entrySet()) {
String funcName = e.getKey();
String mvAggFunc = getAggFunction(funcName, mvAggArg);
Expand All @@ -162,7 +163,11 @@ public void testAggPushDown_RollupFunctions_QueryMV_NoMatch() {
String query = String.format("select LO_ORDERDATE, %s as revenue_sum\n" +
" from lineorder l join dates d on l.LO_ORDERDATE = d.d_datekey\n" +
" group by LO_ORDERDATE", queryAggFunc);
sql(query).nonMatch("mv0");
if (supportedPushDownAggregateFunctions.contains(funcName)) {
sql(query).match("mv0");
} else {
sql(query).nonMatch("mv0");
}
});
}
}
Expand Down
Loading
Loading