Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
924060929 committed Mar 21, 2024
1 parent eb5415a commit b555386
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.apache.doris.nereids.StatementContext;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
import org.apache.doris.nereids.rules.expression.rules.TrySimplifyPredicateWithMarkJoinSlot;
import org.apache.doris.nereids.trees.TreeNode;
import org.apache.doris.nereids.trees.expressions.Alias;
Expand Down Expand Up @@ -117,9 +118,11 @@ public List<Rule> buildRules() {
* if it's semi join with non-null mark slot
* we can safely change the mark conjunct to hash conjunct
*/
ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(ctx.cascadesContext);
boolean isMarkSlotNotNull = conjunct.containsType(MarkJoinSlotReference.class)
? ExpressionUtils.canInferNotNullForMarkSlot(
TrySimplifyPredicateWithMarkJoinSlot.INSTANCE.rewrite(conjunct, null))
TrySimplifyPredicateWithMarkJoinSlot.INSTANCE.rewrite(conjunct,
rewriteContext), rewriteContext)
: false;

applyPlan = subqueryToApply(subqueryExprs.stream()
Expand Down Expand Up @@ -239,9 +242,11 @@ public List<Rule> buildRules() {
* if it's semi join with non-null mark slot
* we can safely change the mark conjunct to hash conjunct
*/
ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(ctx.cascadesContext);
boolean isMarkSlotNotNull = conjunct.containsType(MarkJoinSlotReference.class)
? ExpressionUtils.canInferNotNullForMarkSlot(
TrySimplifyPredicateWithMarkJoinSlot.INSTANCE.rewrite(conjunct, null))
TrySimplifyPredicateWithMarkJoinSlot.INSTANCE.rewrite(conjunct, rewriteContext),
rewriteContext)
: false;
applyPlan = subqueryToApply(
subqueryExprs.stream().collect(ImmutableList.toImmutableList()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@

import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
import org.apache.doris.nereids.rules.expression.rules.TrySimplifyPredicateWithMarkJoinSlot;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.util.ExpressionUtils;

Expand All @@ -38,15 +40,22 @@ public class EliminateMarkJoin extends OneRewriteRuleFactory {
public Rule build() {
return logicalFilter(logicalJoin().when(
join -> join.getJoinType().isSemiJoin() && !join.getMarkJoinConjuncts().isEmpty()))
.when(filter -> canSimplifyMarkJoin(filter.getConjuncts()))
.then(filter -> filter.withChildren(eliminateMarkJoin(filter.child())))
.when(filter -> canSimplifyMarkJoin(filter.getConjuncts(), null))
.thenApply(ctx -> {
LogicalFilter<LogicalJoin<Plan, Plan>> filter = ctx.root;
ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(ctx.cascadesContext);
if (canSimplifyMarkJoin(filter.getConjuncts(), rewriteContext)) {
return filter.withChildren(eliminateMarkJoin(filter.child()));
}
return filter;
})
.toRule(RuleType.ELIMINATE_MARK_JOIN);
}

private boolean canSimplifyMarkJoin(Set<Expression> predicates) {
private boolean canSimplifyMarkJoin(Set<Expression> predicates, ExpressionRewriteContext rewriteContext) {
return ExpressionUtils
.canInferNotNullForMarkSlot(TrySimplifyPredicateWithMarkJoinSlot.INSTANCE
.rewrite(ExpressionUtils.and(predicates), null));
.rewrite(ExpressionUtils.and(predicates), rewriteContext), rewriteContext);
}

private LogicalJoin<Plan, Plan> eliminateMarkJoin(LogicalJoin<Plan, Plan> join) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ public static boolean hasOnlyMetricType(List<Expression> children) {
/**
* canInferNotNullForMarkSlot
*/
public static boolean canInferNotNullForMarkSlot(Expression predicate) {
public static boolean canInferNotNullForMarkSlot(Expression predicate, ExpressionRewriteContext ctx) {
/*
* assume predicate is from LogicalFilter
* the idea is replacing each mark join slot with null and false literal then run FoldConstant rule
Expand Down Expand Up @@ -568,7 +568,7 @@ public static boolean canInferNotNullForMarkSlot(Expression predicate) {
}
Expression evalResult = FoldConstantRule.evaluate(
ExpressionUtils.replace(predicate, replaceMap),
new ExpressionRewriteContext(null)
ctx
);

if (evalResult.equals(BooleanLiteral.TRUE)) {
Expand Down

0 comments on commit b555386

Please sign in to comment.