Skip to content

Commit

Permalink
Implement SQL validation based on grammar element (#3039) (#3044)
Browse files Browse the repository at this point in the history
* Implement SQL validation based on grammar element



* Add function types



* fix style



* Add security lake



* Add File support



* Integrate into SparkQueryDispatcher



* Fix style



* Add tests



* Integration



* Add comments



* Address comments



* Allow join types for now



* Fix style



* Fix coverage check



---------


(cherry picked from commit a87893a)

Signed-off-by: Tomoyuki Morita <[email protected]>
Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
  • Loading branch information
1 parent a20f655 commit 5ed2a28
Show file tree
Hide file tree
Showing 22 changed files with 2,195 additions and 192 deletions.
3 changes: 2 additions & 1 deletion async-query-core/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ jacocoTestCoverageVerification {
'org.opensearch.sql.spark.flint.*',
'org.opensearch.sql.spark.flint.operation.*',
'org.opensearch.sql.spark.rest.*',
'org.opensearch.sql.spark.utils.SQLQueryUtils.*'
'org.opensearch.sql.spark.utils.SQLQueryUtils.*',
'org.opensearch.sql.spark.validator.SQLQueryValidationVisitor'
]
limit {
counter = 'LINE'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
package org.opensearch.sql.spark.dispatcher;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import lombok.AllArgsConstructor;
import org.jetbrains.annotations.NotNull;
Expand All @@ -24,6 +23,7 @@
import org.opensearch.sql.spark.execution.session.SessionManager;
import org.opensearch.sql.spark.rest.model.LangType;
import org.opensearch.sql.spark.utils.SQLQueryUtils;
import org.opensearch.sql.spark.validator.SQLQueryValidator;

/** This class takes care of understanding query and dispatching job query to emr serverless. */
@AllArgsConstructor
Expand All @@ -38,6 +38,7 @@ public class SparkQueryDispatcher {
private final SessionManager sessionManager;
private final QueryHandlerFactory queryHandlerFactory;
private final QueryIdProvider queryIdProvider;
private final SQLQueryValidator sqlQueryValidator;

public DispatchQueryResponse dispatch(
DispatchQueryRequest dispatchQueryRequest,
Expand All @@ -54,13 +55,7 @@ public DispatchQueryResponse dispatch(
dispatchQueryRequest, asyncQueryRequestContext, dataSourceMetadata);
}

List<String> validationErrors =
SQLQueryUtils.validateSparkSqlQuery(
dataSourceService.getDataSource(dispatchQueryRequest.getDatasource()), query);
if (!validationErrors.isEmpty()) {
throw new IllegalArgumentException(
"Query is not allowed: " + String.join(", ", validationErrors));
}
sqlQueryValidator.validate(query, dataSourceMetadata.getConnector());
}
return handleDefaultQuery(dispatchQueryRequest, asyncQueryRequestContext, dataSourceMetadata);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@

package org.opensearch.sql.spark.utils;

import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.Locale;
Expand All @@ -20,8 +18,6 @@
import org.opensearch.sql.common.antlr.CaseInsensitiveCharStream;
import org.opensearch.sql.common.antlr.SyntaxAnalysisErrorListener;
import org.opensearch.sql.common.antlr.SyntaxCheckException;
import org.opensearch.sql.datasource.model.DataSource;
import org.opensearch.sql.datasource.model.DataSourceType;
import org.opensearch.sql.spark.antlr.parser.FlintSparkSqlExtensionsBaseVisitor;
import org.opensearch.sql.spark.antlr.parser.FlintSparkSqlExtensionsLexer;
import org.opensearch.sql.spark.antlr.parser.FlintSparkSqlExtensionsParser;
Expand Down Expand Up @@ -84,71 +80,12 @@ public static boolean isFlintExtensionQuery(String sqlQuery) {
}
}

public static List<String> validateSparkSqlQuery(DataSource datasource, String sqlQuery) {
public static SqlBaseParser getBaseParser(String sqlQuery) {
SqlBaseParser sqlBaseParser =
new SqlBaseParser(
new CommonTokenStream(new SqlBaseLexer(new CaseInsensitiveCharStream(sqlQuery))));
sqlBaseParser.addErrorListener(new SyntaxAnalysisErrorListener());
try {
SqlBaseValidatorVisitor sqlParserBaseVisitor = getSparkSqlValidatorVisitor(datasource);
StatementContext statement = sqlBaseParser.statement();
sqlParserBaseVisitor.visit(statement);
return sqlParserBaseVisitor.getValidationErrors();
} catch (SyntaxCheckException e) {
logger.error(
String.format(
"Failed to parse sql statement context while validating sql query %s", sqlQuery),
e);
return Collections.emptyList();
}
}

private SqlBaseValidatorVisitor getSparkSqlValidatorVisitor(DataSource datasource) {
if (datasource != null
&& datasource.getConnectorType() != null
&& datasource.getConnectorType().equals(DataSourceType.SECURITY_LAKE)) {
return new SparkSqlSecurityLakeValidatorVisitor();
} else {
return new SparkSqlValidatorVisitor();
}
}

/**
* A base class extending SqlBaseParserBaseVisitor for validating Spark Sql Queries. The class
* supports accumulating validation errors on visiting sql statement
*/
@Getter
private static class SqlBaseValidatorVisitor<T> extends SqlBaseParserBaseVisitor<T> {
private final List<String> validationErrors = new ArrayList<>();
}

/** A generic validator impl for Spark Sql Queries */
private static class SparkSqlValidatorVisitor extends SqlBaseValidatorVisitor<Void> {
@Override
public Void visitCreateFunction(SqlBaseParser.CreateFunctionContext ctx) {
getValidationErrors().add("Creating user-defined functions is not allowed");
return super.visitCreateFunction(ctx);
}
}

/** A validator impl specific to Security Lake for Spark Sql Queries */
private static class SparkSqlSecurityLakeValidatorVisitor extends SqlBaseValidatorVisitor<Void> {

public SparkSqlSecurityLakeValidatorVisitor() {
// only select statement allowed. hence we add the validation error to all types of statements
// by default
// and remove the validation error only for select statement.
getValidationErrors()
.add(
"Unsupported sql statement for security lake data source. Only select queries are"
+ " allowed");
}

@Override
public Void visitStatementDefault(SqlBaseParser.StatementDefaultContext ctx) {
getValidationErrors().clear();
return super.visitStatementDefault(ctx);
}
return sqlBaseParser;
}

public static class SparkSqlTableNameVisitor extends SqlBaseParserBaseVisitor<Void> {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.spark.validator;

public class DefaultGrammarElementValidator implements GrammarElementValidator {
@Override
public boolean isValid(GrammarElement element) {
return true;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.spark.validator;

import java.util.Set;
import lombok.RequiredArgsConstructor;

@RequiredArgsConstructor
public class DenyListGrammarElementValidator implements GrammarElementValidator {
private final Set<GrammarElement> denyList;

@Override
public boolean isValid(GrammarElement element) {
return !denyList.contains(element);
}
}
Loading

0 comments on commit 5ed2a28

Please sign in to comment.