Skip to content

Commit

Permalink
not rewrite in cte producer, and change regression mv affected by thi…
Browse files Browse the repository at this point in the history
…s rule, and add some cases
  • Loading branch information
feiniaofeiafei committed Nov 15, 2024
1 parent b393ff3 commit 73c21af
Show file tree
Hide file tree
Showing 8 changed files with 202 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,11 @@ public enum RuleType {
REWRITE_REPEAT_EXPRESSION(RuleTypeClass.REWRITE),
REWRITE_OLAP_TABLE_SINK_EXPRESSION(RuleTypeClass.REWRITE),
REWRITE_SINK_EXPRESSION(RuleTypeClass.REWRITE),
REWRITE_WINDOW_EXPRESSION(RuleTypeClass.REWRITE),
REWRITE_SET_OPERATION_EXPRESSION(RuleTypeClass.REWRITE),
REWRITE_PARTITION_TOPN_EXPRESSION(RuleTypeClass.REWRITE),
REWRITE_QUALIFY_EXPRESSION(RuleTypeClass.REWRITE),
REWRITE_TOPN_EXPRESSION(RuleTypeClass.REWRITE),
EXTRACT_FILTER_FROM_JOIN(RuleTypeClass.REWRITE),
REORDER_JOIN(RuleTypeClass.REWRITE),
MERGE_PERCENTILE_TO_ARRAY(RuleTypeClass.REWRITE),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.properties.DataTrait;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.CTEId;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
Expand All @@ -36,6 +37,7 @@
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

/**
Expand All @@ -49,6 +51,10 @@ public class EliminateGroupByKeyByUniform extends DefaultPlanRewriter<Map<ExprId

@Override
public Plan rewriteRoot(Plan plan, JobContext jobContext) {
Optional<CTEId> cteId = jobContext.getCascadesContext().getCurrentTree();
if (cteId.isPresent()) {
return plan;
}
Map<ExprId, ExprId> replaceMap = new HashMap<>();
ExprIdRewriter.ReplaceRule replaceRule = new ExprIdRewriter.ReplaceRule(replaceMap);
exprIdReplacer = new ExprIdRewriter(replaceRule, jobContext);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.pattern.MatchingContext;
import org.apache.doris.nereids.pattern.Pattern;
import org.apache.doris.nereids.properties.OrderKey;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher;
Expand All @@ -30,15 +31,25 @@
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.OrderExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalPartitionTopN;
import org.apache.doris.nereids.trees.plans.logical.LogicalQualify;
import org.apache.doris.nereids.trees.plans.logical.LogicalSetOperation;
import org.apache.doris.nereids.trees.plans.logical.LogicalSink;
import org.apache.doris.nereids.trees.plans.logical.LogicalTopN;
import org.apache.doris.nereids.trees.plans.logical.LogicalWindow;
import org.apache.doris.nereids.util.ExpressionUtils;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;

/**ExprIdReplacer*/
public class ExprIdRewriter extends ExpressionRewrite {
Expand All @@ -56,6 +67,11 @@ public List<Rule> buildRules() {
ImmutableList.Builder<Rule> builder = ImmutableList.builder();
builder.addAll(super.buildRules());
builder.addAll(ImmutableList.of(
new LogicalPartitionTopNExpressionRewrite().build(),
new LogicalQualifyExpressionRewrite().build(),
new LogicalTopNExpressionRewrite().build(),
new LogicalSetOperationRewrite().build(),
new LogicalWindowRewrite().build(),
new LogicalResultSinkRewrite().build(),
new LogicalFileSinkRewrite().build(),
new LogicalHiveTableSinkRewrite().build(),
Expand Down Expand Up @@ -164,6 +180,91 @@ public Rule build() {
}
}

private class LogicalSetOperationRewrite extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalSetOperation().thenApply(ctx -> {
LogicalSetOperation setOperation = ctx.root;
List<List<SlotReference>> slotsList = setOperation.getRegularChildrenOutputs();
List<List<SlotReference>> newSlotsList = new ArrayList<>();
ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext);
for (List<SlotReference> slots : slotsList) {
List<SlotReference> newSlots = rewriteAll(slots, rewriter, context);
newSlotsList.add(newSlots);
}
if (newSlotsList.equals(slotsList)) {
return setOperation;
}
return setOperation.withChildrenAndTheirOutputs(setOperation.children(), newSlotsList);
})
.toRule(RuleType.REWRITE_SET_OPERATION_EXPRESSION);
}
}

private class LogicalWindowRewrite extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalWindow().thenApply(ctx -> {
LogicalWindow<Plan> window = ctx.root;
List<NamedExpression> windowExpressions = window.getWindowExpressions();
ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext);
List<NamedExpression> newWindowExpressions = rewriteAll(windowExpressions, rewriter, context);
if (newWindowExpressions.equals(windowExpressions)) {
return window;
}
return window.withExpressionsAndChild(newWindowExpressions, window.child());
})
.toRule(RuleType.REWRITE_WINDOW_EXPRESSION);
}
}

private class LogicalTopNExpressionRewrite extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalTopN().thenApply(ctx -> {
LogicalTopN<Plan> topN = ctx.root;
List<OrderKey> orderKeys = topN.getOrderKeys();
ImmutableList.Builder<OrderKey> rewrittenOrderKeys
= ImmutableList.builderWithExpectedSize(orderKeys.size());
ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext);
boolean changed = false;
for (OrderKey k : orderKeys) {
Expression expression = rewriter.rewrite(k.getExpr(), context);
changed |= expression != k.getExpr();
rewrittenOrderKeys.add(new OrderKey(expression, k.isAsc(), k.isNullFirst()));
}
return changed ? topN.withOrderKeys(rewrittenOrderKeys.build()) : topN;
}).toRule(RuleType.REWRITE_TOPN_EXPRESSION);
}
}

private class LogicalPartitionTopNExpressionRewrite extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalPartitionTopN().thenApply(ctx -> {
LogicalPartitionTopN<Plan> partitionTopN = ctx.root;
ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext);
List<OrderExpression> newOrderExpressions = new ArrayList<>();
boolean changed = false;
for (OrderExpression orderExpression : partitionTopN.getOrderKeys()) {
OrderKey orderKey = orderExpression.getOrderKey();
Expression expr = rewriter.rewrite(orderKey.getExpr(), context);
changed |= expr != orderKey.getExpr();
OrderKey newOrderKey = new OrderKey(expr, orderKey.isAsc(), orderKey.isNullFirst());
newOrderExpressions.add(new OrderExpression(newOrderKey));
}
List<Expression> newPartitionKeys = rewriteAll(partitionTopN.getPartitionKeys(), rewriter, context);
if (!newPartitionKeys.equals(partitionTopN.getPartitionKeys())) {
changed = true;
}
if (!changed) {
return partitionTopN;
}
return partitionTopN.withPartitionKeysAndOrderKeys(newPartitionKeys, newOrderExpressions);
}).toRule(RuleType.REWRITE_PARTITION_TOPN_EXPRESSION);
}
}

private LogicalSink<Plan> applyRewrite(MatchingContext<? extends LogicalSink<Plan>> ctx) {
LogicalSink<Plan> sink = ctx.root;
ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,23 +82,23 @@ cherry 3
107 \N

-- !right_join_right_has_filter --
\N 105
\N 102
\N 106
\N 101
\N 102
\N 103
\N 107
\N 104
\N 105
\N 106
\N 107
100 100

-- !right_join_left_has_filter --
\N 101
\N 102
\N 103
\N 107
\N 104
\N 102
\N 106
\N 105
\N 101
\N 106
\N 107
100 100

-- !left_semi_join_right_has_filter --
Expand Down Expand Up @@ -154,14 +154,14 @@ cherry 3
-- !right_anti_join_right_has_where_filter --

-- !cross_join_left_has_filter --
100 103
100 107
100 105
100 100
100 102
100 106
100 101
100 102
100 103
100 104
100 105
100 106
100 107

-- !cross_join_right_has_filter --
100 100
Expand All @@ -173,3 +173,28 @@ cherry 3
106 100
107 100

-- !union --
1 100
5 105

-- !union_all --
1 100
1 100
5 105

-- !intersect --

-- !except --

-- !set_op_mixed --
1 100

-- !window --

-- !cte_producer --
1 1 100

-- !cte_multi_producer --

-- !cte_consumer --

6 changes: 4 additions & 2 deletions regression-test/suites/mv_p0/count_star/count_star.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ suite ("count_star") {
mv_rewrite_success("select k1,k4,count(*) from d_table group by k1,k4;", "kstar")
qt_select_mv "select k1,k4,count(*) from d_table group by k1,k4 order by 1,2;"

mv_rewrite_success("select k1,k4,count(*) from d_table where k1=1 group by k1,k4;", "kstar")
// fail because RBO rule EliminateGroupByKeyByUniform
mv_rewrite_fail("select k1,k4,count(*) from d_table where k1=1 group by k1,k4;", "kstar")
qt_select_mv "select k1,k4,count(*) from d_table where k1=1 group by k1,k4 order by 1,2;"

mv_rewrite_fail("select k1,k4,count(*) from d_table where k3=1 group by k1,k4;", "kstar")
Expand All @@ -65,7 +66,8 @@ suite ("count_star") {

sql """set enable_stats=true;"""
mv_rewrite_success("select k1,k4,count(*) from d_table group by k1,k4;", "kstar")
mv_rewrite_success("select k1,k4,count(*) from d_table where k1=1 group by k1,k4;", "kstar")
// fail because RBO rule EliminateGroupByKeyByUniform
mv_rewrite_fail("select k1,k4,count(*) from d_table where k1=1 group by k1,k4;", "kstar")
mv_rewrite_fail("select k1,k4,count(*) from d_table where k3=1 group by k1,k4;", "kstar")
mv_rewrite_fail("select count(*) from d_table where k3=1;", "kstar")
}
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,12 @@ suite("eliminate_group_by_key_by_uniform") {
qt_project_slot_uniform_confict_value "select max(c3), c1,c2,c3 from (select a c1,1 c2, d c3 from eli_gbk_by_uniform_t where a=1) t where c2=2 group by c1,c2,c3 order by 1,2,3,4;"

// test join
qt_inner_join_left_has_filter "select t1.b,t2.b from eli_gbk_by_uniform_t t1 inner join eli_gbk_by_uniform_t t2 on t1.b=t2.b and t1.b=100 group by t1.b,t2.b,t2.c order by 1"
qt_inner_join_right_has_filter "select t1.b,t2.b from eli_gbk_by_uniform_t t1 inner join eli_gbk_by_uniform_t t2 on t1.b=t2.b and t2.b=100 group by t1.b,t2.b,t2.c order by 1"
qt_left_join_right_has_filter "select t1.b,t2.b from eli_gbk_by_uniform_t t1 left join eli_gbk_by_uniform_t t2 on t1.b=t2.b and t2.b=100 group by t1.b,t2.b,t2.c order by 1"
qt_left_join_left_has_filter "select t1.b,t2.b from eli_gbk_by_uniform_t t1 left join eli_gbk_by_uniform_t t2 on t1.b=t2.b and t1.b=100 group by t1.b,t2.b,t2.c order by 1"
qt_right_join_right_has_filter "select t1.b,t2.b from eli_gbk_by_uniform_t t1 right join eli_gbk_by_uniform_t t2 on t1.b=t2.b and t2.b=100 group by t1.b,t2.b,t2.c order by 1"
qt_right_join_left_has_filter "select t1.b,t2.b from eli_gbk_by_uniform_t t1 right join eli_gbk_by_uniform_t t2 on t1.b=t2.b and t1.b=100 group by t1.b,t2.b,t2.c order by 1"
qt_inner_join_left_has_filter "select t1.b,t2.b from eli_gbk_by_uniform_t t1 inner join eli_gbk_by_uniform_t t2 on t1.b=t2.b and t1.b=100 group by t1.b,t2.b,t2.c order by 1,2"
qt_inner_join_right_has_filter "select t1.b,t2.b from eli_gbk_by_uniform_t t1 inner join eli_gbk_by_uniform_t t2 on t1.b=t2.b and t2.b=100 group by t1.b,t2.b,t2.c order by 1,2"
qt_left_join_right_has_filter "select t1.b,t2.b from eli_gbk_by_uniform_t t1 left join eli_gbk_by_uniform_t t2 on t1.b=t2.b and t2.b=100 group by t1.b,t2.b,t2.c order by 1,2"
qt_left_join_left_has_filter "select t1.b,t2.b from eli_gbk_by_uniform_t t1 left join eli_gbk_by_uniform_t t2 on t1.b=t2.b and t1.b=100 group by t1.b,t2.b,t2.c order by 1,2"
qt_right_join_right_has_filter "select t1.b,t2.b from eli_gbk_by_uniform_t t1 right join eli_gbk_by_uniform_t t2 on t1.b=t2.b and t2.b=100 group by t1.b,t2.b,t2.c order by 1,2"
qt_right_join_left_has_filter "select t1.b,t2.b from eli_gbk_by_uniform_t t1 right join eli_gbk_by_uniform_t t2 on t1.b=t2.b and t1.b=100 group by t1.b,t2.b,t2.c order by 1,2"
qt_left_semi_join_right_has_filter "select t1.b from eli_gbk_by_uniform_t t1 left semi join eli_gbk_by_uniform_t t2 on t1.b=t2.b and t2.b=100 group by t1.b,t1.a order by 1"
qt_left_semi_join_left_has_filter "select t1.b from eli_gbk_by_uniform_t t1 left semi join eli_gbk_by_uniform_t t2 on t1.b=t2.b and t1.b=100 group by t1.b,t1.a order by 1"
qt_left_anti_join_right_has_on_filter "select t1.b from eli_gbk_by_uniform_t t1 left anti join eli_gbk_by_uniform_t t2 on t1.b=t2.b and t2.b=100 group by t1.b,t1.a order by 1"
Expand All @@ -71,6 +71,29 @@ suite("eliminate_group_by_key_by_uniform") {
qt_right_anti_join_right_has_on_filter "select t2.b from eli_gbk_by_uniform_t t1 right anti join eli_gbk_by_uniform_t t2 on t1.b=t2.b and t2.b=100 group by t2.b,t2.c order by 1"
qt_right_anti_join_left_has_on_filter "select t2.b from eli_gbk_by_uniform_t t1 right anti join eli_gbk_by_uniform_t t2 on t1.b=t2.b and t1.b=100 group by t2.b,t2.c order by 1"
qt_right_anti_join_right_has_where_filter "select t2.b from eli_gbk_by_uniform_t t1 right anti join eli_gbk_by_uniform_t t2 on t1.b=t2.b where t2.b=100 group by t2.b,t2.c order by 1"
qt_cross_join_left_has_filter "select t1.b,t2.b from eli_gbk_by_uniform_t t1 cross join eli_gbk_by_uniform_t t2 where t1.b=100 group by t1.b,t2.b,t2.c order by 1"
qt_cross_join_right_has_filter "select t1.b,t2.b from eli_gbk_by_uniform_t t1 cross join eli_gbk_by_uniform_t t2 where t2.b=100 group by t1.b,t2.b,t2.c order by 1"
qt_cross_join_left_has_filter "select t1.b,t2.b from eli_gbk_by_uniform_t t1 cross join eli_gbk_by_uniform_t t2 where t1.b=100 group by t1.b,t2.b,t2.c order by 1,2"
qt_cross_join_right_has_filter "select t1.b,t2.b from eli_gbk_by_uniform_t t1 cross join eli_gbk_by_uniform_t t2 where t2.b=100 group by t1.b,t2.b,t2.c order by 1,2"

//test union
qt_union "select * from (select a,b from eli_gbk_by_uniform_t where a=1 group by a,b union select a,b from eli_gbk_by_uniform_t where b=100 group by a,b union select a,b from eli_gbk_by_uniform_t where a=5 group by a,b) t order by 1,2,3,4,5"
qt_union_all "select * from (select a,b from eli_gbk_by_uniform_t where a=1 group by a,b union all select a,b from eli_gbk_by_uniform_t where b=100 group by a,b union all select a,b from eli_gbk_by_uniform_t where a=5 group by a,b) t order by 1,2,3,4,5"
qt_intersect "select * from (select a,b from eli_gbk_by_uniform_t where a=1 group by a,b intersect select a,b from eli_gbk_by_uniform_t where b=100 group by a,b intersect select a,b from eli_gbk_by_uniform_t where a=5 group by a,b) t order by 1,2,3,4,5"
qt_except "select * from (select a,b from eli_gbk_by_uniform_t where a=1 group by a,b except select a,b from eli_gbk_by_uniform_t where b=100 group by a,b except select a,b from eli_gbk_by_uniform_t where a=5 group by a,b) t order by 1,2,3,4,5"
qt_set_op_mixed "select * from (select a,b from eli_gbk_by_uniform_t where a=1 group by a,b union select a,b from eli_gbk_by_uniform_t where b=100 group by a,b except select a,b from eli_gbk_by_uniform_t where a=5 group by a,b) t order by 1,2,3,4,5"

//test window
qt_window "select max(a) over(partition by a order by a) from eli_gbk_by_uniform_t where a=10 group by a,b order by 1"
//test partition topn
qt_partition_topn "select r from (select rank(a) over(partition by a order by a) r from eli_gbk_by_uniform_t where a=10 group by a,b) t where r<2 order by 1"
qt_partition_topn_qualifiy "select rank() over(partition by a order by a) r from eli_gbk_by_uniform_t where a=10 group by a,b qualify r<2 order by 1"
//test cte
qt_cte_producer "with t as (select a,b,count(*) from eli_gbk_by_uniform_t where a=1 group by a,b) select t1.a,t2.a,t2.b from t t1 inner join t t2 on t1.a=t2.a order by 1,2,3"
qt_cte_multi_producer "with t as (select a,b,count(*) from eli_gbk_by_uniform_t where a=1 group by a,b), tt as (select a,b,count(*) from eli_gbk_by_uniform_t where b=10 group by a,b) select t1.a,t2.a,t2.b from t t1 inner join tt t2 on t1.a=t2.a order by 1,2,3"
qt_cte_consumer "with t as (select * from eli_gbk_by_uniform_t) select t1.a,t2.b from t t1 inner join t t2 on t1.a=t2.a where t1.a=10 group by t1.a,t2.b order by 1,2 "

//test filter
qt_filter "select * from (select a,b from eli_gbk_by_uniform_t where a=1 group by a,b) t where a>0 order by 1,2"

//test topn
qt_topn "select a,b from eli_gbk_by_uniform_t where a=1 group by a,b order by a limit 10 offset 0"
}
Loading

0 comments on commit 73c21af

Please sign in to comment.