diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java index 7bbe87ada541d45..c134b1ee5c506f5 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java @@ -367,13 +367,14 @@ private LogicalHaving bindHavingAggregate( CascadesContext cascadesContext = ctx.cascadesContext; // keep same behavior as mysql + Supplier bindByAggChild = Suppliers.memoize(() -> { + Scope aggChildOutputScope + = toScope(cascadesContext, PlanUtils.fastGetChildrenOutputs(aggregate.children())); + return (analyzer, unboundSlot) -> analyzer.bindSlotByScope(unboundSlot, aggChildOutputScope); + }); - Scope aggChildOutputScope = toScope(cascadesContext, PlanUtils.fastGetChildrenOutputs(aggregate.children())); - Supplier bindByAggChild = Suppliers.memoize(() -> buildSimpleExprAnalyzer( - having, cascadesContext, aggChildOutputScope, false, true) - ); - - Supplier bindByGroupByThenAggOutputThenAggChild = Suppliers.memoize(() -> { + Scope aggOuputScope = toScope(cascadesContext, aggregate.getOutput()); + Supplier bindByGroupByThenAggOutputThenAggChild = Suppliers.memoize(() -> { List groupByExprs = aggregate.getGroupByExpressions(); ImmutableList.Builder groupBySlotsBuilder = ImmutableList.builderWithExpectedSize(groupByExprs.size()); @@ -385,39 +386,74 @@ private LogicalHaving bindHavingAggregate( List groupBySlots = groupBySlotsBuilder.build(); Scope groupBySlotsScope = toScope(cascadesContext, groupBySlots); - Supplier aggOutputScope = Suppliers.memoize(() -> toScope(cascadesContext, aggregate.getOutput())); + return (analyzer, unboundSlot) -> { + List boundInGroupBy = analyzer.bindSlotByScope(unboundSlot, groupBySlotsScope); + if (boundInGroupBy.size() == 1) { + return boundInGroupBy; + } + + List boundInAggOutput = analyzer.bindSlotByScope(unboundSlot, aggOuputScope); + if (boundInAggOutput.size() == 1) { + return boundInAggOutput; + } + + List boundInAggChild = (List) bindByAggChild.get().bindSlot(analyzer, unboundSlot); + return boundInAggChild.size() == 1 ? boundInAggChild : boundInGroupBy; + }; + }); + + FunctionRegistry functionRegistry = cascadesContext.getConnectContext().getEnv().getFunctionRegistry(); + ExpressionAnalyzer havingAnalyzer = new ExpressionAnalyzer(having, aggOuputScope, cascadesContext, + false, true) { + private boolean currentIsInAggregateFunction; - return buildCustomSlotBinderAnalyzer(having, cascadesContext, groupBySlotsScope, - false, true, (self, unboundSlot) -> { - List boundInGroupBy = self.bindSlotByScope(unboundSlot, groupBySlotsScope); - if (boundInGroupBy.size() == 1) { - return boundInGroupBy; + @Override + public Expression visitAggregateFunction(AggregateFunction aggregateFunction, + ExpressionRewriteContext context) { + if (!currentIsInAggregateFunction) { + currentIsInAggregateFunction = true; + try { + return super.visitAggregateFunction(aggregateFunction, context); + } finally { + currentIsInAggregateFunction = false; } + } else { + return super.visitAggregateFunction(aggregateFunction, context); + } + } - List boundInAggOutput = self.bindSlotByScope(unboundSlot, aggOutputScope.get()); - if (boundInAggOutput.size() == 1) { - return boundInAggOutput; + @Override + public Expression visitUnboundFunction(UnboundFunction unboundFunction, ExpressionRewriteContext context) { + if (!currentIsInAggregateFunction && isAggregateFunction(unboundFunction, functionRegistry)) { + currentIsInAggregateFunction = true; + try { + return super.visitUnboundFunction(unboundFunction, context); + } finally { + currentIsInAggregateFunction = false; } + } else { + return super.visitUnboundFunction(unboundFunction, context); + } + } - List boundInAggChild = self.bindSlotByScope(unboundSlot, aggChildOutputScope); - return boundInAggChild.size() == 1 ? boundInAggChild : boundInGroupBy; + @Override + protected List bindSlotByThisScope(UnboundSlot unboundSlot) { + if (currentIsInAggregateFunction) { + return bindByAggChild.get().bindSlot(this, unboundSlot); + } else { + return bindByGroupByThenAggOutputThenAggChild.get().bindSlot(this, unboundSlot); } - ); - }); + } + }; Set havingExprs = having.getConjuncts(); - ImmutableSet.Builder boundHaving = ImmutableSet.builderWithExpectedSize(havingExprs.size()); - FunctionRegistry functionRegistry = cascadesContext.getConnectContext().getEnv().getFunctionRegistry(); + ImmutableSet.Builder analyzedHaving = ImmutableSet.builderWithExpectedSize(havingExprs.size()); + ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(cascadesContext); for (Expression expression : havingExprs) { - boolean hasAggFunc = hasAggregateFunction(expression, functionRegistry); - if (hasAggFunc) { - boundHaving.add(bindByAggChild.get().analyze(expression)); - } else { - boundHaving.add(bindByGroupByThenAggOutputThenAggChild.get().analyze(expression)); - } + analyzedHaving.add(havingAnalyzer.analyze(expression, rewriteContext)); } - return new LogicalHaving<>(boundHaving.build(), having.child()); + return new LogicalHaving<>(analyzedHaving.build(), having.child()); } private LogicalHaving bindHavingByScopes( @@ -810,23 +846,9 @@ private void checkIfOutputAliasNameDuplicatedForGroupBy(Collection e } } - private boolean hasAggregateFunction(Expression expression, FunctionRegistry functionRegistry) { - return expression.anyMatch(expr -> { - if (expr instanceof AggregateFunction) { - return true; - } else if (expr instanceof UnboundFunction) { - UnboundFunction unboundFunction = (UnboundFunction) expr; - boolean isAggregateFunction = functionRegistry - .isAggregateFunction( - unboundFunction.getDbName(), - unboundFunction.getName() - ); - if (isAggregateFunction) { - return true; - } - } - return false; - }); + private boolean isAggregateFunction(UnboundFunction unboundFunction, FunctionRegistry functionRegistry) { + return functionRegistry.isAggregateFunction( + unboundFunction.getDbName(), unboundFunction.getName()); } private E checkBoundExceptLambda(E expression, Plan plan) { diff --git a/regression-test/data/nereids_syntax_p0/bind_priority.out b/regression-test/data/nereids_syntax_p0/bind_priority.out index b3bc666c23d50ed..fec4313d09eed1d 100644 --- a/regression-test/data/nereids_syntax_p0/bind_priority.out +++ b/regression-test/data/nereids_syntax_p0/bind_priority.out @@ -61,3 +61,9 @@ all 2 -- !having_bind_child5 -- 2 11 +-- !having_bind_agg_fun -- + +-- !having_bind_agg_fun -- +2 4 +3 3 + diff --git a/regression-test/suites/nereids_syntax_p0/bind_priority.groovy b/regression-test/suites/nereids_syntax_p0/bind_priority.groovy index 69fd84c89284758..4595265577a63a5 100644 --- a/regression-test/suites/nereids_syntax_p0/bind_priority.groovy +++ b/regression-test/suites/nereids_syntax_p0/bind_priority.groovy @@ -159,10 +159,11 @@ suite("bind_priority") { def testBindHaving = { sql "drop table if exists test_bind_having_slots" - sql "create table test_bind_having_slots " + - "(id int, age int) " + - "distributed by hash(id) " + - "properties('replication_num'='1');" + sql """create table test_bind_having_slots + (id int, age int) + distributed by hash(id) + properties('replication_num'='1'); + """ sql "insert into test_bind_having_slots values(1, 10), (2, 20), (3, 30);" order_qt_having_bind_child """ @@ -179,7 +180,6 @@ suite("bind_priority") { having id = 1; -- bind id from group by """ - order_qt_having_bind_child3 """ select id + 1 as id, sum(age) from test_bind_having_slots s @@ -201,12 +201,11 @@ suite("bind_priority") { having id + 1 = 2; -- bind id from project """ - order_qt_having_bind_project3 """ select id + 1 as id, sum(age + 1) as age from test_bind_having_slots s group by id - having age = 10; -- bind id from age + having age = 10; -- bind age from project """ order_qt_having_bind_project4 """ @@ -229,5 +228,30 @@ suite("bind_priority") { group by id having sum(age + 1) = 11 -- bind age from s """ + + + + + sql "drop table if exists test_bind_having_slots2" + sql """create table test_bind_having_slots2 + (id int) + distributed by hash(id) + properties('replication_num'='1'); + """ + sql "insert into test_bind_having_slots2 values(1), (2), (3), (2);" + + order_qt_having_bind_agg_fun """ + select id, abs(sum(id)) as id + from test_bind_having_slots2 + group by id + having sum(id) + id >= 7 + """ + + order_qt_having_bind_agg_fun """ + select id, abs(sum(id)) as id + from test_bind_having_slots2 + group by id + having sum(id) + id >= 6 + """ }() }