Skip to content

Commit

Permalink
fix bind having aggregate failed
Browse files Browse the repository at this point in the history
  • Loading branch information
924060929 committed Mar 21, 2024
1 parent 09e5845 commit 086db66
Show file tree
Hide file tree
Showing 10 changed files with 356 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,14 @@ public class FunctionRegistry {
// to record the global alias function and other udf.
private static final String GLOBAL_FUNCTION = "__GLOBAL_FUNCTION__";

private final Map<String, List<FunctionBuilder>> name2InternalBuiltinBuilders;
private final Map<String, List<FunctionBuilder>> name2BuiltinBuilders;
private final Map<String, Map<String, List<FunctionBuilder>>> name2UdfBuilders;

public FunctionRegistry() {
name2InternalBuiltinBuilders = new ConcurrentHashMap<>();
name2BuiltinBuilders = new ConcurrentHashMap<>();
name2UdfBuilders = new ConcurrentHashMap<>();
registerBuiltinFunctions(name2InternalBuiltinBuilders);
afterRegisterBuiltinFunctions(name2InternalBuiltinBuilders);
registerBuiltinFunctions(name2BuiltinBuilders);
afterRegisterBuiltinFunctions(name2BuiltinBuilders);
}

// this function is used to test.
Expand All @@ -78,12 +78,33 @@ public FunctionBuilder findFunctionBuilder(String name, Object argument) {
}

public Optional<List<FunctionBuilder>> tryGetBuiltinBuilders(String name) {
List<FunctionBuilder> builders = name2InternalBuiltinBuilders.get(name);
return name2InternalBuiltinBuilders.get(name) == null
List<FunctionBuilder> builders = name2BuiltinBuilders.get(name);
return name2BuiltinBuilders.get(name) == null
? Optional.empty()
: Optional.of(ImmutableList.copyOf(builders));
}

public boolean isAggregateFunction(String dbName, String name) {
name = name.toLowerCase();
Class<?> aggClass = org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction.class;
if (StringUtils.isEmpty(dbName)) {
List<FunctionBuilder> functionBuilders = name2BuiltinBuilders.get(name);
for (FunctionBuilder functionBuilder : functionBuilders) {
if (aggClass.isAssignableFrom(functionBuilder.functionClass())) {
return true;
}
}
}

List<FunctionBuilder> udfBuilders = findUdfBuilder(dbName, name);
for (FunctionBuilder udfBuilder : udfBuilders) {
if (aggClass.isAssignableFrom(udfBuilder.functionClass())) {
return true;
}
}
return false;
}

// currently we only find function by name and arity and args' types.
public FunctionBuilder findFunctionBuilder(String dbName, String name, List<?> arguments) {
List<FunctionBuilder> functionBuilders = null;
Expand All @@ -92,11 +113,11 @@ public FunctionBuilder findFunctionBuilder(String dbName, String name, List<?> a

if (StringUtils.isEmpty(dbName)) {
// search internal function only if dbName is empty
functionBuilders = name2InternalBuiltinBuilders.get(name.toLowerCase());
functionBuilders = name2BuiltinBuilders.get(name.toLowerCase());
if (CollectionUtils.isEmpty(functionBuilders) && AggCombinerFunctionBuilder.isAggStateCombinator(name)) {
String nestedName = AggCombinerFunctionBuilder.getNestedName(name);
String combinatorSuffix = AggCombinerFunctionBuilder.getCombinatorSuffix(name);
functionBuilders = name2InternalBuiltinBuilders.get(nestedName.toLowerCase());
functionBuilders = name2BuiltinBuilders.get(nestedName.toLowerCase());
if (functionBuilders != null) {
List<FunctionBuilder> candidateBuilders = Lists.newArrayListWithCapacity(functionBuilders.size());
for (FunctionBuilder functionBuilder : functionBuilders) {
Expand Down Expand Up @@ -199,8 +220,8 @@ public void dropUdf(String dbName, String name, List<DataType> argTypes) {
}
synchronized (name2UdfBuilders) {
Map<String, List<FunctionBuilder>> builders = name2UdfBuilders.getOrDefault(dbName, ImmutableMap.of());
builders.getOrDefault(name, Lists.newArrayList()).removeIf(builder -> ((UdfBuilder) builder).getArgTypes()
.equals(argTypes));
builders.getOrDefault(name, Lists.newArrayList())
.removeIf(builder -> ((UdfBuilder) builder).getArgTypes().equals(argTypes));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@

import org.apache.doris.catalog.Env;
import org.apache.doris.catalog.FunctionRegistry;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.NereidsPlanner;
import org.apache.doris.nereids.StatementContext;
import org.apache.doris.nereids.analyzer.MappingSlot;
import org.apache.doris.nereids.analyzer.Scope;
import org.apache.doris.nereids.analyzer.UnboundFunction;
import org.apache.doris.nereids.analyzer.UnboundOneRowRelation;
import org.apache.doris.nereids.analyzer.UnboundResultSink;
import org.apache.doris.nereids.analyzer.UnboundSlot;
Expand Down Expand Up @@ -351,12 +353,12 @@ private LogicalHaving<Plan> bindHaving(MatchingContext<LogicalHaving<Plan>> ctx)
CascadesContext cascadesContext = ctx.cascadesContext;

// bind slot by child.output first
Scope defaultScope = toScope(cascadesContext, childPlan.getOutput());
Scope childOutput = toScope(cascadesContext, childPlan.getOutput());
// then bind slot by child.children.output
Supplier<Scope> backupScope = Suppliers.memoize(() ->
Supplier<Scope> childChildrenOutput = Suppliers.memoize(() ->
toScope(cascadesContext, PlanUtils.fastGetChildrenOutputs(childPlan.children()))
);
return bindHavingByScopes(having, cascadesContext, defaultScope, backupScope);
return bindHavingByScopes(having, cascadesContext, childOutput, childChildrenOutput);
}

private LogicalHaving<Plan> bindHavingAggregate(
Expand All @@ -365,13 +367,114 @@ private LogicalHaving<Plan> bindHavingAggregate(
Aggregate<Plan> aggregate = having.child();
CascadesContext cascadesContext = ctx.cascadesContext;

// having(aggregate) should bind slot by aggregate.child.output first
Scope defaultScope = toScope(cascadesContext, PlanUtils.fastGetChildrenOutputs(aggregate.children()));
// then bind slot by aggregate.output
Supplier<Scope> backupScope = Suppliers.memoize(() ->
toScope(cascadesContext, aggregate.getOutput())
);
return bindHavingByScopes(ctx.root, ctx.cascadesContext, defaultScope, backupScope);
// keep same behavior as mysql
Supplier<CustomSlotBinderAnalyzer> bindByAggChild = Suppliers.memoize(() -> {
Scope aggChildOutputScope
= toScope(cascadesContext, PlanUtils.fastGetChildrenOutputs(aggregate.children()));
return (analyzer, unboundSlot) -> analyzer.bindSlotByScope(unboundSlot, aggChildOutputScope);
});

Scope aggOutputScope = toScope(cascadesContext, aggregate.getOutput());
Supplier<CustomSlotBinderAnalyzer> bindByGroupByThenAggOutputThenAggChild = Suppliers.memoize(() -> {
List<Expression> groupByExprs = aggregate.getGroupByExpressions();
ImmutableList.Builder<Slot> groupBySlots
= ImmutableList.builderWithExpectedSize(groupByExprs.size());
for (Expression groupBy : groupByExprs) {
if (groupBy instanceof Slot) {
groupBySlots.add((Slot) groupBy);
}
}
Scope groupBySlotsScope = toScope(cascadesContext, groupBySlots.build());

Supplier<Pair<Scope, Scope>> separateAggOutputScopes = Suppliers.memoize(() -> {
ImmutableList.Builder<Slot> groupByOutputs = ImmutableList.builderWithExpectedSize(
aggregate.getOutputExpressions().size());
ImmutableList.Builder<Slot> aggFunOutputs = ImmutableList.builderWithExpectedSize(
aggregate.getOutputExpressions().size());
for (NamedExpression outputExpression : aggregate.getOutputExpressions()) {
if (outputExpression.anyMatch(AggregateFunction.class::isInstance)) {
aggFunOutputs.add(outputExpression.toSlot());
} else {
groupByOutputs.add(outputExpression.toSlot());
}
}
Scope nonAggFunSlotsScope = toScope(cascadesContext, groupByOutputs.build());
Scope aggFuncSlotsScope = toScope(cascadesContext, aggFunOutputs.build());
return Pair.of(nonAggFunSlotsScope, aggFuncSlotsScope);
});

return (analyzer, unboundSlot) -> {
List<Slot> boundInGroupBy = analyzer.bindSlotByScope(unboundSlot, groupBySlotsScope);
if (boundInGroupBy.size() == 1) {
return boundInGroupBy;
}

Pair<Scope, Scope> separateAggOutputScope = separateAggOutputScopes.get();
List<Slot> boundInNonAggFuncs = analyzer.bindSlotByScope(unboundSlot, separateAggOutputScope.first);
if (boundInNonAggFuncs.size() == 1) {
return boundInNonAggFuncs;
}

List<Slot> boundInAggFuncs = analyzer.bindSlotByScope(unboundSlot, separateAggOutputScope.second);
if (boundInAggFuncs.size() == 1) {
return boundInAggFuncs;
}
return analyzer.bindSlotByScope(unboundSlot, aggOutputScope);
};
});

FunctionRegistry functionRegistry = cascadesContext.getConnectContext().getEnv().getFunctionRegistry();
ExpressionAnalyzer havingAnalyzer = new ExpressionAnalyzer(having, aggOutputScope, cascadesContext,
false, true) {
private boolean currentIsInAggregateFunction;

@Override
public Expression visitAggregateFunction(AggregateFunction aggregateFunction,
ExpressionRewriteContext context) {
if (!currentIsInAggregateFunction) {
currentIsInAggregateFunction = true;
try {
return super.visitAggregateFunction(aggregateFunction, context);
} finally {
currentIsInAggregateFunction = false;
}
} else {
return super.visitAggregateFunction(aggregateFunction, context);
}
}

@Override
public Expression visitUnboundFunction(UnboundFunction unboundFunction, ExpressionRewriteContext context) {
if (!currentIsInAggregateFunction && isAggregateFunction(unboundFunction, functionRegistry)) {
currentIsInAggregateFunction = true;
try {
return super.visitUnboundFunction(unboundFunction, context);
} finally {
currentIsInAggregateFunction = false;
}
} else {
return super.visitUnboundFunction(unboundFunction, context);
}
}

@Override
protected List<? extends Expression> bindSlotByThisScope(UnboundSlot unboundSlot) {
if (currentIsInAggregateFunction) {
return bindByAggChild.get().bindSlot(this, unboundSlot);
} else {
return bindByGroupByThenAggOutputThenAggChild.get().bindSlot(this, unboundSlot);
}
}
};

Set<Expression> havingExprs = having.getConjuncts();
ImmutableSet.Builder<Expression> analyzedHaving = ImmutableSet.builderWithExpectedSize(havingExprs.size());
ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(cascadesContext);
for (Expression expression : havingExprs) {
analyzedHaving.add(havingAnalyzer.analyze(expression, rewriteContext));
}

return new LogicalHaving<>(analyzedHaving.build(), having.child());
}

private LogicalHaving<Plan> bindHavingByScopes(
Expand Down Expand Up @@ -764,6 +867,11 @@ private void checkIfOutputAliasNameDuplicatedForGroupBy(Collection<Expression> e
}
}

private boolean isAggregateFunction(UnboundFunction unboundFunction, FunctionRegistry functionRegistry) {
return functionRegistry.isAggregateFunction(
unboundFunction.getDbName(), unboundFunction.getName());
}

private <E extends Expression> E checkBoundExceptLambda(E expression, Plan plan) {
if (expression instanceof Lambda) {
return expression;
Expand Down Expand Up @@ -797,6 +905,12 @@ private SimpleExprAnalyzer buildSimpleExprAnalyzer(
boolean enableExactMatch, boolean bindSlotInOuterScope) {
List<Slot> childrenOutputs = PlanUtils.fastGetChildrenOutputs(children);
Scope scope = toScope(cascadesContext, childrenOutputs);
return buildSimpleExprAnalyzer(currentPlan, cascadesContext, scope, enableExactMatch, bindSlotInOuterScope);
}

private SimpleExprAnalyzer buildSimpleExprAnalyzer(
Plan currentPlan, CascadesContext cascadesContext, Scope scope,
boolean enableExactMatch, boolean bindSlotInOuterScope) {
ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(cascadesContext);
ExpressionAnalyzer expressionAnalyzer = new ExpressionAnalyzer(currentPlan,
scope, cascadesContext, enableExactMatch, bindSlotInOuterScope);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ public AggCombinerFunctionBuilder(String combinatorSuffix, FunctionBuilder neste
this.nestedBuilder = Objects.requireNonNull(nestedBuilder, "nestedBuilder can not be null");
}

@Override
public Class<? extends BoundFunction> functionClass() {
return nestedBuilder.functionClass();
}

@Override
public boolean canApply(List<? extends Object> arguments) {
if (combinatorSuffix.equals(STATE) || combinatorSuffix.equals(FOREACH)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,21 @@ public class BuiltinFunctionBuilder extends FunctionBuilder {

// Concrete BoundFunction's constructor
private final Constructor<BoundFunction> builderMethod;
private final Class<? extends BoundFunction> functionClass;

public BuiltinFunctionBuilder(Constructor<BoundFunction> builderMethod) {
public BuiltinFunctionBuilder(
Class<? extends BoundFunction> functionClass, Constructor<BoundFunction> builderMethod) {
this.functionClass = Objects.requireNonNull(functionClass, "functionClass can not be null");
this.builderMethod = Objects.requireNonNull(builderMethod, "builderMethod can not be null");
this.arity = builderMethod.getParameterCount();
this.isVariableLength = arity > 0 && builderMethod.getParameterTypes()[arity - 1].isArray();
}

@Override
public Class<? extends BoundFunction> functionClass() {
return functionClass;
}

@Override
public boolean canApply(List<? extends Object> arguments) {
if (isVariableLength && arity > arguments.size() + 1) {
Expand Down Expand Up @@ -133,7 +141,7 @@ public static List<FunctionBuilder> resolve(Class<? extends BoundFunction> funct
+ functionClass.getSimpleName());
return Arrays.stream(functionClass.getConstructors())
.filter(constructor -> Modifier.isPublic(constructor.getModifiers()))
.map(constructor -> new BuiltinFunctionBuilder((Constructor<BoundFunction>) constructor))
.map(constructor -> new BuiltinFunctionBuilder(functionClass, (Constructor<BoundFunction>) constructor))
.collect(ImmutableList.toImmutableList());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
* This class used to build BoundFunction(Builtin or Combinator) by a list of Expressions.
*/
public abstract class FunctionBuilder {
public abstract Class<? extends BoundFunction> functionClass();

/** check whether arguments can apply to the constructor */
public abstract boolean canApply(List<? extends Object> arguments);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ public List<DataType> getArgTypes() {
return aliasUdf.getArgTypes();
}

@Override
public Class<? extends BoundFunction> functionClass() {
return AliasUdf.class;
}

@Override
public boolean canApply(List<?> arguments) {
if (arguments.size() != aliasUdf.arity()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ public List<DataType> getArgTypes() {
.collect(Collectors.toList())).get();
}

@Override
public Class<? extends BoundFunction> functionClass() {
return JavaUdaf.class;
}

@Override
public boolean canApply(List<?> arguments) {
if ((isVarArgs && arity > arguments.size() + 1) || (!isVarArgs && arguments.size() != arity)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ public List<DataType> getArgTypes() {
.collect(Collectors.toList())).get();
}

@Override
public Class<? extends BoundFunction> functionClass() {
return JavaUdf.class;
}

@Override
public boolean canApply(List<?> arguments) {
if ((isVarArgs && arity > arguments.size() + 1) || (!isVarArgs && arguments.size() != arity)) {
Expand Down
31 changes: 31 additions & 0 deletions regression-test/data/nereids_syntax_p0/bind_priority.out
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,34 @@ all 2
4 5
6 6

-- !having_bind_child --
1 10

-- !having_bind_child2 --
2 10

-- !having_bind_child3 --
2 10

-- !having_bind_project --
2 10

-- !having_bind_project2 --

-- !having_bind_project3 --

-- !having_bind_project4 --
2 11

-- !having_bind_child4 --
2 11

-- !having_bind_child5 --
2 11

-- !having_bind_agg_fun --

-- !having_bind_agg_fun --
2 4
3 3

Loading

0 comments on commit 086db66

Please sign in to comment.