From 7d5d3f092c9b9a49bd9c9297d53038420a203087 Mon Sep 17 00:00:00 2001 From: 924060929 <924060929@qq.com> Date: Thu, 21 Mar 2024 18:53:54 +0800 Subject: [PATCH] fix --- .../rules/analysis/BindExpression.java | 39 ++++++++++++++----- .../nereids_syntax_p0/bind_priority.groovy | 34 ++++++++++++++++ 2 files changed, 64 insertions(+), 9 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java index c134b1ee5c506f5..70678aee35e0ce3 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java @@ -19,6 +19,7 @@ import org.apache.doris.catalog.Env; import org.apache.doris.catalog.FunctionRegistry; +import org.apache.doris.common.Pair; import org.apache.doris.nereids.CascadesContext; import org.apache.doris.nereids.NereidsPlanner; import org.apache.doris.nereids.StatementContext; @@ -376,15 +377,31 @@ private LogicalHaving bindHavingAggregate( Scope aggOuputScope = toScope(cascadesContext, aggregate.getOutput()); Supplier bindByGroupByThenAggOutputThenAggChild = Suppliers.memoize(() -> { List groupByExprs = aggregate.getGroupByExpressions(); - ImmutableList.Builder groupBySlotsBuilder + ImmutableList.Builder groupBySlots = ImmutableList.builderWithExpectedSize(groupByExprs.size()); for (Expression groupBy : groupByExprs) { if (groupBy instanceof Slot) { - groupBySlotsBuilder.add((Slot) groupBy); + groupBySlots.add((Slot) groupBy); } } - List groupBySlots = groupBySlotsBuilder.build(); - Scope groupBySlotsScope = toScope(cascadesContext, groupBySlots); + Scope groupBySlotsScope = toScope(cascadesContext, groupBySlots.build()); + + Supplier> separateAggOutputScopes = Suppliers.memoize(() -> { + ImmutableList.Builder groupByOutputs = ImmutableList.builderWithExpectedSize( + aggregate.getOutputExpressions().size()); + ImmutableList.Builder aggFunOutputs = ImmutableList.builderWithExpectedSize( + aggregate.getOutputExpressions().size()); + for (NamedExpression outputExpression : aggregate.getOutputExpressions()) { + if (outputExpression.anyMatch(AggregateFunction.class::isInstance)) { + aggFunOutputs.add(outputExpression.toSlot()); + } else { + groupByOutputs.add(outputExpression.toSlot()); + } + } + Scope nonAggFunSlotsScope = toScope(cascadesContext, groupByOutputs.build()); + Scope aggFuncSlotsScope = toScope(cascadesContext, aggFunOutputs.build()); + return Pair.of(nonAggFunSlotsScope, aggFuncSlotsScope); + }); return (analyzer, unboundSlot) -> { List boundInGroupBy = analyzer.bindSlotByScope(unboundSlot, groupBySlotsScope); @@ -392,13 +409,17 @@ private LogicalHaving bindHavingAggregate( return boundInGroupBy; } - List boundInAggOutput = analyzer.bindSlotByScope(unboundSlot, aggOuputScope); - if (boundInAggOutput.size() == 1) { - return boundInAggOutput; + Pair separateAggOutputScope = separateAggOutputScopes.get(); + List boundInNonAggFuncs = analyzer.bindSlotByScope(unboundSlot, separateAggOutputScope.first); + if (boundInNonAggFuncs.size() == 1) { + return boundInNonAggFuncs; } - List boundInAggChild = (List) bindByAggChild.get().bindSlot(analyzer, unboundSlot); - return boundInAggChild.size() == 1 ? boundInAggChild : boundInGroupBy; + List boundInAggFuncs = analyzer.bindSlotByScope(unboundSlot, separateAggOutputScope.second); + if (boundInAggFuncs.size() == 1) { + return boundInAggFuncs; + } + return analyzer.bindSlotByScope(unboundSlot, aggOuputScope); }; }); diff --git a/regression-test/suites/nereids_syntax_p0/bind_priority.groovy b/regression-test/suites/nereids_syntax_p0/bind_priority.groovy index 4595265577a63a5..4e1740061b63755 100644 --- a/regression-test/suites/nereids_syntax_p0/bind_priority.groovy +++ b/regression-test/suites/nereids_syntax_p0/bind_priority.groovy @@ -253,5 +253,39 @@ suite("bind_priority") { group by id having sum(id) + id >= 6 """ + + + + + + sql "drop table if exists test_bind_having_slots3" + + sql """CREATE TABLE `test_bind_having_slots3`(pk int, pk2 int) + DUPLICATE KEY(`pk`) + DISTRIBUTED BY HASH(`pk`) BUCKETS 10 + properties('replication_num'='1'); + """ + sql "insert into test_bind_having_slots3 values(1, 1), (2, 2), (2, 2), (3, 3), (3, 3), (3, 3);" + + order_qt_having_bind_group_by """ + SELECT pk + 6 as ps, COUNT(pk ) * 3 as pk + FROM test_bind_having_slots3 tbl_alias1 + GROUP by pk + HAVING pk = 1 + """ + + order_qt_having_bind_group_by """ + SELECT pk + 6 as pk, COUNT(pk ) * 3 as pk + FROM test_bind_having_slots3 tbl_alias1 + GROUP by pk + 6 + HAVING pk = 7 + """ + + order_qt_having_bind_group_by """ + SELECT pk + 6, COUNT(pk ) * 3 as pk + FROM test_bind_having_slots3 tbl_alias1 + GROUP by pk + 6 + HAVING pk = 3 + """ }() }