From 73771ce7fabde94220331b2aeb076bf3c9294541 Mon Sep 17 00:00:00 2001 From: minghong Date: Wed, 13 Nov 2024 14:58:31 +0800 Subject: [PATCH 1/5] enhance PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE --- .../PushDownAggThroughJoinOneSide.java | 146 ++++++++++++++---- .../push_down_count_through_join_one_side.out | 20 +++ ...sh_down_count_through_join_one_side.groovy | 95 ++++++++++++ 3 files changed, 232 insertions(+), 29 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOneSide.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOneSide.java index f32bf8ea91b355..ac42182890cbc2 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOneSide.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOneSide.java @@ -36,6 +36,7 @@ import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList.Builder; +import com.google.common.collect.Lists; import java.util.ArrayList; import java.util.HashMap; @@ -88,15 +89,16 @@ public List buildRules() { }) .toRule(RuleType.PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE), logicalAggregate(logicalProject(innerLogicalJoin())) - .when(agg -> agg.child().isAllSlots()) - .when(agg -> agg.child().child().getOtherJoinConjuncts().isEmpty()) - .whenNot(agg -> agg.child().children().stream().anyMatch(p -> p instanceof LogicalAggregate)) + // .when(agg -> agg.child().isAllSlots()) + // .when(agg -> agg.child().child().getOtherJoinConjuncts().isEmpty()) + .whenNot(agg -> agg.child() + .child(0).children().stream().anyMatch(p -> p instanceof LogicalAggregate)) .when(agg -> { Set funcs = agg.getAggregateFunctions(); return !funcs.isEmpty() && funcs.stream() .allMatch(f -> (f instanceof Min || f instanceof Max || f instanceof Sum - || (f instanceof Count && (!((Count) f).isCountStar()))) && !f.isDistinct() - && f.child(0) instanceof Slot); + || f instanceof Count) && !f.isDistinct() + && (f.children().isEmpty() || f.child(0) instanceof Slot)); }) .thenApply(ctx -> { Set enableNereidsRules = ctx.cascadesContext.getConnectContext() @@ -111,6 +113,7 @@ public List buildRules() { ); } + /** * Push down Min/Max/Sum through join. */ @@ -119,21 +122,6 @@ public static LogicalAggregate pushMinMaxSumCount(LogicalAggregate leftOutput = join.left().getOutput(); List rightOutput = join.right().getOutput(); - List leftFuncs = new ArrayList<>(); - List rightFuncs = new ArrayList<>(); - for (AggregateFunction func : agg.getAggregateFunctions()) { - Slot slot = (Slot) func.child(0); - if (leftOutput.contains(slot)) { - leftFuncs.add(func); - } else if (rightOutput.contains(slot)) { - rightFuncs.add(func); - } else { - throw new IllegalStateException("Slot " + slot + " not found in join output"); - } - } - if (leftFuncs.isEmpty() && rightFuncs.isEmpty()) { - return null; - } Set leftGroupBy = new HashSet<>(); Set rightGroupBy = new HashSet<>(); @@ -144,18 +132,76 @@ public static LogicalAggregate pushMinMaxSumCount(LogicalAggregate inputForAlias = proj.getInputSlots(); + if (leftOutput.containsAll(inputForAlias)) { + leftGroupBy.addAll(inputForAlias); + } else if (rightOutput.containsAll(inputForAlias)) { + rightGroupBy.addAll(inputForAlias); + } else { + /* + groupBy(X) + +---> project( a + b as X) + --> join(output: T1.a, T2.b) + --> T1(a) + --> T2(b) + X can not be pushed + */ + return null; + } + break; + } + } + } } } - join.getHashJoinConjuncts().forEach(e -> e.getInputSlots().forEach(slot -> { - if (leftOutput.contains(slot)) { - leftGroupBy.add(slot); - } else if (rightOutput.contains(slot)) { - rightGroupBy.add(slot); + + List leftFuncs = new ArrayList<>(); + List rightFuncs = new ArrayList<>(); + Count countStar = null; + for (AggregateFunction func : agg.getAggregateFunctions()) { + if (func instanceof Count && ((Count) func).isCountStar()) { + countStar = (Count) func; } else { - throw new IllegalStateException("Slot " + slot + " not found in join output"); + Slot slot = (Slot) func.child(0); + if (leftOutput.contains(slot)) { + leftFuncs.add(func); + } else if (rightOutput.contains(slot)) { + rightFuncs.add(func); + } else { + throw new IllegalStateException("Slot " + slot + " not found in join output"); + } + } + } + // determine count(*) + if (countStar != null) { + if (!leftGroupBy.isEmpty()) { + countStar = (Count) countStar.withChildren(leftGroupBy.iterator().next()); + leftFuncs.add(countStar); + } else if (!rightGroupBy.isEmpty()) { + countStar = (Count) countStar.withChildren(rightGroupBy.iterator().next()); + rightFuncs.add(countStar); + } else { + return null; + } + } + for (Expression condition : join.getHashJoinConjuncts()) { + for (Slot joinConditionSlot : condition.getInputSlots()) { + if (leftOutput.contains(joinConditionSlot)) { + leftGroupBy.add(joinConditionSlot); + } else if (rightOutput.contains(joinConditionSlot)) { + rightGroupBy.add(joinConditionSlot); + } else { + // apply failed + return null; + } } - })); + } Plan left = join.left(); Plan right = join.right(); @@ -196,6 +242,10 @@ public static LogicalAggregate pushMinMaxSumCount(LogicalAggregate pushMinMaxSumCount(LogicalAggregate newProjections = Lists.newArrayList(); + newProjections.addAll(project.getProjects()); + Set leftDifference = new HashSet(left.getOutput()); + leftDifference.removeAll(project.getProjects()); + newProjections.addAll(leftDifference); + Set rightDifference = new HashSet(right.getOutput()); + rightDifference.removeAll(project.getProjects()); + newProjections.addAll(rightDifference); + newAggChild = ((LogicalProject) agg.child()).withProjectsAndChild(newProjections, newJoin); + } // TODO: column prune project - return agg.withAggOutputChild(newOutputExprs, newJoin); + LogicalAggregate newAgg = agg.withAggOutputChild(newOutputExprs, newAggChild); + if (checkOutput(newAgg)) { + return newAgg; + } else { + return (LogicalAggregate) agg; + } } private static Expression replaceAggFunc(AggregateFunction func, Slot inputSlot) { @@ -222,4 +290,24 @@ private static Expression replaceAggFunc(AggregateFunction func, Slot inputSlot) return func.withChildren(inputSlot); } } + + private static boolean checkOutput(LogicalAggregate agg) { + if (agg.child() instanceof LogicalProject) { + Set joinOutputs = ((Plan) agg.child().child(0)).getOutputSet(); + if (!joinOutputs.containsAll(((LogicalProject) agg.child()).getInputSlots())) { + return false; + } + Set projectOutputs = ((LogicalProject) agg.child()).getOutputSet(); + if (!projectOutputs.containsAll(agg.getInputSlots())) { + return false; + } + return true; + } else { + Set joinOutputs = ((Plan) agg.child()).getOutputSet(); + if (!joinOutputs.containsAll(agg.getInputSlots())) { + return false; + } + return true; + } + } } diff --git a/regression-test/data/nereids_rules_p0/eager_aggregate/push_down_count_through_join_one_side.out b/regression-test/data/nereids_rules_p0/eager_aggregate/push_down_count_through_join_one_side.out index da69919becd7f2..eddd2733ee903f 100644 --- a/regression-test/data/nereids_rules_p0/eager_aggregate/push_down_count_through_join_one_side.out +++ b/regression-test/data/nereids_rules_p0/eager_aggregate/push_down_count_through_join_one_side.out @@ -1034,3 +1034,23 @@ Used: UnUsed: use_push_down_agg_through_join_one_side SyntaxError: +-- !shape -- +PhysicalResultSink +--PhysicalTopN[MERGE_SORT] +----PhysicalTopN[LOCAL_SORT] +------hashAgg[GLOBAL] +--------hashAgg[LOCAL] +----------hashJoin[INNER_JOIN] hashCondition=((dwd_tracking_sensor_init_tmp_ymd.dt = dw_user_b2c_tracking_info_tmp_ymd.dt) and (dwd_tracking_sensor_init_tmp_ymd.guid = dw_user_b2c_tracking_info_tmp_ymd.guid)) otherCondition=((dwd_tracking_sensor_init_tmp_ymd.dt >= substring(first_visit_time, 1, 10))) +------------filter((dwd_tracking_sensor_init_tmp_ymd.dt = '2024-08-19') and (dwd_tracking_sensor_init_tmp_ymd.tracking_type = 'click')) +--------------PhysicalOlapScan[dwd_tracking_sensor_init_tmp_ymd] +------------filter((dw_user_b2c_tracking_info_tmp_ymd.dt = '2024-08-19')) +--------------PhysicalOlapScan[dw_user_b2c_tracking_info_tmp_ymd] + +Hint log: +Used: +UnUsed: use_PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE +SyntaxError: + +-- !agg_pushed -- +2 是 2024-08-19 + diff --git a/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_count_through_join_one_side.groovy b/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_count_through_join_one_side.groovy index 02e06710296333..e551fa04c9110a 100644 --- a/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_count_through_join_one_side.groovy +++ b/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_count_through_join_one_side.groovy @@ -426,4 +426,99 @@ suite("push_down_count_through_join_one_side") { qt_with_hint_groupby_pushdown_nested_queries """ explain shape plan select /*+ USE_CBO_RULE(push_down_agg_through_join_one_side) */ count(*) from (select * from count_t_one_side where score > 20) t1 join (select * from count_t_one_side where id < 100) t2 on t1.id = t2.id group by t1.name; """ + + sql """ + drop table if exists dw_user_b2c_tracking_info_tmp_ymd; + create table dw_user_b2c_tracking_info_tmp_ymd ( + guid int, + dt varchar, + first_visit_time varchar + )Engine=Olap + DUPLICATE KEY(guid) + distributed by hash(dt) buckets 3 + properties('replication_num' = '1'); + + insert into dw_user_b2c_tracking_info_tmp_ymd values (1, '2024-08-19', '2024-08-19'); + + drop table if exists dwd_tracking_sensor_init_tmp_ymd; + create table dwd_tracking_sensor_init_tmp_ymd ( + guid int, + dt varchar, + tracking_type varchar + )Engine=Olap + DUPLICATE KEY(guid) + distributed by hash(dt) buckets 3 + properties('replication_num' = '1'); + + insert into dwd_tracking_sensor_init_tmp_ymd values(1, '2024-08-19', 'click'), (1, '2024-08-19', 'click'); + """ + sql """ + set ENABLE_NEREIDS_RULES = "PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE"; + set disable_join_reorder=true; + """ + + qt_shape """ + explain shape plan + SELECT /*+use_cbo_rule(PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE)*/ + Count(*) AS accee593, + CASE + WHEN dwd_tracking_sensor_init_tmp_ymd.dt = + Substring(dw_user_b2c_tracking_info_tmp_ymd.first_visit_time, 1, + 10) THEN + '是' + WHEN dwd_tracking_sensor_init_tmp_ymd.dt > + Substring(dw_user_b2c_tracking_info_tmp_ymd.first_visit_time, 1, + 10) THEN + '否' + ELSE '-1' + end AS a1302fb2, + dwd_tracking_sensor_init_tmp_ymd.dt AS ad466123 + FROM dwd_tracking_sensor_init_tmp_ymd + LEFT JOIN dw_user_b2c_tracking_info_tmp_ymd + ON dwd_tracking_sensor_init_tmp_ymd.guid = + dw_user_b2c_tracking_info_tmp_ymd.guid + AND dwd_tracking_sensor_init_tmp_ymd.dt = + dw_user_b2c_tracking_info_tmp_ymd.dt + WHERE dwd_tracking_sensor_init_tmp_ymd.dt = '2024-08-19' + AND dw_user_b2c_tracking_info_tmp_ymd.dt = '2024-08-19' + AND dwd_tracking_sensor_init_tmp_ymd.dt >= + Substring(dw_user_b2c_tracking_info_tmp_ymd.first_visit_time, 1, 10) + AND dwd_tracking_sensor_init_tmp_ymd.tracking_type = 'click' + GROUP BY 2, + 3 + ORDER BY 3 ASC + LIMIT 10000; + """ + + qt_agg_pushed """ + SELECT /*+use_cbo_rule(PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE)*/ + Count(*) AS accee593, + CASE + WHEN dwd_tracking_sensor_init_tmp_ymd.dt = + Substring(dw_user_b2c_tracking_info_tmp_ymd.first_visit_time, 1, + 10) THEN + '是' + WHEN dwd_tracking_sensor_init_tmp_ymd.dt > + Substring(dw_user_b2c_tracking_info_tmp_ymd.first_visit_time, 1, + 10) THEN + '否' + ELSE '-1' + end AS a1302fb2, + dwd_tracking_sensor_init_tmp_ymd.dt AS ad466123 + FROM dwd_tracking_sensor_init_tmp_ymd + LEFT JOIN dw_user_b2c_tracking_info_tmp_ymd + ON dwd_tracking_sensor_init_tmp_ymd.guid = + dw_user_b2c_tracking_info_tmp_ymd.guid + AND dwd_tracking_sensor_init_tmp_ymd.dt = + dw_user_b2c_tracking_info_tmp_ymd.dt + WHERE dwd_tracking_sensor_init_tmp_ymd.dt = '2024-08-19' + AND dw_user_b2c_tracking_info_tmp_ymd.dt = '2024-08-19' + AND dwd_tracking_sensor_init_tmp_ymd.dt >= + Substring(dw_user_b2c_tracking_info_tmp_ymd.first_visit_time, 1, 10) + AND dwd_tracking_sensor_init_tmp_ymd.tracking_type = 'click' + GROUP BY 2, + 3 + ORDER BY 3 ASC + LIMIT 10000; + """ } From 6a99e1b41769ab492522331e85dd522bed8b6c57 Mon Sep 17 00:00:00 2001 From: minghong Date: Fri, 15 Nov 2024 10:27:41 +0800 Subject: [PATCH 2/5] fmt --- .../nereids/rules/rewrite/PushDownAggThroughJoinOneSide.java | 3 --- 1 file changed, 3 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOneSide.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOneSide.java index ac42182890cbc2..86fa1def90c13a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOneSide.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOneSide.java @@ -113,7 +113,6 @@ public List buildRules() { ); } - /** * Push down Min/Max/Sum through join. */ @@ -121,8 +120,6 @@ public static LogicalAggregate pushMinMaxSumCount(LogicalAggregate join, List projects) { List leftOutput = join.left().getOutput(); List rightOutput = join.right().getOutput(); - - Set leftGroupBy = new HashSet<>(); Set rightGroupBy = new HashSet<>(); for (Expression e : agg.getGroupByExpressions()) { From 747e348139402589b9c0f794ce064d387125fadc Mon Sep 17 00:00:00 2001 From: minghong Date: Mon, 18 Nov 2024 14:20:49 +0800 Subject: [PATCH 3/5] fix --- .../PushDownAggThroughJoinOneSide.java | 42 ++++++++----------- .../push_down_count_through_join_one_side.out | 10 +++-- 2 files changed, 24 insertions(+), 28 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOneSide.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOneSide.java index 86fa1def90c13a..ae78c79bc57dea 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOneSide.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOneSide.java @@ -75,8 +75,8 @@ public List buildRules() { Set funcs = agg.getAggregateFunctions(); return !funcs.isEmpty() && funcs.stream() .allMatch(f -> (f instanceof Min || f instanceof Max || f instanceof Sum - || (f instanceof Count && !((Count) f).isCountStar())) && !f.isDistinct() - && f.child(0) instanceof Slot); + || f instanceof Count && !f.isDistinct() + && f.child(0) instanceof Slot)); }) .thenApply(ctx -> { Set enableNereidsRules = ctx.cascadesContext.getConnectContext() @@ -135,21 +135,15 @@ public static LogicalAggregate pushMinMaxSumCount(LogicalAggregate inputForAlias = proj.getInputSlots(); - if (leftOutput.containsAll(inputForAlias)) { - leftGroupBy.addAll(inputForAlias); - } else if (rightOutput.containsAll(inputForAlias)) { - rightGroupBy.addAll(inputForAlias); - } else { - /* - groupBy(X) - +---> project( a + b as X) - --> join(output: T1.a, T2.b) - --> T1(a) - --> T2(b) - X can not be pushed - */ - return null; + Set inputForAliasSet = proj.getInputSlots(); + for (Slot aliasInputSlot : inputForAliasSet) { + if (leftOutput.contains(aliasInputSlot)) { + leftGroupBy.add(aliasInputSlot); + } else if (rightOutput.contains(aliasInputSlot)) { + rightGroupBy.add(aliasInputSlot); + } else { + return null; + } } break; } @@ -161,6 +155,7 @@ public static LogicalAggregate pushMinMaxSumCount(LogicalAggregate leftFuncs = new ArrayList<>(); List rightFuncs = new ArrayList<>(); Count countStar = null; + Count rewrittenCountStar = null; for (AggregateFunction func : agg.getAggregateFunctions()) { if (func instanceof Count && ((Count) func).isCountStar()) { countStar = (Count) func; @@ -175,14 +170,14 @@ public static LogicalAggregate pushMinMaxSumCount(LogicalAggregate pushMinMaxSumCount(LogicalAggregate pushMinMaxSumCount(LogicalAggregate rightDifference = new HashSet(right.getOutput()); rightDifference.removeAll(project.getProjects()); newProjections.addAll(rightDifference); - newAggChild = ((LogicalProject) agg.child()).withProjectsAndChild(newProjections, newJoin); } // TODO: column prune project diff --git a/regression-test/data/nereids_rules_p0/eager_aggregate/push_down_count_through_join_one_side.out b/regression-test/data/nereids_rules_p0/eager_aggregate/push_down_count_through_join_one_side.out index eddd2733ee903f..8267eb3e38ff91 100644 --- a/regression-test/data/nereids_rules_p0/eager_aggregate/push_down_count_through_join_one_side.out +++ b/regression-test/data/nereids_rules_p0/eager_aggregate/push_down_count_through_join_one_side.out @@ -1041,14 +1041,16 @@ PhysicalResultSink ------hashAgg[GLOBAL] --------hashAgg[LOCAL] ----------hashJoin[INNER_JOIN] hashCondition=((dwd_tracking_sensor_init_tmp_ymd.dt = dw_user_b2c_tracking_info_tmp_ymd.dt) and (dwd_tracking_sensor_init_tmp_ymd.guid = dw_user_b2c_tracking_info_tmp_ymd.guid)) otherCondition=((dwd_tracking_sensor_init_tmp_ymd.dt >= substring(first_visit_time, 1, 10))) -------------filter((dwd_tracking_sensor_init_tmp_ymd.dt = '2024-08-19') and (dwd_tracking_sensor_init_tmp_ymd.tracking_type = 'click')) ---------------PhysicalOlapScan[dwd_tracking_sensor_init_tmp_ymd] +------------hashAgg[GLOBAL] +--------------hashAgg[LOCAL] +----------------filter((dwd_tracking_sensor_init_tmp_ymd.dt = '2024-08-19') and (dwd_tracking_sensor_init_tmp_ymd.tracking_type = 'click')) +------------------PhysicalOlapScan[dwd_tracking_sensor_init_tmp_ymd] ------------filter((dw_user_b2c_tracking_info_tmp_ymd.dt = '2024-08-19')) --------------PhysicalOlapScan[dw_user_b2c_tracking_info_tmp_ymd] Hint log: -Used: -UnUsed: use_PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE +Used: use_PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE +UnUsed: SyntaxError: -- !agg_pushed -- From 5e1f66165a0251d3173dfb01af3f326d048e8c56 Mon Sep 17 00:00:00 2001 From: minghong Date: Mon, 18 Nov 2024 16:32:11 +0800 Subject: [PATCH 4/5] fix --- .../PushDownAggThroughJoinOneSide.java | 28 +------------------ 1 file changed, 1 insertion(+), 27 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOneSide.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOneSide.java index ae78c79bc57dea..0e0c502604174a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOneSide.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOneSide.java @@ -265,13 +265,7 @@ public static LogicalAggregate pushMinMaxSumCount(LogicalAggregate newAgg = agg.withAggOutputChild(newOutputExprs, newAggChild); - if (checkOutput(newAgg)) { - return newAgg; - } else { - return (LogicalAggregate) agg; - } + return agg.withAggOutputChild(newOutputExprs, newAggChild); } private static Expression replaceAggFunc(AggregateFunction func, Slot inputSlot) { @@ -281,24 +275,4 @@ private static Expression replaceAggFunc(AggregateFunction func, Slot inputSlot) return func.withChildren(inputSlot); } } - - private static boolean checkOutput(LogicalAggregate agg) { - if (agg.child() instanceof LogicalProject) { - Set joinOutputs = ((Plan) agg.child().child(0)).getOutputSet(); - if (!joinOutputs.containsAll(((LogicalProject) agg.child()).getInputSlots())) { - return false; - } - Set projectOutputs = ((LogicalProject) agg.child()).getOutputSet(); - if (!projectOutputs.containsAll(agg.getInputSlots())) { - return false; - } - return true; - } else { - Set joinOutputs = ((Plan) agg.child()).getOutputSet(); - if (!joinOutputs.containsAll(agg.getInputSlots())) { - return false; - } - return true; - } - } } From eae497f52cdd8ecaeeb62b755ac43d7583ef3ac9 Mon Sep 17 00:00:00 2001 From: minghong Date: Wed, 27 Nov 2024 10:58:14 +0800 Subject: [PATCH 5/5] fix ut --- .../rewrite/PushDownAggThroughJoinOneSide.java | 2 +- .../PushDownMinMaxSumThroughJoinTest.java | 16 +++++++--------- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOneSide.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOneSide.java index 0e0c502604174a..c5d3d0fb49a0a5 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOneSide.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOneSide.java @@ -76,7 +76,7 @@ public List buildRules() { return !funcs.isEmpty() && funcs.stream() .allMatch(f -> (f instanceof Min || f instanceof Max || f instanceof Sum || f instanceof Count && !f.isDistinct() - && f.child(0) instanceof Slot)); + && (f.children().isEmpty() || f.child(0) instanceof Slot))); }) .thenApply(ctx -> { Set enableNereidsRules = ctx.cascadesContext.getConnectContext() diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownMinMaxSumThroughJoinTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownMinMaxSumThroughJoinTest.java index 58ab7fbe9e925f..cffe91045d0ab2 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownMinMaxSumThroughJoinTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownMinMaxSumThroughJoinTest.java @@ -323,11 +323,11 @@ void testSingleCountStar() { .applyTopDown(new PushDownAggThroughJoinOneSide()) .printlnTree() .matches( - logicalAggregate( - logicalJoin( - logicalOlapScan(), + logicalJoin( + logicalAggregate( logicalOlapScan() - ) + ), + logicalOlapScan() ) ); } @@ -346,11 +346,9 @@ void testBothSideCountAndCountStar() { PlanChecker.from(MemoTestUtils.createConnectContext(), plan) .applyTopDown(new PushDownAggThroughJoinOneSide()) .matches( - logicalAggregate( - logicalJoin( - logicalOlapScan(), - logicalOlapScan() - ) + logicalJoin( + logicalAggregate(logicalOlapScan()), + logicalAggregate(logicalOlapScan()) ) ); }