From 5ed2a28f60a27c0282c4a42b270deafa27c4c25a Mon Sep 17 00:00:00 2001 From: "opensearch-trigger-bot[bot]" <98922864+opensearch-trigger-bot[bot]@users.noreply.github.com> Date: Wed, 25 Sep 2024 20:55:02 -0700 Subject: [PATCH] Implement SQL validation based on grammar element (#3039) (#3044) * Implement SQL validation based on grammar element * Add function types * fix style * Add security lake * Add File support * Integrate into SparkQueryDispatcher * Fix style * Add tests * Integration * Add comments * Address comments * Allow join types for now * Fix style * Fix coverage check --------- (cherry picked from commit a87893ac6771cb8739b82cacf8721d0dd7d1cbe3) Signed-off-by: Tomoyuki Morita Signed-off-by: github-actions[bot] Co-authored-by: github-actions[bot] --- async-query-core/build.gradle | 3 +- .../dispatcher/SparkQueryDispatcher.java | 11 +- .../sql/spark/utils/SQLQueryUtils.java | 67 +- .../DefaultGrammarElementValidator.java | 13 + .../DenyListGrammarElementValidator.java | 19 + .../sql/spark/validator/FunctionType.java | 436 +++++++++++++ .../sql/spark/validator/GrammarElement.java | 89 +++ .../validator/GrammarElementValidator.java | 15 + .../GrammarElementValidatorProvider.java | 22 + .../S3GlueGrammarElementValidator.java | 71 ++ .../validator/SQLQueryValidationVisitor.java | 609 ++++++++++++++++++ .../spark/validator/SQLQueryValidator.java | 39 ++ .../SecurityLakeGrammarElementValidator.java | 113 ++++ .../asyncquery/AsyncQueryCoreIntegTest.java | 15 +- .../dispatcher/SparkQueryDispatcherTest.java | 39 +- .../sql/spark/utils/SQLQueryUtilsTest.java | 102 --- .../sql/spark/validator/FunctionTypeTest.java | 47 ++ .../GrammarElementValidatorProviderTest.java | 39 ++ .../validator/SQLQueryValidatorTest.java | 600 +++++++++++++++++ .../config/AsyncExecutorServiceModule.java | 24 +- ...AsyncQueryExecutorServiceImplSpecTest.java | 2 +- .../AsyncQueryExecutorServiceSpec.java | 12 +- 22 files changed, 2195 insertions(+), 192 deletions(-) create mode 100644 async-query-core/src/main/java/org/opensearch/sql/spark/validator/DefaultGrammarElementValidator.java create mode 100644 async-query-core/src/main/java/org/opensearch/sql/spark/validator/DenyListGrammarElementValidator.java create mode 100644 async-query-core/src/main/java/org/opensearch/sql/spark/validator/FunctionType.java create mode 100644 async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElement.java create mode 100644 async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidator.java create mode 100644 async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidatorProvider.java create mode 100644 async-query-core/src/main/java/org/opensearch/sql/spark/validator/S3GlueGrammarElementValidator.java create mode 100644 async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidationVisitor.java create mode 100644 async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidator.java create mode 100644 async-query-core/src/main/java/org/opensearch/sql/spark/validator/SecurityLakeGrammarElementValidator.java create mode 100644 async-query-core/src/test/java/org/opensearch/sql/spark/validator/FunctionTypeTest.java create mode 100644 async-query-core/src/test/java/org/opensearch/sql/spark/validator/GrammarElementValidatorProviderTest.java create mode 100644 async-query-core/src/test/java/org/opensearch/sql/spark/validator/SQLQueryValidatorTest.java diff --git a/async-query-core/build.gradle b/async-query-core/build.gradle index 1de6cb3105..deba81735d 100644 --- a/async-query-core/build.gradle +++ b/async-query-core/build.gradle @@ -122,7 +122,8 @@ jacocoTestCoverageVerification { 'org.opensearch.sql.spark.flint.*', 'org.opensearch.sql.spark.flint.operation.*', 'org.opensearch.sql.spark.rest.*', - 'org.opensearch.sql.spark.utils.SQLQueryUtils.*' + 'org.opensearch.sql.spark.utils.SQLQueryUtils.*', + 'org.opensearch.sql.spark.validator.SQLQueryValidationVisitor' ] limit { counter = 'LINE' diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java index 732f5f71ab..ff8c8d1fe8 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java @@ -6,7 +6,6 @@ package org.opensearch.sql.spark.dispatcher; import java.util.HashMap; -import java.util.List; import java.util.Map; import lombok.AllArgsConstructor; import org.jetbrains.annotations.NotNull; @@ -24,6 +23,7 @@ import org.opensearch.sql.spark.execution.session.SessionManager; import org.opensearch.sql.spark.rest.model.LangType; import org.opensearch.sql.spark.utils.SQLQueryUtils; +import org.opensearch.sql.spark.validator.SQLQueryValidator; /** This class takes care of understanding query and dispatching job query to emr serverless. */ @AllArgsConstructor @@ -38,6 +38,7 @@ public class SparkQueryDispatcher { private final SessionManager sessionManager; private final QueryHandlerFactory queryHandlerFactory; private final QueryIdProvider queryIdProvider; + private final SQLQueryValidator sqlQueryValidator; public DispatchQueryResponse dispatch( DispatchQueryRequest dispatchQueryRequest, @@ -54,13 +55,7 @@ public DispatchQueryResponse dispatch( dispatchQueryRequest, asyncQueryRequestContext, dataSourceMetadata); } - List validationErrors = - SQLQueryUtils.validateSparkSqlQuery( - dataSourceService.getDataSource(dispatchQueryRequest.getDatasource()), query); - if (!validationErrors.isEmpty()) { - throw new IllegalArgumentException( - "Query is not allowed: " + String.join(", ", validationErrors)); - } + sqlQueryValidator.validate(query, dataSourceMetadata.getConnector()); } return handleDefaultQuery(dispatchQueryRequest, asyncQueryRequestContext, dataSourceMetadata); } diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java b/async-query-core/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java index 7550de2f1e..3ba9c23ed7 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java @@ -5,8 +5,6 @@ package org.opensearch.sql.spark.utils; -import java.util.ArrayList; -import java.util.Collections; import java.util.LinkedList; import java.util.List; import java.util.Locale; @@ -20,8 +18,6 @@ import org.opensearch.sql.common.antlr.CaseInsensitiveCharStream; import org.opensearch.sql.common.antlr.SyntaxAnalysisErrorListener; import org.opensearch.sql.common.antlr.SyntaxCheckException; -import org.opensearch.sql.datasource.model.DataSource; -import org.opensearch.sql.datasource.model.DataSourceType; import org.opensearch.sql.spark.antlr.parser.FlintSparkSqlExtensionsBaseVisitor; import org.opensearch.sql.spark.antlr.parser.FlintSparkSqlExtensionsLexer; import org.opensearch.sql.spark.antlr.parser.FlintSparkSqlExtensionsParser; @@ -84,71 +80,12 @@ public static boolean isFlintExtensionQuery(String sqlQuery) { } } - public static List validateSparkSqlQuery(DataSource datasource, String sqlQuery) { + public static SqlBaseParser getBaseParser(String sqlQuery) { SqlBaseParser sqlBaseParser = new SqlBaseParser( new CommonTokenStream(new SqlBaseLexer(new CaseInsensitiveCharStream(sqlQuery)))); sqlBaseParser.addErrorListener(new SyntaxAnalysisErrorListener()); - try { - SqlBaseValidatorVisitor sqlParserBaseVisitor = getSparkSqlValidatorVisitor(datasource); - StatementContext statement = sqlBaseParser.statement(); - sqlParserBaseVisitor.visit(statement); - return sqlParserBaseVisitor.getValidationErrors(); - } catch (SyntaxCheckException e) { - logger.error( - String.format( - "Failed to parse sql statement context while validating sql query %s", sqlQuery), - e); - return Collections.emptyList(); - } - } - - private SqlBaseValidatorVisitor getSparkSqlValidatorVisitor(DataSource datasource) { - if (datasource != null - && datasource.getConnectorType() != null - && datasource.getConnectorType().equals(DataSourceType.SECURITY_LAKE)) { - return new SparkSqlSecurityLakeValidatorVisitor(); - } else { - return new SparkSqlValidatorVisitor(); - } - } - - /** - * A base class extending SqlBaseParserBaseVisitor for validating Spark Sql Queries. The class - * supports accumulating validation errors on visiting sql statement - */ - @Getter - private static class SqlBaseValidatorVisitor extends SqlBaseParserBaseVisitor { - private final List validationErrors = new ArrayList<>(); - } - - /** A generic validator impl for Spark Sql Queries */ - private static class SparkSqlValidatorVisitor extends SqlBaseValidatorVisitor { - @Override - public Void visitCreateFunction(SqlBaseParser.CreateFunctionContext ctx) { - getValidationErrors().add("Creating user-defined functions is not allowed"); - return super.visitCreateFunction(ctx); - } - } - - /** A validator impl specific to Security Lake for Spark Sql Queries */ - private static class SparkSqlSecurityLakeValidatorVisitor extends SqlBaseValidatorVisitor { - - public SparkSqlSecurityLakeValidatorVisitor() { - // only select statement allowed. hence we add the validation error to all types of statements - // by default - // and remove the validation error only for select statement. - getValidationErrors() - .add( - "Unsupported sql statement for security lake data source. Only select queries are" - + " allowed"); - } - - @Override - public Void visitStatementDefault(SqlBaseParser.StatementDefaultContext ctx) { - getValidationErrors().clear(); - return super.visitStatementDefault(ctx); - } + return sqlBaseParser; } public static class SparkSqlTableNameVisitor extends SqlBaseParserBaseVisitor { diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/DefaultGrammarElementValidator.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/DefaultGrammarElementValidator.java new file mode 100644 index 0000000000..ddd0a1d094 --- /dev/null +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/DefaultGrammarElementValidator.java @@ -0,0 +1,13 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.validator; + +public class DefaultGrammarElementValidator implements GrammarElementValidator { + @Override + public boolean isValid(GrammarElement element) { + return true; + } +} diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/DenyListGrammarElementValidator.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/DenyListGrammarElementValidator.java new file mode 100644 index 0000000000..514e2c8ad8 --- /dev/null +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/DenyListGrammarElementValidator.java @@ -0,0 +1,19 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.validator; + +import java.util.Set; +import lombok.RequiredArgsConstructor; + +@RequiredArgsConstructor +public class DenyListGrammarElementValidator implements GrammarElementValidator { + private final Set denyList; + + @Override + public boolean isValid(GrammarElement element) { + return !denyList.contains(element); + } +} diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/FunctionType.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/FunctionType.java new file mode 100644 index 0000000000..da3760efd6 --- /dev/null +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/FunctionType.java @@ -0,0 +1,436 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.validator; + +import com.google.common.collect.ImmutableMap; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; +import lombok.AllArgsConstructor; + +/** + * Enum for defining and looking up SQL function type based on its name. Unknown one will be + * considered as UDF (User Defined Function) + */ +@AllArgsConstructor +public enum FunctionType { + AGGREGATE("Aggregate"), + WINDOW("Window"), + ARRAY("Array"), + MAP("Map"), + DATE_TIMESTAMP("Date and Timestamp"), + JSON("JSON"), + MATH("Math"), + STRING("String"), + CONDITIONAL("Conditional"), + BITWISE("Bitwise"), + CONVERSION("Conversion"), + PREDICATE("Predicate"), + CSV("CSV"), + MISC("Misc"), + GENERATOR("Generator"), + UDF("User Defined Function"); + + private final String name; + + private static final Map> FUNCTION_TYPE_TO_FUNCTION_NAMES_MAP = + ImmutableMap.>builder() + .put( + AGGREGATE, + Set.of( + "any", + "any_value", + "approx_count_distinct", + "approx_percentile", + "array_agg", + "avg", + "bit_and", + "bit_or", + "bit_xor", + "bitmap_construct_agg", + "bitmap_or_agg", + "bool_and", + "bool_or", + "collect_list", + "collect_set", + "corr", + "count", + "count_if", + "count_min_sketch", + "covar_pop", + "covar_samp", + "every", + "first", + "first_value", + "grouping", + "grouping_id", + "histogram_numeric", + "hll_sketch_agg", + "hll_union_agg", + "kurtosis", + "last", + "last_value", + "max", + "max_by", + "mean", + "median", + "min", + "min_by", + "mode", + "percentile", + "percentile_approx", + "regr_avgx", + "regr_avgy", + "regr_count", + "regr_intercept", + "regr_r2", + "regr_slope", + "regr_sxx", + "regr_sxy", + "regr_syy", + "skewness", + "some", + "std", + "stddev", + "stddev_pop", + "stddev_samp", + "sum", + "try_avg", + "try_sum", + "var_pop", + "var_samp", + "variance")) + .put( + WINDOW, + Set.of( + "cume_dist", + "dense_rank", + "lag", + "lead", + "nth_value", + "ntile", + "percent_rank", + "rank", + "row_number")) + .put( + ARRAY, + Set.of( + "array", + "array_append", + "array_compact", + "array_contains", + "array_distinct", + "array_except", + "array_insert", + "array_intersect", + "array_join", + "array_max", + "array_min", + "array_position", + "array_prepend", + "array_remove", + "array_repeat", + "array_union", + "arrays_overlap", + "arrays_zip", + "flatten", + "get", + "sequence", + "shuffle", + "slice", + "sort_array")) + .put( + MAP, + Set.of( + "element_at", + "map", + "map_concat", + "map_contains_key", + "map_entries", + "map_from_arrays", + "map_from_entries", + "map_keys", + "map_values", + "str_to_map", + "try_element_at")) + .put( + DATE_TIMESTAMP, + Set.of( + "add_months", + "convert_timezone", + "curdate", + "current_date", + "current_timestamp", + "current_timezone", + "date_add", + "date_diff", + "date_format", + "date_from_unix_date", + "date_part", + "date_sub", + "date_trunc", + "dateadd", + "datediff", + "datepart", + "day", + "dayofmonth", + "dayofweek", + "dayofyear", + "extract", + "from_unixtime", + "from_utc_timestamp", + "hour", + "last_day", + "localtimestamp", + "make_date", + "make_dt_interval", + "make_interval", + "make_timestamp", + "make_timestamp_ltz", + "make_timestamp_ntz", + "make_ym_interval", + "minute", + "month", + "months_between", + "next_day", + "now", + "quarter", + "second", + "session_window", + "timestamp_micros", + "timestamp_millis", + "timestamp_seconds", + "to_date", + "to_timestamp", + "to_timestamp_ltz", + "to_timestamp_ntz", + "to_unix_timestamp", + "to_utc_timestamp", + "trunc", + "try_to_timestamp", + "unix_date", + "unix_micros", + "unix_millis", + "unix_seconds", + "unix_timestamp", + "weekday", + "weekofyear", + "window", + "window_time", + "year")) + .put( + JSON, + Set.of( + "from_json", + "get_json_object", + "json_array_length", + "json_object_keys", + "json_tuple", + "schema_of_json", + "to_json")) + .put( + MATH, + Set.of( + "abs", + "acos", + "acosh", + "asin", + "asinh", + "atan", + "atan2", + "atanh", + "bin", + "bround", + "cbrt", + "ceil", + "ceiling", + "conv", + "cos", + "cosh", + "cot", + "csc", + "degrees", + "e", + "exp", + "expm1", + "factorial", + "floor", + "greatest", + "hex", + "hypot", + "least", + "ln", + "log", + "log10", + "log1p", + "log2", + "negative", + "pi", + "pmod", + "positive", + "pow", + "power", + "radians", + "rand", + "randn", + "random", + "rint", + "round", + "sec", + "shiftleft", + "sign", + "signum", + "sin", + "sinh", + "sqrt", + "tan", + "tanh", + "try_add", + "try_divide", + "try_multiply", + "try_subtract", + "unhex", + "width_bucket")) + .put( + STRING, + Set.of( + "ascii", + "base64", + "bit_length", + "btrim", + "char", + "char_length", + "character_length", + "chr", + "concat", + "concat_ws", + "contains", + "decode", + "elt", + "encode", + "endswith", + "find_in_set", + "format_number", + "format_string", + "initcap", + "instr", + "lcase", + "left", + "len", + "length", + "levenshtein", + "locate", + "lower", + "lpad", + "ltrim", + "luhn_check", + "mask", + "octet_length", + "overlay", + "position", + "printf", + "regexp_count", + "regexp_extract", + "regexp_extract_all", + "regexp_instr", + "regexp_replace", + "regexp_substr", + "repeat", + "replace", + "right", + "rpad", + "rtrim", + "sentences", + "soundex", + "space", + "split", + "split_part", + "startswith", + "substr", + "substring", + "substring_index", + "to_binary", + "to_char", + "to_number", + "to_varchar", + "translate", + "trim", + "try_to_binary", + "try_to_number", + "ucase", + "unbase64", + "upper")) + .put(CONDITIONAL, Set.of("coalesce", "if", "ifnull", "nanvl", "nullif", "nvl", "nvl2")) + .put( + BITWISE, Set.of("bit_count", "bit_get", "getbit", "shiftright", "shiftrightunsigned")) + .put( + CONVERSION, + Set.of( + "bigint", + "binary", + "boolean", + "cast", + "date", + "decimal", + "double", + "float", + "int", + "smallint", + "string", + "timestamp", + "tinyint")) + .put(PREDICATE, Set.of("isnan", "isnotnull", "isnull", "regexp", "regexp_like", "rlike")) + .put(CSV, Set.of("from_csv", "schema_of_csv", "to_csv")) + .put( + MISC, + Set.of( + "aes_decrypt", + "aes_encrypt", + "assert_true", + "bitmap_bit_position", + "bitmap_bucket_number", + "bitmap_count", + "current_catalog", + "current_database", + "current_schema", + "current_user", + "equal_null", + "hll_sketch_estimate", + "hll_union", + "input_file_block_length", + "input_file_block_start", + "input_file_name", + "java_method", + "monotonically_increasing_id", + "reflect", + "spark_partition_id", + "try_aes_decrypt", + "typeof", + "user", + "uuid", + "version")) + .put( + GENERATOR, + Set.of( + "explode", + "explode_outer", + "inline", + "inline_outer", + "posexplode", + "posexplode_outer", + "stack")) + .build(); + + private static final Map FUNCTION_NAME_TO_FUNCTION_TYPE_MAP = + FUNCTION_TYPE_TO_FUNCTION_NAMES_MAP.entrySet().stream() + .flatMap( + entry -> entry.getValue().stream().map(value -> Map.entry(value, entry.getKey()))) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + + public static FunctionType fromFunctionName(String functionName) { + return FUNCTION_NAME_TO_FUNCTION_TYPE_MAP.getOrDefault(functionName.toLowerCase(), UDF); + } +} diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElement.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElement.java new file mode 100644 index 0000000000..217640bada --- /dev/null +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElement.java @@ -0,0 +1,89 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.validator; + +import lombok.AllArgsConstructor; + +@AllArgsConstructor +enum GrammarElement { + ALTER_NAMESPACE("ALTER (DATABASE|TABLE|NAMESPACE)"), + ALTER_VIEW("ALTER VIEW"), + CREATE_NAMESPACE("CREATE (DATABASE|TABLE|NAMESPACE)"), + CREATE_FUNCTION("CREATE FUNCTION"), + CREATE_VIEW("CREATE VIEW"), + DROP_NAMESPACE("DROP (DATABASE|TABLE|NAMESPACE)"), + DROP_FUNCTION("DROP FUNCTION"), + DROP_VIEW("DROP VIEW"), + DROP_TABLE("DROP TABLE"), + REPAIR_TABLE("REPAIR TABLE"), + TRUNCATE_TABLE("TRUNCATE TABLE"), + // DML Statements + INSERT("INSERT"), + LOAD("LOAD"), + + // Data Retrieval Statements + EXPLAIN("EXPLAIN"), + WITH("WITH"), + CLUSTER_BY("CLUSTER BY"), + DISTRIBUTE_BY("DISTRIBUTE BY"), + // GROUP_BY("GROUP BY"), + // HAVING("HAVING"), + HINTS("HINTS"), + INLINE_TABLE("Inline Table(VALUES)"), + FILE("File"), + INNER_JOIN("INNER JOIN"), + CROSS_JOIN("CROSS JOIN"), + LEFT_OUTER_JOIN("LEFT OUTER JOIN"), + LEFT_SEMI_JOIN("LEFT SEMI JOIN"), + RIGHT_OUTER_JOIN("RIGHT OUTER JOIN"), + FULL_OUTER_JOIN("FULL OUTER JOIN"), + LEFT_ANTI_JOIN("LEFT ANTI JOIN"), + TABLESAMPLE("TABLESAMPLE"), + TABLE_VALUED_FUNCTION("Table-valued function"), + LATERAL_VIEW("LATERAL VIEW"), + LATERAL_SUBQUERY("LATERAL SUBQUERY"), + TRANSFORM("TRANSFORM"), + + // Auxiliary Statements + MANAGE_RESOURCE("Resource management statements"), + ANALYZE_TABLE("ANALYZE TABLE(S)"), + CACHE_TABLE("CACHE TABLE"), + CLEAR_CACHE("CLEAR CACHE"), + DESCRIBE_NAMESPACE("DESCRIBE (NAMESPACE|DATABASE|SCHEMA)"), + DESCRIBE_FUNCTION("DESCRIBE FUNCTION"), + DESCRIBE_QUERY("DESCRIBE QUERY"), + DESCRIBE_TABLE("DESCRIBE TABLE"), + REFRESH_RESOURCE("REFRESH"), + REFRESH_TABLE("REFRESH TABLE"), + REFRESH_FUNCTION("REFRESH FUNCTION"), + RESET("RESET"), + SET("SET"), + SHOW_COLUMNS("SHOW COLUMNS"), + SHOW_CREATE_TABLE("SHOW CREATE TABLE"), + SHOW_NAMESPACES("SHOW (DATABASES|SCHEMAS)"), + SHOW_FUNCTIONS("SHOW FUNCTIONS"), + SHOW_PARTITIONS("SHOW PARTITIONS"), + SHOW_TABLE_EXTENDED("SHOW TABLE EXTENDED"), + SHOW_TABLES("SHOW TABLES"), + SHOW_TBLPROPERTIES("SHOW TBLPROPERTIES"), + SHOW_VIEWS("SHOW VIEWS"), + UNCACHE_TABLE("UNCACHE TABLE"), + + // Functions + MAP_FUNCTIONS("Map functions"), + CSV_FUNCTIONS("CSV functions"), + MISC_FUNCTIONS("Misc functions"), + + // UDF + UDF("User Defined functions"); + + String description; + + @Override + public String toString() { + return description; + } +} diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidator.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidator.java new file mode 100644 index 0000000000..cc49643772 --- /dev/null +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidator.java @@ -0,0 +1,15 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.validator; + +/** Interface for validator to decide if each GrammarElement is valid or not. */ +public interface GrammarElementValidator { + + /** + * @return true if element is valid (accepted) + */ + boolean isValid(GrammarElement element); +} diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidatorProvider.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidatorProvider.java new file mode 100644 index 0000000000..9755a1c0b6 --- /dev/null +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/GrammarElementValidatorProvider.java @@ -0,0 +1,22 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.validator; + +import java.util.Map; +import lombok.AllArgsConstructor; +import org.opensearch.sql.datasource.model.DataSourceType; + +/** Provides GrammarElementValidator based on DataSourceType. */ +@AllArgsConstructor +public class GrammarElementValidatorProvider { + + private final Map validatorMap; + private final GrammarElementValidator defaultValidator; + + public GrammarElementValidator getValidatorForDatasource(DataSourceType dataSourceType) { + return validatorMap.getOrDefault(dataSourceType, defaultValidator); + } +} diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/S3GlueGrammarElementValidator.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/S3GlueGrammarElementValidator.java new file mode 100644 index 0000000000..e7a0ce1b36 --- /dev/null +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/S3GlueGrammarElementValidator.java @@ -0,0 +1,71 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.validator; + +import static org.opensearch.sql.spark.validator.GrammarElement.ALTER_VIEW; +import static org.opensearch.sql.spark.validator.GrammarElement.CLUSTER_BY; +import static org.opensearch.sql.spark.validator.GrammarElement.CREATE_FUNCTION; +import static org.opensearch.sql.spark.validator.GrammarElement.CREATE_VIEW; +import static org.opensearch.sql.spark.validator.GrammarElement.DESCRIBE_FUNCTION; +import static org.opensearch.sql.spark.validator.GrammarElement.DISTRIBUTE_BY; +import static org.opensearch.sql.spark.validator.GrammarElement.DROP_FUNCTION; +import static org.opensearch.sql.spark.validator.GrammarElement.DROP_VIEW; +import static org.opensearch.sql.spark.validator.GrammarElement.FILE; +import static org.opensearch.sql.spark.validator.GrammarElement.HINTS; +import static org.opensearch.sql.spark.validator.GrammarElement.INLINE_TABLE; +import static org.opensearch.sql.spark.validator.GrammarElement.INSERT; +import static org.opensearch.sql.spark.validator.GrammarElement.LOAD; +import static org.opensearch.sql.spark.validator.GrammarElement.MANAGE_RESOURCE; +import static org.opensearch.sql.spark.validator.GrammarElement.MISC_FUNCTIONS; +import static org.opensearch.sql.spark.validator.GrammarElement.REFRESH_FUNCTION; +import static org.opensearch.sql.spark.validator.GrammarElement.REFRESH_RESOURCE; +import static org.opensearch.sql.spark.validator.GrammarElement.RESET; +import static org.opensearch.sql.spark.validator.GrammarElement.SET; +import static org.opensearch.sql.spark.validator.GrammarElement.SHOW_FUNCTIONS; +import static org.opensearch.sql.spark.validator.GrammarElement.SHOW_VIEWS; +import static org.opensearch.sql.spark.validator.GrammarElement.TABLESAMPLE; +import static org.opensearch.sql.spark.validator.GrammarElement.TABLE_VALUED_FUNCTION; +import static org.opensearch.sql.spark.validator.GrammarElement.TRANSFORM; +import static org.opensearch.sql.spark.validator.GrammarElement.UDF; + +import com.google.common.collect.ImmutableSet; +import java.util.Set; + +public class S3GlueGrammarElementValidator extends DenyListGrammarElementValidator { + private static final Set S3GLUE_DENY_LIST = + ImmutableSet.builder() + .add( + ALTER_VIEW, + CREATE_FUNCTION, + CREATE_VIEW, + DROP_FUNCTION, + DROP_VIEW, + INSERT, + LOAD, + CLUSTER_BY, + DISTRIBUTE_BY, + HINTS, + INLINE_TABLE, + FILE, + TABLESAMPLE, + TABLE_VALUED_FUNCTION, + TRANSFORM, + MANAGE_RESOURCE, + DESCRIBE_FUNCTION, + REFRESH_RESOURCE, + REFRESH_FUNCTION, + RESET, + SET, + SHOW_FUNCTIONS, + SHOW_VIEWS, + MISC_FUNCTIONS, + UDF) + .build(); + + public S3GlueGrammarElementValidator() { + super(S3GLUE_DENY_LIST); + } +} diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidationVisitor.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidationVisitor.java new file mode 100644 index 0000000000..9ec0fb0109 --- /dev/null +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidationVisitor.java @@ -0,0 +1,609 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.validator; + +import lombok.AllArgsConstructor; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.AddTableColumnsContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.AddTablePartitionContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.AlterClusterByContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.AlterTableAlterColumnContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.AlterViewQueryContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.AlterViewSchemaBindingContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.AnalyzeContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.AnalyzeTablesContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.CacheTableContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ClearCacheContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ClusterBySpecContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.CreateNamespaceContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.CreateTableContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.CreateTableLikeContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.CreateViewContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.CtesContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.DescribeFunctionContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.DescribeNamespaceContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.DescribeQueryContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.DescribeRelationContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.DropFunctionContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.DropNamespaceContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.DropTableColumnsContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.DropTableContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.DropTablePartitionsContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.DropViewContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ExplainContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.FunctionNameContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.HintContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.HiveReplaceColumnsContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.InlineTableContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.InsertIntoReplaceWhereContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.InsertIntoTableContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.InsertOverwriteDirContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.InsertOverwriteHiveDirContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.InsertOverwriteTableContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.JoinRelationContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.JoinTypeContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.LateralViewContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.LoadDataContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ManageResourceContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.QueryOrganizationContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.RecoverPartitionsContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.RefreshFunctionContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.RefreshResourceContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.RefreshTableContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.RelationContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.RenameTableColumnContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.RenameTableContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.RenameTablePartitionContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.RepairTableContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ReplaceTableContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ResetConfigurationContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ResetQuotedConfigurationContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.SampleContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.SetConfigurationContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.SetNamespaceLocationContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.SetNamespacePropertiesContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.SetTableLocationContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.SetTableSerDeContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ShowColumnsContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ShowCreateTableContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ShowFunctionsContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ShowNamespacesContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ShowPartitionsContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ShowTableExtendedContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ShowTablesContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ShowTblPropertiesContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.ShowViewsContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.TableNameContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.TableValuedFunctionContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.TransformClauseContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.TruncateTableContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.UncacheTableContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.UnsetNamespacePropertiesContext; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParserBaseVisitor; + +/** This visitor validate grammar using GrammarElementValidator */ +@AllArgsConstructor +public class SQLQueryValidationVisitor extends SqlBaseParserBaseVisitor { + private final GrammarElementValidator grammarElementValidator; + + @Override + public Void visitCreateFunction(SqlBaseParser.CreateFunctionContext ctx) { + validateAllowed(GrammarElement.CREATE_FUNCTION); + return super.visitCreateFunction(ctx); + } + + @Override + public Void visitSetNamespaceProperties(SetNamespacePropertiesContext ctx) { + validateAllowed(GrammarElement.ALTER_NAMESPACE); + return super.visitSetNamespaceProperties(ctx); + } + + @Override + public Void visitUnsetNamespaceProperties(UnsetNamespacePropertiesContext ctx) { + validateAllowed(GrammarElement.ALTER_NAMESPACE); + return super.visitUnsetNamespaceProperties(ctx); + } + + @Override + public Void visitAddTableColumns(AddTableColumnsContext ctx) { + validateAllowed(GrammarElement.ALTER_NAMESPACE); + return super.visitAddTableColumns(ctx); + } + + @Override + public Void visitAddTablePartition(AddTablePartitionContext ctx) { + validateAllowed(GrammarElement.ALTER_NAMESPACE); + return super.visitAddTablePartition(ctx); + } + + @Override + public Void visitRenameTableColumn(RenameTableColumnContext ctx) { + validateAllowed(GrammarElement.ALTER_NAMESPACE); + return super.visitRenameTableColumn(ctx); + } + + @Override + public Void visitDropTableColumns(DropTableColumnsContext ctx) { + validateAllowed(GrammarElement.ALTER_NAMESPACE); + return super.visitDropTableColumns(ctx); + } + + @Override + public Void visitAlterTableAlterColumn(AlterTableAlterColumnContext ctx) { + validateAllowed(GrammarElement.ALTER_NAMESPACE); + return super.visitAlterTableAlterColumn(ctx); + } + + @Override + public Void visitHiveReplaceColumns(HiveReplaceColumnsContext ctx) { + validateAllowed(GrammarElement.ALTER_NAMESPACE); + return super.visitHiveReplaceColumns(ctx); + } + + @Override + public Void visitSetTableSerDe(SetTableSerDeContext ctx) { + validateAllowed(GrammarElement.ALTER_NAMESPACE); + return super.visitSetTableSerDe(ctx); + } + + @Override + public Void visitRenameTablePartition(RenameTablePartitionContext ctx) { + validateAllowed(GrammarElement.ALTER_NAMESPACE); + return super.visitRenameTablePartition(ctx); + } + + @Override + public Void visitDropTablePartitions(DropTablePartitionsContext ctx) { + validateAllowed(GrammarElement.ALTER_NAMESPACE); + return super.visitDropTablePartitions(ctx); + } + + @Override + public Void visitSetTableLocation(SetTableLocationContext ctx) { + validateAllowed(GrammarElement.ALTER_NAMESPACE); + return super.visitSetTableLocation(ctx); + } + + @Override + public Void visitRecoverPartitions(RecoverPartitionsContext ctx) { + validateAllowed(GrammarElement.ALTER_NAMESPACE); + return super.visitRecoverPartitions(ctx); + } + + @Override + public Void visitAlterClusterBy(AlterClusterByContext ctx) { + validateAllowed(GrammarElement.ALTER_NAMESPACE); + return super.visitAlterClusterBy(ctx); + } + + @Override + public Void visitSetNamespaceLocation(SetNamespaceLocationContext ctx) { + validateAllowed(GrammarElement.ALTER_NAMESPACE); + return super.visitSetNamespaceLocation(ctx); + } + + @Override + public Void visitAlterViewQuery(AlterViewQueryContext ctx) { + validateAllowed(GrammarElement.ALTER_VIEW); + return super.visitAlterViewQuery(ctx); + } + + @Override + public Void visitAlterViewSchemaBinding(AlterViewSchemaBindingContext ctx) { + validateAllowed(GrammarElement.ALTER_VIEW); + return super.visitAlterViewSchemaBinding(ctx); + } + + @Override + public Void visitRenameTable(RenameTableContext ctx) { + if (ctx.VIEW() != null) { + validateAllowed(GrammarElement.ALTER_VIEW); + } else { + validateAllowed(GrammarElement.ALTER_NAMESPACE); + } + + return super.visitRenameTable(ctx); + } + + @Override + public Void visitCreateNamespace(CreateNamespaceContext ctx) { + validateAllowed(GrammarElement.CREATE_NAMESPACE); + return super.visitCreateNamespace(ctx); + } + + @Override + public Void visitCreateTable(CreateTableContext ctx) { + validateAllowed(GrammarElement.CREATE_NAMESPACE); + return super.visitCreateTable(ctx); + } + + @Override + public Void visitCreateTableLike(CreateTableLikeContext ctx) { + validateAllowed(GrammarElement.CREATE_NAMESPACE); + return super.visitCreateTableLike(ctx); + } + + @Override + public Void visitReplaceTable(ReplaceTableContext ctx) { + validateAllowed(GrammarElement.CREATE_NAMESPACE); + return super.visitReplaceTable(ctx); + } + + @Override + public Void visitDropNamespace(DropNamespaceContext ctx) { + validateAllowed(GrammarElement.DROP_NAMESPACE); + return super.visitDropNamespace(ctx); + } + + @Override + public Void visitDropTable(DropTableContext ctx) { + validateAllowed(GrammarElement.DROP_NAMESPACE); + return super.visitDropTable(ctx); + } + + @Override + public Void visitCreateView(CreateViewContext ctx) { + validateAllowed(GrammarElement.CREATE_VIEW); + return super.visitCreateView(ctx); + } + + @Override + public Void visitDropView(DropViewContext ctx) { + validateAllowed(GrammarElement.DROP_VIEW); + return super.visitDropView(ctx); + } + + @Override + public Void visitDropFunction(DropFunctionContext ctx) { + validateAllowed(GrammarElement.DROP_FUNCTION); + return super.visitDropFunction(ctx); + } + + @Override + public Void visitRepairTable(RepairTableContext ctx) { + validateAllowed(GrammarElement.REPAIR_TABLE); + return super.visitRepairTable(ctx); + } + + @Override + public Void visitTruncateTable(TruncateTableContext ctx) { + validateAllowed(GrammarElement.TRUNCATE_TABLE); + return super.visitTruncateTable(ctx); + } + + @Override + public Void visitInsertOverwriteTable(InsertOverwriteTableContext ctx) { + validateAllowed(GrammarElement.INSERT); + return super.visitInsertOverwriteTable(ctx); + } + + @Override + public Void visitInsertIntoReplaceWhere(InsertIntoReplaceWhereContext ctx) { + validateAllowed(GrammarElement.INSERT); + return super.visitInsertIntoReplaceWhere(ctx); + } + + @Override + public Void visitInsertIntoTable(InsertIntoTableContext ctx) { + validateAllowed(GrammarElement.INSERT); + return super.visitInsertIntoTable(ctx); + } + + @Override + public Void visitInsertOverwriteDir(InsertOverwriteDirContext ctx) { + validateAllowed(GrammarElement.INSERT); + return super.visitInsertOverwriteDir(ctx); + } + + @Override + public Void visitInsertOverwriteHiveDir(InsertOverwriteHiveDirContext ctx) { + validateAllowed(GrammarElement.INSERT); + return super.visitInsertOverwriteHiveDir(ctx); + } + + @Override + public Void visitLoadData(LoadDataContext ctx) { + validateAllowed(GrammarElement.LOAD); + return super.visitLoadData(ctx); + } + + @Override + public Void visitExplain(ExplainContext ctx) { + validateAllowed(GrammarElement.EXPLAIN); + return super.visitExplain(ctx); + } + + @Override + public Void visitTableName(TableNameContext ctx) { + String reference = ctx.identifierReference().getText(); + if (isFileReference(reference)) { + validateAllowed(GrammarElement.FILE); + } + return super.visitTableName(ctx); + } + + private static final String FILE_REFERENCE_PATTERN = "^[a-zA-Z]+\\.`[^`]+`$"; + + private boolean isFileReference(String reference) { + return reference.matches(FILE_REFERENCE_PATTERN); + } + + @Override + public Void visitCtes(CtesContext ctx) { + validateAllowed(GrammarElement.WITH); + return super.visitCtes(ctx); + } + + @Override + public Void visitClusterBySpec(ClusterBySpecContext ctx) { + validateAllowed(GrammarElement.CLUSTER_BY); + return super.visitClusterBySpec(ctx); + } + + @Override + public Void visitQueryOrganization(QueryOrganizationContext ctx) { + if (ctx.CLUSTER() != null) { + validateAllowed(GrammarElement.CLUSTER_BY); + } else if (ctx.DISTRIBUTE() != null) { + validateAllowed(GrammarElement.DISTRIBUTE_BY); + } + return super.visitQueryOrganization(ctx); + } + + @Override + public Void visitHint(HintContext ctx) { + validateAllowed(GrammarElement.HINTS); + return super.visitHint(ctx); + } + + @Override + public Void visitInlineTable(InlineTableContext ctx) { + validateAllowed(GrammarElement.INLINE_TABLE); + return super.visitInlineTable(ctx); + } + + @Override + public Void visitJoinType(JoinTypeContext ctx) { + if (ctx.CROSS() != null) { + validateAllowed(GrammarElement.CROSS_JOIN); + } else if (ctx.LEFT() != null && ctx.SEMI() != null) { + validateAllowed(GrammarElement.LEFT_SEMI_JOIN); + } else if (ctx.ANTI() != null) { + validateAllowed(GrammarElement.LEFT_ANTI_JOIN); + } else if (ctx.LEFT() != null) { + validateAllowed(GrammarElement.LEFT_OUTER_JOIN); + } else if (ctx.RIGHT() != null) { + validateAllowed(GrammarElement.RIGHT_OUTER_JOIN); + } else if (ctx.FULL() != null) { + validateAllowed(GrammarElement.FULL_OUTER_JOIN); + } else { + validateAllowed(GrammarElement.INNER_JOIN); + } + return super.visitJoinType(ctx); + } + + @Override + public Void visitSample(SampleContext ctx) { + validateAllowed(GrammarElement.TABLESAMPLE); + return super.visitSample(ctx); + } + + @Override + public Void visitTableValuedFunction(TableValuedFunctionContext ctx) { + validateAllowed(GrammarElement.TABLE_VALUED_FUNCTION); + return super.visitTableValuedFunction(ctx); + } + + @Override + public Void visitLateralView(LateralViewContext ctx) { + validateAllowed(GrammarElement.LATERAL_VIEW); + return super.visitLateralView(ctx); + } + + @Override + public Void visitRelation(RelationContext ctx) { + if (ctx.LATERAL() != null) { + validateAllowed(GrammarElement.LATERAL_SUBQUERY); + } + return super.visitRelation(ctx); + } + + @Override + public Void visitJoinRelation(JoinRelationContext ctx) { + if (ctx.LATERAL() != null) { + validateAllowed(GrammarElement.LATERAL_SUBQUERY); + } + return super.visitJoinRelation(ctx); + } + + @Override + public Void visitTransformClause(TransformClauseContext ctx) { + if (ctx.TRANSFORM() != null) { + validateAllowed(GrammarElement.TRANSFORM); + } + return super.visitTransformClause(ctx); + } + + @Override + public Void visitManageResource(ManageResourceContext ctx) { + validateAllowed(GrammarElement.MANAGE_RESOURCE); + return super.visitManageResource(ctx); + } + + @Override + public Void visitAnalyze(AnalyzeContext ctx) { + validateAllowed(GrammarElement.ANALYZE_TABLE); + return super.visitAnalyze(ctx); + } + + @Override + public Void visitAnalyzeTables(AnalyzeTablesContext ctx) { + validateAllowed(GrammarElement.ANALYZE_TABLE); + return super.visitAnalyzeTables(ctx); + } + + @Override + public Void visitCacheTable(CacheTableContext ctx) { + validateAllowed(GrammarElement.CACHE_TABLE); + return super.visitCacheTable(ctx); + } + + @Override + public Void visitClearCache(ClearCacheContext ctx) { + validateAllowed(GrammarElement.CLEAR_CACHE); + return super.visitClearCache(ctx); + } + + @Override + public Void visitDescribeNamespace(DescribeNamespaceContext ctx) { + validateAllowed(GrammarElement.DESCRIBE_NAMESPACE); + return super.visitDescribeNamespace(ctx); + } + + @Override + public Void visitDescribeFunction(DescribeFunctionContext ctx) { + validateAllowed(GrammarElement.DESCRIBE_FUNCTION); + return super.visitDescribeFunction(ctx); + } + + @Override + public Void visitDescribeRelation(DescribeRelationContext ctx) { + validateAllowed(GrammarElement.DESCRIBE_TABLE); + return super.visitDescribeRelation(ctx); + } + + @Override + public Void visitDescribeQuery(DescribeQueryContext ctx) { + validateAllowed(GrammarElement.DESCRIBE_QUERY); + return super.visitDescribeQuery(ctx); + } + + @Override + public Void visitRefreshResource(RefreshResourceContext ctx) { + validateAllowed(GrammarElement.REFRESH_RESOURCE); + return super.visitRefreshResource(ctx); + } + + @Override + public Void visitRefreshTable(RefreshTableContext ctx) { + validateAllowed(GrammarElement.REFRESH_TABLE); + return super.visitRefreshTable(ctx); + } + + @Override + public Void visitRefreshFunction(RefreshFunctionContext ctx) { + validateAllowed(GrammarElement.REFRESH_FUNCTION); + return super.visitRefreshFunction(ctx); + } + + @Override + public Void visitResetConfiguration(ResetConfigurationContext ctx) { + validateAllowed(GrammarElement.RESET); + return super.visitResetConfiguration(ctx); + } + + @Override + public Void visitResetQuotedConfiguration(ResetQuotedConfigurationContext ctx) { + validateAllowed(GrammarElement.RESET); + return super.visitResetQuotedConfiguration(ctx); + } + + @Override + public Void visitSetConfiguration(SetConfigurationContext ctx) { + validateAllowed(GrammarElement.SET); + return super.visitSetConfiguration(ctx); + } + + @Override + public Void visitShowColumns(ShowColumnsContext ctx) { + validateAllowed(GrammarElement.SHOW_COLUMNS); + return super.visitShowColumns(ctx); + } + + @Override + public Void visitShowCreateTable(ShowCreateTableContext ctx) { + validateAllowed(GrammarElement.SHOW_CREATE_TABLE); + return super.visitShowCreateTable(ctx); + } + + @Override + public Void visitShowNamespaces(ShowNamespacesContext ctx) { + validateAllowed(GrammarElement.SHOW_NAMESPACES); + return super.visitShowNamespaces(ctx); + } + + @Override + public Void visitShowFunctions(ShowFunctionsContext ctx) { + validateAllowed(GrammarElement.SHOW_FUNCTIONS); + return super.visitShowFunctions(ctx); + } + + @Override + public Void visitShowPartitions(ShowPartitionsContext ctx) { + validateAllowed(GrammarElement.SHOW_PARTITIONS); + return super.visitShowPartitions(ctx); + } + + @Override + public Void visitShowTableExtended(ShowTableExtendedContext ctx) { + validateAllowed(GrammarElement.SHOW_TABLE_EXTENDED); + return super.visitShowTableExtended(ctx); + } + + @Override + public Void visitShowTables(ShowTablesContext ctx) { + validateAllowed(GrammarElement.SHOW_TABLES); + return super.visitShowTables(ctx); + } + + @Override + public Void visitShowTblProperties(ShowTblPropertiesContext ctx) { + validateAllowed(GrammarElement.SHOW_TBLPROPERTIES); + return super.visitShowTblProperties(ctx); + } + + @Override + public Void visitShowViews(ShowViewsContext ctx) { + validateAllowed(GrammarElement.SHOW_VIEWS); + return super.visitShowViews(ctx); + } + + @Override + public Void visitUncacheTable(UncacheTableContext ctx) { + validateAllowed(GrammarElement.UNCACHE_TABLE); + return super.visitUncacheTable(ctx); + } + + @Override + public Void visitFunctionName(FunctionNameContext ctx) { + validateFunctionAllowed(ctx.qualifiedName().getText()); + return super.visitFunctionName(ctx); + } + + private void validateFunctionAllowed(String function) { + FunctionType type = FunctionType.fromFunctionName(function.toLowerCase()); + switch (type) { + case MAP: + validateAllowed(GrammarElement.MAP_FUNCTIONS); + break; + case CSV: + validateAllowed(GrammarElement.CSV_FUNCTIONS); + break; + case MISC: + validateAllowed(GrammarElement.MISC_FUNCTIONS); + break; + case UDF: + validateAllowed(GrammarElement.UDF); + break; + } + } + + private void validateAllowed(GrammarElement element) { + if (!grammarElementValidator.isValid(element)) { + throw new IllegalArgumentException(element + " is not allowed."); + } + } +} diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidator.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidator.java new file mode 100644 index 0000000000..f387cbad25 --- /dev/null +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SQLQueryValidator.java @@ -0,0 +1,39 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.validator; + +import lombok.AllArgsConstructor; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.sql.datasource.model.DataSourceType; +import org.opensearch.sql.spark.utils.SQLQueryUtils; + +/** Validate input SQL query based on the DataSourceType. */ +@AllArgsConstructor +public class SQLQueryValidator { + private static final Logger log = LogManager.getLogger(SQLQueryValidator.class); + + private final GrammarElementValidatorProvider grammarElementValidatorProvider; + + /** + * It will look up validator associated with the DataSourceType, and throw + * IllegalArgumentException if invalid grammar element is found. + * + * @param sqlQuery The query to be validated + * @param datasourceType + */ + public void validate(String sqlQuery, DataSourceType datasourceType) { + GrammarElementValidator grammarElementValidator = + grammarElementValidatorProvider.getValidatorForDatasource(datasourceType); + SQLQueryValidationVisitor visitor = new SQLQueryValidationVisitor(grammarElementValidator); + try { + visitor.visit(SQLQueryUtils.getBaseParser(sqlQuery).singleStatement()); + } catch (IllegalArgumentException e) { + log.error("Query validation failed. DataSourceType=" + datasourceType, e); + throw e; + } + } +} diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SecurityLakeGrammarElementValidator.java b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SecurityLakeGrammarElementValidator.java new file mode 100644 index 0000000000..ca8f2b5bdd --- /dev/null +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/validator/SecurityLakeGrammarElementValidator.java @@ -0,0 +1,113 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.validator; + +import static org.opensearch.sql.spark.validator.GrammarElement.ALTER_NAMESPACE; +import static org.opensearch.sql.spark.validator.GrammarElement.ALTER_VIEW; +import static org.opensearch.sql.spark.validator.GrammarElement.ANALYZE_TABLE; +import static org.opensearch.sql.spark.validator.GrammarElement.CACHE_TABLE; +import static org.opensearch.sql.spark.validator.GrammarElement.CLEAR_CACHE; +import static org.opensearch.sql.spark.validator.GrammarElement.CLUSTER_BY; +import static org.opensearch.sql.spark.validator.GrammarElement.CREATE_FUNCTION; +import static org.opensearch.sql.spark.validator.GrammarElement.CREATE_NAMESPACE; +import static org.opensearch.sql.spark.validator.GrammarElement.CREATE_VIEW; +import static org.opensearch.sql.spark.validator.GrammarElement.CSV_FUNCTIONS; +import static org.opensearch.sql.spark.validator.GrammarElement.DESCRIBE_FUNCTION; +import static org.opensearch.sql.spark.validator.GrammarElement.DESCRIBE_NAMESPACE; +import static org.opensearch.sql.spark.validator.GrammarElement.DESCRIBE_QUERY; +import static org.opensearch.sql.spark.validator.GrammarElement.DESCRIBE_TABLE; +import static org.opensearch.sql.spark.validator.GrammarElement.DISTRIBUTE_BY; +import static org.opensearch.sql.spark.validator.GrammarElement.DROP_FUNCTION; +import static org.opensearch.sql.spark.validator.GrammarElement.DROP_NAMESPACE; +import static org.opensearch.sql.spark.validator.GrammarElement.DROP_VIEW; +import static org.opensearch.sql.spark.validator.GrammarElement.FILE; +import static org.opensearch.sql.spark.validator.GrammarElement.HINTS; +import static org.opensearch.sql.spark.validator.GrammarElement.INLINE_TABLE; +import static org.opensearch.sql.spark.validator.GrammarElement.INSERT; +import static org.opensearch.sql.spark.validator.GrammarElement.LOAD; +import static org.opensearch.sql.spark.validator.GrammarElement.MANAGE_RESOURCE; +import static org.opensearch.sql.spark.validator.GrammarElement.MISC_FUNCTIONS; +import static org.opensearch.sql.spark.validator.GrammarElement.REFRESH_FUNCTION; +import static org.opensearch.sql.spark.validator.GrammarElement.REFRESH_RESOURCE; +import static org.opensearch.sql.spark.validator.GrammarElement.REFRESH_TABLE; +import static org.opensearch.sql.spark.validator.GrammarElement.REPAIR_TABLE; +import static org.opensearch.sql.spark.validator.GrammarElement.RESET; +import static org.opensearch.sql.spark.validator.GrammarElement.SET; +import static org.opensearch.sql.spark.validator.GrammarElement.SHOW_COLUMNS; +import static org.opensearch.sql.spark.validator.GrammarElement.SHOW_CREATE_TABLE; +import static org.opensearch.sql.spark.validator.GrammarElement.SHOW_FUNCTIONS; +import static org.opensearch.sql.spark.validator.GrammarElement.SHOW_NAMESPACES; +import static org.opensearch.sql.spark.validator.GrammarElement.SHOW_PARTITIONS; +import static org.opensearch.sql.spark.validator.GrammarElement.SHOW_TABLES; +import static org.opensearch.sql.spark.validator.GrammarElement.SHOW_TABLE_EXTENDED; +import static org.opensearch.sql.spark.validator.GrammarElement.SHOW_TBLPROPERTIES; +import static org.opensearch.sql.spark.validator.GrammarElement.SHOW_VIEWS; +import static org.opensearch.sql.spark.validator.GrammarElement.TABLESAMPLE; +import static org.opensearch.sql.spark.validator.GrammarElement.TABLE_VALUED_FUNCTION; +import static org.opensearch.sql.spark.validator.GrammarElement.TRANSFORM; +import static org.opensearch.sql.spark.validator.GrammarElement.TRUNCATE_TABLE; +import static org.opensearch.sql.spark.validator.GrammarElement.UDF; +import static org.opensearch.sql.spark.validator.GrammarElement.UNCACHE_TABLE; + +import com.google.common.collect.ImmutableSet; +import java.util.Set; + +public class SecurityLakeGrammarElementValidator extends DenyListGrammarElementValidator { + private static final Set SECURITY_LAKE_DENY_LIST = + ImmutableSet.builder() + .add( + ALTER_NAMESPACE, + ALTER_VIEW, + CREATE_NAMESPACE, + CREATE_FUNCTION, + CREATE_VIEW, + DROP_FUNCTION, + DROP_NAMESPACE, + DROP_VIEW, + REPAIR_TABLE, + TRUNCATE_TABLE, + INSERT, + LOAD, + CLUSTER_BY, + DISTRIBUTE_BY, + HINTS, + INLINE_TABLE, + FILE, + TABLESAMPLE, + TABLE_VALUED_FUNCTION, + TRANSFORM, + MANAGE_RESOURCE, + ANALYZE_TABLE, + CACHE_TABLE, + CLEAR_CACHE, + DESCRIBE_NAMESPACE, + DESCRIBE_FUNCTION, + DESCRIBE_QUERY, + DESCRIBE_TABLE, + REFRESH_RESOURCE, + REFRESH_TABLE, + REFRESH_FUNCTION, + RESET, + SET, + SHOW_COLUMNS, + SHOW_CREATE_TABLE, + SHOW_NAMESPACES, + SHOW_FUNCTIONS, + SHOW_PARTITIONS, + SHOW_TABLE_EXTENDED, + SHOW_TABLES, + SHOW_TBLPROPERTIES, + SHOW_VIEWS, + UNCACHE_TABLE, + CSV_FUNCTIONS, + MISC_FUNCTIONS, + UDF) + .build(); + + public SecurityLakeGrammarElementValidator() { + super(SECURITY_LAKE_DENY_LIST); + } +} diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java index ddadeb65e2..57ad4ecf42 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryCoreIntegTest.java @@ -85,6 +85,10 @@ import org.opensearch.sql.spark.rest.model.CreateAsyncQueryResponse; import org.opensearch.sql.spark.rest.model.LangType; import org.opensearch.sql.spark.scheduler.AsyncQueryScheduler; +import org.opensearch.sql.spark.validator.DefaultGrammarElementValidator; +import org.opensearch.sql.spark.validator.GrammarElementValidatorProvider; +import org.opensearch.sql.spark.validator.S3GlueGrammarElementValidator; +import org.opensearch.sql.spark.validator.SQLQueryValidator; /** * This tests async-query-core library end-to-end using mocked implementation of extension points. @@ -175,9 +179,18 @@ public void setUp() { emrServerlessClientFactory, metricsService, new SparkSubmitParametersBuilderProvider(collection)); + SQLQueryValidator sqlQueryValidator = + new SQLQueryValidator( + new GrammarElementValidatorProvider( + ImmutableMap.of(DataSourceType.S3GLUE, new S3GlueGrammarElementValidator()), + new DefaultGrammarElementValidator())); SparkQueryDispatcher sparkQueryDispatcher = new SparkQueryDispatcher( - dataSourceService, sessionManager, queryHandlerFactory, queryIdProvider); + dataSourceService, + sessionManager, + queryHandlerFactory, + queryIdProvider, + sqlQueryValidator); asyncQueryExecutorService = new AsyncQueryExecutorServiceImpl( asyncQueryJobMetadataStorageService, diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index 75c0e00337..1a38b6977f 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -42,6 +42,7 @@ import com.amazonaws.services.emrserverless.model.GetJobRunResult; import com.amazonaws.services.emrserverless.model.JobRun; import com.amazonaws.services.emrserverless.model.JobRunState; +import com.google.common.collect.ImmutableMap; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; @@ -88,6 +89,10 @@ import org.opensearch.sql.spark.response.JobExecutionResponseReader; import org.opensearch.sql.spark.rest.model.LangType; import org.opensearch.sql.spark.scheduler.AsyncQueryScheduler; +import org.opensearch.sql.spark.validator.DefaultGrammarElementValidator; +import org.opensearch.sql.spark.validator.GrammarElementValidatorProvider; +import org.opensearch.sql.spark.validator.S3GlueGrammarElementValidator; +import org.opensearch.sql.spark.validator.SQLQueryValidator; @ExtendWith(MockitoExtension.class) public class SparkQueryDispatcherTest { @@ -111,6 +116,13 @@ public class SparkQueryDispatcherTest { @Mock private AsyncQueryRequestContext asyncQueryRequestContext; @Mock private MetricsService metricsService; @Mock private AsyncQueryScheduler asyncQueryScheduler; + + private final SQLQueryValidator sqlQueryValidator = + new SQLQueryValidator( + new GrammarElementValidatorProvider( + ImmutableMap.of(DataSourceType.S3GLUE, new S3GlueGrammarElementValidator()), + new DefaultGrammarElementValidator())); + private DataSourceSparkParameterComposer dataSourceSparkParameterComposer = (datasourceMetadata, sparkSubmitParameters, dispatchQueryRequest, context) -> { sparkSubmitParameters.setConfigItem(FLINT_INDEX_STORE_AUTH_KEY, "basic"); @@ -159,7 +171,11 @@ void setUp() { sparkSubmitParametersBuilderProvider); sparkQueryDispatcher = new SparkQueryDispatcher( - dataSourceService, sessionManager, queryHandlerFactory, queryIdProvider); + dataSourceService, + sessionManager, + queryHandlerFactory, + queryIdProvider, + sqlQueryValidator); } @Test @@ -347,19 +363,12 @@ void testDispatchWithSparkUDFQuery() { sparkQueryDispatcher.dispatch( getBaseDispatchQueryRequestBuilder(query).langType(LangType.SQL).build(), asyncQueryRequestContext)); - assertEquals( - "Query is not allowed: Creating user-defined functions is not allowed", - illegalArgumentException.getMessage()); + assertEquals("CREATE FUNCTION is not allowed.", illegalArgumentException.getMessage()); verifyNoInteractions(emrServerlessClient); verifyNoInteractions(flintIndexMetadataService); } } - @Test - void testInvalidSQLQueryDispatchToSpark() { - testDispatchBatchQuery("myselect 1"); - } - @Test void testDispatchQueryWithoutATableAndDataSourceName() { testDispatchBatchQuery("show tables"); @@ -571,7 +580,11 @@ void testDispatchAlterToManualRefreshIndexQuery() { QueryHandlerFactory queryHandlerFactory = mock(QueryHandlerFactory.class); sparkQueryDispatcher = new SparkQueryDispatcher( - dataSourceService, sessionManager, queryHandlerFactory, queryIdProvider); + dataSourceService, + sessionManager, + queryHandlerFactory, + queryIdProvider, + sqlQueryValidator); String query = "ALTER INDEX elb_and_requestUri ON my_glue.default.http_logs WITH" + " (auto_refresh = false)"; @@ -597,7 +610,11 @@ void testDispatchDropIndexQuery() { QueryHandlerFactory queryHandlerFactory = mock(QueryHandlerFactory.class); sparkQueryDispatcher = new SparkQueryDispatcher( - dataSourceService, sessionManager, queryHandlerFactory, queryIdProvider); + dataSourceService, + sessionManager, + queryHandlerFactory, + queryIdProvider, + sqlQueryValidator); String query = "DROP INDEX elb_and_requestUri ON my_glue.default.http_logs"; DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata(); when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata( diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/utils/SQLQueryUtilsTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/utils/SQLQueryUtilsTest.java index 56cab7ce7f..881ad0e56a 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/utils/SQLQueryUtilsTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/utils/SQLQueryUtilsTest.java @@ -10,7 +10,6 @@ import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.mockito.Mockito.when; import static org.opensearch.sql.spark.utils.SQLQueryUtilsTest.IndexQuery.index; import static org.opensearch.sql.spark.utils.SQLQueryUtilsTest.IndexQuery.mv; import static org.opensearch.sql.spark.utils.SQLQueryUtilsTest.IndexQuery.skippingIndex; @@ -22,7 +21,6 @@ import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.sql.datasource.model.DataSource; -import org.opensearch.sql.datasource.model.DataSourceType; import org.opensearch.sql.spark.dispatcher.model.FullyQualifiedTableName; import org.opensearch.sql.spark.dispatcher.model.IndexQueryActionType; import org.opensearch.sql.spark.dispatcher.model.IndexQueryDetails; @@ -444,106 +442,6 @@ void testRecoverIndex() { assertEquals(IndexQueryActionType.RECOVER, indexDetails.getIndexQueryActionType()); } - @Test - void testValidateSparkSqlQuery_ValidQuery() { - List errors = - validateSparkSqlQueryForDataSourceType( - "DELETE FROM Customers WHERE CustomerName='Alfreds Futterkiste'", - DataSourceType.PROMETHEUS); - - assertTrue(errors.isEmpty(), "Valid query should not produce any errors"); - } - - @Test - void testValidateSparkSqlQuery_SelectQuery_DataSourceSecurityLake() { - List errors = - validateSparkSqlQueryForDataSourceType( - "SELECT * FROM users WHERE age > 18", DataSourceType.SECURITY_LAKE); - - assertTrue(errors.isEmpty(), "Valid query should not produce any errors "); - } - - @Test - void testValidateSparkSqlQuery_SelectQuery_DataSourceTypeNull() { - List errors = - validateSparkSqlQueryForDataSourceType("SELECT * FROM users WHERE age > 18", null); - - assertTrue(errors.isEmpty(), "Valid query should not produce any errors "); - } - - @Test - void testValidateSparkSqlQuery_InvalidQuery_SyntaxCheckFailureSkippedWithoutValidationError() { - List errors = - validateSparkSqlQueryForDataSourceType( - "SEECT * FROM users WHERE age > 18", DataSourceType.SECURITY_LAKE); - - assertTrue(errors.isEmpty(), "Valid query should not produce any errors "); - } - - @Test - void testValidateSparkSqlQuery_nullDatasource() { - List errors = - SQLQueryUtils.validateSparkSqlQuery(null, "SELECT * FROM users WHERE age > 18"); - assertTrue(errors.isEmpty(), "Valid query should not produce any errors "); - } - - private List validateSparkSqlQueryForDataSourceType( - String query, DataSourceType dataSourceType) { - when(this.dataSource.getConnectorType()).thenReturn(dataSourceType); - - return SQLQueryUtils.validateSparkSqlQuery(this.dataSource, query); - } - - @Test - void testValidateSparkSqlQuery_SelectQuery_DataSourceSecurityLake_ValidationFails() { - List errors = - validateSparkSqlQueryForDataSourceType( - "REFRESH INDEX cv1 ON mys3.default.http_logs", DataSourceType.SECURITY_LAKE); - - assertFalse( - errors.isEmpty(), - "Invalid query as Security Lake datasource supports only flint queries and SELECT sql" - + " queries. Given query was REFRESH sql query"); - assertEquals( - errors.get(0), - "Unsupported sql statement for security lake data source. Only select queries are allowed"); - } - - @Test - void - testValidateSparkSqlQuery_NonSelectStatementContainingSelectClause_DataSourceSecurityLake_ValidationFails() { - String query = - "CREATE TABLE AccountSummaryOrWhatever AS " - + "select taxid, address1, count(address1) from dbo.t " - + "group by taxid, address1;"; - - List errors = - validateSparkSqlQueryForDataSourceType(query, DataSourceType.SECURITY_LAKE); - - assertFalse( - errors.isEmpty(), - "Invalid query as Security Lake datasource supports only flint queries and SELECT sql" - + " queries. Given query was REFRESH sql query"); - assertEquals( - errors.get(0), - "Unsupported sql statement for security lake data source. Only select queries are allowed"); - } - - @Test - void testValidateSparkSqlQuery_InvalidQuery() { - when(dataSource.getConnectorType()).thenReturn(DataSourceType.PROMETHEUS); - String invalidQuery = "CREATE FUNCTION myUDF AS 'com.example.UDF'"; - - List errors = SQLQueryUtils.validateSparkSqlQuery(dataSource, invalidQuery); - - assertFalse(errors.isEmpty(), "Invalid query should produce errors"); - assertEquals(1, errors.size(), "Should have one error"); - assertEquals( - "Creating user-defined functions is not allowed", - errors.get(0), - "Error message should match"); - } - @Getter protected static class IndexQuery { private String query; diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/validator/FunctionTypeTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/validator/FunctionTypeTest.java new file mode 100644 index 0000000000..a5f868421c --- /dev/null +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/validator/FunctionTypeTest.java @@ -0,0 +1,47 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.validator; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.junit.jupiter.api.Test; + +class FunctionTypeTest { + @Test + public void test() { + assertEquals(FunctionType.AGGREGATE, FunctionType.fromFunctionName("any")); + assertEquals(FunctionType.AGGREGATE, FunctionType.fromFunctionName("variance")); + assertEquals(FunctionType.WINDOW, FunctionType.fromFunctionName("cume_dist")); + assertEquals(FunctionType.WINDOW, FunctionType.fromFunctionName("row_number")); + assertEquals(FunctionType.ARRAY, FunctionType.fromFunctionName("array")); + assertEquals(FunctionType.ARRAY, FunctionType.fromFunctionName("sort_array")); + assertEquals(FunctionType.MAP, FunctionType.fromFunctionName("element_at")); + assertEquals(FunctionType.MAP, FunctionType.fromFunctionName("try_element_at")); + assertEquals(FunctionType.DATE_TIMESTAMP, FunctionType.fromFunctionName("add_months")); + assertEquals(FunctionType.DATE_TIMESTAMP, FunctionType.fromFunctionName("year")); + assertEquals(FunctionType.JSON, FunctionType.fromFunctionName("from_json")); + assertEquals(FunctionType.JSON, FunctionType.fromFunctionName("to_json")); + assertEquals(FunctionType.MATH, FunctionType.fromFunctionName("abs")); + assertEquals(FunctionType.MATH, FunctionType.fromFunctionName("width_bucket")); + assertEquals(FunctionType.STRING, FunctionType.fromFunctionName("ascii")); + assertEquals(FunctionType.STRING, FunctionType.fromFunctionName("upper")); + assertEquals(FunctionType.CONDITIONAL, FunctionType.fromFunctionName("coalesce")); + assertEquals(FunctionType.CONDITIONAL, FunctionType.fromFunctionName("nvl2")); + assertEquals(FunctionType.BITWISE, FunctionType.fromFunctionName("bit_count")); + assertEquals(FunctionType.BITWISE, FunctionType.fromFunctionName("shiftrightunsigned")); + assertEquals(FunctionType.CONVERSION, FunctionType.fromFunctionName("bigint")); + assertEquals(FunctionType.CONVERSION, FunctionType.fromFunctionName("tinyint")); + assertEquals(FunctionType.PREDICATE, FunctionType.fromFunctionName("isnan")); + assertEquals(FunctionType.PREDICATE, FunctionType.fromFunctionName("rlike")); + assertEquals(FunctionType.CSV, FunctionType.fromFunctionName("from_csv")); + assertEquals(FunctionType.CSV, FunctionType.fromFunctionName("to_csv")); + assertEquals(FunctionType.MISC, FunctionType.fromFunctionName("aes_decrypt")); + assertEquals(FunctionType.MISC, FunctionType.fromFunctionName("version")); + assertEquals(FunctionType.GENERATOR, FunctionType.fromFunctionName("explode")); + assertEquals(FunctionType.GENERATOR, FunctionType.fromFunctionName("stack")); + assertEquals(FunctionType.UDF, FunctionType.fromFunctionName("unknown")); + } +} diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/validator/GrammarElementValidatorProviderTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/validator/GrammarElementValidatorProviderTest.java new file mode 100644 index 0000000000..7d4b255356 --- /dev/null +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/validator/GrammarElementValidatorProviderTest.java @@ -0,0 +1,39 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.validator; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import com.google.common.collect.ImmutableMap; +import org.junit.jupiter.api.Test; +import org.opensearch.sql.datasource.model.DataSourceType; + +class GrammarElementValidatorProviderTest { + S3GlueGrammarElementValidator s3GlueGrammarElementValidator = new S3GlueGrammarElementValidator(); + SecurityLakeGrammarElementValidator securityLakeGrammarElementValidator = + new SecurityLakeGrammarElementValidator(); + DefaultGrammarElementValidator defaultGrammarElementValidator = + new DefaultGrammarElementValidator(); + GrammarElementValidatorProvider grammarElementValidatorProvider = + new GrammarElementValidatorProvider( + ImmutableMap.of( + DataSourceType.S3GLUE, s3GlueGrammarElementValidator, + DataSourceType.SECURITY_LAKE, securityLakeGrammarElementValidator), + defaultGrammarElementValidator); + + @Test + public void test() { + assertEquals( + s3GlueGrammarElementValidator, + grammarElementValidatorProvider.getValidatorForDatasource(DataSourceType.S3GLUE)); + assertEquals( + securityLakeGrammarElementValidator, + grammarElementValidatorProvider.getValidatorForDatasource(DataSourceType.SECURITY_LAKE)); + assertEquals( + defaultGrammarElementValidator, + grammarElementValidatorProvider.getValidatorForDatasource(DataSourceType.PROMETHEUS)); + } +} diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/validator/SQLQueryValidatorTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/validator/SQLQueryValidatorTest.java new file mode 100644 index 0000000000..6726b56994 --- /dev/null +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/validator/SQLQueryValidatorTest.java @@ -0,0 +1,600 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.validator; + +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; + +import java.util.Arrays; +import lombok.AllArgsConstructor; +import lombok.Getter; +import org.antlr.v4.runtime.CommonTokenStream; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.common.antlr.CaseInsensitiveCharStream; +import org.opensearch.sql.datasource.model.DataSourceType; +import org.opensearch.sql.spark.antlr.parser.SqlBaseLexer; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser; +import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.SingleStatementContext; + +@ExtendWith(MockitoExtension.class) +class SQLQueryValidatorTest { + @Mock GrammarElementValidatorProvider mockedProvider; + + @InjectMocks SQLQueryValidator sqlQueryValidator; + + private enum TestElement { + // DDL Statements + ALTER_DATABASE( + "ALTER DATABASE inventory SET DBPROPERTIES ('Edit-date' = '01/01/2001');", + "ALTER DATABASE dbx.tab1 UNSET PROPERTIES ('winner');", + "ALTER DATABASE dbx.tab1 SET LOCATION '/path/to/part/ways';"), + ALTER_TABLE( + "ALTER TABLE default.StudentInfo PARTITION (age='10') RENAME TO PARTITION (age='15');", + "ALTER TABLE dbx.tab1 UNSET TBLPROPERTIES ('winner');", + "ALTER TABLE StudentInfo ADD columns (LastName string, DOB timestamp);", + "ALTER TABLE StudentInfo ADD IF NOT EXISTS PARTITION (age=18);", + "ALTER TABLE StudentInfo RENAME COLUMN name TO FirstName;", + "ALTER TABLE StudentInfo RENAME TO newName;", + "ALTER TABLE StudentInfo DROP columns (LastName, DOB);", + "ALTER TABLE StudentInfo ALTER COLUMN FirstName COMMENT \"new comment\";", + "ALTER TABLE StudentInfo REPLACE COLUMNS (name string, ID int COMMENT 'new comment');", + "ALTER TABLE test_tab SET SERDE 'org.apache.LazyBinaryColumnarSerDe';", + "ALTER TABLE StudentInfo DROP IF EXISTS PARTITION (age=18);", + "ALTER TABLE dbx.tab1 PARTITION (a='1', b='2') SET LOCATION '/path/to/part/ways';", + "ALTER TABLE dbx.tab1 RECOVER PARTITIONS;", + "ALTER TABLE dbx.tab1 CLUSTER BY NONE;", + "ALTER TABLE dbx.tab1 SET LOCATION '/path/to/part/ways';"), + ALTER_VIEW( + "ALTER VIEW tempdb1.v1 RENAME TO tempdb1.v2;", + "ALTER VIEW tempdb1.v2 AS SELECT * FROM tempdb1.v1;", + "ALTER VIEW tempdb1.v2 WITH SCHEMA BINDING"), + CREATE_DATABASE("CREATE DATABASE IF NOT EXISTS customer_db;\n"), + CREATE_FUNCTION("CREATE FUNCTION simple_udf AS 'SimpleUdf' USING JAR '/tmp/SimpleUdf.jar';"), + CREATE_TABLE( + "CREATE TABLE Student_Dupli like Student;", + "CREATE TABLE student (id INT, name STRING, age INT) USING CSV;", + "CREATE TABLE student_copy USING CSV AS SELECT * FROM student;", + "CREATE TABLE student (id INT, name STRING, age INT);", + "REPLACE TABLE student (id INT, name STRING, age INT) USING CSV;"), + CREATE_VIEW( + "CREATE OR REPLACE VIEW experienced_employee" + + " (ID COMMENT 'Unique identification number', Name)" + + " COMMENT 'View for experienced employees'" + + " AS SELECT id, name FROM all_employee" + + " WHERE working_years > 5;"), + DROP_DATABASE("DROP DATABASE inventory_db CASCADE;"), + DROP_FUNCTION("DROP FUNCTION test_avg;"), + DROP_TABLE("DROP TABLE employeetable;"), + DROP_VIEW("DROP VIEW employeeView;"), + REPAIR_TABLE("REPAIR TABLE t1;"), + TRUNCATE_TABLE("TRUNCATE TABLE Student partition(age=10);"), + + // DML Statements + INSERT_TABLE( + "INSERT INTO target_table SELECT * FROM source_table;", + "INSERT INTO persons REPLACE WHERE ssn = 123456789 SELECT * FROM persons2;", + "INSERT OVERWRITE students VALUES ('Ashua Hill', '456 Erica Ct, Cupertino', 111111);"), + INSERT_OVERWRITE_DIRECTORY( + "INSERT OVERWRITE DIRECTORY '/path/to/output' SELECT * FROM source_table;", + "INSERT OVERWRITE DIRECTORY USING myTable SELECT * FROM source_table;", + "INSERT OVERWRITE LOCAL DIRECTORY '/tmp/destination' STORED AS orc SELECT * FROM" + + " test_table;"), + LOAD("LOAD DATA INPATH '/path/to/data' INTO TABLE target_table;"), + + // Data Retrieval Statements + SELECT("SELECT 1;"), + EXPLAIN("EXPLAIN SELECT * FROM my_table;"), + COMMON_TABLE_EXPRESSION( + "WITH cte AS (SELECT * FROM my_table WHERE age > 30) SELECT * FROM cte;"), + CLUSTER_BY_CLAUSE( + "SELECT * FROM my_table CLUSTER BY age;", "ALTER TABLE testTable CLUSTER BY (age);"), + DISTRIBUTE_BY_CLAUSE("SELECT * FROM my_table DISTRIBUTE BY name;"), + GROUP_BY_CLAUSE("SELECT name, count(*) FROM my_table GROUP BY name;"), + HAVING_CLAUSE("SELECT name, count(*) FROM my_table GROUP BY name HAVING count(*) > 1;"), + HINTS("SELECT /*+ BROADCAST(my_table) */ * FROM my_table;"), + INLINE_TABLE("SELECT * FROM (VALUES (1, 'a'), (2, 'b')) AS inline_table(id, value);"), + FILE("SELECT * FROM text.`/path/to/file.txt`;"), + INNER_JOIN("SELECT t1.name, t2.age FROM table1 t1 INNER JOIN table2 t2 ON t1.id = t2.id;"), + CROSS_JOIN("SELECT t1.name, t2.age FROM table1 t1 CROSS JOIN table2 t2;"), + LEFT_OUTER_JOIN( + "SELECT t1.name, t2.age FROM table1 t1 LEFT OUTER JOIN table2 t2 ON t1.id = t2.id;"), + LEFT_SEMI_JOIN("SELECT t1.name FROM table1 t1 LEFT SEMI JOIN table2 t2 ON t1.id = t2.id;"), + RIGHT_OUTER_JOIN( + "SELECT t1.name, t2.age FROM table1 t1 RIGHT OUTER JOIN table2 t2 ON t1.id = t2.id;"), + FULL_OUTER_JOIN( + "SELECT t1.name, t2.age FROM table1 t1 FULL OUTER JOIN table2 t2 ON t1.id = t2.id;"), + LEFT_ANTI_JOIN("SELECT t1.name FROM table1 t1 LEFT ANTI JOIN table2 t2 ON t1.id = t2.id;"), + LIKE_PREDICATE("SELECT * FROM my_table WHERE name LIKE 'A%';"), + LIMIT_CLAUSE("SELECT * FROM my_table LIMIT 10;"), + OFFSET_CLAUSE("SELECT * FROM my_table OFFSET 5;"), + ORDER_BY_CLAUSE("SELECT * FROM my_table ORDER BY age DESC;"), + SET_OPERATORS("SELECT * FROM table1 UNION SELECT * FROM table2;"), + SORT_BY_CLAUSE("SELECT * FROM my_table SORT BY age DESC;"), + TABLESAMPLE("SELECT * FROM my_table TABLESAMPLE(10 PERCENT);"), + // TABLE_VALUED_FUNCTION("SELECT explode(array(10, 20));"), TODO: Need to handle this case + TABLE_VALUED_FUNCTION("SELECT * FROM explode(array(10, 20));"), + WHERE_CLAUSE("SELECT * FROM my_table WHERE age > 30;"), + AGGREGATE_FUNCTION("SELECT count(*) FROM my_table;"), + WINDOW_FUNCTION("SELECT name, age, rank() OVER (ORDER BY age DESC) FROM my_table;"), + CASE_CLAUSE("SELECT name, CASE WHEN age > 30 THEN 'Adult' ELSE 'Young' END FROM my_table;"), + PIVOT_CLAUSE( + "SELECT * FROM (SELECT name, age, gender FROM my_table) PIVOT (COUNT(*) FOR gender IN ('M'," + + " 'F'));"), + UNPIVOT_CLAUSE( + "SELECT name, value, category FROM (SELECT name, 'M' AS gender, age AS male_age, 0 AS" + + " female_age FROM my_table) UNPIVOT (value FOR category IN (male_age, female_age));"), + LATERAL_VIEW_CLAUSE( + "SELECT name, age, exploded_value FROM my_table LATERAL VIEW OUTER EXPLODE(split(comments," + + " ',')) exploded_table AS exploded_value;"), + LATERAL_SUBQUERY( + "SELECT * FROM t1, LATERAL (SELECT * FROM t2 WHERE t1.c1 = t2.c1);", + "SELECT * FROM t1 JOIN LATERAL (SELECT * FROM t2 WHERE t1.c1 = t2.c1);"), + TRANSFORM_CLAUSE( + "SELECT transform(zip_code, name, age) USING 'cat' AS (a, b, c) FROM my_table;"), + + // Auxiliary Statements + ADD_FILE("ADD FILE /tmp/test.txt;"), + ADD_JAR("ADD JAR /path/to/my.jar;"), + ANALYZE_TABLE( + "ANALYZE TABLE my_table COMPUTE STATISTICS;", + "ANALYZE TABLES IN school_db COMPUTE STATISTICS NOSCAN;"), + CACHE_TABLE("CACHE TABLE my_table;"), + CLEAR_CACHE("CLEAR CACHE;"), + DESCRIBE_DATABASE("DESCRIBE DATABASE my_db;"), + DESCRIBE_FUNCTION("DESCRIBE FUNCTION my_function;"), + DESCRIBE_QUERY("DESCRIBE QUERY SELECT * FROM my_table;"), + DESCRIBE_TABLE("DESCRIBE TABLE my_table;"), + LIST_FILE("LIST FILE '/path/to/files';"), + LIST_JAR("LIST JAR;"), + REFRESH("REFRESH;"), + REFRESH_TABLE("REFRESH TABLE my_table;"), + REFRESH_FUNCTION("REFRESH FUNCTION my_function;"), + RESET("RESET;", "RESET spark.abc;", "RESET `key`;"), + SET( + "SET spark.sql.shuffle.partitions=200;", + "SET -v;", + "SET;", + "SET spark.sql.variable.substitute;"), + SHOW_COLUMNS("SHOW COLUMNS FROM my_table;"), + SHOW_CREATE_TABLE("SHOW CREATE TABLE my_table;"), + SHOW_DATABASES("SHOW DATABASES;"), + SHOW_FUNCTIONS("SHOW FUNCTIONS;"), + SHOW_PARTITIONS("SHOW PARTITIONS my_table;"), + SHOW_TABLE_EXTENDED("SHOW TABLE EXTENDED LIKE 'my_table';"), + SHOW_TABLES("SHOW TABLES;"), + SHOW_TBLPROPERTIES("SHOW TBLPROPERTIES my_table;"), + SHOW_VIEWS("SHOW VIEWS;"), + UNCACHE_TABLE("UNCACHE TABLE my_table;"), + + // Functions + ARRAY_FUNCTIONS("SELECT array_contains(array(1, 2, 3), 2);"), + MAP_FUNCTIONS("SELECT map_keys(map('a', 1, 'b', 2));"), + DATE_AND_TIMESTAMP_FUNCTIONS("SELECT date_format(current_date(), 'yyyy-MM-dd');"), + JSON_FUNCTIONS("SELECT json_tuple('{\"a\":1, \"b\":2}', 'a', 'b');"), + MATHEMATICAL_FUNCTIONS("SELECT round(3.1415, 2);"), + STRING_FUNCTIONS("SELECT ascii('Hello');"), + BITWISE_FUNCTIONS("SELECT bit_count(42);"), + CONVERSION_FUNCTIONS("SELECT cast('2023-04-01' as date);"), + CONDITIONAL_FUNCTIONS("SELECT if(1 > 0, 'true', 'false');"), + PREDICATE_FUNCTIONS("SELECT isnotnull(1);"), + CSV_FUNCTIONS("SELECT from_csv(array('a', 'b', 'c'), ',');"), + MISC_FUNCTIONS("SELECT current_user();"), + + // Aggregate-like Functions + AGGREGATE_FUNCTIONS("SELECT count(*), max(age), min(age) FROM my_table;"), + WINDOW_FUNCTIONS("SELECT name, age, rank() OVER (ORDER BY age DESC) FROM my_table;"), + + // Generator Functions + GENERATOR_FUNCTIONS("SELECT explode(array(1, 2, 3));"), + + // UDFs (User-Defined Functions) + SCALAR_USER_DEFINED_FUNCTIONS("SELECT my_udf(name) FROM my_table;"), + USER_DEFINED_AGGREGATE_FUNCTIONS("SELECT my_udaf(age) FROM my_table GROUP BY name;"), + INTEGRATION_WITH_HIVE_UDFS_UDAFS_UDTFS("SELECT my_hive_udf(name) FROM my_table;"); + + @Getter private final String[] queries; + + TestElement(String... queries) { + this.queries = queries; + } + } + + @Test + void testAllowAllByDefault() { + when(mockedProvider.getValidatorForDatasource(any())) + .thenReturn(new DefaultGrammarElementValidator()); + VerifyValidator v = new VerifyValidator(sqlQueryValidator, DataSourceType.SPARK); + Arrays.stream(TestElement.values()).forEach(v::ok); + } + + @Test + void testDenyAllValidator() { + when(mockedProvider.getValidatorForDatasource(any())).thenReturn(element -> false); + VerifyValidator v = new VerifyValidator(sqlQueryValidator, DataSourceType.SPARK); + // The elements which doesn't have validation will be accepted. + // That's why there are some ok case + + // DDL Statements + v.ng(TestElement.ALTER_DATABASE); + v.ng(TestElement.ALTER_TABLE); + v.ng(TestElement.ALTER_VIEW); + v.ng(TestElement.CREATE_DATABASE); + v.ng(TestElement.CREATE_FUNCTION); + v.ng(TestElement.CREATE_TABLE); + v.ng(TestElement.CREATE_VIEW); + v.ng(TestElement.DROP_DATABASE); + v.ng(TestElement.DROP_FUNCTION); + v.ng(TestElement.DROP_TABLE); + v.ng(TestElement.DROP_VIEW); + v.ng(TestElement.REPAIR_TABLE); + v.ng(TestElement.TRUNCATE_TABLE); + + // DML Statements + v.ng(TestElement.INSERT_TABLE); + v.ng(TestElement.INSERT_OVERWRITE_DIRECTORY); + v.ng(TestElement.LOAD); + + // Data Retrieval + v.ng(TestElement.EXPLAIN); + v.ng(TestElement.COMMON_TABLE_EXPRESSION); + v.ng(TestElement.CLUSTER_BY_CLAUSE); + v.ng(TestElement.DISTRIBUTE_BY_CLAUSE); + v.ok(TestElement.GROUP_BY_CLAUSE); + v.ok(TestElement.HAVING_CLAUSE); + v.ng(TestElement.HINTS); + v.ng(TestElement.INLINE_TABLE); + v.ng(TestElement.FILE); + v.ng(TestElement.INNER_JOIN); + v.ng(TestElement.CROSS_JOIN); + v.ng(TestElement.LEFT_OUTER_JOIN); + v.ng(TestElement.LEFT_SEMI_JOIN); + v.ng(TestElement.RIGHT_OUTER_JOIN); + v.ng(TestElement.FULL_OUTER_JOIN); + v.ng(TestElement.LEFT_ANTI_JOIN); + v.ok(TestElement.LIKE_PREDICATE); + v.ok(TestElement.LIMIT_CLAUSE); + v.ok(TestElement.OFFSET_CLAUSE); + v.ok(TestElement.ORDER_BY_CLAUSE); + v.ok(TestElement.SET_OPERATORS); + v.ok(TestElement.SORT_BY_CLAUSE); + v.ng(TestElement.TABLESAMPLE); + v.ng(TestElement.TABLE_VALUED_FUNCTION); + v.ok(TestElement.WHERE_CLAUSE); + v.ok(TestElement.AGGREGATE_FUNCTION); + v.ok(TestElement.WINDOW_FUNCTION); + v.ok(TestElement.CASE_CLAUSE); + v.ok(TestElement.PIVOT_CLAUSE); + v.ok(TestElement.UNPIVOT_CLAUSE); + v.ng(TestElement.LATERAL_VIEW_CLAUSE); + v.ng(TestElement.LATERAL_SUBQUERY); + v.ng(TestElement.TRANSFORM_CLAUSE); + + // Auxiliary Statements + v.ng(TestElement.ADD_FILE); + v.ng(TestElement.ADD_JAR); + v.ng(TestElement.ANALYZE_TABLE); + v.ng(TestElement.CACHE_TABLE); + v.ng(TestElement.CLEAR_CACHE); + v.ng(TestElement.DESCRIBE_DATABASE); + v.ng(TestElement.DESCRIBE_FUNCTION); + v.ng(TestElement.DESCRIBE_QUERY); + v.ng(TestElement.DESCRIBE_TABLE); + v.ng(TestElement.LIST_FILE); + v.ng(TestElement.LIST_JAR); + v.ng(TestElement.REFRESH); + v.ng(TestElement.REFRESH_TABLE); + v.ng(TestElement.REFRESH_FUNCTION); + v.ng(TestElement.RESET); + v.ng(TestElement.SET); + v.ng(TestElement.SHOW_COLUMNS); + v.ng(TestElement.SHOW_CREATE_TABLE); + v.ng(TestElement.SHOW_DATABASES); + v.ng(TestElement.SHOW_FUNCTIONS); + v.ng(TestElement.SHOW_PARTITIONS); + v.ng(TestElement.SHOW_TABLE_EXTENDED); + v.ng(TestElement.SHOW_TABLES); + v.ng(TestElement.SHOW_TBLPROPERTIES); + v.ng(TestElement.SHOW_VIEWS); + v.ng(TestElement.UNCACHE_TABLE); + + // Functions + v.ok(TestElement.ARRAY_FUNCTIONS); + v.ng(TestElement.MAP_FUNCTIONS); + v.ok(TestElement.DATE_AND_TIMESTAMP_FUNCTIONS); + v.ok(TestElement.JSON_FUNCTIONS); + v.ok(TestElement.MATHEMATICAL_FUNCTIONS); + v.ok(TestElement.STRING_FUNCTIONS); + v.ok(TestElement.BITWISE_FUNCTIONS); + v.ok(TestElement.CONVERSION_FUNCTIONS); + v.ok(TestElement.CONDITIONAL_FUNCTIONS); + v.ok(TestElement.PREDICATE_FUNCTIONS); + v.ng(TestElement.CSV_FUNCTIONS); + v.ng(TestElement.MISC_FUNCTIONS); + + // Aggregate-like Functions + v.ok(TestElement.AGGREGATE_FUNCTIONS); + v.ok(TestElement.WINDOW_FUNCTIONS); + + // Generator Functions + v.ok(TestElement.GENERATOR_FUNCTIONS); + + // UDFs + v.ng(TestElement.SCALAR_USER_DEFINED_FUNCTIONS); + v.ng(TestElement.USER_DEFINED_AGGREGATE_FUNCTIONS); + v.ng(TestElement.INTEGRATION_WITH_HIVE_UDFS_UDAFS_UDTFS); + } + + @Test + void testS3glueQueries() { + when(mockedProvider.getValidatorForDatasource(any())) + .thenReturn(new S3GlueGrammarElementValidator()); + VerifyValidator v = new VerifyValidator(sqlQueryValidator, DataSourceType.S3GLUE); + + // DDL Statements + v.ok(TestElement.ALTER_DATABASE); + v.ok(TestElement.ALTER_TABLE); + v.ng(TestElement.ALTER_VIEW); + v.ok(TestElement.CREATE_DATABASE); + v.ng(TestElement.CREATE_FUNCTION); + v.ok(TestElement.CREATE_TABLE); + v.ng(TestElement.CREATE_VIEW); + v.ok(TestElement.DROP_DATABASE); + v.ng(TestElement.DROP_FUNCTION); + v.ok(TestElement.DROP_TABLE); + v.ng(TestElement.DROP_VIEW); + v.ok(TestElement.REPAIR_TABLE); + v.ok(TestElement.TRUNCATE_TABLE); + + // DML Statements + v.ng(TestElement.INSERT_TABLE); + v.ng(TestElement.INSERT_OVERWRITE_DIRECTORY); + v.ng(TestElement.LOAD); + + // Data Retrieval + v.ok(TestElement.SELECT); + v.ok(TestElement.EXPLAIN); + v.ok(TestElement.COMMON_TABLE_EXPRESSION); + v.ng(TestElement.CLUSTER_BY_CLAUSE); + v.ng(TestElement.DISTRIBUTE_BY_CLAUSE); + v.ok(TestElement.GROUP_BY_CLAUSE); + v.ok(TestElement.HAVING_CLAUSE); + v.ng(TestElement.HINTS); + v.ng(TestElement.INLINE_TABLE); + v.ng(TestElement.FILE); + v.ok(TestElement.INNER_JOIN); + v.ok(TestElement.CROSS_JOIN); + v.ok(TestElement.LEFT_OUTER_JOIN); + v.ok(TestElement.LEFT_SEMI_JOIN); + v.ok(TestElement.RIGHT_OUTER_JOIN); + v.ok(TestElement.FULL_OUTER_JOIN); + v.ok(TestElement.LEFT_ANTI_JOIN); + v.ok(TestElement.LIKE_PREDICATE); + v.ok(TestElement.LIMIT_CLAUSE); + v.ok(TestElement.OFFSET_CLAUSE); + v.ok(TestElement.ORDER_BY_CLAUSE); + v.ok(TestElement.SET_OPERATORS); + v.ok(TestElement.SORT_BY_CLAUSE); + v.ng(TestElement.TABLESAMPLE); + v.ng(TestElement.TABLE_VALUED_FUNCTION); + v.ok(TestElement.WHERE_CLAUSE); + v.ok(TestElement.AGGREGATE_FUNCTION); + v.ok(TestElement.WINDOW_FUNCTION); + v.ok(TestElement.CASE_CLAUSE); + v.ok(TestElement.PIVOT_CLAUSE); + v.ok(TestElement.UNPIVOT_CLAUSE); + v.ok(TestElement.LATERAL_VIEW_CLAUSE); + v.ok(TestElement.LATERAL_SUBQUERY); + v.ng(TestElement.TRANSFORM_CLAUSE); + + // Auxiliary Statements + v.ng(TestElement.ADD_FILE); + v.ng(TestElement.ADD_JAR); + v.ok(TestElement.ANALYZE_TABLE); + v.ok(TestElement.CACHE_TABLE); + v.ok(TestElement.CLEAR_CACHE); + v.ok(TestElement.DESCRIBE_DATABASE); + v.ng(TestElement.DESCRIBE_FUNCTION); + v.ok(TestElement.DESCRIBE_QUERY); + v.ok(TestElement.DESCRIBE_TABLE); + v.ng(TestElement.LIST_FILE); + v.ng(TestElement.LIST_JAR); + v.ng(TestElement.REFRESH); + v.ok(TestElement.REFRESH_TABLE); + v.ng(TestElement.REFRESH_FUNCTION); + v.ng(TestElement.RESET); + v.ng(TestElement.SET); + v.ok(TestElement.SHOW_COLUMNS); + v.ok(TestElement.SHOW_CREATE_TABLE); + v.ok(TestElement.SHOW_DATABASES); + v.ng(TestElement.SHOW_FUNCTIONS); + v.ok(TestElement.SHOW_PARTITIONS); + v.ok(TestElement.SHOW_TABLE_EXTENDED); + v.ok(TestElement.SHOW_TABLES); + v.ok(TestElement.SHOW_TBLPROPERTIES); + v.ng(TestElement.SHOW_VIEWS); + v.ok(TestElement.UNCACHE_TABLE); + + // Functions + v.ok(TestElement.ARRAY_FUNCTIONS); + v.ok(TestElement.MAP_FUNCTIONS); + v.ok(TestElement.DATE_AND_TIMESTAMP_FUNCTIONS); + v.ok(TestElement.JSON_FUNCTIONS); + v.ok(TestElement.MATHEMATICAL_FUNCTIONS); + v.ok(TestElement.STRING_FUNCTIONS); + v.ok(TestElement.BITWISE_FUNCTIONS); + v.ok(TestElement.CONVERSION_FUNCTIONS); + v.ok(TestElement.CONDITIONAL_FUNCTIONS); + v.ok(TestElement.PREDICATE_FUNCTIONS); + v.ok(TestElement.CSV_FUNCTIONS); + v.ng(TestElement.MISC_FUNCTIONS); + + // Aggregate-like Functions + v.ok(TestElement.AGGREGATE_FUNCTIONS); + v.ok(TestElement.WINDOW_FUNCTIONS); + + // Generator Functions + v.ok(TestElement.GENERATOR_FUNCTIONS); + + // UDFs + v.ng(TestElement.SCALAR_USER_DEFINED_FUNCTIONS); + v.ng(TestElement.USER_DEFINED_AGGREGATE_FUNCTIONS); + v.ng(TestElement.INTEGRATION_WITH_HIVE_UDFS_UDAFS_UDTFS); + } + + @Test + void testSecurityLakeQueries() { + when(mockedProvider.getValidatorForDatasource(any())) + .thenReturn(new SecurityLakeGrammarElementValidator()); + VerifyValidator v = new VerifyValidator(sqlQueryValidator, DataSourceType.SECURITY_LAKE); + + // DDL Statements + v.ng(TestElement.ALTER_DATABASE); + v.ng(TestElement.ALTER_TABLE); + v.ng(TestElement.ALTER_VIEW); + v.ng(TestElement.CREATE_DATABASE); + v.ng(TestElement.CREATE_FUNCTION); + v.ng(TestElement.CREATE_TABLE); + v.ng(TestElement.CREATE_VIEW); + v.ng(TestElement.DROP_DATABASE); + v.ng(TestElement.DROP_FUNCTION); + v.ng(TestElement.DROP_TABLE); + v.ng(TestElement.DROP_VIEW); + v.ng(TestElement.REPAIR_TABLE); + v.ng(TestElement.TRUNCATE_TABLE); + + // DML Statements + v.ng(TestElement.INSERT_TABLE); + v.ng(TestElement.INSERT_OVERWRITE_DIRECTORY); + v.ng(TestElement.LOAD); + + // Data Retrieval + v.ok(TestElement.SELECT); + v.ok(TestElement.EXPLAIN); + v.ok(TestElement.COMMON_TABLE_EXPRESSION); + v.ng(TestElement.CLUSTER_BY_CLAUSE); + v.ng(TestElement.DISTRIBUTE_BY_CLAUSE); + v.ok(TestElement.GROUP_BY_CLAUSE); + v.ok(TestElement.HAVING_CLAUSE); + v.ng(TestElement.HINTS); + v.ng(TestElement.INLINE_TABLE); + v.ng(TestElement.FILE); + v.ok(TestElement.INNER_JOIN); + v.ok(TestElement.CROSS_JOIN); + v.ok(TestElement.LEFT_OUTER_JOIN); + v.ok(TestElement.LEFT_SEMI_JOIN); + v.ok(TestElement.RIGHT_OUTER_JOIN); + v.ok(TestElement.FULL_OUTER_JOIN); + v.ok(TestElement.LEFT_ANTI_JOIN); + v.ok(TestElement.LIKE_PREDICATE); + v.ok(TestElement.LIMIT_CLAUSE); + v.ok(TestElement.OFFSET_CLAUSE); + v.ok(TestElement.ORDER_BY_CLAUSE); + v.ok(TestElement.SET_OPERATORS); + v.ok(TestElement.SORT_BY_CLAUSE); + v.ng(TestElement.TABLESAMPLE); + v.ng(TestElement.TABLE_VALUED_FUNCTION); + v.ok(TestElement.WHERE_CLAUSE); + v.ok(TestElement.AGGREGATE_FUNCTION); + v.ok(TestElement.WINDOW_FUNCTION); + v.ok(TestElement.CASE_CLAUSE); + v.ok(TestElement.PIVOT_CLAUSE); + v.ok(TestElement.UNPIVOT_CLAUSE); + v.ok(TestElement.LATERAL_VIEW_CLAUSE); + v.ok(TestElement.LATERAL_SUBQUERY); + v.ng(TestElement.TRANSFORM_CLAUSE); + + // Auxiliary Statements + v.ng(TestElement.ADD_FILE); + v.ng(TestElement.ADD_JAR); + v.ng(TestElement.ANALYZE_TABLE); + v.ng(TestElement.CACHE_TABLE); + v.ng(TestElement.CLEAR_CACHE); + v.ng(TestElement.DESCRIBE_DATABASE); + v.ng(TestElement.DESCRIBE_FUNCTION); + v.ng(TestElement.DESCRIBE_QUERY); + v.ng(TestElement.DESCRIBE_TABLE); + v.ng(TestElement.LIST_FILE); + v.ng(TestElement.LIST_JAR); + v.ng(TestElement.REFRESH); + v.ng(TestElement.REFRESH_TABLE); + v.ng(TestElement.REFRESH_FUNCTION); + v.ng(TestElement.RESET); + v.ng(TestElement.SET); + v.ng(TestElement.SHOW_COLUMNS); + v.ng(TestElement.SHOW_CREATE_TABLE); + v.ng(TestElement.SHOW_DATABASES); + v.ng(TestElement.SHOW_FUNCTIONS); + v.ng(TestElement.SHOW_PARTITIONS); + v.ng(TestElement.SHOW_TABLE_EXTENDED); + v.ng(TestElement.SHOW_TABLES); + v.ng(TestElement.SHOW_TBLPROPERTIES); + v.ng(TestElement.SHOW_VIEWS); + v.ng(TestElement.UNCACHE_TABLE); + + // Functions + v.ok(TestElement.ARRAY_FUNCTIONS); + v.ok(TestElement.MAP_FUNCTIONS); + v.ok(TestElement.DATE_AND_TIMESTAMP_FUNCTIONS); + v.ok(TestElement.JSON_FUNCTIONS); + v.ok(TestElement.MATHEMATICAL_FUNCTIONS); + v.ok(TestElement.STRING_FUNCTIONS); + v.ok(TestElement.BITWISE_FUNCTIONS); + v.ok(TestElement.CONVERSION_FUNCTIONS); + v.ok(TestElement.CONDITIONAL_FUNCTIONS); + v.ok(TestElement.PREDICATE_FUNCTIONS); + v.ng(TestElement.CSV_FUNCTIONS); + v.ng(TestElement.MISC_FUNCTIONS); + + // Aggregate-like Functions + v.ok(TestElement.AGGREGATE_FUNCTIONS); + v.ok(TestElement.WINDOW_FUNCTIONS); + + // Generator Functions + v.ok(TestElement.GENERATOR_FUNCTIONS); + + // UDFs + v.ng(TestElement.SCALAR_USER_DEFINED_FUNCTIONS); + v.ng(TestElement.USER_DEFINED_AGGREGATE_FUNCTIONS); + v.ng(TestElement.INTEGRATION_WITH_HIVE_UDFS_UDAFS_UDTFS); + } + + @AllArgsConstructor + private static class VerifyValidator { + private final SQLQueryValidator validator; + private final DataSourceType dataSourceType; + + public void ok(TestElement query) { + runValidate(query.getQueries()); + } + + public void ng(TestElement query) { + assertThrows( + IllegalArgumentException.class, + () -> runValidate(query.getQueries()), + "The query should throw: query=`" + query.toString() + "`"); + } + + void runValidate(String[] queries) { + Arrays.stream(queries).forEach(query -> validator.validate(query, dataSourceType)); + } + + void runValidate(String query) { + validator.validate(query, dataSourceType); + } + + SingleStatementContext getParser(String query) { + SqlBaseParser sqlBaseParser = + new SqlBaseParser( + new CommonTokenStream(new SqlBaseLexer(new CaseInsensitiveCharStream(query)))); + return sqlBaseParser.singleStatement(); + } + } +} diff --git a/async-query/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java b/async-query/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java index c6f6ffcd81..db070182a3 100644 --- a/async-query/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java +++ b/async-query/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java @@ -7,6 +7,7 @@ import static org.opensearch.sql.spark.execution.statestore.StateStore.ALL_DATASOURCE; +import com.google.common.collect.ImmutableMap; import lombok.RequiredArgsConstructor; import org.opensearch.client.node.NodeClient; import org.opensearch.cluster.service.ClusterService; @@ -64,6 +65,11 @@ import org.opensearch.sql.spark.response.OpenSearchJobExecutionResponseReader; import org.opensearch.sql.spark.scheduler.AsyncQueryScheduler; import org.opensearch.sql.spark.scheduler.OpenSearchAsyncQueryScheduler; +import org.opensearch.sql.spark.validator.DefaultGrammarElementValidator; +import org.opensearch.sql.spark.validator.GrammarElementValidatorProvider; +import org.opensearch.sql.spark.validator.S3GlueGrammarElementValidator; +import org.opensearch.sql.spark.validator.SQLQueryValidator; +import org.opensearch.sql.spark.validator.SecurityLakeGrammarElementValidator; @RequiredArgsConstructor public class AsyncExecutorServiceModule extends AbstractModule { @@ -101,9 +107,10 @@ public SparkQueryDispatcher sparkQueryDispatcher( DataSourceService dataSourceService, SessionManager sessionManager, QueryHandlerFactory queryHandlerFactory, - QueryIdProvider queryIdProvider) { + QueryIdProvider queryIdProvider, + SQLQueryValidator sqlQueryValidator) { return new SparkQueryDispatcher( - dataSourceService, sessionManager, queryHandlerFactory, queryIdProvider); + dataSourceService, sessionManager, queryHandlerFactory, queryIdProvider, sqlQueryValidator); } @Provides @@ -174,6 +181,19 @@ public SparkSubmitParametersBuilderProvider sparkSubmitParametersBuilderProvider return new SparkSubmitParametersBuilderProvider(collection); } + @Provides + public SQLQueryValidator sqlQueryValidator() { + GrammarElementValidatorProvider validatorProvider = + new GrammarElementValidatorProvider( + ImmutableMap.of( + DataSourceType.S3GLUE, + new S3GlueGrammarElementValidator(), + DataSourceType.SECURITY_LAKE, + new SecurityLakeGrammarElementValidator()), + new DefaultGrammarElementValidator()); + return new SQLQueryValidator(validatorProvider); + } + @Provides public IndexDMLResultStorageService indexDMLResultStorageService( DataSourceService dataSourceService, StateStore stateStore) { diff --git a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java index db0adfc156..175f9ac914 100644 --- a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java +++ b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java @@ -312,7 +312,7 @@ public void withSessionCreateAsyncQueryFailed() { // 1. create async query. CreateAsyncQueryResponse response = asyncQueryExecutorService.createAsyncQuery( - new CreateAsyncQueryRequest("myselect 1", MYS3_DATASOURCE, LangType.SQL, null), + new CreateAsyncQueryRequest("select 1", MYS3_DATASOURCE, LangType.SQL, null), asyncQueryRequestContext); assertNotNull(response.getSessionId()); Optional statementModel = diff --git a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java index 9b897d36b4..72ed17f5aa 100644 --- a/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java +++ b/async-query/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java @@ -102,6 +102,10 @@ import org.opensearch.sql.spark.response.OpenSearchJobExecutionResponseReader; import org.opensearch.sql.spark.scheduler.AsyncQueryScheduler; import org.opensearch.sql.spark.scheduler.OpenSearchAsyncQueryScheduler; +import org.opensearch.sql.spark.validator.DefaultGrammarElementValidator; +import org.opensearch.sql.spark.validator.GrammarElementValidatorProvider; +import org.opensearch.sql.spark.validator.S3GlueGrammarElementValidator; +import org.opensearch.sql.spark.validator.SQLQueryValidator; import org.opensearch.sql.storage.DataSourceFactory; import org.opensearch.test.OpenSearchIntegTestCase; @@ -308,6 +312,11 @@ protected AsyncQueryExecutorService createAsyncQueryExecutorService( emrServerlessClientFactory, new OpenSearchMetricsService(), sparkSubmitParametersBuilderProvider); + SQLQueryValidator sqlQueryValidator = + new SQLQueryValidator( + new GrammarElementValidatorProvider( + ImmutableMap.of(DataSourceType.S3GLUE, new S3GlueGrammarElementValidator()), + new DefaultGrammarElementValidator())); SparkQueryDispatcher sparkQueryDispatcher = new SparkQueryDispatcher( this.dataSourceService, @@ -318,7 +327,8 @@ protected AsyncQueryExecutorService createAsyncQueryExecutorService( sessionConfigSupplier, sessionIdProvider), queryHandlerFactory, - new DatasourceEmbeddedQueryIdProvider()); + new DatasourceEmbeddedQueryIdProvider(), + sqlQueryValidator); return new AsyncQueryExecutorServiceImpl( asyncQueryJobMetadataStorageService, sparkQueryDispatcher,