-
Notifications
You must be signed in to change notification settings - Fork 1.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Enhancement] Support push down aggregate functions in mv rewrite #49979
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,20 +22,30 @@ | |
import com.starrocks.analysis.Expr; | ||
import com.starrocks.catalog.Function; | ||
import com.starrocks.catalog.FunctionSet; | ||
import com.starrocks.catalog.Type; | ||
import com.starrocks.common.Pair; | ||
import com.starrocks.qe.ConnectContext; | ||
import com.starrocks.sql.analyzer.FunctionAnalyzer; | ||
import com.starrocks.sql.optimizer.operator.scalar.BinaryPredicateOperator; | ||
import com.starrocks.sql.optimizer.operator.scalar.CallOperator; | ||
import com.starrocks.sql.optimizer.operator.scalar.CaseWhenOperator; | ||
import com.starrocks.sql.optimizer.operator.scalar.ColumnRefOperator; | ||
import com.starrocks.sql.optimizer.operator.scalar.ScalarOperator; | ||
import com.starrocks.sql.optimizer.operator.scalar.ScalarOperatorUtil; | ||
import com.starrocks.sql.optimizer.operator.scalar.ScalarOperatorVisitor; | ||
import com.starrocks.sql.optimizer.rewrite.BaseScalarOperatorShuttle; | ||
import com.starrocks.sql.optimizer.rewrite.ReplaceColumnRefRewriter; | ||
import com.starrocks.sql.optimizer.rewrite.ScalarOperatorRewriter; | ||
import com.starrocks.sql.optimizer.rewrite.scalar.ImplicitCastRule; | ||
import com.starrocks.sql.optimizer.rule.transformation.materialization.equivalent.EquivalentShuttleContext; | ||
import com.starrocks.sql.optimizer.rule.transformation.materialization.equivalent.IRewriteEquivalent; | ||
import com.starrocks.sql.optimizer.rule.transformation.materialization.equivalent.RewriteEquivalent; | ||
import com.starrocks.sql.parser.NodePosition; | ||
|
||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.Optional; | ||
import java.util.stream.Collectors; | ||
|
||
import static com.starrocks.sql.optimizer.rule.transformation.materialization.equivalent.RewriteEquivalent.EQUIVALENTS; | ||
|
||
|
@@ -48,6 +58,9 @@ public class EquationRewriter { | |
private AggregateFunctionRewriter aggregateFunctionRewriter; | ||
boolean underAggFunctionRewriteContext; | ||
|
||
// Replace the corresponding ColumnRef with ScalarOperator if this call operator can be pushed down. | ||
private Map<String, Map<ColumnRefOperator, CallOperator>> aggPushDownOperatorMap = Maps.newHashMap(); | ||
|
||
public EquationRewriter() { | ||
this.equationMap = ArrayListMultimap.create(); | ||
this.rewriteEquivalents = Maps.newHashMap(); | ||
|
@@ -161,6 +174,13 @@ public ScalarOperator visitCall(CallOperator call, Void context) { | |
return rewritten; | ||
} | ||
} | ||
|
||
// rewrite by pushing down aggregate | ||
rewritten = rewriteByPushDownAggregation(call); | ||
if (rewritten != null) { | ||
return rewritten; | ||
} | ||
|
||
// If count(1)/sum(1) cannot be rewritten by mv's defined equivalents, return null directly, | ||
// otherwise it may cause a wrong plan. | ||
// mv : SELECT 1, count(distinct k1) from tbl1; | ||
|
@@ -169,9 +189,167 @@ public ScalarOperator visitCall(CallOperator call, Void context) { | |
if (call.isAggregate() && call.isConstant()) { | ||
return null; | ||
} | ||
|
||
return super.visitCall(call, context); | ||
} | ||
|
||
private ScalarOperator rewriteByPushDownAggregation(CallOperator call) { | ||
if (AggregateFunctionRollupUtils.MV_REWRITE_PUSH_DOWN_FUNCTION_MAP.containsKey(call.getFnName()) && | ||
aggPushDownOperatorMap.containsKey(call.getFnName())) { | ||
Map<ColumnRefOperator, CallOperator> operatorMap = aggPushDownOperatorMap.get(call.getFnName()); | ||
ScalarOperator arg0 = call.getChild(0); | ||
if (call.getChildren().size() != 1) { | ||
return null; | ||
} | ||
// push down aggregate now only supports one child | ||
// it's fine since rewrite will clone argo | ||
ScalarOperator pdCall = pushDownAggregationToArg0(call, arg0, operatorMap); | ||
if (pdCall == null || pdCall.equals(arg0)) { | ||
return null; | ||
} | ||
ScalarOperator rewritten = pdCall.accept(this, null); | ||
// only can be used if pdCall is rewritten | ||
if (rewritten != null && !rewritten.equals(pdCall)) { | ||
shuttleContext.setRewrittenByEquivalent(true); | ||
if (FunctionSet.SUM.equalsIgnoreCase(call.getFnName())) { | ||
Function newFn = ScalarOperatorUtil.findSumFn(new Type[] {rewritten.getType()}); | ||
CallOperator newCall = new CallOperator(call.getFnName(), call.getType(), Lists.newArrayList(rewritten), | ||
newFn); | ||
ScalarOperatorRewriter scalarRewriter = new ScalarOperatorRewriter(); | ||
CallOperator result = (CallOperator) scalarRewriter.rewrite(newCall, | ||
Lists.newArrayList(new ImplicitCastRule())); | ||
return result; | ||
} else { | ||
return new CallOperator(call.getFnName(), call.getType(), Lists.newArrayList(rewritten), | ||
call.getFunction()); | ||
} | ||
} | ||
} | ||
return null; | ||
} | ||
|
||
/** | ||
* Whether the call operator can be pushed down: | ||
* - Only supports to push down min/max/sum aggregate function | ||
* - If the argument is a column ref, and operator map contains the column ref, return true | ||
* - If the argument is a call operator and contains multi-column refs(Ony IF/CaseWhen is supported), | ||
* ensure the aggregate column does not appear in the condition clause. | ||
*/ | ||
private ScalarOperator pushDownAggregationToArg0(ScalarOperator call, | ||
ScalarOperator arg0, | ||
Map<ColumnRefOperator, CallOperator> operatorMap) { | ||
if (arg0 == null || !(arg0 instanceof CallOperator)) { | ||
return null; | ||
} | ||
CallOperator arg0Call = (CallOperator) arg0; | ||
List<ColumnRefOperator> columnRefs = arg0.getColumnRefs(); | ||
if (columnRefs.size() == 1) { | ||
ColumnRefOperator child0 = columnRefs.get(0); | ||
if (!operatorMap.containsKey(child0)) { | ||
return null; | ||
} | ||
CallOperator aggFunc = operatorMap.get(child0); | ||
// strict mode, only supports col's type and agg(col)'s type are strict equal; otherwise we should | ||
// refresh the call operator's argument/result type recursively. | ||
if (!aggFunc.getType().equals(child0.getType())) { | ||
return null; | ||
} | ||
ReplaceColumnRefRewriter rewriter = new ReplaceColumnRefRewriter(operatorMap); | ||
return rewriter.rewrite(arg0); | ||
} else { | ||
// if there are many column refs in arg0, the agg column must be the same. | ||
if (arg0Call.getFnName().equalsIgnoreCase(FunctionSet.IF)) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how about other functions: ifnull, nullif, coalesce? |
||
if (arg0Call.getChildren().size() != 3) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IF functions's children num must be 3, no need check! |
||
return null; | ||
} | ||
// if the first column ref is in the operatorMap, means the agg column maybe as condition which | ||
// cannot be rewritten | ||
if (isContainAggregateColumn(arg0Call.getChild(0), operatorMap)) { | ||
return null; | ||
} | ||
ScalarOperator child1 = arg0Call.getChild(1); | ||
ScalarOperator rewritten1 = rewriteIfOnlyContianAggregateColumn(child1, operatorMap); | ||
if (rewritten1 == null) { | ||
return null; | ||
} | ||
ScalarOperator child2 = arg0Call.getChild(2); | ||
ScalarOperator rewritten2 = rewriteIfOnlyContianAggregateColumn(child2, operatorMap); | ||
if (rewritten2 == null) { | ||
return null; | ||
} | ||
ConnectContext ctx = ConnectContext.get() == null ? new ConnectContext() : ConnectContext.get(); | ||
List<ScalarOperator> args = Lists.newArrayList(arg0Call.getChild(0), rewritten1, rewritten2); | ||
Type[] argTypes = args.stream().map(x -> x.getType()).collect(Collectors.toList()).toArray(new Type[0]); | ||
Function newFn = FunctionAnalyzer.getAnalyzedBuiltInFunction(ctx, FunctionSet.IF, null, argTypes, | ||
NodePosition.ZERO); | ||
if (newFn == null) { | ||
return null; | ||
} | ||
return new CallOperator(FunctionSet.IF, newFn.getReturnType(), args, newFn); | ||
} else if (arg0Call instanceof CaseWhenOperator) { | ||
CaseWhenOperator caseClause = (CaseWhenOperator) arg0Call; | ||
// if case condition contains any agg column ref, return false | ||
ScalarOperator caseExpr = caseClause.hasCase() ? caseClause.getCaseClause() : null; | ||
if (caseExpr != null && isContainAggregateColumn(caseExpr, operatorMap)) { | ||
return null; | ||
} | ||
List<ScalarOperator> newCaseWhens = Lists.newArrayList(); | ||
for (int i = 0; i < caseClause.getWhenClauseSize(); i++) { | ||
ScalarOperator when = caseClause.getWhenClause(i); | ||
if (isContainAggregateColumn(when, operatorMap)) { | ||
return null; | ||
} | ||
newCaseWhens.add(when); | ||
|
||
// when clause or else clause can only contain aggregate column ref | ||
ScalarOperator then = caseClause.getThenClause(i); | ||
ScalarOperator newThen = rewriteIfOnlyContianAggregateColumn(then, operatorMap); | ||
if (newThen == null) { | ||
return null; | ||
} | ||
newCaseWhens.add(newThen); | ||
} | ||
ScalarOperator elseClause = caseClause.hasElse() ? caseClause.getElseClause() : null; | ||
ScalarOperator newElseClause = elseClause; | ||
if (elseClause != null) { | ||
newElseClause = rewriteIfOnlyContianAggregateColumn(elseClause, operatorMap); | ||
if (newElseClause == null) { | ||
return null; | ||
} | ||
} | ||
// NOTE: use call's result type as its input. | ||
return new CaseWhenOperator(call.getType(), caseExpr, newElseClause, newCaseWhens); | ||
} | ||
} | ||
return null; | ||
} | ||
|
||
private boolean isContainAggregateColumn(ScalarOperator child, | ||
Map<ColumnRefOperator, CallOperator> operatorMap) { | ||
return child.getColumnRefs().stream().anyMatch(x -> operatorMap.containsKey(x)); | ||
} | ||
|
||
private ScalarOperator rewriteIfOnlyContianAggregateColumn(ScalarOperator child, | ||
Map<ColumnRefOperator, CallOperator> operatorMap) { | ||
List<ColumnRefOperator> colRefs = child.getColumnRefs(); | ||
if (colRefs.size() > 1) { | ||
return null; | ||
} | ||
// constant operator | ||
if (colRefs.size() == 0) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how about const null operator? |
||
return child; | ||
} | ||
// TODO: only supports column ref now, support common expression later. | ||
if (!(child instanceof ColumnRefOperator)) { | ||
return null; | ||
|
||
} | ||
if (!operatorMap.containsKey(colRefs.get(0))) { | ||
return null; | ||
} | ||
return operatorMap.get(colRefs.get(0)); | ||
} | ||
|
||
Optional<ScalarOperator> replace(ScalarOperator scalarOperator) { | ||
if (equationMap.containsKey(scalarOperator)) { | ||
Optional<Pair<ColumnRefOperator, ScalarOperator>> mappedColumnAndExprRef = | ||
|
@@ -240,6 +418,7 @@ public void addMapping(ScalarOperator expr, ColumnRefOperator col) { | |
equationMap.put(extendedEntry.first, Pair.create(col, extendedEntry.second)); | ||
} | ||
|
||
// add into equivalents | ||
for (IRewriteEquivalent equivalent : EQUIVALENTS) { | ||
IRewriteEquivalent.RewriteEquivalentContext eqContext = equivalent.prepare(expr); | ||
if (eqContext != null) { | ||
|
@@ -248,6 +427,25 @@ public void addMapping(ScalarOperator expr, ColumnRefOperator col) { | |
.add(eq); | ||
} | ||
} | ||
|
||
// add into a push-down operator map | ||
if (expr instanceof CallOperator) { | ||
CallOperator call = (CallOperator) expr; | ||
String fnName = call.getFnName(); | ||
if (AggregateFunctionRollupUtils.REWRITE_ROLLUP_FUNCTION_MAP.containsKey(fnName) && call.getChildren().size() == 1) { | ||
ScalarOperator arg0 = call.getChild(0); | ||
// NOTE: only support push down when the argument is a column ref. | ||
// eg: | ||
// mv: sum(cast(col as tinyint)) | ||
// query: 2 * sum(col) | ||
// query cannot be used to push down, because the argument is not a column ref. | ||
if (arg0 != null && arg0.isColumnRef()) { | ||
aggPushDownOperatorMap | ||
.computeIfAbsent(fnName, x -> Maps.newHashMap()) | ||
.put((ColumnRefOperator) arg0, call); | ||
} | ||
} | ||
} | ||
} | ||
|
||
private static class EquationTransformer extends ScalarOperatorVisitor<Void, Void> { | ||
LiShuMing marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not support COUNT?