Skip to content

Commit

Permalink
Support push down aggregate functions in mv rewrite
Browse files Browse the repository at this point in the history
Signed-off-by: shuming.li <[email protected]>
  • Loading branch information
LiShuMing committed Aug 22, 2024
1 parent 89a6b77 commit e8ef9ad
Show file tree
Hide file tree
Showing 7 changed files with 609 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,19 @@ public class AggregateFunctionRollupUtils {
.put(FunctionSet.ARRAY_AGG_DISTINCT, FunctionSet.ARRAY_UNIQUE_AGG)
.build();

// Functions that can be pushed down to mv union rewrite.
// eg:
// sum(fn(col)) = fn(sum(col))
// min(fn(col)) = fn(min(col))
// max(fn(col)) = fn(max(col))
// if fn is a scalar function, it can be pushed down to mv union rewrite.
public static final Map<String, String> MV_REWRITE_PUSH_DOWN_FUNCTION_MAP = ImmutableMap.<String, String>builder()
// Functions and rollup functions are the same.
.put(FunctionSet.SUM, FunctionSet.SUM)
.put(FunctionSet.MAX, FunctionSet.MAX)
.put(FunctionSet.MIN, FunctionSet.MIN)
.build();

public static final Set<String> NON_CUMULATIVE_ROLLUP_FUNCTION_MAP = ImmutableSet.<String>builder()
.add(FunctionSet.MAX)
.add(FunctionSet.MIN)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@
import com.starrocks.common.Pair;
import com.starrocks.sql.optimizer.operator.scalar.BinaryPredicateOperator;
import com.starrocks.sql.optimizer.operator.scalar.CallOperator;
import com.starrocks.sql.optimizer.operator.scalar.CaseWhenOperator;
import com.starrocks.sql.optimizer.operator.scalar.ColumnRefOperator;
import com.starrocks.sql.optimizer.operator.scalar.ScalarOperator;
import com.starrocks.sql.optimizer.operator.scalar.ScalarOperatorVisitor;
import com.starrocks.sql.optimizer.rewrite.BaseScalarOperatorShuttle;
import com.starrocks.sql.optimizer.rewrite.ReplaceColumnRefRewriter;
import com.starrocks.sql.optimizer.rule.transformation.materialization.equivalent.EquivalentShuttleContext;
import com.starrocks.sql.optimizer.rule.transformation.materialization.equivalent.IRewriteEquivalent;
import com.starrocks.sql.optimizer.rule.transformation.materialization.equivalent.RewriteEquivalent;
Expand All @@ -48,6 +50,9 @@ public class EquationRewriter {
private AggregateFunctionRewriter aggregateFunctionRewriter;
boolean underAggFunctionRewriteContext;

// Replace the corresponding ColumnRef with ScalarOperator if this call operator can be pushed down.
private Map<String, Map<ColumnRefOperator, CallOperator>> pushDownAggOperatorMap = Maps.newHashMap();

public EquationRewriter() {
this.equationMap = ArrayListMultimap.create();
this.rewriteEquivalents = Maps.newHashMap();
Expand Down Expand Up @@ -162,9 +167,88 @@ public ScalarOperator visitCall(CallOperator call, Void context) {
}
}

// rewrite by pushing down aggregate
rewritten = rewriteByPushDownAggregation(call);
if (rewritten != null) {
return rewritten;
}

return super.visitCall(call, context);
}

private ScalarOperator rewriteByPushDownAggregation(CallOperator call) {
if (AggregateFunctionRollupUtils.MV_REWRITE_PUSH_DOWN_FUNCTION_MAP.containsKey(call.getFnName()) &&
pushDownAggOperatorMap.containsKey(call.getFnName())) {
Map<ColumnRefOperator, CallOperator> operatorMap = pushDownAggOperatorMap.get(call.getFnName());
ScalarOperator arg0 = call.getChild(0);
if (call.getChildren().size() != 1) {
return null;
}
// push down aggregate now only supports one child
if (!canPushDownCallOperator(arg0, operatorMap)) {
return null;
}

ReplaceColumnRefRewriter rewriter = new ReplaceColumnRefRewriter(operatorMap);
ScalarOperator pdCall = rewriter.rewrite(arg0);
if (pdCall == null || pdCall.equals(arg0)) {
return null;
}
ScalarOperator rewritten = pdCall.accept(this, null);
if (rewritten != null) {
shuttleContext.setRewrittenByEquivalent(true);
return new CallOperator(call.getFnName(), call.getType(), Lists.newArrayList(rewritten),
call.getFunction());
}
}
return null;
}

private boolean canPushDownCallOperator(ScalarOperator arg0,
Map<ColumnRefOperator, CallOperator> operatorMap) {
if (arg0 == null || !(arg0 instanceof CallOperator)) {
return false;
}
CallOperator call = (CallOperator) arg0;
List<ColumnRefOperator> columnRefs = arg0.getColumnRefs();
if (columnRefs.size() == 1) {
return operatorMap.containsKey(columnRefs.get(0));
} else {
// if there are many column refs in arg0, the agg column must be the same.
if (call.getFnName().equalsIgnoreCase(FunctionSet.IF)) {
if (columnRefs.size() != 3) {
return false;
}
// if the first column ref is in the operatorMap, means the agg column maybe as condition which
// cannot be rewritten
if (operatorMap.containsKey(columnRefs.get(0))) {
return false;
}
return true;
} else if (call instanceof CaseWhenOperator) {
CaseWhenOperator caseWhen = (CaseWhenOperator) call;
// if case condition contains any agg column ref, return false
if (caseWhen.getCaseClause() != null) {
List<ColumnRefOperator> caseColumnRefs = caseWhen.getCaseClause().getColumnRefs();
if (caseColumnRefs.stream().anyMatch(x -> operatorMap.containsKey(x))) {
return false;
}
}
// if case condition contains any agg column ref, return false
for (int i = 0; i < caseWhen.getWhenClauseSize(); i++) {
ScalarOperator when = caseWhen.getWhenClause(i);
if (when != null) {
List<ColumnRefOperator> whenColumnRefs = when.getColumnRefs();
if (whenColumnRefs.stream().anyMatch(x -> operatorMap.containsKey(x))) {
return false;
}
}
}
}
return false;
}
}

Optional<ScalarOperator> replace(ScalarOperator scalarOperator) {
if (equationMap.containsKey(scalarOperator)) {
Optional<Pair<ColumnRefOperator, ScalarOperator>> mappedColumnAndExprRef =
Expand Down Expand Up @@ -233,6 +317,7 @@ public void addMapping(ScalarOperator expr, ColumnRefOperator col) {
equationMap.put(extendedEntry.first, Pair.create(col, extendedEntry.second));
}

// add into equivalents
for (IRewriteEquivalent equivalent : EQUIVALENTS) {
IRewriteEquivalent.RewriteEquivalentContext eqContext = equivalent.prepare(expr);
if (eqContext != null) {
Expand All @@ -241,6 +326,25 @@ public void addMapping(ScalarOperator expr, ColumnRefOperator col) {
.add(eq);
}
}

// add into a push-down operator map
if (expr instanceof CallOperator) {
CallOperator call = (CallOperator) expr;
String fnName = call.getFnName();
if (AggregateFunctionRollupUtils.REWRITE_ROLLUP_FUNCTION_MAP.containsKey(fnName) && call.getChildren().size() == 1) {
ScalarOperator arg0 = call.getChild(0);
// NOTE: only support push down when the argument is a column ref.
// eg:
// mv: sum(cast(col as tinyint))
// query: 2 * sum(col)
// query cannot be used to push down, because the argument is not a column ref.
if (arg0 != null && arg0.isColumnRef()) {
pushDownAggOperatorMap
.computeIfAbsent(fnName, x -> Maps.newHashMap())
.put((ColumnRefOperator) arg0, call);
}
}
}
}

private static class EquationTransformer extends ScalarOperatorVisitor<Void, Void> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import java.util.Map;
import java.util.Set;

import static com.starrocks.sql.optimizer.rule.transformation.materialization.AggregateFunctionRollupUtils.MV_REWRITE_PUSH_DOWN_FUNCTION_MAP;
import static com.starrocks.sql.optimizer.rule.transformation.materialization.AggregateFunctionRollupUtils.REWRITE_ROLLUP_FUNCTION_MAP;
import static com.starrocks.sql.optimizer.rule.transformation.materialization.AggregateFunctionRollupUtils.SAFE_REWRITE_ROLLUP_FUNCTION_MAP;

Expand Down Expand Up @@ -162,7 +163,11 @@ public void testAggPushDown_RollupFunctions_QueryMV_NoMatch() {
String query = String.format("select LO_ORDERDATE, %s as revenue_sum\n" +
" from lineorder l join dates d on l.LO_ORDERDATE = d.d_datekey\n" +
" group by LO_ORDERDATE", queryAggFunc);
sql(query).nonMatch("mv0");
if (MV_REWRITE_PUSH_DOWN_FUNCTION_MAP.containsKey(funcName)) {
sql(query).contains("mv0");
} else {
sql(query).nonMatch("mv0");
}
});
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,132 @@ public void testNullableTestCase1() throws Exception {
.match("join_null_mv_2");
}

@Test
public void testRewriteWithPushDownEquivalent1() throws Exception {
starRocksAssert.withTable("CREATE TABLE `tbl1` (\n" +
" `k1` date,\n" +
" `k2` decimal64(18, 2),\n" +
" `k3` varchar(255),\n" +
" `v1` bigint \n" +
") ENGINE=OLAP \n" +
"DUPLICATE KEY(`k1`, `k2`, `k3`)\n" +
"DISTRIBUTED BY RANDOM\n" +
"PROPERTIES (\n" +
"\"replication_num\" = \"1\"\n" +
");");
starRocksAssert.withMaterializedView("CREATE MATERIALIZED VIEW `mv1` \n" +
"DISTRIBUTED BY RANDOM\n" +
"REFRESH ASYNC\n" +
"PROPERTIES (\n" +
"\"replication_num\" = \"1\"\n" +
")\n" +
"AS SELECT k1, k2, k3, sum(v1) from tbl1 group by k1, k2, k3");
{
String sql = "select t1.k1, " +
" sum(case when t1.k1 between date_add('2024-07-20', interval -1 month) and " +
" date_add('2024-07-20', interval 1 month) then t1.v1 else 0 end) " +
" from tbl1 t1 group by t1.k1";
sql(sql).contains("mv1")
.contains(" 1:AGGREGATE (update serialize)\n" +
" | STREAMING\n" +
" | output: sum(if((7: k1 >= '2024-06-20') AND (7: k1 <= '2024-08-20'), 10: sum(v1), 0))\n" +
" | group by: 7: k1\n" +
" | \n" +
" 0:OlapScanNode\n" +
" TABLE: mv1");
}
{
String sql = "select t1.k1, " +
" 2 * sum(case when t1.k1 between date_add('2024-07-20', interval -1 month) and " +
" date_add('2024-07-20', interval 1 month) then t1.v1 else 0 end) " +
" from tbl1 t1 group by t1.k1";
sql(sql).contains("mv1")
.contains(" 1:AGGREGATE (update serialize)\n" +
" | STREAMING\n" +
" | output: sum(if((7: k1 >= '2024-06-20') AND (7: k1 <= '2024-08-20'), 10: sum(v1), 0))\n" +
" | group by: 7: k1\n" +
" | \n" +
" 0:OlapScanNode\n" +
" TABLE: mv1")
.contains(" 4:Project\n" +
" | <slot 1> : 8: k1\n" +
" | <slot 7> : 2 * 12: sum");
}
starRocksAssert.dropMaterializedView("mv1");
starRocksAssert.dropTable("tbl1");
}

@Test
public void testRewriteWithPushDownEquivalent2() throws Exception {
starRocksAssert.withTable("CREATE TABLE `tbl1` (\n" +
" `k1` date,\n" +
" `k2` decimal64(18, 2),\n" +
" `k3` varchar(255),\n" +
" `v1` bigint \n" +
") ENGINE=OLAP \n" +
"DUPLICATE KEY(`k1`, `k2`, `k3`)\n" +
"DISTRIBUTED BY RANDOM\n" +
"PROPERTIES (\n" +
"\"replication_num\" = \"1\"\n" +
");");
starRocksAssert.withMaterializedView("CREATE MATERIALIZED VIEW `mv1` \n" +
"DISTRIBUTED BY RANDOM\n" +
"REFRESH ASYNC\n" +
"PROPERTIES (\n" +
"\"replication_num\" = \"1\"\n" +
")\n" +
"AS SELECT k1, k2, k3, sum(2 * v1) from tbl1 group by k1, k2, k3");
{
// mv's sum doesn't contain column ref which cannot be used for rewrite
String sql = "select t1.k1, " +
" sum(case when t1.k1 between date_add('2024-07-20', interval -1 month) and " +
" date_add('2024-07-20', interval 1 month) then 2 * t1.v1 else 0 end) " +
" from tbl1 t1 group by t1.k1";
sql(sql).notContain("mv1");
}
starRocksAssert.dropMaterializedView("mv1");
starRocksAssert.dropTable("tbl1");
}

@Test
public void testRewriteWithPushDownEquivalent3() throws Exception {
starRocksAssert.withTable("CREATE TABLE `tbl1` (\n" +
" `k1` date,\n" +
" `k2` decimal64(18, 2),\n" +
" `k3` varchar(255),\n" +
" `v1` bigint \n" +
") ENGINE=OLAP \n" +
"DUPLICATE KEY(`k1`, `k2`, `k3`)\n" +
"DISTRIBUTED BY RANDOM\n" +
"PROPERTIES (\n" +
"\"replication_num\" = \"1\"\n" +
");");
starRocksAssert.withMaterializedView("CREATE MATERIALIZED VIEW `mv1` \n" +
"DISTRIBUTED BY RANDOM\n" +
"REFRESH ASYNC\n" +
"PROPERTIES (\n" +
"\"replication_num\" = \"1\"\n" +
")\n" +
"AS SELECT k1, k2, k3, sum(v1), max(v1) from tbl1 group by k1, k2, k3");
{
String sql = "select t1.k1, 2 * min(v1 + 1) from tbl1 t1 group by t1.k1";
sql(sql).notContain("mv1");
}
{
String sql = "select t1.k1, 2 * sum(case when t1.v1 > 10 then t1.v1 else 0 end) " +
" from tbl1 t1 group by t1.k1";
sql(sql).notContain("mv1");
}
{
String sql = "select t1.k1, " +
" 2 * sum(v1 + cast(k3 as int)) " +
" from tbl1 t1 group by t1.k1";
sql(sql).notContain("mv1");
}
starRocksAssert.dropMaterializedView("mv1");
starRocksAssert.dropTable("tbl1");
}

@Test
public void testNullableTestCase2() throws Exception {
String mv = "create materialized view join_null_mv\n" +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ public void testTracerLogMV_Fail1() {
Tracers.init(connectContext, Tracers.Mode.LOGS, "MV");
String mv = "select locations.locationid, empid, sum(emps.deptno) as col3 from emps " +
"join locations on emps.locationid = locations.locationid group by empid,locations.locationid";
testRewriteFail(mv, "select emps.locationid, empid, sum(emps.deptno + 1) as col3 from emps " +
testRewriteFail(mv, "select emps.locationid, empid, min(emps.deptno + 1) as col3 from emps " +
"join locations on emps.locationid = locations.locationid where empid > 10 group by empid,emps.locationid");
String pr = Tracers.printLogs();
Tracers.close();
Expand All @@ -163,7 +163,7 @@ public void testTracerLogMV_Fail2() {
Tracers.init(connectContext, Tracers.Mode.LOGS, "MV");
String mv = "select locations.locationid, empid, sum(emps.deptno) as col3 from emps " +
"join locations on emps.locationid = locations.locationid group by empid,locations.locationid";
testRewriteFail(mv, "select emps.locationid, empid, sum(emps.deptno + 1) as col3 from emps " +
testRewriteFail(mv, "select emps.locationid, empid, min(emps.deptno + 1) as col3 from emps " +
"join locations on emps.locationid = locations.locationid where empid > 10 group by empid,emps.locationid");
String pr = Tracers.printLogs();
Tracers.close();
Expand Down
Loading

0 comments on commit e8ef9ad

Please sign in to comment.