Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement SQL validation based on grammar element #3039

Merged
merged 14 commits into from
Sep 23, 2024
3 changes: 2 additions & 1 deletion async-query-core/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ jacocoTestCoverageVerification {
'org.opensearch.sql.spark.flint.*',
'org.opensearch.sql.spark.flint.operation.*',
'org.opensearch.sql.spark.rest.*',
'org.opensearch.sql.spark.utils.SQLQueryUtils.*'
'org.opensearch.sql.spark.utils.SQLQueryUtils.*',
'org.opensearch.sql.spark.validator.SQLQueryValidationVisitor'
]
limit {
counter = 'LINE'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
package org.opensearch.sql.spark.dispatcher;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import lombok.AllArgsConstructor;
import org.jetbrains.annotations.NotNull;
Expand All @@ -24,6 +23,7 @@
import org.opensearch.sql.spark.execution.session.SessionManager;
import org.opensearch.sql.spark.rest.model.LangType;
import org.opensearch.sql.spark.utils.SQLQueryUtils;
import org.opensearch.sql.spark.validator.SQLQueryValidator;

/** This class takes care of understanding query and dispatching job query to emr serverless. */
@AllArgsConstructor
Expand All @@ -38,6 +38,7 @@ public class SparkQueryDispatcher {
private final SessionManager sessionManager;
private final QueryHandlerFactory queryHandlerFactory;
private final QueryIdProvider queryIdProvider;
private final SQLQueryValidator sqlQueryValidator;

public DispatchQueryResponse dispatch(
DispatchQueryRequest dispatchQueryRequest,
Expand All @@ -54,13 +55,7 @@ public DispatchQueryResponse dispatch(
dispatchQueryRequest, asyncQueryRequestContext, dataSourceMetadata);
}

List<String> validationErrors =
SQLQueryUtils.validateSparkSqlQuery(
dataSourceService.getDataSource(dispatchQueryRequest.getDatasource()), query);
if (!validationErrors.isEmpty()) {
throw new IllegalArgumentException(
"Query is not allowed: " + String.join(", ", validationErrors));
}
sqlQueryValidator.validate(query, dataSourceMetadata.getConnector());
}
return handleDefaultQuery(dispatchQueryRequest, asyncQueryRequestContext, dataSourceMetadata);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@

package org.opensearch.sql.spark.utils;

import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.Locale;
Expand All @@ -20,8 +18,6 @@
import org.opensearch.sql.common.antlr.CaseInsensitiveCharStream;
import org.opensearch.sql.common.antlr.SyntaxAnalysisErrorListener;
import org.opensearch.sql.common.antlr.SyntaxCheckException;
import org.opensearch.sql.datasource.model.DataSource;
import org.opensearch.sql.datasource.model.DataSourceType;
import org.opensearch.sql.spark.antlr.parser.FlintSparkSqlExtensionsBaseVisitor;
import org.opensearch.sql.spark.antlr.parser.FlintSparkSqlExtensionsLexer;
import org.opensearch.sql.spark.antlr.parser.FlintSparkSqlExtensionsParser;
Expand Down Expand Up @@ -84,71 +80,12 @@ public static boolean isFlintExtensionQuery(String sqlQuery) {
}
}

public static List<String> validateSparkSqlQuery(DataSource datasource, String sqlQuery) {
public static SqlBaseParser getBaseParser(String sqlQuery) {
SqlBaseParser sqlBaseParser =
new SqlBaseParser(
new CommonTokenStream(new SqlBaseLexer(new CaseInsensitiveCharStream(sqlQuery))));
sqlBaseParser.addErrorListener(new SyntaxAnalysisErrorListener());
try {
SqlBaseValidatorVisitor sqlParserBaseVisitor = getSparkSqlValidatorVisitor(datasource);
StatementContext statement = sqlBaseParser.statement();
sqlParserBaseVisitor.visit(statement);
return sqlParserBaseVisitor.getValidationErrors();
} catch (SyntaxCheckException e) {
logger.error(
String.format(
"Failed to parse sql statement context while validating sql query %s", sqlQuery),
e);
return Collections.emptyList();
}
}

private SqlBaseValidatorVisitor getSparkSqlValidatorVisitor(DataSource datasource) {
if (datasource != null
&& datasource.getConnectorType() != null
&& datasource.getConnectorType().equals(DataSourceType.SECURITY_LAKE)) {
return new SparkSqlSecurityLakeValidatorVisitor();
} else {
return new SparkSqlValidatorVisitor();
}
}

/**
* A base class extending SqlBaseParserBaseVisitor for validating Spark Sql Queries. The class
* supports accumulating validation errors on visiting sql statement
*/
@Getter
private static class SqlBaseValidatorVisitor<T> extends SqlBaseParserBaseVisitor<T> {
private final List<String> validationErrors = new ArrayList<>();
}

/** A generic validator impl for Spark Sql Queries */
private static class SparkSqlValidatorVisitor extends SqlBaseValidatorVisitor<Void> {
@Override
public Void visitCreateFunction(SqlBaseParser.CreateFunctionContext ctx) {
getValidationErrors().add("Creating user-defined functions is not allowed");
return super.visitCreateFunction(ctx);
}
}

/** A validator impl specific to Security Lake for Spark Sql Queries */
private static class SparkSqlSecurityLakeValidatorVisitor extends SqlBaseValidatorVisitor<Void> {

public SparkSqlSecurityLakeValidatorVisitor() {
// only select statement allowed. hence we add the validation error to all types of statements
// by default
// and remove the validation error only for select statement.
getValidationErrors()
.add(
"Unsupported sql statement for security lake data source. Only select queries are"
+ " allowed");
}

@Override
public Void visitStatementDefault(SqlBaseParser.StatementDefaultContext ctx) {
getValidationErrors().clear();
return super.visitStatementDefault(ctx);
}
return sqlBaseParser;
}

public static class SparkSqlTableNameVisitor extends SqlBaseParserBaseVisitor<Void> {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.spark.validator;

public class DefaultGrammarElementValidator implements GrammarElementValidator {
@Override
public boolean isValid(GrammarElement element) {
return true;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.spark.validator;

import java.util.Set;
import lombok.RequiredArgsConstructor;

@RequiredArgsConstructor
public class DenyListGrammarElementValidator implements GrammarElementValidator {
private final Set<GrammarElement> denyList;

@Override
public boolean isValid(GrammarElement element) {
return !denyList.contains(element);
}
}
Loading
Loading