Skip to content

Commit

Permalink
Make aggregation statement compilation robust
Browse files Browse the repository at this point in the history
Signed-off-by: Lantao Jin <[email protected]>
  • Loading branch information
LantaoJin committed Jun 27, 2024
1 parent b9f544b commit 3e92818
Show file tree
Hide file tree
Showing 15 changed files with 434 additions and 59 deletions.
67 changes: 63 additions & 4 deletions core/src/main/java/org/opensearch/sql/analysis/Analyzer.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.opensearch.sql.DataSourceSchemaName;
Expand All @@ -40,13 +41,15 @@
import org.opensearch.sql.ast.expression.ParseMethod;
import org.opensearch.sql.ast.expression.QualifiedName;
import org.opensearch.sql.ast.expression.UnresolvedExpression;
import org.opensearch.sql.ast.expression.WindowFunction;
import org.opensearch.sql.ast.tree.AD;
import org.opensearch.sql.ast.tree.Aggregation;
import org.opensearch.sql.ast.tree.CloseCursor;
import org.opensearch.sql.ast.tree.Dedupe;
import org.opensearch.sql.ast.tree.Eval;
import org.opensearch.sql.ast.tree.FetchCursor;
import org.opensearch.sql.ast.tree.Filter;
import org.opensearch.sql.ast.tree.Having;
import org.opensearch.sql.ast.tree.Head;
import org.opensearch.sql.ast.tree.Kmeans;
import org.opensearch.sql.ast.tree.Limit;
Expand Down Expand Up @@ -81,6 +84,7 @@
import org.opensearch.sql.expression.function.FunctionName;
import org.opensearch.sql.expression.function.TableFunctionImplementation;
import org.opensearch.sql.expression.parse.ParseExpression;
import org.opensearch.sql.expression.window.WindowFunctionExpression;
import org.opensearch.sql.planner.logical.LogicalAD;
import org.opensearch.sql.planner.logical.LogicalAggregation;
import org.opensearch.sql.planner.logical.LogicalCloseCursor;
Expand All @@ -102,6 +106,7 @@
import org.opensearch.sql.planner.logical.LogicalValues;
import org.opensearch.sql.planner.physical.datasource.DataSourceTable;
import org.opensearch.sql.storage.Table;
import org.opensearch.sql.utils.ExpressionUtils;
import org.opensearch.sql.utils.ParseUtils;

/**
Expand Down Expand Up @@ -235,6 +240,16 @@ public LogicalPlan visitLimit(Limit node, AnalysisContext context) {
public LogicalPlan visitFilter(Filter node, AnalysisContext context) {
LogicalPlan child = node.getChild().get(0).accept(this, context);
Expression condition = expressionAnalyzer.analyze(node.getCondition(), context);
// Check if the filter condition is a valid predicate.
if (condition.type() != ExprCoreType.BOOLEAN) {
throw QueryCompilationError.nonBooleanExpressionInFilterOrHavingError(condition.type());
}
// Check if any window functions in filter
List<Expression> results = new ArrayList<>();
ExpressionUtils.findExpressions(condition, e -> e instanceof WindowFunctionExpression, results);
if (!results.isEmpty()) {
throw QueryCompilationError.windowFunctionNotAllowedError();
}

ExpressionReferenceOptimizer optimizer =
new ExpressionReferenceOptimizer(expressionAnalyzer.getRepository(), child);
Expand Down Expand Up @@ -291,17 +306,44 @@ public LogicalPlan visitRename(Rename node, AnalysisContext context) {
return new LogicalRename(child, renameMapBuilder.build());
}

/** Build {@link LogicalAggregation}. */
/** Resolve Having clause to merge its aggregators to {@link LogicalAggregation}. */
@Override
public LogicalPlan visitAggregation(Aggregation node, AnalysisContext context) {
final LogicalPlan child = node.getChild().get(0).accept(this, context);
public LogicalPlan visitHaving(Having node, AnalysisContext context) {
LogicalAggregation aggregation =
(LogicalAggregation) node.getChild().get(0).accept(this, context);
if (node.getCondition() instanceof WindowFunction) {
throw QueryCompilationError.windowFunctionNotAllowedError();
}
// Extract aggregator from Having clause
ImmutableList.Builder<NamedAggregator> aggregatorBuilder = new ImmutableList.Builder<>();
for (UnresolvedExpression expr : node.getAggExprList()) {
for (UnresolvedExpression expr : node.getAggregators()) {
NamedExpression aggExpr = namedExpressionAnalyzer.analyze(expr, context);
aggregatorBuilder.add(
new NamedAggregator(aggExpr.getNameOrAlias(), (Aggregator) aggExpr.getDelegated()));
}
List<NamedAggregator> aggregatorListFromHaving = aggregatorBuilder.build();
// new context
context.push();
TypeEnvironment newEnv = context.peek();
aggregatorListFromHaving.forEach(
aggregator ->
newEnv.define(
new Symbol(Namespace.FIELD_NAME, aggregator.getName()), aggregator.type()));

List<NamedAggregator> aggregatorListFromChild = aggregation.getAggregatorList();
List<NamedAggregator> mergedList =
Stream.of(aggregatorListFromChild, aggregatorListFromHaving)
.flatMap(List::stream)
.collect(Collectors.toList());

return new LogicalAggregation(
aggregation.getChild().get(0), mergedList, aggregation.getGroupByList());
}

/** Build {@link LogicalAggregation}. */
@Override
public LogicalPlan visitAggregation(Aggregation node, AnalysisContext context) {
final LogicalPlan child = node.getChild().get(0).accept(this, context);
ImmutableList.Builder<NamedExpression> groupbyBuilder = new ImmutableList.Builder<>();
// Span should be first expression if exist.
if (node.getSpan() != null) {
Expand All @@ -310,12 +352,29 @@ public LogicalPlan visitAggregation(Aggregation node, AnalysisContext context) {

for (UnresolvedExpression expr : node.getGroupExprList()) {
NamedExpression resolvedExpr = namedExpressionAnalyzer.analyze(expr, context);
if (resolvedExpr.getDelegated() instanceof Aggregator) {
throw QueryCompilationError.aggregateFunctionNotAllowedInGroupByError(
((Aggregator<?>) resolvedExpr.getDelegated()).getFunctionName().getFunctionName());
}
verifySupportsCondition(resolvedExpr.getDelegated());
groupbyBuilder.add(resolvedExpr);
}
ImmutableList<NamedExpression> groupBys = groupbyBuilder.build();

ImmutableList.Builder<NamedAggregator> aggregatorBuilder = new ImmutableList.Builder<>();
for (UnresolvedExpression expr : node.getAggExprList()) {
NamedExpression aggExpr = namedExpressionAnalyzer.analyze(expr, context);
if (aggExpr.getDelegated() instanceof Aggregator) {
aggregatorBuilder.add(
new NamedAggregator(aggExpr.getNameOrAlias(), (Aggregator) aggExpr.getDelegated()));
} else if (aggExpr.getDelegated() instanceof ReferenceExpression) {
if (groupBys.stream().noneMatch(k -> k.getName().equalsIgnoreCase(aggExpr.getName()))) {
throw QueryCompilationError.fieldNotInGroupByClauseError(aggExpr);
}
}
}
ImmutableList<NamedAggregator> aggregators = aggregatorBuilder.build();

// new context
context.push();
TypeEnvironment newEnv = context.peek();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@
import org.opensearch.sql.expression.parse.ParseExpression;
import org.opensearch.sql.expression.span.SpanExpression;
import org.opensearch.sql.expression.window.aggregation.AggregateWindowFunction;
import org.opensearch.sql.expression.window.ranking.RankingWindowFunction;
import org.opensearch.sql.utils.ExpressionUtils;

/**
* Analyze the {@link UnresolvedExpression} in the {@link AnalysisContext} to construct the {@link
Expand Down Expand Up @@ -169,11 +171,23 @@ public Expression visitAggregateFunction(AggregateFunction node, AnalysisContext
builder.build());
aggregator.distinct(node.getDistinct());
if (node.condition() != null) {
aggregator.condition(analyze(node.condition(), context));
// Check if the filter condition is a valid predicate.
Expression predicate = node.condition().accept(this, context);
if (predicate.type() != ExprCoreType.BOOLEAN) {
throw QueryCompilationError.nonBooleanExpressionInFilterOrHavingError(predicate.type());
}
// Check if any aggregate function in filter
List<Expression> results = new ArrayList<>();
ExpressionUtils.findExpressions(predicate, e -> e instanceof Aggregator, results);
if (!results.isEmpty()) {
throw QueryCompilationError.aggregateFunctionNotAllowedInFilterError(
((Aggregator) results.get(0)).getFunctionName().getFunctionName());
}
aggregator.condition(predicate);
}
return aggregator;
} else {
throw new SemanticCheckException("Unsupported aggregation function " + node.getFuncName());
throw QueryCompilationError.unsupportedAggregateFunctionError(node.getFuncName());
}
}

Expand Down Expand Up @@ -211,6 +225,10 @@ public Expression visitWindowFunction(WindowFunction node, AnalysisContext conte
if (expr instanceof Aggregator) {
return new AggregateWindowFunction((Aggregator<AggregationState>) expr);
}
if (expr instanceof RankingWindowFunction && node.getSortList().isEmpty()) {
throw QueryCompilationError.rankingWindowFunctionMissesOrderClauseError(
((RankingWindowFunction) expr).getFunctionName().getFunctionName());
}
return expr;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.analysis;

import static org.opensearch.sql.common.utils.StringUtils.format;

import org.opensearch.sql.ast.expression.UnresolvedExpression;
import org.opensearch.sql.data.type.ExprType;
import org.opensearch.sql.exception.SemanticCheckException;
import org.opensearch.sql.expression.NamedExpression;

/** Grouping error messages from {@link SemanticCheckException} thrown during query compilation. */
public class QueryCompilationError {

public static SemanticCheckException fieldNotInGroupByClauseError(NamedExpression expr) {
return new SemanticCheckException(
format(
"Field [%s] must appear in the GROUP BY clause or be used in an aggregate function",
expr.getName()));
}

public static SemanticCheckException groupByClauseIsMissingError(UnresolvedExpression expr) {
return new SemanticCheckException(
format(
"Explicit GROUP BY clause is required because expression [%s] contains non-aggregated"
+ " column",
expr));
}

public static SemanticCheckException aggregateFunctionNotAllowedInGroupByError(
String functionName) {
return new SemanticCheckException(
format(
"Aggregate function is not allowed in a GROUP BY clause, but found [%s]",
functionName));
}

public static SemanticCheckException nonBooleanExpressionInFilterOrHavingError(ExprType type) {
return new SemanticCheckException(
format(
"FILTER or HAVING expression must be type boolean, but found [%s]", type.typeName()));
}

public static SemanticCheckException groupByOrdinalRefersToAggregateFunctionError(int ordinal) {
return new SemanticCheckException(
format(
"GROUP BY %s refers to an expression that contains an aggregate function. Aggregate"
+ " functions are not allowed in GROUP BY",
ordinal));
}

public static SemanticCheckException ordinalRefersOutOfBounds(int ordinal) {
return new SemanticCheckException(
format("Ordinal [%d] is out of bound of select item list", ordinal));
}

public static SemanticCheckException aggregateFunctionNotAllowedInFilterError(
String functionName) {
return new SemanticCheckException(
format("Aggregate function is not allowed in a FILTER, but found [%s]", functionName));
}

public static SemanticCheckException windowFunctionNotAllowedError() {
return new SemanticCheckException("window functions are not allowed in WHERE or HAVING");
}

public static SemanticCheckException unsupportedAggregateFunctionError(String functionName) {
return new SemanticCheckException(format("Unsupported aggregation function %s", functionName));
}

public static SemanticCheckException rankingWindowFunctionMissesOrderClauseError(
String functionName) {
return new SemanticCheckException(
format(
"Window function [%s] requires window to be ordered, please add ORDER BY clause.",
functionName));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import org.opensearch.sql.ast.tree.Eval;
import org.opensearch.sql.ast.tree.FetchCursor;
import org.opensearch.sql.ast.tree.Filter;
import org.opensearch.sql.ast.tree.Having;
import org.opensearch.sql.ast.tree.Head;
import org.opensearch.sql.ast.tree.Kmeans;
import org.opensearch.sql.ast.tree.Limit;
Expand Down Expand Up @@ -312,4 +313,8 @@ public T visitFetchCursor(FetchCursor cursor, C context) {
public T visitCloseCursor(CloseCursor closeCursor, C context) {
return visitChildren(closeCursor, context);
}

public T visitHaving(Having having, C context) {
return visitChildren(having, context);
}
}
8 changes: 8 additions & 0 deletions core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import org.opensearch.sql.ast.tree.Dedupe;
import org.opensearch.sql.ast.tree.Eval;
import org.opensearch.sql.ast.tree.Filter;
import org.opensearch.sql.ast.tree.Having;
import org.opensearch.sql.ast.tree.Head;
import org.opensearch.sql.ast.tree.Limit;
import org.opensearch.sql.ast.tree.Parse;
Expand Down Expand Up @@ -471,4 +472,11 @@ public static Parse parse(
java.util.Map<String, Literal> arguments) {
return new Parse(parseMethod, sourceField, pattern, arguments, input);
}

public static UnresolvedPlan having(
UnresolvedPlan input,
List<UnresolvedExpression> aggregators,
UnresolvedExpression condition) {
return new Having(aggregators, condition).attach(input);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,6 @@ public <R, C> R accept(AbstractNodeVisitor<R, C> nodeVisitor, C context) {

@Override
public String toString() {
return StringUtils.format("%s(%s)", funcName, field);
return StringUtils.format("%s(%s%s)", funcName, distinct ? "DISTINCT " : "", field);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ public class Aggregation extends UnresolvedPlan {

/** Aggregation Constructor without span and argument. */
public Aggregation(
// In unresolved logical plan, the aggExprList not only includes AggregatorFunctions,
// but also includes select expressions.
// Those invalid expressions will be erased when it is resolving to resolved plan.
// As a result, only aggregator functions will be converted to NamedAggregator.
List<UnresolvedExpression> aggExprList,
List<UnresolvedExpression> sortExprList,
List<UnresolvedExpression> groupExprList) {
Expand Down
50 changes: 50 additions & 0 deletions core/src/main/java/org/opensearch/sql/ast/tree/Having.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.ast.tree;

import com.google.common.collect.ImmutableList;
import java.util.List;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.Setter;
import lombok.ToString;
import org.opensearch.sql.ast.AbstractNodeVisitor;
import org.opensearch.sql.ast.expression.UnresolvedExpression;

/**
* Represents unresolved HAVING clause, its child can be Aggregation. Having without aggregation
* equals to {@link Filter}
*/
@Getter
@Setter
@ToString
@EqualsAndHashCode(callSuper = false)
public class Having extends UnresolvedPlan {
private List<UnresolvedExpression> aggregators;
private UnresolvedExpression condition;
private UnresolvedPlan child;

public Having(List<UnresolvedExpression> aggregators, UnresolvedExpression condition) {
this.aggregators = aggregators;
this.condition = condition;
}

@Override
public UnresolvedPlan attach(UnresolvedPlan child) {
this.child = child;
return this;
}

@Override
public List<UnresolvedPlan> getChild() {
return ImmutableList.of(child);
}

@Override
public <T, C> T accept(AbstractNodeVisitor<T, C> nodeVisitor, C context) {
return nodeVisitor.visitHaving(this, context);
}
}
Loading

0 comments on commit 3e92818

Please sign in to comment.