Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
924060929 committed Mar 21, 2024
1 parent ea05341 commit 4c664fd
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -367,13 +367,14 @@ private LogicalHaving<Plan> bindHavingAggregate(
CascadesContext cascadesContext = ctx.cascadesContext;

// keep same behavior as mysql
Supplier<CustomSlotBinderAnalyzer> bindByAggChild = Suppliers.memoize(() -> {
Scope aggChildOutputScope
= toScope(cascadesContext, PlanUtils.fastGetChildrenOutputs(aggregate.children()));
return (analyzer, unboundSlot) -> analyzer.bindSlotByScope(unboundSlot, aggChildOutputScope);
});

Scope aggChildOutputScope = toScope(cascadesContext, PlanUtils.fastGetChildrenOutputs(aggregate.children()));
Supplier<SimpleExprAnalyzer> bindByAggChild = Suppliers.memoize(() -> buildSimpleExprAnalyzer(
having, cascadesContext, aggChildOutputScope, false, true)
);

Supplier<SimpleExprAnalyzer> bindByGroupByThenAggOutputThenAggChild = Suppliers.memoize(() -> {
Scope aggOuputScope = toScope(cascadesContext, aggregate.getOutput());
Supplier<CustomSlotBinderAnalyzer> bindByGroupByThenAggOutputThenAggChild = Suppliers.memoize(() -> {
List<Expression> groupByExprs = aggregate.getGroupByExpressions();
ImmutableList.Builder<Slot> groupBySlotsBuilder
= ImmutableList.builderWithExpectedSize(groupByExprs.size());
Expand All @@ -385,39 +386,74 @@ private LogicalHaving<Plan> bindHavingAggregate(
List<Slot> groupBySlots = groupBySlotsBuilder.build();
Scope groupBySlotsScope = toScope(cascadesContext, groupBySlots);

Supplier<Scope> aggOutputScope = Suppliers.memoize(() -> toScope(cascadesContext, aggregate.getOutput()));
return (analyzer, unboundSlot) -> {
List<Slot> boundInGroupBy = analyzer.bindSlotByScope(unboundSlot, groupBySlotsScope);
if (boundInGroupBy.size() == 1) {
return boundInGroupBy;
}

List<Slot> boundInAggOutput = analyzer.bindSlotByScope(unboundSlot, aggOuputScope);
if (boundInAggOutput.size() == 1) {
return boundInAggOutput;
}

List<Slot> boundInAggChild = (List) bindByAggChild.get().bindSlot(analyzer, unboundSlot);
return boundInAggChild.size() == 1 ? boundInAggChild : boundInGroupBy;
};
});

FunctionRegistry functionRegistry = cascadesContext.getConnectContext().getEnv().getFunctionRegistry();
ExpressionAnalyzer havingAnalyzer = new ExpressionAnalyzer(having, aggOuputScope, cascadesContext,
false, true) {
private boolean currentIsInAggregateFunction;

return buildCustomSlotBinderAnalyzer(having, cascadesContext, groupBySlotsScope,
false, true, (self, unboundSlot) -> {
List<Slot> boundInGroupBy = self.bindSlotByScope(unboundSlot, groupBySlotsScope);
if (boundInGroupBy.size() == 1) {
return boundInGroupBy;
@Override
public Expression visitAggregateFunction(AggregateFunction aggregateFunction,
ExpressionRewriteContext context) {
if (!currentIsInAggregateFunction) {
currentIsInAggregateFunction = true;
try {
return super.visitAggregateFunction(aggregateFunction, context);
} finally {
currentIsInAggregateFunction = false;
}
} else {
return super.visitAggregateFunction(aggregateFunction, context);
}
}

List<Slot> boundInAggOutput = self.bindSlotByScope(unboundSlot, aggOutputScope.get());
if (boundInAggOutput.size() == 1) {
return boundInAggOutput;
@Override
public Expression visitUnboundFunction(UnboundFunction unboundFunction, ExpressionRewriteContext context) {
if (!currentIsInAggregateFunction && isAggregateFunction(unboundFunction, functionRegistry)) {
currentIsInAggregateFunction = true;
try {
return super.visitUnboundFunction(unboundFunction, context);
} finally {
currentIsInAggregateFunction = false;
}
} else {
return super.visitUnboundFunction(unboundFunction, context);
}
}

List<Slot> boundInAggChild = self.bindSlotByScope(unboundSlot, aggChildOutputScope);
return boundInAggChild.size() == 1 ? boundInAggChild : boundInGroupBy;
@Override
protected List<? extends Expression> bindSlotByThisScope(UnboundSlot unboundSlot) {
if (currentIsInAggregateFunction) {
return bindByAggChild.get().bindSlot(this, unboundSlot);
} else {
return bindByGroupByThenAggOutputThenAggChild.get().bindSlot(this, unboundSlot);
}
);
});
}
};

Set<Expression> havingExprs = having.getConjuncts();
ImmutableSet.Builder<Expression> boundHaving = ImmutableSet.builderWithExpectedSize(havingExprs.size());
FunctionRegistry functionRegistry = cascadesContext.getConnectContext().getEnv().getFunctionRegistry();
ImmutableSet.Builder<Expression> analyzedHaving = ImmutableSet.builderWithExpectedSize(havingExprs.size());
ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(cascadesContext);
for (Expression expression : havingExprs) {
boolean hasAggFunc = hasAggregateFunction(expression, functionRegistry);
if (hasAggFunc) {
boundHaving.add(bindByAggChild.get().analyze(expression));
} else {
boundHaving.add(bindByGroupByThenAggOutputThenAggChild.get().analyze(expression));
}
analyzedHaving.add(havingAnalyzer.analyze(expression, rewriteContext));
}

return new LogicalHaving<>(boundHaving.build(), having.child());
return new LogicalHaving<>(analyzedHaving.build(), having.child());
}

private LogicalHaving<Plan> bindHavingByScopes(
Expand Down Expand Up @@ -810,23 +846,9 @@ private void checkIfOutputAliasNameDuplicatedForGroupBy(Collection<Expression> e
}
}

private boolean hasAggregateFunction(Expression expression, FunctionRegistry functionRegistry) {
return expression.anyMatch(expr -> {
if (expr instanceof AggregateFunction) {
return true;
} else if (expr instanceof UnboundFunction) {
UnboundFunction unboundFunction = (UnboundFunction) expr;
boolean isAggregateFunction = functionRegistry
.isAggregateFunction(
unboundFunction.getDbName(),
unboundFunction.getName()
);
if (isAggregateFunction) {
return true;
}
}
return false;
});
private boolean isAggregateFunction(UnboundFunction unboundFunction, FunctionRegistry functionRegistry) {
return functionRegistry.isAggregateFunction(
unboundFunction.getDbName(), unboundFunction.getName());
}

private <E extends Expression> E checkBoundExceptLambda(E expression, Plan plan) {
Expand Down
6 changes: 6 additions & 0 deletions regression-test/data/nereids_syntax_p0/bind_priority.out
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,9 @@ all 2
-- !having_bind_child5 --
2 11

-- !having_bind_agg_fun --

-- !having_bind_agg_fun --
2 4
3 3

38 changes: 31 additions & 7 deletions regression-test/suites/nereids_syntax_p0/bind_priority.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,11 @@ suite("bind_priority") {
def testBindHaving = {
sql "drop table if exists test_bind_having_slots"

sql "create table test_bind_having_slots " +
"(id int, age int) " +
"distributed by hash(id) " +
"properties('replication_num'='1');"
sql """create table test_bind_having_slots
(id int, age int)
distributed by hash(id)
properties('replication_num'='1');
"""
sql "insert into test_bind_having_slots values(1, 10), (2, 20), (3, 30);"

order_qt_having_bind_child """
Expand All @@ -179,7 +180,6 @@ suite("bind_priority") {
having id = 1; -- bind id from group by
"""


order_qt_having_bind_child3 """
select id + 1 as id, sum(age)
from test_bind_having_slots s
Expand All @@ -201,12 +201,11 @@ suite("bind_priority") {
having id + 1 = 2; -- bind id from project
"""


order_qt_having_bind_project3 """
select id + 1 as id, sum(age + 1) as age
from test_bind_having_slots s
group by id
having age = 10; -- bind id from age
having age = 10; -- bind age from project
"""

order_qt_having_bind_project4 """
Expand All @@ -229,5 +228,30 @@ suite("bind_priority") {
group by id
having sum(age + 1) = 11 -- bind age from s
"""




sql "drop table if exists test_bind_having_slots2"
sql """create table test_bind_having_slots2
(id int)
distributed by hash(id)
properties('replication_num'='1');
"""
sql "insert into test_bind_having_slots2 values(1), (2), (3), (2);"

order_qt_having_bind_agg_fun """
select id, abs(sum(id)) as id
from test_bind_having_slots2
group by id
having sum(id) + id >= 7
"""

order_qt_having_bind_agg_fun """
select id, abs(sum(id)) as id
from test_bind_having_slots2
group by id
having sum(id) + id >= 6
"""
}()
}

0 comments on commit 4c664fd

Please sign in to comment.