From b9958287a1b083f3f82b3a4f2ae6f5a4f2ff5504 Mon Sep 17 00:00:00 2001 From: Kuan-Po Tseng Date: Wed, 25 Oct 2023 17:14:10 +0800 Subject: [PATCH] Remove syntactic sugar and add to-1 relation support in model expression (#378) * Remove accio query syntax runtime support * Implement model expression 1-1 relationship --- .../src/main/java/io/accio/base/AccioMDL.java | 16 + .../main/java/io/accio/base/dto/Column.java | 4 - .../java/io/accio/base/dto/Relationship.java | 10 +- accio-sqlrewrite/pom.xml | 5 + .../io/accio/sqlrewrite/AccioPlanner.java | 12 +- .../io/accio/sqlrewrite/AccioSqlRewrite.java | 284 +--- .../java/io/accio/sqlrewrite/EnumRewrite.java | 91 ++ .../LambdaExpressionBodyRewrite.java | 150 -- .../sqlrewrite/MetricViewSqlRewrite.java | 3 +- .../java/io/accio/sqlrewrite/ModelInfo.java | 240 +++ .../io/accio/sqlrewrite/RelationshipCTE.java | 137 -- .../sqlrewrite/RelationshipCteGenerator.java | 1064 ------------ .../sqlrewrite/RelationshipRewriter.java | 73 + .../accio/sqlrewrite/ScopeAwareRewrite.java | 220 --- .../sqlrewrite/SyntacticSugarRewrite.java | 131 -- .../main/java/io/accio/sqlrewrite/Utils.java | 16 +- .../accio/sqlrewrite/analyzer/Analysis.java | 108 +- .../analyzer/ExpressionAnalysis.java | 64 - .../analyzer/ExpressionAnalyzer.java | 451 ------ .../ExpressionRelationshipAnalyzer.java | 162 ++ .../analyzer/ExpressionRelationshipInfo.java | 81 + .../analyzer/FunctionChainAnalyzer.java | 304 ---- .../analyzer/PreAggregationAnalysis.java | 14 + .../analyzer/StatementAnalyzer.java | 116 +- .../java/io/accio/TestScopeAwareRewrite.java | 240 --- .../accio/sqlrewrite/TestAccioSqlRewrite.java | 139 -- .../accio/sqlrewrite/TestAllRulesRewrite.java | 31 +- .../io/accio/sqlrewrite/TestEnumRewrite.java | 3 +- .../TestExpressionRelationshipRewriter.java | 106 ++ .../TestLambdaExpressionBodyRewrite.java | 55 - .../sqlrewrite/TestMetricViewSqlRewrite.java | 56 +- .../accio/sqlrewrite/TestModelSqlRewrite.java | 290 ++++ .../sqlrewrite/TestRelationshipAccessing.java | 1427 ----------------- .../sqlrewrite/TestSyntacticSugarRewrite.java | 125 -- .../src/test/resources/tpch_mdl.json | 328 ++++ accio-testing/pom.xml | 10 +- .../accio/testing/AbstractTestFramework.java | 52 +- .../io/accio/TestBigQuerySqlConverter.java | 2 +- .../io/accio/testing/RequireAccioServer.java | 15 +- .../bigquery/AbstractPreAggregationTest.java | 11 +- .../bigquery/TestAccioWithBigquery.java | 625 +------- .../bigquery/TestBigQueryPreAggregation.java | 6 +- .../testing/bigquery/TestBigQueryType.java | 13 +- .../testing/bigquery/TestPreAggregation.java | 20 +- .../bigquery/TestRefreshPreAggregation.java | 5 +- .../bigquery/TestReloadPreAggregation.java | 4 +- accio-tests/src/test/resources/tpch_mdl.json | 21 +- pom.xml | 6 + .../io/trino/sql/ExpressionFormatter.java | 2 +- .../trino/sql/tree/DereferenceExpression.java | 6 +- .../java/io/trino/sql/tree/QualifiedName.java | 9 + 51 files changed, 1643 insertions(+), 5720 deletions(-) create mode 100644 accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/EnumRewrite.java delete mode 100644 accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/LambdaExpressionBodyRewrite.java create mode 100644 accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/ModelInfo.java delete mode 100644 accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/RelationshipCTE.java delete mode 100644 accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/RelationshipCteGenerator.java create mode 100644 accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/RelationshipRewriter.java delete mode 100644 accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/ScopeAwareRewrite.java delete mode 100644 accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/SyntacticSugarRewrite.java delete mode 100644 accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/analyzer/ExpressionAnalysis.java delete mode 100644 accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/analyzer/ExpressionAnalyzer.java create mode 100644 accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/analyzer/ExpressionRelationshipAnalyzer.java create mode 100644 accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/analyzer/ExpressionRelationshipInfo.java delete mode 100644 accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/analyzer/FunctionChainAnalyzer.java delete mode 100644 accio-sqlrewrite/src/test/java/io/accio/TestScopeAwareRewrite.java delete mode 100644 accio-sqlrewrite/src/test/java/io/accio/sqlrewrite/TestAccioSqlRewrite.java create mode 100644 accio-sqlrewrite/src/test/java/io/accio/sqlrewrite/TestExpressionRelationshipRewriter.java delete mode 100644 accio-sqlrewrite/src/test/java/io/accio/sqlrewrite/TestLambdaExpressionBodyRewrite.java create mode 100644 accio-sqlrewrite/src/test/java/io/accio/sqlrewrite/TestModelSqlRewrite.java delete mode 100644 accio-sqlrewrite/src/test/java/io/accio/sqlrewrite/TestRelationshipAccessing.java delete mode 100644 accio-sqlrewrite/src/test/java/io/accio/sqlrewrite/TestSyntacticSugarRewrite.java create mode 100644 accio-sqlrewrite/src/test/resources/tpch_mdl.json diff --git a/accio-base/src/main/java/io/accio/base/AccioMDL.java b/accio-base/src/main/java/io/accio/base/AccioMDL.java index 8536e8302..12ce822bb 100644 --- a/accio-base/src/main/java/io/accio/base/AccioMDL.java +++ b/accio-base/src/main/java/io/accio/base/AccioMDL.java @@ -16,6 +16,7 @@ import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; +import io.accio.base.dto.Column; import io.accio.base.dto.EnumDefinition; import io.accio.base.dto.Manifest; import io.accio.base.dto.Metric; @@ -156,4 +157,19 @@ public Optional getView(CatalogSchemaTableName name) } return Optional.empty(); } + + public static Optional getRelationshipColumn(Model model, String name) + { + return getColumn(model, name) + .filter(column -> column.getRelationship().isPresent()); + } + + private static Optional getColumn(Model model, String name) + { + requireNonNull(model); + requireNonNull(name); + return model.getColumns().stream() + .filter(column -> column.getName().equals(name)) + .findAny(); + } } diff --git a/accio-base/src/main/java/io/accio/base/dto/Column.java b/accio-base/src/main/java/io/accio/base/dto/Column.java index 502dde7ea..60e2ad2da 100644 --- a/accio-base/src/main/java/io/accio/base/dto/Column.java +++ b/accio-base/src/main/java/io/accio/base/dto/Column.java @@ -107,10 +107,6 @@ public Optional getExpression() public String getSqlExpression() { - if (getRelationship().isPresent()) { - return String.format("'relationship<%s>' as %s", relationship, quote(name)); - } - if (getExpression().isEmpty()) { return quote(name); } diff --git a/accio-base/src/main/java/io/accio/base/dto/Relationship.java b/accio-base/src/main/java/io/accio/base/dto/Relationship.java index 7b36573f1..1bfecac12 100644 --- a/accio-base/src/main/java/io/accio/base/dto/Relationship.java +++ b/accio-base/src/main/java/io/accio/base/dto/Relationship.java @@ -16,6 +16,7 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.Lists; import java.util.Arrays; import java.util.List; @@ -53,7 +54,14 @@ public static Relationship relationship(String name, List models, JoinTy public static Relationship reverse(Relationship relationship) { - return new Relationship(relationship.name, relationship.getModels(), JoinType.reverse(relationship.joinType), relationship.getCondition(), true, relationship.getManySideSortKeys(), relationship.getDescription()); + return new Relationship( + relationship.name, + Lists.reverse(relationship.getModels()), + JoinType.reverse(relationship.joinType), + relationship.getCondition(), + true, + relationship.getManySideSortKeys(), + relationship.getDescription()); } @JsonCreator diff --git a/accio-sqlrewrite/pom.xml b/accio-sqlrewrite/pom.xml index 7f57bae37..64376c03e 100644 --- a/accio-sqlrewrite/pom.xml +++ b/accio-sqlrewrite/pom.xml @@ -47,6 +47,11 @@ annotations + + org.jgrapht + jgrapht-core + + io.accio accio-testing diff --git a/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/AccioPlanner.java b/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/AccioPlanner.java index 51349b9a9..20e3ff2ee 100644 --- a/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/AccioPlanner.java +++ b/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/AccioPlanner.java @@ -24,16 +24,16 @@ import java.util.List; import static io.accio.sqlrewrite.AccioSqlRewrite.ACCIO_SQL_REWRITE; +import static io.accio.sqlrewrite.EnumRewrite.ENUM_REWRITE; import static io.accio.sqlrewrite.MetricViewSqlRewrite.METRIC_VIEW_SQL_REWRITE; -import static io.accio.sqlrewrite.SyntacticSugarRewrite.SYNTACTIC_SUGAR_REWRITE; import static io.trino.sql.parser.ParsingOptions.DecimalLiteralTreatment.AS_DECIMAL; public class AccioPlanner { public static final List ALL_RULES = List.of( METRIC_VIEW_SQL_REWRITE, - SYNTACTIC_SUGAR_REWRITE, - ACCIO_SQL_REWRITE); + ACCIO_SQL_REWRITE, + ENUM_REWRITE); private static final SqlParser SQL_PARSER = new SqlParser(); private AccioPlanner() {} @@ -46,12 +46,10 @@ public static String rewrite(String sql, SessionContext sessionContext, AccioMDL public static String rewrite(String sql, SessionContext sessionContext, AccioMDL accioMDL, List rules) { Statement statement = SQL_PARSER.createStatement(sql, new ParsingOptions(AS_DECIMAL)); - Statement scopedStatement = ScopeAwareRewrite.SCOPE_AWARE_REWRITE.rewrite(statement, accioMDL, sessionContext); - Statement result = scopedStatement; for (AccioRule rule : rules) { // we will replace or rewrite sql node in sql rewrite, to avoid rewrite rules affect each other, format and parse sql before each rewrite - result = rule.apply(SQL_PARSER.createStatement(SqlFormatter.formatSql(result), new ParsingOptions(AS_DECIMAL)), sessionContext, accioMDL); + statement = rule.apply(SQL_PARSER.createStatement(SqlFormatter.formatSql(statement), new ParsingOptions(AS_DECIMAL)), sessionContext, accioMDL); } - return SqlFormatter.formatSql(result); + return SqlFormatter.formatSql(statement); } } diff --git a/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/AccioSqlRewrite.java b/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/AccioSqlRewrite.java index 7195877e5..af920681b 100644 --- a/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/AccioSqlRewrite.java +++ b/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/AccioSqlRewrite.java @@ -14,51 +14,38 @@ package io.accio.sqlrewrite; -import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; import io.accio.base.AccioMDL; import io.accio.base.SessionContext; -import io.accio.base.dto.EnumDefinition; -import io.accio.base.dto.EnumValue; -import io.accio.base.dto.Model; import io.accio.sqlrewrite.analyzer.Analysis; import io.accio.sqlrewrite.analyzer.StatementAnalyzer; -import io.trino.sql.QueryUtil; -import io.trino.sql.tree.AliasedRelation; -import io.trino.sql.tree.ComparisonExpression; import io.trino.sql.tree.DereferenceExpression; -import io.trino.sql.tree.Expression; -import io.trino.sql.tree.FunctionCall; import io.trino.sql.tree.FunctionRelation; import io.trino.sql.tree.Identifier; -import io.trino.sql.tree.JoinCriteria; import io.trino.sql.tree.Node; import io.trino.sql.tree.NodeRef; import io.trino.sql.tree.QualifiedName; import io.trino.sql.tree.Query; -import io.trino.sql.tree.Relation; import io.trino.sql.tree.Statement; -import io.trino.sql.tree.StringLiteral; -import io.trino.sql.tree.SubscriptExpression; import io.trino.sql.tree.Table; import io.trino.sql.tree.With; import io.trino.sql.tree.WithQuery; +import org.jgrapht.graph.DirectedAcyclicGraph; +import org.jgrapht.graph.GraphCycleProhibitedException; -import java.util.Collection; +import java.util.ArrayList; +import java.util.HashSet; import java.util.List; -import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.stream.Stream; -import static io.trino.sql.QueryUtil.equal; -import static io.trino.sql.QueryUtil.joinOn; -import static io.trino.sql.QueryUtil.table; -import static io.trino.sql.tree.DereferenceExpression.getQualifiedName; +import static com.google.common.base.Strings.nullToEmpty; +import static io.accio.sqlrewrite.Utils.parseQuery; import static java.lang.String.format; import static java.util.Objects.requireNonNull; import static java.util.stream.Collectors.toSet; import static java.util.stream.Collectors.toUnmodifiableList; -import static java.util.stream.Collectors.toUnmodifiableMap; public class AccioSqlRewrite implements AccioRule @@ -70,12 +57,46 @@ private AccioSqlRewrite() {} @Override public Statement apply(Statement root, SessionContext sessionContext, Analysis analysis, AccioMDL accioMDL) { - Map modelQueries = - analysis.getModels().stream() - .collect(toUnmodifiableMap(Model::getName, Utils::parseModelSql)); + DirectedAcyclicGraph graph = new DirectedAcyclicGraph<>(Object.class); + Set modelInfos = analysis.getModels().stream().map(model -> ModelInfo.get(model, accioMDL)).collect(toSet()); + Set requiredModelInfos = new HashSet<>(); + modelInfos.forEach(modelInfo -> addModelToGraph(modelInfo, graph, accioMDL, requiredModelInfos)); + Set allModelInfos = ImmutableSet.builder().addAll(modelInfos).addAll(requiredModelInfos).build(); + + List withQueries = new ArrayList<>(); + graph.iterator().forEachRemaining(modelName -> { + ModelInfo modelInfo = allModelInfos.stream() + .filter(info -> info.getModel().getName().equals(modelName)) + .findAny() + .orElseThrow(() -> new IllegalArgumentException(format("Missing model name %s in graph", modelName))); + withQueries.add(new WithQuery(new Identifier(modelInfo.getModel().getName()), parseQuery(modelInfo.getSql()), Optional.empty())); + }); + + Node rewriteWith = new WithRewriter(withQueries).process(root); + return (Statement) new Rewriter(accioMDL, analysis).process(rewriteWith); + } + + private static void addModelToGraph(ModelInfo modelInfo, DirectedAcyclicGraph graph, AccioMDL mdl, Set modelInfos) + { + // add vertex + graph.addVertex(modelInfo.getModel().getName()); + modelInfo.getRequiredModels().forEach(graph::addVertex); - Node rewriteWith = new WithRewriter(modelQueries, analysis).process(root); - return (Statement) new Rewriter(analysis, accioMDL).process(rewriteWith); + //add edge + try { + modelInfo.getRequiredModels().forEach(modelName -> + graph.addEdge(modelName, modelInfo.getModel().getName())); + } + catch (GraphCycleProhibitedException ex) { + throw new IllegalArgumentException("found cycle in models", ex); + } + + // add required models to graph + for (String modelName : modelInfo.getRequiredModels()) { + ModelInfo info = ModelInfo.get(mdl.getModel(modelName).orElseThrow(), mdl); + modelInfos.add(info); + addModelToGraph(info, graph, mdl, modelInfos); + } } @Override @@ -85,60 +106,19 @@ public Statement apply(Statement root, SessionContext sessionContext, AccioMDL a return apply(root, sessionContext, analysis, accioMDL); } - /** - * In MLRewriter, we will add all participated model sql in WITH-QUERY, and rewrite - * all tables that are models to TableSubQuery in WITH-QUERYs - *

- * e.g. Given model "foo" and its reference sql is SELECT * FROM t1 - *

-     *     SELECT * FROM foo
-     * 
- * will be rewritten to - *
-     *     WITH foo AS (SELECT * FROM t1)
-     *     SELECT * FROM foo
-     * 
- * and - *
-     *     WITH a AS (SELECT * FROM foo)
-     *     SELECT * FROM a JOIN b on a.id=b.id
-     * 
- * will be rewritten to - *
-     *     WITH foo AS (SELECT * FROM t1),
-     *          a AS (SELECT * FROM foo)
-     *     SELECT * FROM a JOIN b on a.id=b.id
-     * 
- */ private static class WithRewriter extends BaseRewriter { - private final Map modelQueries; - private final Analysis analysis; + private final List withQueries; - public WithRewriter( - Map modelQueries, - Analysis analysis) + public WithRewriter(List withQueries) { - this.modelQueries = requireNonNull(modelQueries, "modelQueries is null"); - this.analysis = requireNonNull(analysis, "analysis is null"); + this.withQueries = requireNonNull(withQueries, "withQueries is null"); } @Override protected Node visitQuery(Query node, Void context) { - List modelWithQueries = modelQueries.entrySet().stream() - .sorted(Map.Entry.comparingByKey()) // sort here to avoid test failed due to wrong with-query order - .map(e -> new WithQuery(new Identifier(e.getKey()), e.getValue(), Optional.empty())) - .collect(toUnmodifiableList()); - - Collection relationshipCTEs = analysis.getRelationshipCTE().values(); - - List withQueries = ImmutableList.builder() - .addAll(modelWithQueries) - .addAll(relationshipCTEs) - .build(); - return new Query( node.getWith() .map(with -> new With( @@ -158,10 +138,10 @@ protected Node visitQuery(Query node, Void context) private static class Rewriter extends BaseRewriter { - private final Analysis analysis; private final AccioMDL accioMDL; + private final Analysis analysis; - Rewriter(Analysis analysis, AccioMDL accioMDL) + Rewriter(AccioMDL accioMDL, Analysis analysis) { this.analysis = analysis; this.accioMDL = accioMDL; @@ -174,40 +154,25 @@ protected Node visitTable(Table node, Void context) if (analysis.getModelNodeRefs().contains(NodeRef.of(node))) { result = applyModelRule(node); } - - Set relationshipCTENames = analysis.getReplaceTableWithCTEs().getOrDefault(NodeRef.of(node), Set.of()); - if (relationshipCTENames.size() > 0) { - result = applyRelationshipRule((Table) result, relationshipCTENames); - } - return result; } + // remove catalog schema from expression if exist since all tables are in with cte @Override - protected Node visitAliasedRelation(AliasedRelation node, Void context) + protected Node visitDereferenceExpression(DereferenceExpression dereferenceExpression, Void context) { - Relation result; - - // rewrite the fields in QueryBody - if (node.getLocation().isPresent()) { - result = new AliasedRelation( - node.getLocation().get(), - visitAndCast(node.getRelation(), context), - node.getAlias(), - node.getColumnNames()); - } - else { - result = new AliasedRelation( - visitAndCast(node.getRelation(), context), - node.getAlias(), - node.getColumnNames()); + QualifiedName qualifiedName = DereferenceExpression.getQualifiedName(dereferenceExpression); + if (qualifiedName != null && !nullToEmpty(accioMDL.getCatalog()).isEmpty() && !nullToEmpty(accioMDL.getSchema()).isEmpty()) { + if (qualifiedName.hasPrefix(QualifiedName.of(accioMDL.getCatalog(), accioMDL.getSchema()))) { + return DereferenceExpression.from( + QualifiedName.of(qualifiedName.getOriginalParts().subList(2, qualifiedName.getOriginalParts().size()))); + } + if (qualifiedName.hasPrefix(QualifiedName.of(accioMDL.getSchema()))) { + return DereferenceExpression.from( + QualifiedName.of(qualifiedName.getOriginalParts().subList(1, qualifiedName.getOriginalParts().size()))); + } } - - Set relationshipCTENames = analysis.getReplaceTableWithCTEs().getOrDefault(NodeRef.of(node), Set.of()); - if (relationshipCTENames.size() > 0) { - result = applyRelationshipRule(result, relationshipCTENames); - } - return result; + return dereferenceExpression; } @Override @@ -220,131 +185,10 @@ protected Node visitFunctionRelation(FunctionRelation node, Void context) throw new IllegalArgumentException("MetricRollup node is not replaced"); } - @Override - protected Node visitDereferenceExpression(DereferenceExpression node, Void context) - { - Expression newNode = analysis.getRelationshipFields().getOrDefault(NodeRef.of(node), rewriteEnumIfNeed(node)); - if (newNode != node) { - return newNode; - } - return new DereferenceExpression(node.getLocation(), (Expression) process(node.getBase()), node.getField()); - } - - @Override - protected Node visitSubscriptExpression(SubscriptExpression node, Void context) - { - Expression newNode = analysis.getRelationshipFields().getOrDefault(NodeRef.of(node), node); - if (newNode != node) { - return newNode; - } - return new SubscriptExpression(node.getLocation(), (Expression) process(node.getBase()), node.getIndex()); - } - - private Expression rewriteEnumIfNeed(DereferenceExpression node) - { - QualifiedName qualifiedName = DereferenceExpression.getQualifiedName(node); - if (qualifiedName == null || qualifiedName.getOriginalParts().size() != 2) { - return node; - } - - String enumName = qualifiedName.getOriginalParts().get(0).getValue(); - Optional enumDefinitionOptional = accioMDL.getEnum(enumName); - if (enumDefinitionOptional.isEmpty()) { - return node; - } - - return enumDefinitionOptional.get().valueOf(qualifiedName.getOriginalParts().get(1).getValue()) - .map(EnumValue::getValue) - .map(StringLiteral::new) - .orElseThrow(() -> new IllegalArgumentException(format("Enum value '%s' not found in enum '%s'", qualifiedName.getParts().get(1), qualifiedName.getParts().get(0)))); - } - - @Override - protected Node visitIdentifier(Identifier node, Void context) - { - return analysis.getRelationshipFields().getOrDefault(NodeRef.of(node), node); - } - - @Override - protected Node visitFunctionCall(FunctionCall node, Void context) - { - return analysis.getRelationshipFields().getOrDefault(NodeRef.of(node), - new FunctionCall( - node.getLocation(), - node.getName(), - node.getWindow(), - node.getFilter(), - node.getOrderBy(), - node.isDistinct(), - node.getNullTreatment(), - node.getProcessingMode(), - visitNodes(node.getArguments(), context))); - } - // the model is added in with query, and the catalog and schema should be removed private Node applyModelRule(Table table) { return new Table(QualifiedName.of(table.getName().getSuffix())); } - - private Relation applyRelationshipRule(Relation table, Set relationshipCTENames) - { - Map relationshipInfoMapping = analysis.getRelationshipInfoMapping(); - Set requiredRsCteName = analysis.getRelationshipFields().values().stream() - .map(this::getBaseName) - .collect(toSet()); - - List cteTables = - relationshipCTENames.stream() - .filter(name -> requiredRsCteName.contains(analysis.getRelationshipNameMapping().get(name))) - .map(name -> analysis.getRelationshipCTE().get(name)) - .map(WithQuery::getName) - .map(Identifier::getValue) - .map(QualifiedName::of) - .map(name -> relationshipInfoMapping.get(name.toString())) - .collect(toUnmodifiableList()); - - return leftJoin(table, cteTables); - } - - private String getBaseName(Expression expression) - { - if (expression instanceof DereferenceExpression) { - return ((DereferenceExpression) expression).getBase().toString(); - } - else if (expression instanceof Identifier) { - return ((Identifier) expression).getValue(); - } - throw new IllegalArgumentException("Unexpected expression: " + expression.getClass().getName()); - } - - private static Relation leftJoin(Relation left, List relationshipCTEJoinInfos) - { - Identifier aliasedName = null; - if (left instanceof AliasedRelation) { - aliasedName = ((AliasedRelation) left).getAlias(); - } - - for (RelationshipCteGenerator.RelationshipCTEJoinInfo info : relationshipCTEJoinInfos) { - left = QueryUtil.leftJoin(left, table(QualifiedName.of(info.getCteName())), replaceIfAliased(info.getCondition(), info.getBaseModelName(), aliasedName)); - } - return left; - } - - private static JoinCriteria replaceIfAliased(JoinCriteria original, String baseModelName, Identifier aliasedName) - { - if (aliasedName == null) { - return original; - } - - ComparisonExpression comparisonExpression = (ComparisonExpression) original.getNodes().get(0); - DereferenceExpression left = (DereferenceExpression) comparisonExpression.getLeft(); - Optional originalTableName = requireNonNull(getQualifiedName(left)).getPrefix(); - - if (originalTableName.isPresent() && originalTableName.get().getSuffix().equals(baseModelName)) { - left = new DereferenceExpression(aliasedName, left.getField().orElseThrow()); - } - return joinOn(equal(left, comparisonExpression.getRight())); - } } } diff --git a/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/EnumRewrite.java b/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/EnumRewrite.java new file mode 100644 index 000000000..30cde46da --- /dev/null +++ b/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/EnumRewrite.java @@ -0,0 +1,91 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.accio.sqlrewrite; + +import io.accio.base.AccioMDL; +import io.accio.base.SessionContext; +import io.accio.base.dto.EnumDefinition; +import io.accio.base.dto.EnumValue; +import io.accio.sqlrewrite.analyzer.Analysis; +import io.trino.sql.tree.DereferenceExpression; +import io.trino.sql.tree.Expression; +import io.trino.sql.tree.Node; +import io.trino.sql.tree.QualifiedName; +import io.trino.sql.tree.Statement; +import io.trino.sql.tree.StringLiteral; + +import java.util.Optional; + +import static java.lang.String.format; + +public class EnumRewrite + implements AccioRule +{ + public static final EnumRewrite ENUM_REWRITE = new EnumRewrite(); + + private EnumRewrite() {} + + @Override + public Statement apply(Statement root, SessionContext sessionContext, AccioMDL accioMDL) + { + return apply(root, sessionContext, null, accioMDL); + } + + @Override + public Statement apply(Statement root, SessionContext sessionContext, Analysis analysis, AccioMDL accioMDL) + { + return (Statement) new Rewriter(accioMDL).process(root); + } + + private static class Rewriter + extends BaseRewriter + { + private final AccioMDL accioMDL; + + Rewriter(AccioMDL accioMDL) + { + this.accioMDL = accioMDL; + } + + @Override + protected Node visitDereferenceExpression(DereferenceExpression node, Void context) + { + Expression newNode = rewriteEnumIfNeed(node); + if (newNode != node) { + return newNode; + } + return new DereferenceExpression(node.getLocation(), (Expression) process(node.getBase()), node.getField()); + } + + private Expression rewriteEnumIfNeed(DereferenceExpression node) + { + QualifiedName qualifiedName = DereferenceExpression.getQualifiedName(node); + if (qualifiedName == null || qualifiedName.getOriginalParts().size() != 2) { + return node; + } + + String enumName = qualifiedName.getOriginalParts().get(0).getValue(); + Optional enumDefinitionOptional = accioMDL.getEnum(enumName); + if (enumDefinitionOptional.isEmpty()) { + return node; + } + + return enumDefinitionOptional.get().valueOf(qualifiedName.getOriginalParts().get(1).getValue()) + .map(EnumValue::getValue) + .map(StringLiteral::new) + .orElseThrow(() -> new IllegalArgumentException(format("Enum value '%s' not found in enum '%s'", qualifiedName.getParts().get(1), qualifiedName.getParts().get(0)))); + } + } +} diff --git a/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/LambdaExpressionBodyRewrite.java b/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/LambdaExpressionBodyRewrite.java deleted file mode 100644 index 95eee7387..000000000 --- a/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/LambdaExpressionBodyRewrite.java +++ /dev/null @@ -1,150 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.accio.sqlrewrite; - -import io.trino.sql.tree.ComparisonExpression; -import io.trino.sql.tree.DereferenceExpression; -import io.trino.sql.tree.Expression; -import io.trino.sql.tree.FunctionCall; -import io.trino.sql.tree.Identifier; -import io.trino.sql.tree.Literal; -import io.trino.sql.tree.Node; -import io.trino.sql.tree.StringLiteral; -import io.trino.sql.tree.SubscriptExpression; -import io.trino.sql.tree.Window; -import io.trino.sql.tree.WindowReference; -import io.trino.sql.tree.WindowSpecification; - -import java.util.List; -import java.util.Optional; - -import static io.accio.sqlrewrite.RelationshipCteGenerator.TARGET_REFERENCE; -import static java.util.Objects.requireNonNull; -import static java.util.stream.Collectors.toList; - -public class LambdaExpressionBodyRewrite -{ - public static Expression rewrite(Node node, String modelName, Identifier argument) - { - return (Expression) new Visitor(modelName, argument).process(node, Optional.empty()); - } - - private LambdaExpressionBodyRewrite() {} - - static class Visitor - extends BaseRewriter> - { - private final String modelName; - private final Identifier argument; - - Visitor(String modelName, Identifier argument) - { - this.modelName = requireNonNull(modelName, "baseField is null"); - this.argument = requireNonNull(argument, "argument is null"); - } - - @Override - protected Node visitDereferenceExpression(DereferenceExpression node, Optional context) - { - if (node.getBase() instanceof Identifier) { - return visitAndCast(node.getBase(), Optional.ofNullable(node.getField().orElse(null))); - } - return new DereferenceExpression(visitAndCast(node.getBase(), Optional.ofNullable(node.getField().orElse(null))), node.getField().orElseThrow()); - } - - @Override - protected Node visitSubscriptExpression(SubscriptExpression node, Optional context) - { - if (node.getBase() instanceof DereferenceExpression) { - return new SubscriptExpression(visitAndCast(node.getBase(), - Optional.ofNullable((((DereferenceExpression) node.getBase()).getBase()))), - node.getIndex()); - } - return new SubscriptExpression(visitAndCast(node.getBase(), Optional.empty()), node.getIndex()); - } - - @Override - protected Node visitIdentifier(Identifier node, Optional context) - { - if (context.isEmpty()) { - return new StringLiteral(String.format("Relationship<%s>", modelName)); - } - if (argument.equals(node)) { - return new DereferenceExpression(new Identifier(TARGET_REFERENCE), (Identifier) context.get()); - } - return node; - } - - @Override - protected Node visitFunctionCall(FunctionCall node, Optional context) - { - return new FunctionCall( - node.getLocation(), - node.getName(), - node.getWindow().map(window -> visitAndCast(window, context)), - node.getFilter().map(filter -> visitAndCast(filter, context)), - node.getOrderBy().map(orderBy -> visitAndCast(orderBy, context)), node.isDistinct(), - node.getNullTreatment(), node.getProcessingMode(), - visitNodes(node.getArguments(), context)); - } - - @Override - protected Node visitComparisonExpression(ComparisonExpression node, Optional context) - { - if (node.getLocation().isPresent()) { - return new ComparisonExpression( - node.getLocation().get(), - node.getOperator(), - visitAndCast(node.getLeft(), context), - visitAndCast(node.getRight(), context)); - } - return new ComparisonExpression( - node.getOperator(), - visitAndCast(node.getLeft(), context), - visitAndCast(node.getRight(), context)); - } - - @Override - protected Node visitLiteral(Literal node, Optional context) - { - return node; - } - - protected S visitAndCast(S node, Optional context) - { - return (S) process(node, context); - } - - protected S visitAndCast(S window, Optional context) - { - Node node = null; - if (window instanceof WindowSpecification) { - node = (WindowSpecification) window; - } - else if (window instanceof WindowReference) { - node = (WindowReference) window; - } - return (S) process(node, context); - } - - @SuppressWarnings("unchecked") - protected List visitNodes(List nodes, Optional context) - { - return nodes.stream() - .map(node -> (S) process(node, context)) - .collect(toList()); - } - } -} diff --git a/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/MetricViewSqlRewrite.java b/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/MetricViewSqlRewrite.java index fe2c96c7e..8fc7845e1 100644 --- a/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/MetricViewSqlRewrite.java +++ b/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/MetricViewSqlRewrite.java @@ -35,7 +35,6 @@ import java.util.Optional; import java.util.stream.Stream; -import static io.accio.sqlrewrite.ScopeAwareRewrite.SCOPE_AWARE_REWRITE; import static io.accio.sqlrewrite.Utils.parseView; import static java.util.Objects.requireNonNull; import static java.util.stream.Collectors.toUnmodifiableList; @@ -77,7 +76,7 @@ public Statement apply(Statement root, SessionContext sessionContext, Analysis a // The generation of views has a sequential order, with later views being able to reference earlier views. Map viewQueries = new LinkedHashMap<>(); allAnalysis.stream().flatMap(a -> a.getViews().stream()) - .forEach(view -> viewQueries.put(view.getName(), (Query) SCOPE_AWARE_REWRITE.rewrite(parseView(view.getStatement()), accioMDL, sessionContext))); + .forEach(view -> viewQueries.put(view.getName(), parseView(view.getStatement()))); return (Statement) new WithRewriter(metricQueries, metricRollupQueries, ImmutableMap.copyOf(viewQueries)).process(root); } diff --git a/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/ModelInfo.java b/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/ModelInfo.java new file mode 100644 index 000000000..c6bf5b499 --- /dev/null +++ b/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/ModelInfo.java @@ -0,0 +1,240 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.accio.sqlrewrite; + +import io.accio.base.AccioMDL; +import io.accio.base.dto.Column; +import io.accio.base.dto.Model; +import io.accio.base.dto.Relationship; +import io.accio.sqlrewrite.analyzer.ExpressionRelationshipAnalyzer; +import io.accio.sqlrewrite.analyzer.ExpressionRelationshipInfo; +import io.trino.sql.tree.ComparisonExpression; +import io.trino.sql.tree.DereferenceExpression; +import io.trino.sql.tree.Expression; + +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.function.Function; +import java.util.stream.Collectors; + +import static io.accio.base.Utils.checkArgument; +import static io.accio.sqlrewrite.Utils.parseExpression; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; +import static java.util.stream.Collectors.joining; +import static java.util.stream.Collectors.toList; +import static java.util.stream.Collectors.toSet; + +public class ModelInfo +{ + private final Model model; + private final Set requiredModels; + private final String sql; + + public static ModelInfo get(Model model, AccioMDL mdl) + { + return getModelSql(model, mdl); + } + + private ModelInfo( + Model model, + Set requiredModels, + String sql) + { + this.model = requireNonNull(model); + this.requiredModels = requireNonNull(requiredModels); + this.sql = requireNonNull(sql); + } + + public Model getModel() + { + return model; + } + + public Set getRequiredModels() + { + return requiredModels; + } + + public String getSql() + { + return sql; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + ModelInfo modelInfo = (ModelInfo) o; + return Objects.equals(model, modelInfo.model) + && Objects.equals(requiredModels, modelInfo.requiredModels) + && Objects.equals(sql, modelInfo.sql); + } + + @Override + public int hashCode() + { + return Objects.hash(model, requiredModels, sql); + } + + private static ModelInfo getModelSql(Model model, AccioMDL mdl) + { + requireNonNull(model, "model is null"); + if (model.getColumns().isEmpty()) { + return new ModelInfo(model, Set.of(), model.getRefSql()); + } + + // key is alias_name.column_name, value is column name, this map is used to compose select items in model sql + Map selectItems = new LinkedHashMap<>(); + // key is alias name, value is query contains join condition, this map is used to compose join conditions in model sql + Map tableJoinSqls = new LinkedHashMap<>(); + // collect models that used in relationships + Set requiredModels = new HashSet<>(); + // key is column name in model, value is column expression, this map store columns not use relationships + Map columnWithoutRelationships = new LinkedHashMap<>(); + model.getColumns().stream() + .filter(column -> column.getRelationship().isEmpty()) + .forEach(column -> { + if (column.getExpression().isPresent()) { + Expression expression = parseExpression(column.getExpression().get()); + List relationshipInfos = ExpressionRelationshipAnalyzer.getRelationships(expression, mdl, model); + if (!relationshipInfos.isEmpty()) { + Expression newExpression = (Expression) RelationshipRewriter.rewrite(relationshipInfos, expression); + String tableJoins = format("(%s) AS \"%s\" %s", + getSubquerySql(model, relationshipInfos.stream().map(ExpressionRelationshipInfo::getBaseModelRelationship).collect(toList()), mdl), + model.getName(), + relationshipInfos.stream() + .map(ExpressionRelationshipInfo::getRelationships) + .flatMap(List::stream) + .distinct() + .map(relationship -> format(" LEFT JOIN \"%s\" ON %s", relationship.getModels().get(1), relationship.getCondition())) + .collect(Collectors.joining())); + + checkArgument(model.getPrimaryKey() != null, format("primary key in model %s contains relationship shouldn't be null", model.getName())); + + tableJoinSqls.put( + column.getName(), + format("SELECT \"%s\".\"%s\", %s AS \"%s\" FROM (%s)", + model.getName(), + model.getPrimaryKey(), + newExpression, + column.getName(), + tableJoins)); + // collect all required models in relationships + requiredModels.addAll( + relationshipInfos.stream() + .map(ExpressionRelationshipInfo::getRelationships) + .flatMap(List::stream) + .map(Relationship::getModels) + .flatMap(List::stream) + .filter(modelName -> !modelName.equals(model.getName())) + .collect(toSet())); + + // output from column use relationship will use another subquery which use column name from model as alias name + selectItems.put(column.getName(), format("\"%s\".\"%s\"", column.getName(), column.getName())); + } + else { + selectItems.put(column.getName(), format("\"%s\".\"%s\"", model.getName(), column.getName())); + columnWithoutRelationships.put(column.getName(), column.getExpression().get()); + } + } + else { + selectItems.put(column.getName(), format("\"%s\".\"%s\"", model.getName(), column.getName())); + columnWithoutRelationships.put(column.getName(), format("\"%s\".\"%s\"", model.getName(), column.getName())); + } + }); + + String modelSubQuery = format("(SELECT %s FROM (%s) AS \"%s\") AS \"%s\"", + columnWithoutRelationships.entrySet().stream() + .map(e -> format("%s AS \"%s\"", e.getValue(), e.getKey())) + .collect(joining(", ")), + model.getRefSql(), + model.getName(), + model.getName()); + Function tableJoinCondition = (name) -> format("\"%s\".\"%s\" = \"%s\".\"%s\"", model.getName(), model.getPrimaryKey(), name, model.getPrimaryKey()); + String tableJoinsSql = modelSubQuery + + tableJoinSqls.entrySet().stream() + .map(e -> format(" LEFT JOIN (%s) AS \"%s\" ON %s", e.getValue(), e.getKey(), tableJoinCondition.apply(e.getKey()))) + .collect(joining()); + + String selectItemsSql = selectItems.entrySet().stream() + .map(e -> format("%s AS \"%s\"", e.getValue(), e.getKey())) + .collect(joining(", ")); + + return new ModelInfo( + model, + requiredModels, + format("SELECT %s FROM %s", selectItemsSql, tableJoinsSql)); + } + + private static String getSubquerySql(Model model, List relationships, AccioMDL mdl) + { + Column primaryKey = model.getColumns().stream() + .filter(column -> column.getName().equals(model.getPrimaryKey())) + .findAny() + .orElseThrow(() -> new IllegalArgumentException("primary key not found in model " + model.getName())); + // TODO: this should be checked in validator too + primaryKey.getExpression().ifPresent(expression -> + checkArgument(ExpressionRelationshipAnalyzer.getRelationships(parseExpression(expression), mdl, model).isEmpty(), + "primary key expression can't use relation")); + + String joinKeys = relationships.stream() + .map(relationship -> { + String joinColumnName = findJoinColumn(model, relationship); + Column joinColumn = model.getColumns().stream() + .filter(column -> column.getName().equals(joinColumnName)) + .findAny() + .orElseThrow(() -> new IllegalArgumentException(format("join column %s not found in model %s", joinColumnName, model.getName()))); + // TODO: this should be checked in validator too + joinColumn.getExpression().ifPresent(expression -> + checkArgument(ExpressionRelationshipAnalyzer.getRelationships(parseExpression(expression), mdl, model).isEmpty(), + "column in join condition can't use relation")); + return format("%s AS \"%s\"", joinColumn.getExpression().orElse(joinColumn.getName()), joinColumn.getName()); + }) + .collect(joining(",")); + + return format("SELECT %s, %s FROM (%s) AS \"%s\"", + format("%s AS \"%s\"", primaryKey.getExpression().orElse(primaryKey.getName()), primaryKey.getName()), + joinKeys, + model.getRefSql(), + model.getName()); + } + + private static String findJoinColumn(Model model, Relationship relationship) + { + checkArgument(relationship.getModels().contains(model.getName()), format("model %s not found in relationship %s", model.getName(), relationship.getName())); + ComparisonExpression joinCondition = (ComparisonExpression) parseExpression(relationship.getCondition()); + checkArgument(joinCondition.getLeft() instanceof DereferenceExpression, "invalid join condition"); + checkArgument(joinCondition.getRight() instanceof DereferenceExpression, "invalid join condition"); + DereferenceExpression left = (DereferenceExpression) joinCondition.getLeft(); + DereferenceExpression right = (DereferenceExpression) joinCondition.getRight(); + if (left.getBase().toString().equals(model.getName())) { + return left.getField().orElseThrow().getValue(); + } + if (right.getBase().toString().equals(model.getName())) { + return right.getField().orElseThrow().getValue(); + } + throw new IllegalArgumentException(format("join column in relationship %s not found in model %s", relationship.getName(), model.getName())); + } +} diff --git a/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/RelationshipCTE.java b/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/RelationshipCTE.java deleted file mode 100644 index ce26a9961..000000000 --- a/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/RelationshipCTE.java +++ /dev/null @@ -1,137 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.accio.sqlrewrite; - -import io.accio.base.dto.Relationship; - -import java.util.List; -import java.util.Optional; -import java.util.stream.Stream; - -import static java.lang.String.format; -import static java.util.stream.Collectors.toList; - -/** - * Defines how to compose a Relationship Common Table Expression (CTE). A relationship defines how to map a source table to a target table. - * For example, if a relationship defines a connection between the User and Book tables, - * the User table is considered the source table and the Book table is considered the target table. - *

- * For a TO_ONE relationship, all columns in the CTE are target objects. - *

- * For a TO_MANY relationship, only the relationship field is target object. - **/ -public class RelationshipCTE -{ - private final String name; - private final Relation source; - private final Relation target; - - private final String index; - private final Relationship relationship; - // The base key is the primary key of the base model. - private final String baseKey; - - public RelationshipCTE(String name, Relation source, Relation target, Relationship relationship, String index, String baseKey) - { - this.name = name; - this.source = source; - this.target = target; - this.relationship = relationship; - this.index = index; - this.baseKey = baseKey; - } - - public String getName() - { - return name; - } - - public Relation getSource() - { - return source; - } - - public Relation getTarget() - { - return target; - } - - public Relation getManySide() - { - switch (relationship.getJoinType()) { - case MANY_TO_ONE: - return source; - case ONE_TO_MANY: - return target; - } - throw new IllegalArgumentException(format("join type %s can't get many side", relationship.getJoinType())); - } - - public Relationship getRelationship() - { - return relationship; - } - - public List getOutputColumn() - { - return Stream.concat(target.getColumns().stream(), List.of(source.getJoinKey()).stream()).collect(toList()); - } - - public Optional getIndex() - { - return Optional.ofNullable(index); - } - - public String getBaseKey() - { - return baseKey; - } - - public static class Relation - { - private final String name; - private final List columns; - private final String primaryKey; - private final String joinKey; - - public Relation(String name, List columns, String primaryKey, String joinKey) - { - this.name = name; - this.columns = columns; - this.primaryKey = primaryKey; - this.joinKey = joinKey; - } - - public String getName() - { - return name; - } - - public List getColumns() - { - return columns; - } - - public String getPrimaryKey() - { - return primaryKey; - } - - public String getJoinKey() - { - return joinKey; - } - } -} diff --git a/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/RelationshipCteGenerator.java b/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/RelationshipCteGenerator.java deleted file mode 100644 index ee1011c24..000000000 --- a/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/RelationshipCteGenerator.java +++ /dev/null @@ -1,1064 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.accio.sqlrewrite; - -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableSet; -import io.accio.base.AccioMDL; -import io.accio.base.dto.Column; -import io.accio.base.dto.Model; -import io.accio.base.dto.Relationship; -import io.accio.base.dto.Relationship.SortKey; -import io.trino.sql.QueryUtil; -import io.trino.sql.tree.ComparisonExpression; -import io.trino.sql.tree.DereferenceExpression; -import io.trino.sql.tree.Expression; -import io.trino.sql.tree.FunctionCall; -import io.trino.sql.tree.GroupBy; -import io.trino.sql.tree.Identifier; -import io.trino.sql.tree.IsNotNullPredicate; -import io.trino.sql.tree.JoinCriteria; -import io.trino.sql.tree.LongLiteral; -import io.trino.sql.tree.OrderBy; -import io.trino.sql.tree.QualifiedName; -import io.trino.sql.tree.Select; -import io.trino.sql.tree.SelectItem; -import io.trino.sql.tree.SimpleGroupBy; -import io.trino.sql.tree.SingleColumn; -import io.trino.sql.tree.SortItem; -import io.trino.sql.tree.WithQuery; - -import java.util.HashMap; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.stream.IntStream; - -import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.collect.Iterables.getLast; -import static io.accio.base.dto.Relationship.SortKey.Ordering.ASC; -import static io.accio.base.dto.Relationship.SortKey.sortKey; -import static io.accio.sqlrewrite.RelationshipCteGenerator.RsItem.Type.RS; -import static io.accio.sqlrewrite.Utils.randomTableSuffix; -import static io.trino.sql.QueryUtil.aliased; -import static io.trino.sql.QueryUtil.crossJoin; -import static io.trino.sql.QueryUtil.equal; -import static io.trino.sql.QueryUtil.getConditionNode; -import static io.trino.sql.QueryUtil.identifier; -import static io.trino.sql.QueryUtil.joinOn; -import static io.trino.sql.QueryUtil.leftJoin; -import static io.trino.sql.QueryUtil.nameReference; -import static io.trino.sql.QueryUtil.quotedIdentifier; -import static io.trino.sql.QueryUtil.simpleQuery; -import static io.trino.sql.QueryUtil.subscriptExpression; -import static io.trino.sql.QueryUtil.table; -import static io.trino.sql.QueryUtil.unnest; -import static java.lang.String.format; -import static java.util.Objects.requireNonNull; -import static java.util.stream.Collectors.toList; - -/** - *

This class will generate a registered relationship accessing to be a CTE SQL. e.g., - * If given a model, `Book`, with a relationship column, `author`, related to the model, `User`, - * when a sql try to access the field fo a relationship:

- *

SELECT Book.author.name FROM Book

- *

The generator will produce the followed sql:

- *

- * rs_1inddfmsuy (bookId, name, author, authorId) AS (
- * SELECT r.bookId, r.name, r.author, r.authorId
- * FROM (User l LEFT JOIN Book r ON (l.userId = r.authorId))) - *

- *

- * The output result of a relationship CTE is the all column of the right-side model. - * We can think the right-side model is the output side. - */ -public class RelationshipCteGenerator -{ - private static final String ONE_REFERENCE = "o"; - private static final String MANY_REFERENCE = "m"; - - public static final String SOURCE_REFERENCE = "s"; - - public static final String TARGET_REFERENCE = "t"; - public static final String LAMBDA_RESULT_NAME = "f1"; - - private static final String UNNEST_REFERENCE = "u"; - - private static final String UNNEST_COLUMN_REFERENCE = "uc"; - private static final String BASE_KEY_ALIAS = "bk"; - private final AccioMDL accioMDL; - private final Map registeredWithQuery = new LinkedHashMap<>(); - private final Map registeredCte = new HashMap<>(); - private final Map nameMapping = new HashMap<>(); - private final Map relationshipInfoMapping = new HashMap<>(); - - public RelationshipCteGenerator(AccioMDL accioMDL) - { - this.accioMDL = requireNonNull(accioMDL); - } - - public void register(List nameParts, RelationshipOperation operation) - { - requireNonNull(nameParts, "nameParts is null"); - checkArgument(!nameParts.isEmpty(), "nameParts is empty"); - register(nameParts, operation, nameParts.get(0), getLast(nameParts)); - } - - public void register(List nameParts, RelationshipOperation operation, String baseModel) - { - requireNonNull(nameParts, "nameParts is null"); - checkArgument(!nameParts.isEmpty(), "nameParts is empty"); - register(nameParts, operation, baseModel, getLast(nameParts)); - } - - public void register(List nameParts, RelationshipOperation operation, String baseModel, String originalName) - { - requireNonNull(nameParts, "nameParts is null"); - requireNonNull(operation.getRsItems(), "rsItems is null"); - checkArgument(!nameParts.isEmpty(), "nameParts is empty"); - checkArgument(!operation.getRsItems().isEmpty() && operation.getRsItems().size() <= 2, "The size of rsItems should be 1 or 2"); - - // avoid duplicate cte registering - if (nameMapping.containsKey(String.join(".", nameParts))) { - return; - } - - RelationshipCTE relationshipCTE = createRelationshipCTE(operation.getRsItems()); - String name = String.join(".", nameParts); - WithQuery withQuery = transferToCte(originalName, relationshipCTE, operation); - registeredWithQuery.put(name, withQuery); - registeredCte.put(name, relationshipCTE); - nameMapping.put(name, withQuery.getName().getValue()); - relationshipInfoMapping.put(withQuery.getName().getValue(), transferToRelationshipCTEJoinInfo(withQuery.getName().getValue(), relationshipCTE, baseModel)); - } - - /** - * Generate the join condition between the base model and the relationship CTE. - * We used the output-side(right-side) model to decide the join condition with the base model. - *

- * - * @param rsName the QualifiedName of the relationship column, e.g., `Book.author.book`. - * @param relationshipCTE the parsed information of relationship. - * @param baseModelName the base model of the relationship column, e.g., `Book`. - * @return The join condition between the base model and this relationship cte. - */ - private RelationshipCTEJoinInfo transferToRelationshipCTEJoinInfo(String rsName, RelationshipCTE relationshipCTE, String baseModelName) - { - Model baseModel = accioMDL.getModel(baseModelName).orElseThrow(() -> new IllegalArgumentException(format("Model %s is not found", baseModelName))); - return new RelationshipCTEJoinInfo(relationshipCTE.getName(), - buildCondition(baseModelName, baseModel.getPrimaryKey(), rsName, BASE_KEY_ALIAS), baseModelName); - } - - private JoinCriteria buildCondition(String leftName, String leftKey, String rightName, String rightKey) - { - return joinOn(equal(nameReference(leftName, leftKey), nameReference(rightName, rightKey))); - } - - private WithQuery transferToCte(String originalName, RelationshipCTE relationshipCTE, RelationshipOperation operation) - { - List arguments = operation.getFunctionCallArguments(); - switch (operation.getOperatorType()) { - case ACCESS: - return transferToAccessCte(originalName, relationshipCTE); - case TRANSFORM: - checkArgument(operation.getLambdaExpression().isPresent(), "Lambda expression is missing"); - return transferToTransformCte( - operation.getManySideResultField().orElse(operation.getLambdaExpression().get().toString()), - operation.getLambdaExpression().get(), relationshipCTE, operation.getUnnestField()); - case FILTER: - checkArgument(operation.getLambdaExpression().isPresent(), "Lambda expression is missing"); - return transferToFilterCte( - operation.getManySideResultField().orElse(operation.getLambdaExpression().get().toString()), - operation.getLambdaExpression().get(), relationshipCTE, operation.getUnnestField()); - case AGGREGATE: - checkArgument(operation.getAggregateOperator().isPresent(), "Aggregate operator is missing"); - return transferToAggregateCte( - operation.getManySideResultField() - .orElseThrow(() -> new IllegalArgumentException(operation.getAggregateOperator().get() + " relationship field not found")), - relationshipCTE, operation.getUnnestField(), operation.getAggregateOperator().get()); - case ARRAY_SORT: - checkArgument(arguments.size() == 3, "array_sort function should have 3 arguments"); - SortKey sortKey = sortKey(arguments.get(1).toString(), SortKey.Ordering.get(arguments.get(2).toString())); - return transferToArraySortCte( - operation.getManySideResultField().orElseThrow(() -> new IllegalArgumentException("array_sort relationship field not found")), - sortKey, - relationshipCTE, - operation.getUnnestField()); - case SLICE: - checkArgument(arguments.size() == 3, "slice function should have 3 arguments"); - checkArgument(arguments.get(1) instanceof LongLiteral, "Incorrect argument in slice function second argument"); - checkArgument(arguments.get(2) instanceof LongLiteral, "Incorrect argument in slice function third argument"); - return transferToSliceCte( - operation.getManySideResultField().orElseThrow(() -> new IllegalArgumentException("array_sort relationship field not found")), - relationshipCTE, - arguments.subList(1, 3)); - } - throw new UnsupportedOperationException(format("%s relationship operation is unsupported", operation.getOperatorType())); - } - - private WithQuery transferToAccessCte(String originalName, RelationshipCTE relationshipCTE) - { - switch (relationshipCTE.getRelationship().getJoinType()) { - case ONE_TO_ONE: - return oneToOneResultRelationship(relationshipCTE); - case MANY_TO_ONE: - return manyToOneResultRelationship(relationshipCTE); - case ONE_TO_MANY: - return relationshipCTE.getIndex().isPresent() ? - oneToManyRelationshipAccessByIndex(originalName, relationshipCTE) : - oneToManyResultRelationship(originalName, relationshipCTE); - } - throw new UnsupportedOperationException(format("%s relationship accessing is unsupported", relationshipCTE.getRelationship().getJoinType())); - } - - private WithQuery oneToOneResultRelationship(RelationshipCTE relationshipCTE) - { - List targetSelectItem = - ImmutableSet.builder() - // make sure the primary key come first. - .add(relationshipCTE.getTarget().getPrimaryKey()) - .addAll(relationshipCTE.getTarget().getColumns()) - .add(relationshipCTE.getTarget().getJoinKey()) - .build() - .stream() - .map(column -> nameReference(TARGET_REFERENCE, column)) - .map(SingleColumn::new) - .collect(toList()); - - ImmutableSet.Builder builder = ImmutableSet - .builder() - .addAll(targetSelectItem) - .add(new SingleColumn(nameReference(SOURCE_REFERENCE, relationshipCTE.getBaseKey()), identifier(BASE_KEY_ALIAS))); - List selectItems = ImmutableList.copyOf(builder.build()); - - List outputSchema = selectItems.stream() - .map(item -> (SingleColumn) item) - .map(item -> item.getAlias().map(alias -> (Expression) alias).orElse(item.getExpression())) - .map(this::getReferenceField).map(QueryUtil::identifier).collect(toList()); - - return new WithQuery(identifier(relationshipCTE.getName()), - simpleQuery( - new Select(false, selectItems), - leftJoin(aliased(table(QualifiedName.of(relationshipCTE.getSource().getName())), SOURCE_REFERENCE), - aliased(table(QualifiedName.of(relationshipCTE.getTarget().getName())), TARGET_REFERENCE), - joinOn(equal(nameReference(SOURCE_REFERENCE, relationshipCTE.getSource().getJoinKey()), nameReference(TARGET_REFERENCE, relationshipCTE.getTarget().getJoinKey()))))), - Optional.of(outputSchema)); - } - - private WithQuery manyToOneResultRelationship(RelationshipCTE relationshipCTE) - { - List targetSelectItem = - ImmutableSet.builder() - // make sure the primary key come first. - .add(relationshipCTE.getTarget().getPrimaryKey()) - .addAll(relationshipCTE.getTarget().getColumns()) - .add(relationshipCTE.getTarget().getJoinKey()) - .build() - .stream() - .map(column -> nameReference(TARGET_REFERENCE, column)) - .map(SingleColumn::new) - .collect(toList()); - - ImmutableSet.Builder builder = ImmutableSet - .builder() - .addAll(targetSelectItem) - .add(new SingleColumn(nameReference(SOURCE_REFERENCE, relationshipCTE.getBaseKey()), identifier(BASE_KEY_ALIAS))); - List selectItems = ImmutableList.copyOf(builder.build()); - - List outputSchema = selectItems.stream() - .map(item -> (SingleColumn) item) - .map(item -> item.getAlias().map(alias -> (Expression) alias).orElse(item.getExpression())) - .map(this::getReferenceField).map(QueryUtil::identifier).collect(toList()); - - return new WithQuery(identifier(relationshipCTE.getName()), - simpleQuery( - new Select(true, selectItems), - leftJoin(aliased(table(QualifiedName.of(relationshipCTE.getSource().getName())), SOURCE_REFERENCE), - aliased(table(QualifiedName.of(relationshipCTE.getTarget().getName())), TARGET_REFERENCE), - joinOn(equal(nameReference(SOURCE_REFERENCE, relationshipCTE.getSource().getJoinKey()), nameReference(TARGET_REFERENCE, relationshipCTE.getTarget().getJoinKey()))))), - Optional.of(outputSchema)); - } - - private WithQuery oneToManyResultRelationship(String originalName, RelationshipCTE relationshipCTE) - { - List oneTableFields = - ImmutableSet.builder() - // make sure the primary key be first. - .add(relationshipCTE.getSource().getPrimaryKey()) - .add(relationshipCTE.getSource().getJoinKey()) - .build() - .stream() - .map(column -> nameReference(ONE_REFERENCE, column)) - .collect(toList()); - List sortKeys = relationshipCTE.getRelationship().getManySideSortKeys().isEmpty() ? - List.of(new Relationship.SortKey(relationshipCTE.getManySide().getPrimaryKey(), ASC)) : - relationshipCTE.getRelationship().getManySideSortKeys(); - - SingleColumn relationshipField = new SingleColumn( - toArrayAgg(nameReference(MANY_REFERENCE, relationshipCTE.getTarget().getPrimaryKey()), MANY_REFERENCE, sortKeys), - identifier(originalName)); - - List normalFields = oneTableFields - .stream() - .map(field -> new SingleColumn(field, identifier(requireNonNull(getQualifiedName(field)).getSuffix()))) - .collect(toList()); - normalFields.add(new SingleColumn(nameReference(ONE_REFERENCE, relationshipCTE.getBaseKey()), identifier(BASE_KEY_ALIAS))); - - ImmutableSet.Builder builder = ImmutableSet - .builder() - .addAll(normalFields) - .add(relationshipField); - List selectItems = ImmutableList.copyOf(builder.build()); - - List outputSchema = selectItems.stream() - .map(selectItem -> (SingleColumn) selectItem) - .map(singleColumn -> - singleColumn.getAlias() - .orElse(quotedIdentifier(singleColumn.getExpression().toString()))) - .collect(toList()); - - return new WithQuery(identifier(relationshipCTE.getName()), - simpleQuery( - new Select(false, selectItems), - leftJoin(aliased(table(QualifiedName.of(relationshipCTE.getSource().getName())), ONE_REFERENCE), - aliased(table(QualifiedName.of(relationshipCTE.getTarget().getName())), MANY_REFERENCE), - joinOn(equal(nameReference(ONE_REFERENCE, relationshipCTE.getSource().getJoinKey()), nameReference(MANY_REFERENCE, relationshipCTE.getTarget().getJoinKey())))), - Optional.empty(), - Optional.of(new GroupBy(false, IntStream.rangeClosed(1, normalFields.size()) - .mapToObj(number -> new LongLiteral(String.valueOf(number))) - .map(longLiteral -> new SimpleGroupBy(List.of(longLiteral))) - .collect(toList()))), - Optional.empty(), - Optional.empty(), - Optional.empty(), - Optional.empty()), - Optional.of(outputSchema)); - } - - private WithQuery oneToManyRelationshipAccessByIndex(String originalName, RelationshipCTE relationshipCTE) - { - checkArgument(relationshipCTE.getIndex().isPresent(), "index is null"); - List targetSelectItem = - ImmutableSet.builder() - // make sure the primary key be first. - .add(relationshipCTE.getTarget().getPrimaryKey()) - .addAll(relationshipCTE.getTarget().getColumns()) - .add(relationshipCTE.getTarget().getJoinKey()) - .build() - .stream() - .map(column -> nameReference(TARGET_REFERENCE, column)) - .map(field -> new SingleColumn(field, identifier(requireNonNull(getQualifiedName(field)).getSuffix()))) - .collect(toList()); - - ImmutableSet.Builder builder = ImmutableSet - .builder() - .addAll(targetSelectItem) - .add(new SingleColumn(nameReference(SOURCE_REFERENCE, relationshipCTE.getBaseKey()), identifier(BASE_KEY_ALIAS))); - List selectItems = ImmutableList.copyOf(builder.build()); - - List outputSchema = selectItems.stream() - .map(selectItem -> (SingleColumn) selectItem) - .map(singleColumn -> - singleColumn.getAlias() - .orElse(quotedIdentifier(singleColumn.getExpression().toString()))) - .collect(toList()); - - return new WithQuery(identifier(relationshipCTE.getName()), - simpleQuery( - new Select(false, selectItems), - leftJoin(aliased(table(QualifiedName.of(relationshipCTE.getSource().getName())), SOURCE_REFERENCE), - aliased(table(QualifiedName.of(relationshipCTE.getTarget().getName())), TARGET_REFERENCE), - joinOn(equal(subscriptExpression(nameReference(SOURCE_REFERENCE, originalName.split("\\[")[0]), relationshipCTE.getIndex().get()), nameReference(TARGET_REFERENCE, relationshipCTE.getTarget().getPrimaryKey()))))), - Optional.of(outputSchema)); - } - - private WithQuery transferToTransformCte( - String manyResultField, - Expression lambdaExpressionBody, - RelationshipCTE relationshipCTE, - Optional outputField) - { - List oneTableFields = - ImmutableSet.builder() - // make sure the primary key be first. - .add(relationshipCTE.getSource().getPrimaryKey()) - // remove duplicate column name - .addAll(relationshipCTE.getSource().getColumns().stream() - .filter(column -> !column.equals(manyResultField) && !column.equals(LAMBDA_RESULT_NAME)) - .collect(toList())) - .add(relationshipCTE.getSource().getJoinKey()) - .build() - .stream() - .map(column -> nameReference(SOURCE_REFERENCE, column)) - .collect(toList()); - List sortKeys = relationshipCTE.getRelationship().getManySideSortKeys().isEmpty() ? - List.of(new Relationship.SortKey(relationshipCTE.getManySide().getPrimaryKey(), ASC)) : - relationshipCTE.getRelationship().getManySideSortKeys(); - - SingleColumn arrayAggField = new SingleColumn( - toArrayAgg(lambdaExpressionBody, TARGET_REFERENCE, sortKeys), - identifier(LAMBDA_RESULT_NAME)); - - List normalFields = oneTableFields - .stream() - .map(field -> new SingleColumn(field, identifier(requireNonNull(getQualifiedName(field)).getSuffix()))) - .collect(toList()); - - ImmutableSet.Builder builder = ImmutableSet - .builder() - .addAll(normalFields) - .add(arrayAggField) - .add(new SingleColumn(nameReference(SOURCE_REFERENCE, relationshipCTE.getBaseKey()), identifier(BASE_KEY_ALIAS))); - List selectItems = ImmutableList.copyOf(builder.build()); - - List outputSchema = selectItems.stream() - .map(selectItem -> (SingleColumn) selectItem) - .map(singleColumn -> - singleColumn.getAlias() - .orElse(quotedIdentifier(singleColumn.getExpression().toString()))) - .collect(toList()); - - Expression unnestField = outputField.orElse(nameReference(SOURCE_REFERENCE, manyResultField)); - - return new WithQuery(identifier(relationshipCTE.getName()), - simpleQuery( - new Select(false, selectItems), - leftJoin( - crossJoin(aliased(table(QualifiedName.of(relationshipCTE.getSource().getName())), SOURCE_REFERENCE), - aliased(unnest(unnestField), UNNEST_REFERENCE, List.of(UNNEST_COLUMN_REFERENCE))), - aliased(table(QualifiedName.of(relationshipCTE.getTarget().getName())), TARGET_REFERENCE), - joinOn(equal(nameReference(UNNEST_REFERENCE, UNNEST_COLUMN_REFERENCE), nameReference(TARGET_REFERENCE, relationshipCTE.getTarget().getPrimaryKey())))), - Optional.empty(), - Optional.of(new GroupBy(false, IntStream.range(1, oneTableFields.size() + 1) - .mapToObj(number -> new LongLiteral(String.valueOf(number))) - .map(longLiteral -> new SimpleGroupBy(List.of(longLiteral))).collect(toList()))), - Optional.empty(), - Optional.empty(), - Optional.empty(), - Optional.empty()), - Optional.of(outputSchema)); - } - - private WithQuery transferToAggregateCte( - String manyResultField, - RelationshipCTE relationshipCTE, - Optional outputField, - String operator) - { - List oneTableFields = - ImmutableSet.builder() - // make sure the primary key be first. - .add(relationshipCTE.getSource().getPrimaryKey()) - // remove duplicate column name - .addAll(relationshipCTE.getSource().getColumns().stream() - .filter(column -> !column.equals(manyResultField) && !column.equals(LAMBDA_RESULT_NAME)) - .collect(toList())) - .add(relationshipCTE.getSource().getJoinKey()) - .build() - .stream() - .map(column -> nameReference(SOURCE_REFERENCE, column)) - .collect(toList()); - - SingleColumn aggField = new SingleColumn( - toAggregate(DereferenceExpression.from(QualifiedName.of(UNNEST_REFERENCE, UNNEST_COLUMN_REFERENCE)), operator), - identifier(LAMBDA_RESULT_NAME)); - - List normalFields = oneTableFields - .stream() - .map(field -> new SingleColumn(field, identifier(requireNonNull(getQualifiedName(field)).getSuffix()))) - .collect(toList()); - - ImmutableSet.Builder builder = ImmutableSet - .builder() - .addAll(normalFields) - .add(aggField) - .add(new SingleColumn(nameReference(SOURCE_REFERENCE, relationshipCTE.getBaseKey()), identifier(BASE_KEY_ALIAS))); - List selectItems = ImmutableList.copyOf(builder.build()); - - List outputSchema = selectItems.stream() - .map(selectItem -> (SingleColumn) selectItem) - .map(singleColumn -> - singleColumn.getAlias() - .orElse(quotedIdentifier(singleColumn.getExpression().toString()))) - .collect(toList()); - - Expression unnestField = outputField.orElse(nameReference(SOURCE_REFERENCE, manyResultField)); - - return new WithQuery(identifier(relationshipCTE.getName()), - simpleQuery( - new Select(false, selectItems), - crossJoin(aliased(table(QualifiedName.of(relationshipCTE.getSource().getName())), SOURCE_REFERENCE), - aliased(unnest(unnestField), UNNEST_REFERENCE, List.of(UNNEST_COLUMN_REFERENCE))), - Optional.empty(), - Optional.of(new GroupBy(false, IntStream.range(1, oneTableFields.size() + 1) - .mapToObj(number -> new LongLiteral(String.valueOf(number))) - .map(longLiteral -> new SimpleGroupBy(List.of(longLiteral))).collect(toList()))), - Optional.empty(), - Optional.empty(), - Optional.empty(), - Optional.empty()), - Optional.of(outputSchema)); - } - - private WithQuery transferToFilterCte( - String manyResultField, - Expression lambdaExpressionBody, - RelationshipCTE relationshipCTE, - Optional outputField) - { - List oneTableFields = - ImmutableSet.builder() - // make sure the primary key be first. - .add(relationshipCTE.getSource().getPrimaryKey()) - // remove column name - .addAll(relationshipCTE.getSource().getColumns().stream() - .filter(column -> !column.equals(manyResultField) && !column.equals(LAMBDA_RESULT_NAME)) - .collect(toList())) - .add(relationshipCTE.getSource().getJoinKey()) - .build() - .stream() - .map(column -> nameReference(SOURCE_REFERENCE, column)) - .collect(toList()); - List sortKeys = relationshipCTE.getRelationship().getManySideSortKeys().isEmpty() ? - List.of(new Relationship.SortKey(relationshipCTE.getManySide().getPrimaryKey(), ASC)) : - relationshipCTE.getRelationship().getManySideSortKeys(); - - SingleColumn arrayAggField = new SingleColumn( - toArrayAgg(DereferenceExpression.from(QualifiedName.of(TARGET_REFERENCE, relationshipCTE.getTarget().getPrimaryKey())), TARGET_REFERENCE, sortKeys), - identifier(LAMBDA_RESULT_NAME)); - - List normalFields = oneTableFields - .stream() - .map(field -> new SingleColumn(field, identifier(requireNonNull(getQualifiedName(field)).getSuffix()))) - .collect(toList()); - - ImmutableSet.Builder builder = ImmutableSet - .builder() - .addAll(normalFields) - .add(arrayAggField) - .add(new SingleColumn(nameReference(SOURCE_REFERENCE, relationshipCTE.getBaseKey()), identifier(BASE_KEY_ALIAS))); - List selectItems = ImmutableList.copyOf(builder.build()); - - List outputSchema = selectItems.stream() - .map(selectItem -> (SingleColumn) selectItem) - .map(singleColumn -> - singleColumn.getAlias() - .orElse(quotedIdentifier(singleColumn.getExpression().toString()))) - .collect(toList()); - - Expression unnestField = outputField.orElse(nameReference(SOURCE_REFERENCE, manyResultField)); - - return new WithQuery(identifier(relationshipCTE.getName()), - simpleQuery( - new Select(false, selectItems), - leftJoin( - crossJoin(aliased(table(QualifiedName.of(relationshipCTE.getSource().getName())), SOURCE_REFERENCE), - aliased(unnest(unnestField), UNNEST_REFERENCE, List.of(UNNEST_COLUMN_REFERENCE))), - aliased(table(QualifiedName.of(relationshipCTE.getTarget().getName())), TARGET_REFERENCE), - joinOn(equal(nameReference(UNNEST_REFERENCE, UNNEST_COLUMN_REFERENCE), nameReference(TARGET_REFERENCE, relationshipCTE.getTarget().getPrimaryKey())))), - Optional.of(lambdaExpressionBody), - Optional.of(new GroupBy(false, IntStream.range(1, oneTableFields.size() + 1) - .mapToObj(number -> new LongLiteral(String.valueOf(number))) - .map(longLiteral -> new SimpleGroupBy(List.of(longLiteral))).collect(toList()))), - Optional.empty(), - Optional.empty(), - Optional.empty(), - Optional.empty()), - Optional.of(outputSchema)); - } - - private WithQuery transferToArraySortCte( - String manyResultField, - SortKey sortKey, - RelationshipCTE relationshipCTE, - Optional outputField) - { - List oneTableFields = - ImmutableSet.builder() - // make sure the primary key be first. - .add(relationshipCTE.getSource().getPrimaryKey()) - // remove column name - .addAll(relationshipCTE.getSource().getColumns().stream() - .filter(column -> !column.equals(manyResultField) && !column.equals(LAMBDA_RESULT_NAME)) - .collect(toList())) - .add(relationshipCTE.getSource().getJoinKey()) - .build() - .stream() - .map(column -> nameReference(SOURCE_REFERENCE, column)) - .collect(toList()); - List sortKeys = List.of(sortKey); - - SingleColumn arrayAggField = new SingleColumn( - toArrayAgg(DereferenceExpression.from(QualifiedName.of(TARGET_REFERENCE, relationshipCTE.getTarget().getPrimaryKey())), TARGET_REFERENCE, sortKeys), - identifier(LAMBDA_RESULT_NAME)); - - List normalFields = oneTableFields - .stream() - .map(field -> new SingleColumn(field, identifier(requireNonNull(getQualifiedName(field)).getSuffix()))) - .collect(toList()); - - ImmutableSet.Builder builder = ImmutableSet - .builder() - .addAll(normalFields) - .add(arrayAggField) - .add(new SingleColumn(nameReference(SOURCE_REFERENCE, relationshipCTE.getBaseKey()), identifier(BASE_KEY_ALIAS))); - List selectItems = ImmutableList.copyOf(builder.build()); - - List outputSchema = selectItems.stream() - .map(selectItem -> (SingleColumn) selectItem) - .map(singleColumn -> - singleColumn.getAlias() - .orElse(quotedIdentifier(singleColumn.getExpression().toString()))) - .collect(toList()); - - Expression unnestField = outputField.orElse(nameReference(SOURCE_REFERENCE, manyResultField)); - - return new WithQuery(identifier(relationshipCTE.getName()), - simpleQuery( - new Select(false, selectItems), - leftJoin( - crossJoin(aliased(table(QualifiedName.of(relationshipCTE.getSource().getName())), SOURCE_REFERENCE), - aliased(unnest(unnestField), UNNEST_REFERENCE, List.of(UNNEST_COLUMN_REFERENCE))), - aliased(table(QualifiedName.of(relationshipCTE.getTarget().getName())), TARGET_REFERENCE), - joinOn(equal(nameReference(UNNEST_REFERENCE, UNNEST_COLUMN_REFERENCE), nameReference(TARGET_REFERENCE, relationshipCTE.getTarget().getPrimaryKey())))), - Optional.empty(), - Optional.of(new GroupBy(false, IntStream.range(1, oneTableFields.size() + 1) - .mapToObj(number -> new LongLiteral(String.valueOf(number))) - .map(longLiteral -> new SimpleGroupBy(List.of(longLiteral))).collect(toList()))), - Optional.empty(), - Optional.empty(), - Optional.empty(), - Optional.empty()), - Optional.of(outputSchema)); - } - - // TODO: find a way to combine slice to upper ctes https://github.com/Canner/canner-metric-layer/issues/302 - private WithQuery transferToSliceCte( - String manyResultField, - RelationshipCTE relationshipCTE, - List startEnd) - { - List oneTableFields = - ImmutableSet.builder() - // make sure the primary key be first. - .add(relationshipCTE.getSource().getPrimaryKey()) - // remove column name - .addAll(relationshipCTE.getSource().getColumns().stream() - .filter(column -> !column.equals(manyResultField) && !column.equals(LAMBDA_RESULT_NAME)) - .collect(toList())) - .add(relationshipCTE.getSource().getJoinKey()) - .build() - .stream() - .map(column -> nameReference(SOURCE_REFERENCE, column)) - .collect(toList()); - - SingleColumn sliceColumn = new SingleColumn( - new FunctionCall( - QualifiedName.of("slice"), - List.of( - DereferenceExpression.from(QualifiedName.of(SOURCE_REFERENCE, manyResultField)), - startEnd.get(0), - startEnd.get(1))), - identifier(LAMBDA_RESULT_NAME)); - - List normalFields = oneTableFields - .stream() - .map(field -> new SingleColumn(field, identifier(requireNonNull(getQualifiedName(field)).getSuffix()))) - .collect(toList()); - - ImmutableSet.Builder builder = ImmutableSet - .builder() - .addAll(normalFields) - .add(sliceColumn) - .add(new SingleColumn(nameReference(SOURCE_REFERENCE, relationshipCTE.getBaseKey()), identifier(BASE_KEY_ALIAS))); - List selectItems = ImmutableList.copyOf(builder.build()); - - List outputSchema = selectItems.stream() - .map(selectItem -> (SingleColumn) selectItem) - .map(singleColumn -> - singleColumn.getAlias() - .orElse(quotedIdentifier(singleColumn.getExpression().toString()))) - .collect(toList()); - - return new WithQuery(identifier(relationshipCTE.getName()), - simpleQuery( - new Select(false, selectItems), - aliased(table(QualifiedName.of(relationshipCTE.getSource().getName())), SOURCE_REFERENCE), - Optional.empty(), - Optional.empty(), - Optional.empty(), - Optional.empty(), - Optional.empty(), - Optional.empty()), - Optional.of(outputSchema)); - } - - private QualifiedName getQualifiedName(Expression expression) - { - if (expression instanceof DereferenceExpression) { - return DereferenceExpression.getQualifiedName((DereferenceExpression) expression); - } - if (expression instanceof Identifier) { - return QualifiedName.of(ImmutableList.of((Identifier) expression)); - } - return null; - } - - private Expression toArrayAgg(Expression field, String sortKeyPrefix, List sortKeys) - { - return new FunctionCall( - Optional.empty(), - QualifiedName.of("array_agg"), - Optional.empty(), - Optional.of(new IsNotNullPredicate(field)), - Optional.of(new OrderBy(sortKeys.stream() - .map(sortKey -> - new SortItem(nameReference(sortKeyPrefix, sortKey.getName()), sortKey.isDescending() ? SortItem.Ordering.DESCENDING : SortItem.Ordering.ASCENDING, - SortItem.NullOrdering.UNDEFINED)) - .collect(toList()))), - false, - Optional.empty(), - Optional.empty(), - List.of(field)); - } - - private Expression toAggregate(Expression field, String operator) - { - return new FunctionCall( - Optional.empty(), - QualifiedName.of(operator), - Optional.empty(), - Optional.empty(), - Optional.empty(), - false, - Optional.empty(), - Optional.empty(), - List.of(field)); - } - - private RelationshipCTE createRelationshipCTE(List rsItems) - { - RelationshipCTE.Relation source; - RelationshipCTE.Relation target; - Relationship relationship; - String baseKey; - // If the first item is CTE, the second one is RS or REVERSE_RS. - if (rsItems.get(0).getType().equals(RsItem.Type.CTE)) { - relationship = accioMDL.listRelationships().stream().filter(r -> r.getName().equals(rsItems.get(1).getName())).findAny().get(); - ComparisonExpression comparisonExpression = getConditionNode(relationship.getCondition()); - WithQuery leftQuery = registeredWithQuery.get(rsItems.get(0).getName()); - baseKey = BASE_KEY_ALIAS; - - if (rsItems.get(1).getType() == RS) { - source = new RelationshipCTE.Relation( - leftQuery.getName().getValue(), - leftQuery.getColumnNames().get().stream().map(Identifier::getValue).collect(toList()), - // TODO: we should make sure the first field is its primary key. - leftQuery.getColumnNames().map(columns -> columns.get(0).getValue()).get(), - getReferenceField(comparisonExpression.getLeft())); - Model rightModel = getRightModel(relationship, accioMDL.listModels()); - target = new RelationshipCTE.Relation( - rightModel.getName(), - rightModel.getColumns().stream().map(Column::getName).collect(toList()), - rightModel.getPrimaryKey(), - getReferenceField(comparisonExpression.getRight())); - } - else { - source = new RelationshipCTE.Relation( - leftQuery.getName().getValue(), - leftQuery.getColumnNames().get().stream().map(Identifier::getValue).collect(toList()), - // TODO: we should make sure the first field is its primary key. - leftQuery.getColumnNames().map(columns -> columns.get(0).getValue()).get(), - // If it's a REVERSE relationship, the left and right side will be swapped. - getReferenceField(comparisonExpression.getRight())); - // If it's a REVERSE relationship, the left and right side will be swapped. - Model rightModel = getLeftModel(relationship, accioMDL.listModels()); - target = new RelationshipCTE.Relation( - rightModel.getName(), - rightModel.getColumns().stream().map(Column::getName).collect(toList()), - rightModel.getPrimaryKey(), - // If it's a REVERSE relationship, the left and right side will be swapped. - getReferenceField(comparisonExpression.getLeft())); - } - } - else { - if (rsItems.get(0).getType() == RS) { - relationship = accioMDL.listRelationships().stream().filter(r -> r.getName().equals(rsItems.get(0).getName())).findAny().get(); - ComparisonExpression comparisonExpression = getConditionNode(relationship.getCondition()); - Model leftModel = getLeftModel(relationship, accioMDL.listModels()); - source = new RelationshipCTE.Relation( - leftModel.getName(), - leftModel.getColumns().stream().map(Column::getName).collect(toList()), - leftModel.getPrimaryKey(), getReferenceField(comparisonExpression.getLeft())); - - Model rightModel = getRightModel(relationship, accioMDL.listModels()); - target = new RelationshipCTE.Relation( - rightModel.getName(), - rightModel.getColumns().stream().map(Column::getName).collect(toList()), - rightModel.getPrimaryKey(), getReferenceField(comparisonExpression.getRight())); - } - else { - relationship = accioMDL.listRelationships().stream().filter(r -> r.getName().equals(rsItems.get(0).getName())).findAny().get(); - ComparisonExpression comparisonExpression = getConditionNode(relationship.getCondition()); - // If it's a REVERSE relationship, the left and right side will be swapped. - Model leftModel = getRightModel(relationship, accioMDL.listModels()); - source = new RelationshipCTE.Relation( - leftModel.getName(), - leftModel.getColumns().stream().map(Column::getName).collect(toList()), - // If it's a REVERSE relationship, the left and right side will be swapped. - leftModel.getPrimaryKey(), getReferenceField(comparisonExpression.getRight())); - - // If it's a REVERSE relationship, the left and right side will be swapped. - Model rightModel = getLeftModel(relationship, accioMDL.listModels()); - target = new RelationshipCTE.Relation( - rightModel.getName(), - rightModel.getColumns().stream().map(Column::getName).collect(toList()), - // If it's a REVERSE relationship, the left and right side will be swapped. - rightModel.getPrimaryKey(), getReferenceField(comparisonExpression.getLeft())); - } - baseKey = source.getPrimaryKey(); - } - - return new RelationshipCTE("rs_" + randomTableSuffix(), source, target, - getLast(rsItems).getType().equals(RS) ? relationship : Relationship.reverse(relationship), - rsItems.get(0).getIndex().orElse(null), baseKey); - } - - private static Model getLeftModel(Relationship relationship, List models) - { - return models.stream().filter(model -> model.getName().equals(relationship.getModels().get(0))).findAny() - .orElseThrow(() -> new IllegalArgumentException(format("Left model %s not found in the given models.", relationship.getModels().get(0)))); - } - - private static Model getRightModel(Relationship relationship, List models) - { - return models.stream().filter(model -> model.getName().equals(relationship.getModels().get(1))).findAny() - .orElseThrow(() -> new IllegalArgumentException(format("Right model %s not found in the given models.", relationship.getModels().get(1)))); - } - - private String getReferenceField(Expression expression) - { - if (expression instanceof DereferenceExpression) { - return ((DereferenceExpression) expression).getField().orElseThrow().getValue(); - } - - return expression.toString(); - } - - public Map getRegisteredWithQuery() - { - return registeredWithQuery; - } - - public Map getNameMapping() - { - return nameMapping; - } - - public Map getRelationshipInfoMapping() - { - return relationshipInfoMapping; - } - - public Map getRelationshipCTEs() - { - return registeredCte; - } - - public static class RelationshipOperation - { - enum OperatorType - { - ACCESS, - TRANSFORM, - FILTER, - AGGREGATE, - ARRAY_SORT, - SLICE, - } - - public static RelationshipOperation access(List rsItems) - { - return new RelationshipOperation(rsItems, OperatorType.ACCESS, null, null, null, null, null); - } - - public static RelationshipOperation transform(List rsItems, Expression lambdaExpression, String manySideResultField, Expression unnestField) - { - return new RelationshipOperation(rsItems, OperatorType.TRANSFORM, lambdaExpression, manySideResultField, unnestField, null, null); - } - - public static RelationshipOperation filter(List rsItems, Expression lambdaExpression, String manySideResultField, Expression unnestField) - { - return new RelationshipOperation(rsItems, OperatorType.FILTER, lambdaExpression, manySideResultField, unnestField, null, null); - } - - public static RelationshipOperation aggregate(List rsItems, String manySideResultField, String aggregateOperator) - { - return new RelationshipOperation(rsItems, OperatorType.AGGREGATE, null, manySideResultField, null, aggregateOperator, null); - } - - public static RelationshipOperation arraySort(List rsItems, String manySideResultField, Expression unnestField, List functionCallArguments) - { - return new RelationshipOperation(rsItems, OperatorType.ARRAY_SORT, null, manySideResultField, unnestField, null, functionCallArguments); - } - - public static RelationshipOperation slice(List rsItems, String manySideResultField, List functionCallArguments) - { - return new RelationshipOperation(rsItems, OperatorType.SLICE, null, manySideResultField, null, null, functionCallArguments); - } - - private final List rsItems; - private final OperatorType operatorType; - private final Expression lambdaExpression; - private final String manySideResultField; - // for lambda cte generation - private final Expression unnestField; - private final String aggregateOperator; - private final List functionCallArguments; - - private RelationshipOperation( - List rsItems, - OperatorType operatorType, - Expression lambdaExpression, - String manySideResultField, - Expression unnestField, - String aggregateOperator, - List functionCallArguments) - { - this.rsItems = requireNonNull(rsItems); - this.operatorType = requireNonNull(operatorType); - this.lambdaExpression = lambdaExpression; - this.manySideResultField = manySideResultField; - this.unnestField = unnestField; - this.aggregateOperator = aggregateOperator; - this.functionCallArguments = functionCallArguments == null ? List.of() : functionCallArguments; - } - - public List getRsItems() - { - return rsItems; - } - - public OperatorType getOperatorType() - { - return operatorType; - } - - public Optional getLambdaExpression() - { - return Optional.ofNullable(lambdaExpression); - } - - public Optional getManySideResultField() - { - return Optional.ofNullable(manySideResultField); - } - - public Optional getUnnestField() - { - return Optional.ofNullable(unnestField); - } - - public Optional getAggregateOperator() - { - return Optional.ofNullable(aggregateOperator); - } - - public List getFunctionCallArguments() - { - return functionCallArguments; - } - } - - public static class RsItem - { - public static RsItem rsItem(String name, Type type) - { - return rsItem(name, type, null); - } - - public static RsItem rsItem(String name, Type type, String index) - { - return new RsItem(name, type, index); - } - - public enum Type - { - CTE, - RS, - REVERSE_RS - } - - private final String name; - private final Type type; - - private final String index; - - private RsItem(String name, Type type, String index) - { - this.name = name; - this.type = type; - this.index = index; - } - - public Optional getIndex() - { - return Optional.ofNullable(index); - } - - public String getName() - { - return name; - } - - public Type getType() - { - return type; - } - } - - /** - * Used for build the join node between the base model and relationship cte. - */ - public static class RelationshipCTEJoinInfo - { - private final String cteName; - private final JoinCriteria condition; - private final String baseModelName; - - public RelationshipCTEJoinInfo(String cteName, JoinCriteria condition, String baseModelName) - { - this.cteName = cteName; - this.condition = condition; - this.baseModelName = baseModelName; - } - - public String getCteName() - { - return cteName; - } - - public JoinCriteria getCondition() - { - return condition; - } - - public String getBaseModelName() - { - return baseModelName; - } - } -} diff --git a/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/RelationshipRewriter.java b/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/RelationshipRewriter.java new file mode 100644 index 000000000..31693304e --- /dev/null +++ b/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/RelationshipRewriter.java @@ -0,0 +1,73 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.accio.sqlrewrite; + +import io.accio.sqlrewrite.analyzer.ExpressionRelationshipInfo; +import io.trino.sql.tree.DereferenceExpression; +import io.trino.sql.tree.Expression; +import io.trino.sql.tree.Identifier; +import io.trino.sql.tree.Node; +import io.trino.sql.tree.QualifiedName; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import static io.trino.sql.tree.DereferenceExpression.getQualifiedName; +import static java.util.Objects.requireNonNull; +import static java.util.stream.Collectors.toUnmodifiableMap; + +public class RelationshipRewriter + extends BaseRewriter +{ + private final Map replacements; + + public static Node rewrite(List relationshipInfos, Expression expression) + { + requireNonNull(relationshipInfos); + return new RelationshipRewriter( + relationshipInfos.stream() + .collect(toUnmodifiableMap(ExpressionRelationshipInfo::getQualifiedName, RelationshipRewriter::toDereferenceExpression))) + .process(expression); + } + + public RelationshipRewriter(Map replacements) + { + this.replacements = requireNonNull(replacements); + } + + @Override + protected Node visitDereferenceExpression(DereferenceExpression node, Void ignored) + { + if (node.getField().isPresent()) { + QualifiedName qualifiedName = getQualifiedName(node); + if (qualifiedName != null) { + return replacements.get(qualifiedName) == null ? node : replacements.get(qualifiedName); + } + } + return node; + } + + private static DereferenceExpression toDereferenceExpression(ExpressionRelationshipInfo expressionRelationshipInfo) + { + String base = expressionRelationshipInfo.getRelationships().get(expressionRelationshipInfo.getRelationships().size() - 1).getModels().get(1); + List parts = new ArrayList<>(); + parts.add(new Identifier(base, true)); + expressionRelationshipInfo.getRemainingParts().stream() + .map(part -> new Identifier(part, true)) + .forEach(parts::add); + return (DereferenceExpression) DereferenceExpression.from(QualifiedName.of(parts)); + } +} diff --git a/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/ScopeAwareRewrite.java b/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/ScopeAwareRewrite.java deleted file mode 100644 index 5e8691929..000000000 --- a/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/ScopeAwareRewrite.java +++ /dev/null @@ -1,220 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.accio.sqlrewrite; - -import com.google.common.annotations.VisibleForTesting; -import com.google.common.collect.ImmutableList; -import io.accio.base.AccioMDL; -import io.accio.base.SessionContext; -import io.accio.sqlrewrite.analyzer.Field; -import io.accio.sqlrewrite.analyzer.Scope; -import io.trino.sql.tree.DereferenceExpression; -import io.trino.sql.tree.Expression; -import io.trino.sql.tree.Identifier; -import io.trino.sql.tree.Node; -import io.trino.sql.tree.QualifiedName; -import io.trino.sql.tree.QuerySpecification; -import io.trino.sql.tree.Statement; -import io.trino.sql.tree.SubscriptExpression; - -import java.util.Arrays; -import java.util.List; -import java.util.Optional; -import java.util.stream.Collectors; - -import static com.google.common.collect.Iterables.getLast; -import static io.accio.sqlrewrite.Utils.analyzeFrom; -import static io.accio.sqlrewrite.Utils.getNextPart; -import static io.accio.sqlrewrite.Utils.parseExpression; -import static io.accio.sqlrewrite.Utils.toQualifiedName; -import static io.trino.sql.QueryUtil.identifier; -import static java.lang.String.format; - -/** - * Rewrite the AST to replace all identifiers or dereference expressions - * without a relation prefix with the relation prefix. - */ -public class ScopeAwareRewrite -{ - public static final ScopeAwareRewrite SCOPE_AWARE_REWRITE = new ScopeAwareRewrite(); - - public Statement rewrite(Node root, AccioMDL accioMDL, SessionContext sessionContext) - { - return (Statement) new Rewriter(accioMDL, sessionContext).process(root); - } - - private static class Rewriter - extends BaseRewriter - { - private final AccioMDL accioMDL; - private final SessionContext sessionContext; - - public Rewriter(AccioMDL accioMDL, SessionContext sessionContext) - { - this.accioMDL = accioMDL; - this.sessionContext = sessionContext; - } - - @Override - protected Node visitQuerySpecification(QuerySpecification node, Scope context) - { - Scope relationScope; - if (node.getFrom().isPresent()) { - relationScope = analyzeFrom(accioMDL, sessionContext, node.getFrom().get(), Optional.ofNullable(context)); - } - else { - relationScope = context; - } - return super.visitQuerySpecification(node, relationScope); - } - - @Override - protected Node visitIdentifier(Identifier node, Scope context) - { - if (context.getRelationType().isPresent()) { - List field = context.getRelationType().get().resolveFields(QualifiedName.of(node.getValue())); - if (field.size() == 1) { - return new DereferenceExpression(identifier(field.get(0).getRelationAlias() - .orElse(toQualifiedName(field.get(0).getModelName())) - .getSuffix()), identifier(field.get(0).getColumnName())); - } - if (field.size() > 1) { - throw new IllegalArgumentException("Ambiguous column name: " + node.getValue()); - } - } - return node; - } - - @Override - protected Node visitDereferenceExpression(DereferenceExpression node, Scope context) - { - if (context != null && context.getRelationType().isPresent()) { - List parts = getPartsQuietly(node); - for (int i = 0; i < parts.size(); i++) { - List field = context.getRelationType().get().resolveFields(QualifiedName.of(parts.subList(0, i + 1))); - if (field.size() == 1) { - Field firstMatch = field.get(0); - if (i == 3) { - // catalog.schema.table.column - return removePrefix(node, 2); - } - if (i == 2) { - // schema.table.column - return removePrefix(node, 1); - } - if (i == 1) { - // table.column - return node; - } - return addPrefix(node, identifier(field.get(0).getRelationAlias().orElse(toQualifiedName(field.get(0).getModelName())).getSuffix())); - } - if (field.size() > 1) { - throw new IllegalArgumentException("Ambiguous column name: " + DereferenceExpression.getQualifiedName(node)); - } - } - } - return node; - } - - private List getPartsQuietly(Expression expression) - { - try { - return getParts(expression); - } - catch (IllegalArgumentException ex) { - return List.of(); - } - } - - private List getParts(Expression expression) - { - if (expression instanceof Identifier) { - return ImmutableList.of(((Identifier) expression).getValue()); - } - else if (expression instanceof DereferenceExpression) { - DereferenceExpression dereferenceExpression = (DereferenceExpression) expression; - List baseQualifiedName = getParts(dereferenceExpression.getBase()); - ImmutableList.Builder builder = ImmutableList.builder(); - builder.addAll(baseQualifiedName); - builder.add(dereferenceExpression.getField().orElseThrow().getValue()); - return builder.build(); - } - else if (expression instanceof SubscriptExpression) { - SubscriptExpression subscriptExpression = (SubscriptExpression) expression; - List baseQualifiedName = getParts(subscriptExpression.getBase()); - if (baseQualifiedName != null) { - ImmutableList.Builder builder = ImmutableList.builder(); - builder.addAll(baseQualifiedName.subList(0, baseQualifiedName.size() - 1)); - builder.add(format("%s[%s]", getLast(baseQualifiedName), subscriptExpression.getIndex().toString())); - return builder.build(); - } - } - else { - throw new IllegalArgumentException("Unsupported node "); - } - return ImmutableList.of(); - } - } - - @VisibleForTesting - public static Expression addPrefix(Expression source, Identifier prefix) - { - ImmutableList.Builder builder = ImmutableList.builder(); - - Expression node = source; - while (node instanceof DereferenceExpression || node instanceof SubscriptExpression) { - if (node instanceof DereferenceExpression) { - DereferenceExpression dereferenceExpression = (DereferenceExpression) node; - builder.add(dereferenceExpression.getField().orElseThrow()); - node = dereferenceExpression.getBase(); - } - else { - SubscriptExpression subscriptExpression = (SubscriptExpression) node; - Identifier base; - if (subscriptExpression.getBase() instanceof Identifier) { - base = (Identifier) subscriptExpression.getBase(); - } - else { - base = ((DereferenceExpression) subscriptExpression.getBase()).getField().orElseThrow(); - } - builder.add(new SubscriptExpression(base, subscriptExpression.getIndex())); - node = getNextPart(subscriptExpression); - } - } - - if (node instanceof Identifier) { - builder.add(node); - } - - return builder.add(prefix).build().reverse().stream().reduce((a, b) -> { - if (b instanceof SubscriptExpression) { - SubscriptExpression subscriptExpression = (SubscriptExpression) b; - return new SubscriptExpression(new DereferenceExpression(a, (Identifier) subscriptExpression.getBase()), ((SubscriptExpression) b).getIndex()); - } - else if (b instanceof Identifier) { - return new DereferenceExpression(a, (Identifier) b); - } - throw new IllegalArgumentException(format("Unexpected expression: %s", b)); - }).orElseThrow(() -> new IllegalArgumentException(format("Unexpected expression: %s", source))); - } - - private static Expression removePrefix(DereferenceExpression dereferenceExpression, int removeFirstN) - { - return parseExpression( - Arrays.stream(dereferenceExpression.toString().split("\\.")) - .skip(removeFirstN) - .collect(Collectors.joining("."))); - } -} diff --git a/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/SyntacticSugarRewrite.java b/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/SyntacticSugarRewrite.java deleted file mode 100644 index 3dcb0b0c2..000000000 --- a/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/SyntacticSugarRewrite.java +++ /dev/null @@ -1,131 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.accio.sqlrewrite; - -import io.accio.base.AccioMDL; -import io.accio.base.SessionContext; -import io.accio.sqlrewrite.analyzer.Analysis; -import io.accio.sqlrewrite.analyzer.Scope; -import io.accio.sqlrewrite.analyzer.StatementAnalyzer; -import io.trino.sql.tree.DereferenceExpression; -import io.trino.sql.tree.Expression; -import io.trino.sql.tree.FunctionCall; -import io.trino.sql.tree.Identifier; -import io.trino.sql.tree.LongLiteral; -import io.trino.sql.tree.Node; -import io.trino.sql.tree.QualifiedName; -import io.trino.sql.tree.SingleColumn; -import io.trino.sql.tree.Statement; -import io.trino.sql.tree.SubscriptExpression; - -import java.util.List; -import java.util.Optional; - -import static java.util.Objects.requireNonNull; - -/** - * Rewrite Accio syntactic sugar: - *

  • Add column alias to avoid losing original column name since we will rewrite relationship column in {@link AccioSqlRewrite} - * e.g. {@code SELECT author FROM Book} -> {@code SELECT author AS author FROM Book}
  • - *
  • `any` Function is an alias of to-many result accessing. - * e.g. {@code SELECT any(books) FROM User} -> {@code SELECT books[1] FROM User}
  • - *
  • `first` Function is an alias of sorted to-many result accessing. - * e.g. {@code SELECT first(books) FROM User} -> {@code SELECT array_sort(books)[1] FROM User}
  • - */ -public class SyntacticSugarRewrite - implements AccioRule -{ - public static final SyntacticSugarRewrite SYNTACTIC_SUGAR_REWRITE = new SyntacticSugarRewrite(); - - @Override - public Statement apply(Statement root, SessionContext sessionContext, AccioMDL accioMDL) - { - Analysis analysis = StatementAnalyzer.analyze(root, sessionContext, accioMDL); - return apply(root, sessionContext, analysis, accioMDL); - } - - @Override - public Statement apply(Statement root, SessionContext sessionContext, Analysis analysis, AccioMDL accioMDL) - { - return (Statement) new SyntacticSugarRewrite.Rewriter(analysis).process(root); - } - - private static class Rewriter - extends BaseRewriter - { - private final Analysis analysis; - - Rewriter(Analysis analysis) - { - this.analysis = requireNonNull(analysis); - } - - @Override - protected Node visitSingleColumn(SingleColumn node, Void context) - { - Expression result = visitAndCast(node.getExpression(), context); - if (result.equals(node.getExpression()) && !belongsToAccioDataObject(node.getExpression())) { - return new SingleColumn(result, node.getAlias()); - } - // Because we rewrite the relationship field in AccioSqlRewrite - // we need to add an alias to keep its original name. - Identifier resultAlias = node.getAlias().orElse(null); - if (node.getExpression() instanceof Identifier) { - Identifier identifier = (Identifier) node.getExpression(); - resultAlias = node.getAlias().orElse(identifier); - } - else if (node.getExpression() instanceof DereferenceExpression) { - DereferenceExpression dereferenceExpression = (DereferenceExpression) node.getExpression(); - resultAlias = node.getAlias().orElse(dereferenceExpression.getField().orElse(null)); - } - if (node.getLocation().isPresent()) { - return new SingleColumn(node.getLocation().get(), result, Optional.ofNullable(resultAlias)); - } - return new SingleColumn(result, Optional.ofNullable(resultAlias)); - } - - @Override - protected Node visitFunctionCall(FunctionCall node, Void context) - { - String name = node.getName().toString(); - if (name.equalsIgnoreCase("any")) { - return new SubscriptExpression(requireNonNull(node.getArguments().get(0)), new LongLiteral("1")); - } - if (node.getName().toString().equalsIgnoreCase("first")) { - return new SubscriptExpression(new FunctionCall(QualifiedName.of("array_sort"), node.getArguments()), new LongLiteral("1")); - } - return super.visitFunctionCall(node, context); - } - - private boolean belongsToAccioDataObject(Expression node) - { - QualifiedName qualifiedName; - if (node instanceof Identifier) { - qualifiedName = QualifiedName.of(List.of((Identifier) node)); - } - else if (node instanceof DereferenceExpression) { - qualifiedName = DereferenceExpression.getQualifiedName((DereferenceExpression) node); - } - else { - return false; - } - - return analysis.tryGetScope(node) - .flatMap(Scope::getRelationType) - .flatMap(relationType -> relationType.resolveAnyField(qualifiedName)) - .isPresent(); - } - } -} diff --git a/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/Utils.java b/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/Utils.java index c2d98ccfc..221c8dc6a 100644 --- a/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/Utils.java +++ b/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/Utils.java @@ -21,7 +21,6 @@ import io.accio.base.SessionContext; import io.accio.base.dto.Column; import io.accio.base.dto.Metric; -import io.accio.base.dto.Model; import io.accio.sqlrewrite.analyzer.Field; import io.accio.sqlrewrite.analyzer.MetricRollupInfo; import io.accio.sqlrewrite.analyzer.RelationType; @@ -87,24 +86,13 @@ public static DataType parseType(String type) return SQL_PARSER.createType(type); } - public static Query parseModelSql(Model model) + public static Query parseQuery(String sql) { - String sql = getModelSql(model); Statement statement = SQL_PARSER.createStatement(sql, new ParsingOptions(AS_DECIMAL)); if (statement instanceof Query) { return (Query) statement; } - throw new IllegalArgumentException(format("model %s is not a query, sql %s", model.getName(), sql)); - } - - public static String getModelSql(Model model) - { - requireNonNull(model, "model is null"); - if (model.getColumns().isEmpty()) { - return model.getRefSql(); - } - // In postgres, all subquery should have alias. - return format("SELECT %s FROM (%s) t", model.getColumns().stream().map(Column::getSqlExpression).collect(joining(", ")), model.getRefSql()); + throw new IllegalArgumentException("model sql is not a query"); } public static Query parseMetricSql(Metric metric) diff --git a/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/analyzer/Analysis.java b/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/analyzer/Analysis.java index c33508384..5fca867cd 100644 --- a/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/analyzer/Analysis.java +++ b/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/analyzer/Analysis.java @@ -19,48 +19,32 @@ import io.accio.base.dto.Model; import io.accio.base.dto.Relationship; import io.accio.base.dto.View; -import io.accio.sqlrewrite.RelationshipCteGenerator; -import io.trino.sql.tree.Expression; import io.trino.sql.tree.FunctionRelation; -import io.trino.sql.tree.GroupBy; -import io.trino.sql.tree.Node; import io.trino.sql.tree.NodeRef; -import io.trino.sql.tree.Relation; import io.trino.sql.tree.Statement; import io.trino.sql.tree.Table; -import io.trino.sql.tree.WithQuery; import java.util.HashMap; import java.util.HashSet; -import java.util.LinkedHashMap; -import java.util.List; import java.util.Map; -import java.util.Optional; import java.util.Set; -import static java.lang.String.format; import static java.util.Objects.requireNonNull; public class Analysis { private final Statement root; private final Set tables = new HashSet<>(); - private final RelationshipCteGenerator relationshipCteGenerator; - private final Map, Expression> relationshipFields = new HashMap<>(); private final Set> modelNodeRefs = new HashSet<>(); - private final Map, Set> replaceTableWithCTEs = new HashMap<>(); private final Set relationships = new HashSet<>(); private final Set models = new HashSet<>(); - private final Map, Scope> scopes = new LinkedHashMap<>(); private final Set metrics = new HashSet<>(); private final Map, MetricRollupInfo> metricRollups = new HashMap<>(); private final Set views = new HashSet<>(); - private final Map, GroupByAnalysis> groupByAnalysis = new HashMap<>(); - Analysis(Statement statement, RelationshipCteGenerator relationshipCteGenerator) + Analysis(Statement statement) { this.root = requireNonNull(statement, "statement is null"); - this.relationshipCteGenerator = relationshipCteGenerator; } public Statement getRoot() @@ -78,46 +62,6 @@ public Set getTables() return Set.copyOf(tables); } - public Map getRelationshipCTE() - { - return relationshipCteGenerator.getRegisteredWithQuery(); - } - - public Map getRelationshipNameMapping() - { - return relationshipCteGenerator.getNameMapping(); - } - - public Map getRelationshipInfoMapping() - { - return relationshipCteGenerator.getRelationshipInfoMapping(); - } - - void addRelationshipFields(Map, Expression> relationshipFields) - { - this.relationshipFields.putAll(relationshipFields); - } - - public Map, Expression> getRelationshipFields() - { - return Map.copyOf(relationshipFields); - } - - public void addReplaceTableWithCTEs(NodeRef relationNodeRef, Set relationshipCTENames) - { - this.replaceTableWithCTEs.put(relationNodeRef, relationshipCTENames); - } - - public Map, Set> getReplaceTableWithCTEs() - { - return Map.copyOf(replaceTableWithCTEs); - } - - void addRelationships(Set relationships) - { - this.relationships.addAll(relationships); - } - public Set getRelationships() { return relationships; @@ -143,26 +87,6 @@ public Set> getModelNodeRefs() return modelNodeRefs; } - public Optional getOutputDescriptor(Node node) - { - return getScope(node).getRelationType(); - } - - public Scope getScope(Node node) - { - return tryGetScope(node).orElseThrow(() -> new IllegalArgumentException(format("Analysis does not contain information for node: %s", node))); - } - - public Optional tryGetScope(Node node) - { - NodeRef key = NodeRef.of(node); - if (scopes.containsKey(key)) { - return Optional.of(scopes.get(key)); - } - - return Optional.empty(); - } - void addMetrics(Set metrics) { this.metrics.addAll(metrics); @@ -173,11 +97,6 @@ public Set getMetrics() return metrics; } - public void setScope(Node node, Scope scope) - { - scopes.put(NodeRef.of(node), scope); - } - void addMetricRollups(NodeRef metricRollupNodeRef, MetricRollupInfo metricRollupInfo) { metricRollups.put(metricRollupNodeRef, metricRollupInfo); @@ -188,16 +107,6 @@ public Map, MetricRollupInfo> getMetricRollups() return metricRollups; } - void addGroupAnalysis(GroupBy groupByNode, GroupByAnalysis groupByAnalysis) - { - this.groupByAnalysis.put(NodeRef.of(groupByNode), groupByAnalysis); - } - - public Map, GroupByAnalysis> getGroupByAnalysis() - { - return groupByAnalysis; - } - public Set getViews() { return views; @@ -207,19 +116,4 @@ void addViews(Set views) { this.views.addAll(views); } - - public static class GroupByAnalysis - { - private final List originalExpressions; - - public GroupByAnalysis(List originalExpressions) - { - this.originalExpressions = originalExpressions; - } - - public List getOriginalExpressions() - { - return originalExpressions; - } - } } diff --git a/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/analyzer/ExpressionAnalysis.java b/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/analyzer/ExpressionAnalysis.java deleted file mode 100644 index 6173f40c8..000000000 --- a/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/analyzer/ExpressionAnalysis.java +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.accio.sqlrewrite.analyzer; - -import io.accio.base.dto.Relationship; -import io.trino.sql.tree.Expression; -import io.trino.sql.tree.NodeRef; - -import java.util.Map; -import java.util.Set; - -import static java.util.Objects.requireNonNull; - -public class ExpressionAnalysis -{ - private final Expression expression; - private final Map, Expression> relationshipFieldRewrites; - private final Set relationshipCTENames; - private final Set relationships; - - public ExpressionAnalysis( - Expression expression, - Map, Expression> relationshipFields, - Set relationshipCTENames, - Set relationships) - { - this.expression = requireNonNull(expression, "expression is null"); - this.relationshipFieldRewrites = requireNonNull(relationshipFields, "relationshipFields is null"); - this.relationshipCTENames = requireNonNull(relationshipCTENames, "relationshipCTENames is null"); - this.relationships = requireNonNull(relationships, "relationships is null"); - } - - public Expression getExpression() - { - return expression; - } - - public Map, Expression> getRelationshipFieldRewrites() - { - return relationshipFieldRewrites; - } - - public Set getRelationshipCTENames() - { - return relationshipCTENames; - } - - public Set getRelationships() - { - return relationships; - } -} diff --git a/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/analyzer/ExpressionAnalyzer.java b/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/analyzer/ExpressionAnalyzer.java deleted file mode 100644 index 7bfb31917..000000000 --- a/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/analyzer/ExpressionAnalyzer.java +++ /dev/null @@ -1,451 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.accio.sqlrewrite.analyzer; - -import com.google.common.collect.ImmutableList; -import com.google.common.collect.Lists; -import io.accio.base.AccioMDL; -import io.accio.base.SessionContext; -import io.accio.base.dto.Column; -import io.accio.base.dto.Model; -import io.accio.base.dto.Relationship; -import io.accio.sqlrewrite.RelationshipCTE; -import io.accio.sqlrewrite.RelationshipCteGenerator; -import io.accio.sqlrewrite.analyzer.FunctionChainAnalyzer.ReturnContext; -import io.trino.sql.tree.DefaultTraversalVisitor; -import io.trino.sql.tree.DereferenceExpression; -import io.trino.sql.tree.Expression; -import io.trino.sql.tree.FunctionCall; -import io.trino.sql.tree.Identifier; -import io.trino.sql.tree.NodeRef; -import io.trino.sql.tree.QualifiedName; -import io.trino.sql.tree.SubscriptExpression; - -import java.util.HashMap; -import java.util.HashSet; -import java.util.LinkedList; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.Set; - -import static io.accio.base.Utils.checkArgument; -import static io.accio.sqlrewrite.RelationshipCteGenerator.LAMBDA_RESULT_NAME; -import static io.accio.sqlrewrite.RelationshipCteGenerator.RelationshipOperation.access; -import static io.accio.sqlrewrite.RelationshipCteGenerator.RsItem.Type.CTE; -import static io.accio.sqlrewrite.RelationshipCteGenerator.RsItem.Type.REVERSE_RS; -import static io.accio.sqlrewrite.RelationshipCteGenerator.RsItem.Type.RS; -import static io.accio.sqlrewrite.RelationshipCteGenerator.RsItem.rsItem; -import static java.lang.String.format; -import static java.util.Objects.requireNonNull; - -public final class ExpressionAnalyzer -{ - private ExpressionAnalyzer() {} - - private final Map, Expression> relationshipFieldsRewrite = new HashMap<>(); - private final Set relationshipCTENames = new HashSet<>(); - private final Set relationships = new HashSet<>(); - - public static ExpressionAnalysis analyze( - Expression expression, - SessionContext sessionContext, - AccioMDL accioMDL, - RelationshipCteGenerator relationshipCteGenerator, - Scope scope) - { - ExpressionAnalyzer expressionAnalyzer = new ExpressionAnalyzer(); - return expressionAnalyzer.analyzeExpression(expression, sessionContext, accioMDL, relationshipCteGenerator, scope); - } - - private ExpressionAnalysis analyzeExpression( - Expression expression, - SessionContext sessionContext, - AccioMDL accioMDL, - RelationshipCteGenerator relationshipCteGenerator, - Scope scope) - { - new Visitor(sessionContext, accioMDL, relationshipCteGenerator, scope).process(expression); - return new ExpressionAnalysis(expression, relationshipFieldsRewrite, relationshipCTENames, relationships); - } - - private class Visitor - extends DefaultTraversalVisitor - { - private final SessionContext sessionContext; - private final AccioMDL accioMDL; - private final RelationshipCteGenerator relationshipCteGenerator; - private final Scope scope; - private final FunctionChainAnalyzer functionChainAnalyzer; - - public Visitor(SessionContext sessionContext, AccioMDL accioMDL, RelationshipCteGenerator relationshipCteGenerator, Scope scope) - { - this.sessionContext = requireNonNull(sessionContext, "sessionContext is null"); - this.accioMDL = requireNonNull(accioMDL, "accioMDL is null"); - this.relationshipCteGenerator = requireNonNull(relationshipCteGenerator, "relationshipCteGenerator is null"); - this.scope = requireNonNull(scope, "scope is null"); - this.functionChainAnalyzer = FunctionChainAnalyzer.of(relationshipCteGenerator, this::registerRelationshipCTEs); - } - - @Override - protected Void visitFunctionCall(FunctionCall node, Void ignored) - { - Optional returnContext = functionChainAnalyzer.analyze(node); - if (returnContext.isEmpty()) { - return null; - } - - returnContext.get().getNodesToReplace().forEach((nodeToReplace, rsField) -> { - // nodeToReplace is a lambda function call - if (nodeToReplace.getNode() instanceof FunctionCall) { - relationshipCTENames.add(nodeToReplace.getNode().toString()); - relationshipFieldsRewrite.put( - nodeToReplace, - DereferenceExpression.from( - QualifiedName.of( - List.of( - relationshipCteGenerator.getNameMapping().get(nodeToReplace.getNode().toString()), - LAMBDA_RESULT_NAME)))); - } - // nodeToReplace is a relationship field in function - else { - String cteName = String.join(".", rsField.getCteNameParts()); - relationshipCTENames.add(cteName); - relationshipFieldsRewrite.put( - nodeToReplace, - DereferenceExpression.from( - QualifiedName.of( - List.of( - relationshipCteGenerator.getNameMapping().get(cteName), - rsField.getColumnName())))); - } - }); - return null; - } - - @Override - protected Void visitDereferenceExpression(DereferenceExpression node, Void ignored) - { - registerRelationshipCTEs(node) - .ifPresent(info -> { - relationshipCTENames.add(String.join(".", info.getReplacementNameParts())); - relationshipFieldsRewrite.put(NodeRef.of(info.getOriginal()), info.getReplacement()); - }); - return null; - } - - @Override - protected Void visitIdentifier(Identifier node, Void ignored) - { - registerRelationshipCTEs(node) - .ifPresent(info -> { - relationshipCTENames.add(String.join(".", info.getReplacementNameParts())); - relationshipFieldsRewrite.put(NodeRef.of(info.getOriginal()), info.getReplacement()); - }); - return null; - } - - // register needed relationship CTEs and return node replacement information - private Optional registerRelationshipCTEs(Expression node) - { - if (!scope.isTableScope()) { - return Optional.empty(); - } - - LinkedList elements = elements(node); - if (elements.isEmpty()) { - return Optional.empty(); - } - - String baseModelName; - Expression root = elements.peekFirst(); - LinkedList nameParts = new LinkedList<>(); - LinkedList chain = new LinkedList<>(); - // process the root node, root node should be either FunctionCall or Identifier, if not, relationship rewrite won't be fired - if (root instanceof FunctionCall) { - Optional returnContext = functionChainAnalyzer.analyze((FunctionCall) root); - boolean functionCallNeedsReplacement = returnContext.isPresent() && returnContext.get().getNodesToReplace().size() > 0; - if (!functionCallNeedsReplacement) { - return Optional.empty(); - } - Map, RelationshipField> nodesToReplace = returnContext.get().getNodesToReplace(); - checkArgument(nodesToReplace.size() == 1, "No node or multiple node to replace in function chain in DereferenceExpression chain"); - - nodesToReplace.forEach((nodeToReplace, rsField) -> - relationshipFieldsRewrite.put( - nodeToReplace, - DereferenceExpression.from( - QualifiedName.of( - ImmutableList.builder() - .add(relationshipCteGenerator.getNameMapping().get(nodeToReplace.getNode().toString())) - .add(LAMBDA_RESULT_NAME).build())))); - nameParts.add(elements.pop().toString()); - baseModelName = nodesToReplace.values().iterator().next().getBaseModelName(); - } - else if (root instanceof Identifier) { - List modelFields = scope.getRelationType() - .orElseThrow(() -> new IllegalArgumentException("relation type is empty")) - .getFields(); - - Optional relationshipField = Optional.empty(); - // process column with prefix. i.e. [TableAlias|TableName].column - while (elements.size() > 0 && relationshipField.isEmpty()) { - QualifiedName current; - Expression element = elements.pop(); - if (element instanceof Identifier) { - current = QualifiedName.of(List.of((Identifier) element)); - } - else if (element instanceof DereferenceExpression) { - current = DereferenceExpression.getQualifiedName((DereferenceExpression) element); - } - else { - break; - } - - relationshipField = modelFields.stream() - .filter(scopeField -> scopeField.canResolve(current)) - .filter(Field::isRelationship) - .findAny(); - } - - if (relationshipField.isPresent()) { - String fieldModelName = relationshipField.get().getModelName().getSchemaTableName().getTableName(); - String fieldTypeName = relationshipField.get().getType(); - String fieldName = relationshipField.get().getColumnName(); - Relationship relationship = relationshipField.get().getRelationship().orElseThrow(); - List parts = List.of(fieldModelName, relationshipField.get().getColumnName()); - relationships.add(relationship); - relationshipCteGenerator.register( - parts, - access(List.of(rsItem(relationship.getName(), relationship.getModels().get(0).equals(fieldTypeName) ? REVERSE_RS : RS)))); - - baseModelName = relationshipField.get().getModelName().getSchemaTableName().getTableName(); - nameParts.addAll(parts); - chain.add(new RelationshipField(nameParts, fieldModelName, fieldName, relationship, baseModelName)); - } - else { - return Optional.empty(); - } - } - else { - return Optional.empty(); - } - - while (elements.size() > 0) { - Expression expression = elements.pop(); - if (expression instanceof DereferenceExpression) { - DereferenceExpression dereferenceExpression = (DereferenceExpression) expression; - RelationshipCTE cte = relationshipCteGenerator.getRelationshipCTEs().get(String.join(".", nameParts)); - if (cte == null) { - return Optional.empty(); - } - - Identifier field = dereferenceExpression.getField().orElseThrow(); - String modelName = cte.getTarget().getName(); - Optional relationshipColumn = accioMDL.getModel(modelName) - .stream() - .map(Model::getColumns) - .flatMap(List::stream) - .filter(column -> column.getName().equals(field.getValue()) && column.getRelationship().isPresent()) - .findAny(); - - if (relationshipColumn.isPresent()) { - Relationship relationship = accioMDL.getRelationship(relationshipColumn.get().getRelationship().get()) - .orElseThrow(() -> new IllegalArgumentException("Relationship not found")); - relationships.add(relationship); - String relationshipColumnType = relationshipColumn.get().getType(); - nameParts.add(field.getValue()); - relationshipCteGenerator.register( - nameParts, - access(List.of( - rsItem(String.join(".", nameParts.subList(0, nameParts.size() - 1)), CTE), - rsItem(relationship.getName(), relationship.getModels().get(0).equals(relationshipColumnType) ? REVERSE_RS : RS))), - baseModelName); - chain.add(new RelationshipField(nameParts, modelName, relationshipColumn.get().getName(), relationship, baseModelName)); - } - else { - return Optional.of( - new ReplaceNodeInfo( - nameParts, - dereferenceExpression.getBase(), - new Identifier( - relationshipCteGenerator.getNameMapping().get(String.join(".", nameParts))), - chain.isEmpty() ? Optional.empty() : Optional.of(chain.getLast()))); - } - } - else if (expression instanceof SubscriptExpression) { - SubscriptExpression subscriptExpression = (SubscriptExpression) expression; - String index = subscriptExpression.getIndex().toString(); - String cteName = String.join(".", nameParts); - RelationshipCTE cte = relationshipCteGenerator.getRelationshipCTEs().get(cteName); - if (cte == null) { - return Optional.empty(); - } - Relationship relationship = cte.getRelationship(); - relationships.add(relationship); - String lastNamePart = nameParts.removeLast(); - nameParts.add(format("%s[%s]", lastNamePart, index)); - - relationshipCteGenerator.register( - nameParts, - access(List.of( - rsItem(cteName, CTE, index), - rsItem(relationship.getName(), relationship.isReverse() ? REVERSE_RS : RS))), - baseModelName, - subscriptExpression.getBase() instanceof FunctionCall ? LAMBDA_RESULT_NAME : lastNamePart); - } - else { - throw new IllegalArgumentException("Unsupported operation"); - } - } - - // whole dereference expression or identifier is a relationship column - return Optional.ofNullable(relationshipCteGenerator.getRelationshipCTEs().get(String.join(".", nameParts))) - .map(relationshipCTE -> - new ReplaceNodeInfo( - nameParts, - node, - toRelationshipReplacement(relationshipCTE, nameParts.getLast()), - chain.isEmpty() ? Optional.empty() : Optional.of(chain.getLast()))); - } - - // if rs column is a to-1 rs, use the to-1 side model primary key - // if rs column is a to-N rs, use the last name part of expression since we keep the relationship column - // name as the same in all CTEs, here we could directly use last name part in expression to represent - // the target to-N rs column name. - private Expression toRelationshipReplacement(RelationshipCTE relationshipCTE, String lastNamePart) - { - if (relationshipCTE.getRelationship().getJoinType().isToOne()) { - return DereferenceExpression.from( - QualifiedName.of(relationshipCTE.getName(), relationshipCTE.getTarget().getPrimaryKey())); - } - else { - return DereferenceExpression.from( - QualifiedName.of(relationshipCTE.getName(), lastNamePart)); - } - } - } - - private static LinkedList elements(Expression expression) - { - Expression current = expression; - LinkedList elements = new LinkedList<>(); - while (true) { - if (current instanceof FunctionCall || current instanceof Identifier) { - elements.add(current); - // in dereference expression, function call or identifier should be the root node - break; - } - else if (current instanceof DereferenceExpression) { - elements.add(current); - current = ((DereferenceExpression) current).getBase(); - } - else if (current instanceof SubscriptExpression) { - elements.add(current); - current = ((SubscriptExpression) current).getBase(); - } - else { - // unexpected node in dereference expression, clear everything and return - elements.clear(); - break; - } - } - return new LinkedList<>(Lists.reverse(elements)); - } - - static class RelationshipField - { - private final List cteNameParts; - private final String modelName; - private final String columnName; - private final Relationship relationship; - private final String baseModelName; - - public RelationshipField(List cteNameParts, String modelName, String columnName, Relationship relationship, String baseModelName) - { - this.cteNameParts = requireNonNull(cteNameParts); - this.modelName = requireNonNull(modelName); - this.columnName = requireNonNull(columnName); - this.relationship = requireNonNull(relationship); - this.baseModelName = requireNonNull(baseModelName); - } - - public List getCteNameParts() - { - return cteNameParts; - } - - public String getModelName() - { - return modelName; - } - - public String getColumnName() - { - return columnName; - } - - public Relationship getRelationship() - { - return relationship; - } - - public String getBaseModelName() - { - return baseModelName; - } - } - - static class ReplaceNodeInfo - { - private final List replacementNameParts; - private final Expression original; - private final Expression replacement; - // TODO: this is required for function call processor, find other better way to search relationship field in function call - private final Optional lastRelationshipField; - - public ReplaceNodeInfo( - List replacementNameParts, - Expression original, - Expression replacement, - Optional lastRelationshipField) - { - this.replacementNameParts = requireNonNull(replacementNameParts); - this.original = requireNonNull(original); - this.replacement = requireNonNull(replacement); - this.lastRelationshipField = requireNonNull(lastRelationshipField); - } - - public List getReplacementNameParts() - { - return replacementNameParts; - } - - public Expression getOriginal() - { - return original; - } - - public Expression getReplacement() - { - return replacement; - } - - public Optional getLastRelationshipField() - { - return lastRelationshipField; - } - } -} diff --git a/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/analyzer/ExpressionRelationshipAnalyzer.java b/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/analyzer/ExpressionRelationshipAnalyzer.java new file mode 100644 index 000000000..347a01032 --- /dev/null +++ b/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/analyzer/ExpressionRelationshipAnalyzer.java @@ -0,0 +1,162 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.accio.sqlrewrite.analyzer; + +import io.accio.base.AccioMDL; +import io.accio.base.dto.Column; +import io.accio.base.dto.Model; +import io.accio.base.dto.Relationship; +import io.trino.sql.tree.DefaultTraversalVisitor; +import io.trino.sql.tree.DereferenceExpression; +import io.trino.sql.tree.Expression; +import io.trino.sql.tree.QualifiedName; + +import java.util.ArrayList; +import java.util.List; +import java.util.NoSuchElementException; +import java.util.Optional; + +import static com.google.common.base.Preconditions.checkArgument; +import static io.accio.base.AccioMDL.getRelationshipColumn; +import static io.trino.sql.tree.DereferenceExpression.getQualifiedName; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +public class ExpressionRelationshipAnalyzer +{ + private ExpressionRelationshipAnalyzer() {} + + public static List getRelationships(Expression expression, AccioMDL mdl, Model model) + { + RelationshipCollector collector = new RelationshipCollector(mdl, model); + collector.process(expression); + return collector.getExpressionRelationshipInfo(); + } + + private static class RelationshipCollector + extends DefaultTraversalVisitor + { + private final AccioMDL accioMDL; + private final Model model; + private final List relationships = new ArrayList<>(); + + public RelationshipCollector(AccioMDL accioMDL, Model model) + { + this.accioMDL = requireNonNull(accioMDL); + this.model = requireNonNull(model); + } + + public List getExpressionRelationshipInfo() + { + return relationships; + } + + @Override + protected Void visitDereferenceExpression(DereferenceExpression node, Void ignored) + { + if (node.getField().isPresent()) { + QualifiedName qualifiedName = getQualifiedName(node); + if (qualifiedName != null) { + Optional expressionRelationshipInfo = createRelationshipInfo(qualifiedName, model, accioMDL); + if (expressionRelationshipInfo.isPresent()) { + validateToOne(expressionRelationshipInfo.get()); + relationships.add(expressionRelationshipInfo.get()); + } + } + } + return null; + } + } + + private static Optional createRelationshipInfo(QualifiedName qualifiedName, Model model, AccioMDL mdl) + { + List relationships = new ArrayList<>(); + Model current = model; + Relationship baseModelRelationship = null; + + for (int i = 0; i < qualifiedName.getParts().size(); i++) { + String columnName = qualifiedName.getParts().get(i); + Optional relationshipColumnOpt = getRelationshipColumn(current, columnName); + + if (relationshipColumnOpt.isEmpty()) { + if (i == 0) { + return Optional.empty(); + } + return buildExpressionRelationshipInfo(qualifiedName, relationships, baseModelRelationship, i); + } + + Column relationshipColumn = relationshipColumnOpt.get(); + Relationship relationship = getRelationshipFromMDL(relationshipColumn, mdl); + relationship = reverseIfNeeded(relationship, relationshipColumn.getType()); + + relationships.add(relationship); + if (current == model) { + baseModelRelationship = relationship; + } + + current = getNextModel(relationshipColumn, mdl); + checkForCycle(current, model); + } + + return Optional.empty(); + } + + private static Relationship getRelationshipFromMDL(Column relationshipColumn, AccioMDL mdl) + { + String relationshipName = relationshipColumn.getRelationship().get(); + return mdl.getRelationship(relationshipName) + .orElseThrow(() -> new NoSuchElementException(format("relationship %s not found", relationshipName))); + } + + private static Model getNextModel(Column relationshipColumn, AccioMDL mdl) + { + return mdl.getModel(relationshipColumn.getType()) + .orElseThrow(() -> new NoSuchElementException(format("model %s not found", relationshipColumn.getType()))); + } + + private static void checkForCycle(Model current, Model model) + { + checkArgument(current != model, "found cycle in expression"); + } + + private static Optional buildExpressionRelationshipInfo( + QualifiedName qualifiedName, + List relationships, + Relationship baseModelRelationship, + int index) + { + return Optional.of(new ExpressionRelationshipInfo( + qualifiedName, + qualifiedName.getParts().subList(0, index), + qualifiedName.getParts().subList(index, qualifiedName.getParts().size()), + relationships, + baseModelRelationship)); + } + + private static Relationship reverseIfNeeded(Relationship relationship, String firstModelName) + { + if (relationship.getModels().get(1).equals(firstModelName)) { + return relationship; + } + return Relationship.reverse(relationship); + } + + private static void validateToOne(ExpressionRelationshipInfo expressionRelationshipInfo) + { + for (Relationship relationship : expressionRelationshipInfo.getRelationships()) { + checkArgument(relationship.getJoinType().isToOne(), "expr in model only accept to-one relation"); + } + } +} diff --git a/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/analyzer/ExpressionRelationshipInfo.java b/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/analyzer/ExpressionRelationshipInfo.java new file mode 100644 index 000000000..9719cf885 --- /dev/null +++ b/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/analyzer/ExpressionRelationshipInfo.java @@ -0,0 +1,81 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.accio.sqlrewrite.analyzer; + +import io.accio.base.dto.Relationship; +import io.trino.sql.tree.QualifiedName; + +import java.util.List; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static io.accio.base.Utils.checkArgument; +import static java.util.Objects.requireNonNull; + +public class ExpressionRelationshipInfo +{ + private final QualifiedName qualifiedName; + // for debug usage + private final List relationshipParts; + private final List remainingParts; + private final List relationships; + private final Relationship baseModelRelationship; + + public ExpressionRelationshipInfo( + QualifiedName qualifiedName, + List relationshipParts, + List remainingParts, + List relationships, + Relationship baseModelRelationship) + { + this.qualifiedName = requireNonNull(qualifiedName); + this.relationshipParts = requireNonNull(relationshipParts); + this.remainingParts = requireNonNull(remainingParts); + this.relationships = requireNonNull(relationships); + this.baseModelRelationship = requireNonNull(baseModelRelationship); + checkArgument(relationshipParts.size() + remainingParts.size() == qualifiedName.getParts().size(), "mismatch part size"); + } + + public QualifiedName getQualifiedName() + { + return qualifiedName; + } + + public List getRemainingParts() + { + return remainingParts; + } + + public List getRelationships() + { + return relationships; + } + + public Relationship getBaseModelRelationship() + { + return baseModelRelationship; + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("qualifiedName", qualifiedName) + .add("relationshipParts", relationshipParts) + .add("remainingParts", remainingParts) + .add("relationships", relationships) + .add("baseModelRelationship", baseModelRelationship) + .toString(); + } +} diff --git a/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/analyzer/FunctionChainAnalyzer.java b/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/analyzer/FunctionChainAnalyzer.java deleted file mode 100644 index e5984cbc3..000000000 --- a/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/analyzer/FunctionChainAnalyzer.java +++ /dev/null @@ -1,304 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.accio.sqlrewrite.analyzer; - -import com.google.common.collect.ImmutableList; -import com.google.common.collect.Lists; -import io.accio.base.dto.Relationship; -import io.accio.sqlrewrite.LambdaExpressionBodyRewrite; -import io.accio.sqlrewrite.RelationshipCteGenerator; -import io.accio.sqlrewrite.analyzer.ExpressionAnalyzer.RelationshipField; -import io.accio.sqlrewrite.analyzer.ExpressionAnalyzer.ReplaceNodeInfo; -import io.trino.sql.tree.AstVisitor; -import io.trino.sql.tree.DereferenceExpression; -import io.trino.sql.tree.Expression; -import io.trino.sql.tree.FunctionCall; -import io.trino.sql.tree.Identifier; -import io.trino.sql.tree.LambdaExpression; -import io.trino.sql.tree.NodeRef; -import io.trino.sql.tree.QualifiedName; - -import java.util.ArrayList; -import java.util.Collection; -import java.util.List; -import java.util.Map; -import java.util.Map.Entry; -import java.util.Optional; -import java.util.function.Function; - -import static io.accio.base.Utils.checkArgument; -import static io.accio.sqlrewrite.RelationshipCteGenerator.LAMBDA_RESULT_NAME; -import static io.accio.sqlrewrite.RelationshipCteGenerator.RelationshipOperation.aggregate; -import static io.accio.sqlrewrite.RelationshipCteGenerator.RelationshipOperation.arraySort; -import static io.accio.sqlrewrite.RelationshipCteGenerator.RelationshipOperation.filter; -import static io.accio.sqlrewrite.RelationshipCteGenerator.RelationshipOperation.slice; -import static io.accio.sqlrewrite.RelationshipCteGenerator.RelationshipOperation.transform; -import static io.accio.sqlrewrite.RelationshipCteGenerator.RsItem.Type.CTE; -import static io.accio.sqlrewrite.RelationshipCteGenerator.RsItem.Type.REVERSE_RS; -import static io.accio.sqlrewrite.RelationshipCteGenerator.RsItem.Type.RS; -import static io.accio.sqlrewrite.RelationshipCteGenerator.RsItem.rsItem; -import static io.accio.sqlrewrite.RelationshipCteGenerator.SOURCE_REFERENCE; -import static java.lang.String.format; -import static java.util.Locale.ENGLISH; -import static java.util.Objects.requireNonNull; -import static java.util.stream.Collectors.toMap; - -public interface FunctionChainAnalyzer -{ - Optional analyze(FunctionCall functionCall); - - static FunctionChainAnalyzer of( - RelationshipCteGenerator relationshipCteGenerator, - Function> registerRelationshipCTEs) - { - return functionCall -> new FunctionChainProcessor(relationshipCteGenerator, registerRelationshipCTEs) - .process(functionCall, new Context()); - } - - class Context - { - private final List functionChain; - - private Context() - { - this.functionChain = List.of(); - } - - private Context(List functionChain) - { - this.functionChain = requireNonNull(functionChain); - } - - public List getFunctionChain() - { - return functionChain; - } - } - - class ReturnContext - { - private final Map, RelationshipField> nodesToReplace; - - private ReturnContext(Map, RelationshipField> nodesToReplace) - { - this.nodesToReplace = requireNonNull(nodesToReplace); - } - - public Map, RelationshipField> getNodesToReplace() - { - return nodesToReplace; - } - } - - class FunctionChainProcessor - extends AstVisitor, Context> - { - private final RelationshipCteGenerator relationshipCteGenerator; - private final Function> registerRelationshipCTEs; - - private FunctionChainProcessor( - RelationshipCteGenerator relationshipCteGenerator, - Function> registerRelationshipCTEs) - { - this.relationshipCteGenerator = requireNonNull(relationshipCteGenerator); - this.registerRelationshipCTEs = requireNonNull(registerRelationshipCTEs); - } - - @Override - protected Optional visitFunctionCall(FunctionCall node, Context context) - { - List returnContexts = new ArrayList<>(); - Context newContext = new Context(ImmutableList.builder().addAll(context.getFunctionChain()).add(node).build()); - for (Expression argument : node.getArguments()) { - if (argument instanceof FunctionCall) { - visitFunctionCall((FunctionCall) argument, newContext).ifPresent(returnContexts::add); - } - else if (argument instanceof DereferenceExpression || argument instanceof Identifier) { - processFunctionChain(argument, newContext).ifPresent(returnContexts::add); - } - } - - Map, RelationshipField> nodesToReplace = returnContexts.stream() - .map(returnContext -> returnContext.nodesToReplace.entrySet()) - .flatMap(Collection::stream) - .collect(toMap(Map.Entry::getKey, Map.Entry::getValue)); - - Optional, RelationshipField>> fullMatch = nodesToReplace.entrySet().stream() - .filter(entry -> NodeRef.of(node).equals(entry.getKey())) - .findAny(); - - return Optional.of( - fullMatch - .map(match -> new ReturnContext(Map.of(match.getKey(), match.getValue()))) - .orElseGet(() -> new ReturnContext(nodesToReplace))); - } - - private Optional processFunctionChain(Expression node, Context context) - { - checkArgument(node instanceof DereferenceExpression || node instanceof Identifier, "node is not DereferenceExpression or Identifier"); - checkFunctionChainIsValid(context.getFunctionChain()); - - Optional replaceNodeInfo = registerRelationshipCTEs.apply(node); - if (replaceNodeInfo.isEmpty()) { - return Optional.empty(); - } - - checkArgument(replaceNodeInfo.get().getLastRelationshipField().isPresent(), "last relationship field not found in function call"); - RelationshipField rsField = replaceNodeInfo.get().getLastRelationshipField().get(); - - FunctionCall previousFunctionCall = null; - for (FunctionCall functionCall : Lists.reverse(context.getFunctionChain())) { - if (isAccioFunction(functionCall.getName())) { - collectRelationshipInFunction( - functionCall, - rsField, - Optional.ofNullable(previousFunctionCall), - relationshipCteGenerator); - } - else if (isArrayFunction(functionCall.getName())) { - return Optional.of(new ReturnContext(Map.of(NodeRef.of(node), rsField))); - } - else { - break; - } - previousFunctionCall = functionCall; - } - return Optional.ofNullable(previousFunctionCall) - .map(functionCall -> new ReturnContext(Map.of(NodeRef.of(functionCall), rsField))); - } - } - - private static void checkFunctionChainIsValid(List functionCalls) - { - boolean startChaining = false; - for (FunctionCall functionCall : functionCalls) { - if (isAccioFunction(functionCall.getName())) { - startChaining = true; - } - else { - checkArgument(!startChaining, format("accio function chain contains invalid function %s", functionCall.getName())); - } - } - } - - private static boolean isAccioFunction(QualifiedName funcName) - { - return isLambdaFunction(funcName) - || isArrayAggregateFunction(funcName) - || funcName.getSuffix().equalsIgnoreCase("array_sort") - || funcName.getSuffix().equalsIgnoreCase("slice"); - } - - private static boolean isLambdaFunction(QualifiedName funcName) - { - return List.of("transform", "filter").contains(funcName.getSuffix()); - } - - private static boolean isArrayAggregateFunction(QualifiedName funcName) - { - return List.of("array_count", - "array_sum", - "array_avg", - "array_min", - "array_max", - "array_bool_or", - "array_every") - .contains(funcName.getSuffix()); - } - - private static boolean isArrayFunction(QualifiedName funcName) - { - // TODO: define what's array function - // Refer to trino array function temporarily - // TODO: bigquery array function mapping - return List.of("cardinality", "array_max", "array_min", "array_length").contains(funcName.toString()); - } - - private static void collectRelationshipInFunction( - FunctionCall functionCall, - RelationshipField relationshipField, - Optional previousLambdaCall, - RelationshipCteGenerator relationshipCteGenerator) - { - String modelName = relationshipField.getModelName(); - String columnName = relationshipField.getColumnName(); - Relationship relationship = relationshipField.getRelationship(); - - RelationshipCteGenerator.RelationshipOperation operation; - String functionName = functionCall.getName().toString(); - String cteName = previousLambdaCall.map(Expression::toString).orElse(String.join(".", relationshipField.getCteNameParts())); - Expression unnestField = previousLambdaCall.isPresent() ? DereferenceExpression.from(QualifiedName.of(SOURCE_REFERENCE, LAMBDA_RESULT_NAME)) : null; - List arguments = functionCall.getArguments(); - if (isLambdaFunction(QualifiedName.of(functionName.toLowerCase(ENGLISH)))) { - checkArgument(arguments.size() == 2, "Lambda function should have 2 arguments"); - LambdaExpression lambdaExpression = (LambdaExpression) functionCall.getArguments().get(1); - checkArgument(lambdaExpression.getArguments().size() == 1, "lambda expression must have one argument"); - Expression expression = LambdaExpressionBodyRewrite.rewrite(lambdaExpression.getBody(), modelName, lambdaExpression.getArguments().get(0).getName()); - if (functionName.equalsIgnoreCase("transform")) { - operation = transform( - List.of(rsItem(cteName, CTE), - rsItem(relationship.getName(), relationship.getModels().get(0).equals(modelName) ? RS : REVERSE_RS)), - expression, - columnName, - unnestField); - } - else if (functionName.equalsIgnoreCase("filter")) { - operation = filter( - List.of(rsItem(cteName, CTE), - rsItem(relationship.getName(), relationship.getModels().get(0).equals(modelName) ? RS : REVERSE_RS)), - expression, - columnName, - unnestField); - } - else { - throw new IllegalArgumentException(functionName + " not supported"); - } - } - else if (isArrayAggregateFunction(QualifiedName.of(functionName))) { - checkArgument(arguments.size() == 1, "Accio aggregate function should have 1 argument"); - operation = aggregate( - List.of(rsItem(cteName, CTE), - rsItem(relationship.getName(), relationship.getModels().get(0).equals(modelName) ? RS : REVERSE_RS)), - previousLambdaCall.isPresent() ? LAMBDA_RESULT_NAME : columnName, - getArrayBaseFunctionName(functionName)); - } - else if (functionName.equalsIgnoreCase("array_sort")) { - operation = arraySort( - List.of(rsItem(cteName, CTE), - rsItem(relationship.getName(), relationship.getModels().get(0).equals(modelName) ? RS : REVERSE_RS)), - columnName, - unnestField, - arguments); - } - else if (functionName.equalsIgnoreCase("slice")) { - checkArgument(arguments.size() == 3, "slice function should have 3 arguments"); - operation = slice( - List.of(rsItem(cteName, CTE), - rsItem(relationship.getName(), relationship.getModels().get(0).equals(modelName) ? RS : REVERSE_RS)), - previousLambdaCall.isPresent() ? LAMBDA_RESULT_NAME : columnName, - arguments); - } - else { - throw new IllegalArgumentException(functionName + " not supported"); - } - - relationshipCteGenerator.register(List.of(functionCall.toString()), operation, relationshipField.getBaseModelName()); - } - - private static String getArrayBaseFunctionName(String functionName) - { - return functionName.split("array_")[1]; - } -} diff --git a/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/analyzer/PreAggregationAnalysis.java b/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/analyzer/PreAggregationAnalysis.java index 7de7950f4..a87e168a3 100644 --- a/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/analyzer/PreAggregationAnalysis.java +++ b/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/analyzer/PreAggregationAnalysis.java @@ -1,3 +1,17 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package io.accio.sqlrewrite.analyzer; import io.accio.base.CatalogSchemaTableName; diff --git a/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/analyzer/StatementAnalyzer.java b/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/analyzer/StatementAnalyzer.java index f6cec338c..64d6dd5f3 100644 --- a/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/analyzer/StatementAnalyzer.java +++ b/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/analyzer/StatementAnalyzer.java @@ -14,8 +14,6 @@ package io.accio.sqlrewrite.analyzer; -import com.google.common.collect.ImmutableList; -import io.accio.base.AccioException; import io.accio.base.AccioMDL; import io.accio.base.CatalogSchemaTableName; import io.accio.base.SessionContext; @@ -23,25 +21,17 @@ import io.accio.base.dto.Model; import io.accio.base.dto.Relationship; import io.accio.base.dto.View; -import io.accio.sqlrewrite.RelationshipCteGenerator; import io.trino.sql.tree.AliasedRelation; -import io.trino.sql.tree.AllColumns; import io.trino.sql.tree.AstVisitor; import io.trino.sql.tree.Expression; import io.trino.sql.tree.FunctionRelation; -import io.trino.sql.tree.GroupingElement; import io.trino.sql.tree.Identifier; import io.trino.sql.tree.Join; -import io.trino.sql.tree.LongLiteral; import io.trino.sql.tree.Node; import io.trino.sql.tree.NodeRef; import io.trino.sql.tree.QualifiedName; import io.trino.sql.tree.Query; import io.trino.sql.tree.QuerySpecification; -import io.trino.sql.tree.SelectItem; -import io.trino.sql.tree.SimpleGroupBy; -import io.trino.sql.tree.SingleColumn; -import io.trino.sql.tree.SortItem; import io.trino.sql.tree.Statement; import io.trino.sql.tree.Table; import io.trino.sql.tree.TableSubquery; @@ -51,7 +41,6 @@ import io.trino.sql.tree.With; import io.trino.sql.tree.WithQuery; -import java.util.ArrayList; import java.util.Collection; import java.util.List; import java.util.Optional; @@ -61,14 +50,10 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static io.accio.base.Utils.checkArgument; import static io.accio.base.dto.TimeGrain.TimeUnit.timeUnit; -import static io.accio.base.metadata.StandardErrorCode.INVALID_COLUMN_REFERENCE; import static io.accio.sqlrewrite.Utils.toCatalogSchemaTableName; import static io.trino.sql.QueryUtil.getQualifiedName; -import static java.lang.Math.toIntExact; import static java.lang.String.format; import static java.util.Objects.requireNonNull; -import static java.util.stream.Collectors.toList; -import static java.util.stream.Collectors.toSet; import static java.util.stream.Collectors.toUnmodifiableSet; /** @@ -80,13 +65,8 @@ private StatementAnalyzer() {} public static Analysis analyze(Statement statement, SessionContext sessionContext, AccioMDL accioMDL) { - return analyze(statement, sessionContext, accioMDL, new RelationshipCteGenerator(accioMDL)); - } - - public static Analysis analyze(Statement statement, SessionContext sessionContext, AccioMDL accioMDL, RelationshipCteGenerator relationshipCteGenerator) - { - Analysis analysis = new Analysis(statement, relationshipCteGenerator); - new Visitor(sessionContext, analysis, accioMDL, relationshipCteGenerator).process(statement, Optional.empty()); + Analysis analysis = new Analysis(statement); + new Visitor(sessionContext, analysis, accioMDL).process(statement, Optional.empty()); // add models directly used in sql query analysis.addModels( @@ -149,14 +129,12 @@ private static class Visitor private final SessionContext sessionContext; private final Analysis analysis; private final AccioMDL accioMDL; - private final RelationshipCteGenerator relationshipCteGenerator; - public Visitor(SessionContext sessionContext, Analysis analysis, AccioMDL accioMDL, RelationshipCteGenerator relationshipCteGenerator) + public Visitor(SessionContext sessionContext, Analysis analysis, AccioMDL accioMDL) { this.sessionContext = requireNonNull(sessionContext, "sessionContext is null"); this.analysis = requireNonNull(analysis, "analysis is null"); this.accioMDL = requireNonNull(accioMDL, "accioMDL is null"); - this.relationshipCteGenerator = requireNonNull(relationshipCteGenerator, "relationshipCteGenerator is null"); } @Override @@ -226,26 +204,9 @@ protected Scope visitQuery(Query node, Optional scope) @Override protected Scope visitQuerySpecification(QuerySpecification node, Optional scope) { - Scope sourceScope = analyzeFrom(node, scope); - List expressionAnalysisList = analyzeSelect(node, sourceScope); - Set relationshipCTENames = expressionAnalysisList.stream() - .map(ExpressionAnalysis::getRelationshipCTENames) - .flatMap(Set::stream) - .collect(toSet()); - node.getWhere().ifPresent(where -> relationshipCTENames.addAll(analyzeExpression(where, sourceScope).getRelationshipCTENames())); - node.getGroupBy().ifPresent(groupBy -> { - analyzeGroupBy(node, sourceScope, expressionAnalysisList.stream().map(ExpressionAnalysis::getExpression).collect(toList())); - groupBy.getGroupingElements().stream() - .map(GroupingElement::getExpressions) - .flatMap(Collection::stream) - .forEach(expression -> relationshipCTENames.addAll(analyzeExpression(expression, sourceScope).getRelationshipCTENames())); - }); - node.getHaving().ifPresent(having -> relationshipCTENames.addAll(analyzeExpression(having, sourceScope).getRelationshipCTENames())); - node.getOrderBy().ifPresent(orderBy -> - orderBy.getSortItems().stream() - .map(SortItem::getSortKey) - .forEach(expression -> relationshipCTENames.addAll(analyzeExpression(expression, sourceScope).getRelationshipCTENames()))); - node.getFrom().ifPresent(relation -> analysis.addReplaceTableWithCTEs(NodeRef.of(relation), relationshipCTENames)); + if (node.getFrom().isPresent()) { + return process(node.getFrom().get(), scope); + } // TODO: output scope here isn't right return Scope.builder().parent(scope).build(); } @@ -359,71 +320,6 @@ private Optional analyzeWith(Query node, Optional scope) return Optional.of(withScopeBuilder.build()); } - private Scope analyzeFrom(QuerySpecification node, Optional scope) - { - if (node.getFrom().isPresent()) { - return process(node.getFrom().get(), scope); - } - return Scope.builder().parent(scope).build(); - } - - private List analyzeSelect(QuerySpecification node, Scope scope) - { - List selectExpressionAnalyses = new ArrayList<>(); - for (SelectItem item : node.getSelect().getSelectItems()) { - if (item instanceof SingleColumn) { - selectExpressionAnalyses.add(analyzeSelectSingleColumn((SingleColumn) item, scope)); - } - else if (item instanceof AllColumns) { - // DO NOTHING - } - else { - throw new IllegalArgumentException("Unsupported SelectItem type: " + item.getClass().getName()); - } - } - return List.copyOf(selectExpressionAnalyses); - } - - public void analyzeGroupBy(QuerySpecification node, Scope scope, List outputExpressions) - { - if (node.getGroupBy().isEmpty()) { - return; - } - ImmutableList.Builder groupingExpressions = ImmutableList.builder(); - for (GroupingElement groupingElement : node.getGroupBy().get().getGroupingElements()) { - if (groupingElement instanceof SimpleGroupBy) { - for (Expression column : groupingElement.getExpressions()) { - // simple GROUP BY expressions allow ordinals or arbitrary expressions - if (column instanceof LongLiteral) { - long ordinal = ((LongLiteral) column).getValue(); - if (ordinal < 1 || ordinal > outputExpressions.size()) { - throw new AccioException(INVALID_COLUMN_REFERENCE, format("GROUP BY position %s is not in select list", ordinal)); - } - column = outputExpressions.get(toIntExact(ordinal - 1)); - } - groupingExpressions.add(column); - } - } - // TODO: support other grouping elements - } - analysis.addGroupAnalysis(node.getGroupBy().get(), new Analysis.GroupByAnalysis(groupingExpressions.build())); - } - - private ExpressionAnalysis analyzeSelectSingleColumn(SingleColumn singleColumn, Scope scope) - { - Expression expression = singleColumn.getExpression(); - return analyzeExpression(expression, scope); - } - - private ExpressionAnalysis analyzeExpression(Expression expression, Scope scope) - { - ExpressionAnalysis expressionAnalysis = ExpressionAnalyzer.analyze(expression, sessionContext, accioMDL, relationshipCteGenerator, scope); - analysis.addRelationshipFields(expressionAnalysis.getRelationshipFieldRewrites()); - analysis.addRelationships(expressionAnalysis.getRelationships()); - analysis.setScope(expression, scope); - return expressionAnalysis; - } - private Scope process(Node node, Scope scope) { return process(node, Optional.of(scope)); diff --git a/accio-sqlrewrite/src/test/java/io/accio/TestScopeAwareRewrite.java b/accio-sqlrewrite/src/test/java/io/accio/TestScopeAwareRewrite.java deleted file mode 100644 index 546a0fcff..000000000 --- a/accio-sqlrewrite/src/test/java/io/accio/TestScopeAwareRewrite.java +++ /dev/null @@ -1,240 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.accio; - -import io.accio.base.AccioMDL; -import io.accio.base.AccioTypes; -import io.accio.base.SessionContext; -import io.accio.base.dto.JoinType; -import io.accio.base.dto.Model; -import io.accio.sqlrewrite.ScopeAwareRewrite; -import io.accio.testing.AbstractTestFramework; -import io.trino.sql.SqlFormatter; -import io.trino.sql.parser.ParsingOptions; -import io.trino.sql.parser.SqlParser; -import io.trino.sql.tree.Expression; -import io.trino.sql.tree.Identifier; -import io.trino.sql.tree.Query; -import io.trino.sql.tree.QuerySpecification; -import io.trino.sql.tree.SingleColumn; -import io.trino.sql.tree.Statement; -import org.assertj.core.api.Assertions; -import org.testng.annotations.DataProvider; -import org.testng.annotations.Test; - -import java.util.List; - -import static io.accio.base.AccioTypes.INTEGER; -import static io.accio.base.AccioTypes.VARCHAR; -import static io.accio.base.dto.Column.column; -import static io.accio.base.dto.Metric.metric; -import static io.accio.base.dto.Relationship.relationship; -import static io.accio.sqlrewrite.ScopeAwareRewrite.SCOPE_AWARE_REWRITE; -import static io.trino.sql.parser.ParsingOptions.DecimalLiteralTreatment.AS_DECIMAL; -import static org.assertj.core.api.AssertionsForClassTypes.assertThat; - -public class TestScopeAwareRewrite - extends AbstractTestFramework -{ - private final AccioMDL accioMDL; - private static final SqlParser SQL_PARSER = new SqlParser(); - - public TestScopeAwareRewrite() - { - accioMDL = AccioMDL.fromManifest(withDefaultCatalogSchema() - .setModels(List.of( - Model.model("Book", - "select * from (values (1, 'book1', 1), (2, 'book2', 2), (3, 'book3', 3)) Book(bookId, name, authorId)", - List.of( - column("bookId", AccioTypes.INTEGER, null, true), - column("name", AccioTypes.VARCHAR, null, true), - column("author", "People", "BookPeople", true), - column("authorId", AccioTypes.INTEGER, null, true)), - "bookId"), - Model.model("People", - "select * from (values (1, 'user1'), (2, 'user2'), (3, 'user3')) People(userId, name)", - List.of( - column("userId", AccioTypes.INTEGER, null, true), - column("name", AccioTypes.VARCHAR, null, true), - column("book", "Book", "BookPeople", true), - column("items", "Item", "ItemPeople", true)), - "userId"), - Model.model("Item", - "select * from (values (1, 'item1', 1), (2, 'item2', 2), (3, 'item3', 1), (4, 'item4', 3) Item(itemId, name, userId)", - List.of( - column("itemId", AccioTypes.INTEGER, null, true), - column("price", AccioTypes.INTEGER, null, true), - column("user", "People", "ItemPeople", true)), - "orderId"))) - .setRelationships(List.of( - relationship("BookPeople", List.of("Book", "People"), JoinType.ONE_TO_ONE, "Book.authorId = People.userId"), - relationship("ItemPeople", List.of("Item", "People"), JoinType.MANY_TO_ONE, "Item.userId = People.userId"))) - .setMetrics(List.of( - metric( - "AuthorBookCount", - "Book", - List.of(column("authorId", VARCHAR, null, true)), - List.of(column("count", INTEGER, null, true, "sum(*)")), - List.of()))) - .build()); - } - - @DataProvider - public Object[][] scopeRewrite() - { - return new Object[][] { - {"SELECT name FROM Book", "SELECT Book.name FROM Book"}, - {"SELECT name FROM Book b", "SELECT b.name FROM Book b"}, - {"SELECT * FROM Book WHERE name = 'canner'", "SELECT * FROM Book WHERE Book.name = 'canner'"}, - {"SELECT * FROM Book b WHERE name = 'canner'", "SELECT * FROM Book b WHERE b.name = 'canner'"}, - {"SELECT author.book.name FROM Book", "SELECT Book.author.book.name FROM Book"}, - {"SELECT author.book.name FROM Book b", "SELECT b.author.book.name FROM Book b"}, - {"SELECT name, count(*) FROM Book b GROUP BY name", "SELECT b.name, count(*) FROM Book b GROUP BY b.name"}, - {"SELECT b.name, p.name, book FROM Book b JOIN People p ON authorId = userId", "SELECT b.name, p.name, p.book FROM Book b JOIN People p ON b.authorId = p.userId"}, - {"SELECT user.items[1].price FROM Item", "SELECT Item.user.items[1].price FROM Item"}, - {"SELECT name FROM accio.test.Book", "SELECT Book.name FROM accio.test.Book"}, - {"SELECT name FROM test.Book", "SELECT Book.name FROM test.Book"}, - {"SELECT name FROM (SELECT * FROM Book) b", "SELECT name FROM (SELECT * FROM Book) b"}, - {"SELECT name FROM (SELECT name FROM Book) b", "SELECT name FROM (SELECT Book.name FROM Book) b"}, - {"WITH b AS (SELECT name, author FROM Book) SELECT author FROM b", "WITH b AS (SELECT Book.name, Book.author FROM Book) SELECT author FROM b"}, - {"WITH b AS (SELECT o_clerk, author FROM Book) SELECT author FROM b", "WITH b AS (SELECT o_clerk, Book.author FROM Book) SELECT author FROM b"}, - {"SELECT concat(name, '12') FROM test.Book", "SELECT concat(Book.name, '12') FROM test.Book"}, - {"SELECT concat(name, '12') = '123' FROM test.Book", "SELECT concat(Book.name, '12') = '123' FROM test.Book"}, - {"SELECT concat(name, '12') + 123 FROM test.Book", "SELECT concat(Book.name, '12') + 123 FROM test.Book"}, - {"SELECT accio.test.Book.author.book.name FROM Book", "SELECT Book.author.book.name FROM Book"}, - {"SELECT accio.test.Book.author.books[1].name FROM Book", "SELECT Book.author.books[1].name FROM Book"}, - {"SELECT test.Book.author.book.name FROM Book", "SELECT Book.author.book.name FROM Book"}, - {"SELECT test.Book.author.books[1].name FROM Book", "SELECT Book.author.books[1].name FROM Book"}, - {"SELECT accio.test.AuthorBookCount.authorId FROM AuthorBookCount", "SELECT AuthorBookCount.authorId FROM AuthorBookCount"}, - {"SELECT test.AuthorBookCount.authorId FROM AuthorBookCount", "SELECT AuthorBookCount.authorId FROM AuthorBookCount"}, - {"SELECT authorId FROM AuthorBookCount", "SELECT AuthorBookCount.authorId FROM AuthorBookCount"}, - }; - } - - @Test(dataProvider = "scopeRewrite") - public void testScopeRewriter(String original, String expected) - { - Statement expectedState = SQL_PARSER.createStatement(expected, new ParsingOptions(AS_DECIMAL)); - String actualSql = rewrite(original); - Assertions.assertThat(actualSql).isEqualTo(SqlFormatter.formatSql(expectedState)); - } - - @DataProvider - public Object[][] wrongSessionContextNoRewrite() - { - return new Object[][] { - {"SELECT name FROM Book"}, - {"SELECT name FROM test.Book"}, - {"SELECT name FROM fake1.test.Book"} - }; - } - - @Test(dataProvider = "wrongSessionContextNoRewrite") - public void testScopeRewriterWithWrongSessionContextNoRewrite(String original) - { - SessionContext sessionContext = SessionContext.builder() - .setCatalog("wrongCatalog") - .setSchema("wrongSchema") - .build(); - - Statement expectedState = SQL_PARSER.createStatement(original, new ParsingOptions(AS_DECIMAL)); - String actualSql = rewrite(original, sessionContext); - Assertions.assertThat(actualSql).isEqualTo(SqlFormatter.formatSql(expectedState)); - } - - @Test - public void testScopeRewriterWithWrongSessionContextRewrite() - { - SessionContext sessionContext = SessionContext.builder() - .setCatalog("wrongCatalog") - .setSchema("wrongSchema") - .build(); - String sql = "SELECT name FROM accio.test.Book"; - String expected = "SELECT Book.name FROM accio.test.Book"; - - Statement expectedState = SQL_PARSER.createStatement(expected, new ParsingOptions(AS_DECIMAL)); - String actualSql = rewrite(sql, sessionContext); - Assertions.assertThat(actualSql).isEqualTo(SqlFormatter.formatSql(expectedState)); - } - - @DataProvider - public Object[][] notRewritten() - { - return new Object[][] { - {"SELECT Book.name FROM Book"}, - {"SELECT name FROM NotBelongToAccio"}, - {"SELECT Book.author.book.name FROM Book"}, - {"SELECT name FROM fake1.fake2.Book"}, - {"SELECT name FROM fake2.Book"}, - {"WITH b AS (SELECT * FROM Book) SELECT author FROM b"}, - {"SELECT notfound FROM b"}, - {"SELECT Book.author.books[1].name FROM Book"}, - {"SELECT AuthorBookCount.authorId FROM AuthorBookCount"}, - {"SELECT Book.author.book.name FROM Book"}, - {"SELECT fakecatalog.fakeschema.Book.author.book.name FROM Book"}, - {"SELECT fakeschema.Book.author.book.name FROM Book"}, - }; - } - - @Test(dataProvider = "notRewritten") - public void testNotRewritten(String sql) - { - String rewrittenSql = rewrite(sql); - Statement expectedResult = SQL_PARSER.createStatement(sql, new ParsingOptions(AS_DECIMAL)); - assertThat(rewrittenSql).isEqualTo(SqlFormatter.formatSql(expectedResult)); - } - - @Test - public void testDetectAmbiguous() - { - String sql = "SELECT name, book FROM Book b JOIN People p ON authorId = userId"; - Assertions.assertThatThrownBy(() -> rewrite(sql)) - .hasMessage("Ambiguous column name: name"); - } - - private String rewrite(String sql) - { - return rewrite(sql, DEFAULT_SESSION_CONTEXT); - } - - private String rewrite(String sql, SessionContext sessionContext) - { - Statement statement = SQL_PARSER.createStatement(sql, new ParsingOptions(AS_DECIMAL)); - return SqlFormatter.formatSql(SCOPE_AWARE_REWRITE.rewrite(statement, accioMDL, sessionContext)); - } - - @DataProvider - public Object[][] addPrefix() - { - return new Object[][] { - {"author.book.name", "Book.author.book.name"}, - {"author.books[1].name", "Book.author.books[1].name"}, - }; - } - - @Test(dataProvider = "addPrefix") - public void testAddPrefix(String source, String expected) - { - Expression expression = getSelectItem(String.format("SELECT %s FROM Book", source)); - Expression node = ScopeAwareRewrite.addPrefix(expression, new Identifier("Book")); - assertThat(node.toString()).isEqualTo(expected); - } - - private Expression getSelectItem(String sql) - { - Statement statement = SQL_PARSER.createStatement(sql, new ParsingOptions(AS_DECIMAL)); - return ((SingleColumn) ((QuerySpecification) ((Query) statement).getQueryBody()).getSelect().getSelectItems().get(0)).getExpression(); - } -} diff --git a/accio-sqlrewrite/src/test/java/io/accio/sqlrewrite/TestAccioSqlRewrite.java b/accio-sqlrewrite/src/test/java/io/accio/sqlrewrite/TestAccioSqlRewrite.java deleted file mode 100644 index 8cf95f8d5..000000000 --- a/accio-sqlrewrite/src/test/java/io/accio/sqlrewrite/TestAccioSqlRewrite.java +++ /dev/null @@ -1,139 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.accio.sqlrewrite; - -import io.accio.base.AccioMDL; -import io.accio.testing.AbstractTestFramework; -import io.trino.sql.SqlFormatter; -import io.trino.sql.parser.ParsingOptions; -import io.trino.sql.parser.SqlParser; -import io.trino.sql.tree.Statement; -import org.intellij.lang.annotations.Language; -import org.testng.annotations.Test; - -import java.util.List; - -import static io.accio.base.dto.Column.column; -import static io.accio.base.dto.Model.model; -import static io.accio.sqlrewrite.AccioSqlRewrite.ACCIO_SQL_REWRITE; -import static io.trino.sql.parser.ParsingOptions.DecimalLiteralTreatment.AS_DECIMAL; -import static java.lang.String.format; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatNoException; - -public class TestAccioSqlRewrite - extends AbstractTestFramework -{ - private static final AccioMDL ACCIOMDL = AccioMDL.fromManifest(withDefaultCatalogSchema() - .setModels(List.of( - model( - "People", - "SELECT * FROM People", - List.of( - column("id", "STRING", null, false), - column("email", "STRING", null, false))), - model( - "Book", - "SELECT * FROM Book", - List.of( - column("authorId", "STRING", null, false), - column("publish_date", "STRING", null, false), - column("publish_year", "DATE", null, false, "date_trunc('year', publish_date)"))))) - .build()); - - @Override - protected void prepareData() - { - exec("CREATE TABLE People AS SELECT * FROM\n" + - "(VALUES\n" + - "('SN1001', 'foo@foo.org'),\n" + - "('SN1002', 'bar@bar.org'),\n" + - "('SN1003', 'code@code.org'))\n" + - "People (id, email)"); - exec("CREATE TABLE Book AS SELECT * FROM\n" + - "(VALUES\n" + - "('P1001', CAST('1991-01-01' AS TIMESTAMP)),\n" + - "('P1002', CAST('1992-02-02' AS TIMESTAMP)),\n" + - "('P1003', CAST('1993-03-03' AS TIMESTAMP)))\n" + - "Book (authorId, publish_date)"); - exec("CREATE TABLE WishList AS SELECT * FROM\n" + - "(VALUES\n" + - "('SN1001'),\n" + - "('SN1002'),\n" + - "('SN10010'))\n" + - "WishList (id)"); - } - - @Test - public void testModelRewrite() - { - assertSqlEqualsAndValid(rewrite("SELECT * FROM People"), - "WITH People AS (SELECT id, email FROM (SELECT * FROM People) t) SELECT * FROM People"); - assertSqlEqualsAndValid(rewrite("SELECT * FROM Book"), - "WITH Book AS (SELECT authorId, publish_date, date_trunc('year', publish_date) publish_year FROM (SELECT * FROM Book) t) SELECT * FROM Book"); - assertSqlEqualsAndValid(rewrite("SELECT * FROM People WHERE id = 'SN1001'"), - "WITH People AS (SELECT id, email FROM (SELECT * FROM People) t) SELECT * FROM People WHERE People.id = 'SN1001'"); - - assertSqlEqualsAndValid(rewrite("SELECT * FROM People a join Book b ON a.id = b.authorId WHERE a.id = 'SN1001'"), - "WITH Book AS (SELECT authorId, publish_date, date_trunc('year', publish_date) publish_year FROM (SELECT * FROM Book) t),\n" + - "People AS (SELECT id, email FROM (SELECT * FROM People) t)\n" + - "SELECT * FROM People a join Book b ON a.id = b.authorId WHERE a.id = 'SN1001'"); - - assertSqlEqualsAndValid(rewrite("SELECT * FROM People a join WishList b ON a.id = b.id WHERE a.id = 'SN1001'"), - "WITH People AS (SELECT id, email FROM (SELECT * FROM People) t)\n" + - "SELECT * FROM People a join WishList b ON a.id = b.id WHERE a.id = 'SN1001'"); - - assertSqlEqualsAndValid(rewrite("WITH a AS (SELECT * FROM WishList) SELECT * FROM a JOIN People ON a.id = People.id"), - "WITH People AS (SELECT id, email FROM (SELECT * FROM People) t), a AS (SELECT * FROM WishList)\n" + - "SELECT * FROM a JOIN People ON a.id = People.id"); - - // rewrite table in with query - assertSqlEqualsAndValid(rewrite("WITH a AS (SELECT * FROM People) SELECT * FROM a"), - "WITH People AS (SELECT id, email FROM (SELECT * FROM People) t),\n" + - "a AS (SELECT * FROM People)\n" + - "SELECT * FROM a"); - } - - @Test - public void testNoRewrite() - { - assertSqlEquals(rewrite("SELECT * FROM WithList"), "SELECT * FROM WithList"); - } - - private String rewrite(String sql) - { - return AccioPlanner.rewrite(sql, DEFAULT_SESSION_CONTEXT, ACCIOMDL, List.of(ACCIO_SQL_REWRITE)); - } - - private void assertSqlEqualsAndValid(@Language("SQL") String actual, @Language("SQL") String expected) - { - assertSqlEquals(actual, expected); - assertThatNoException() - .describedAs(format("actual sql: %s is invalid", actual)) - .isThrownBy(() -> query(actual)); - } - - private void assertSqlEquals(String actual, String expected) - { - SqlParser sqlParser = new SqlParser(); - ParsingOptions parsingOptions = new ParsingOptions(AS_DECIMAL); - Statement actualStmt = sqlParser.createStatement(actual, parsingOptions); - Statement expectedStmt = sqlParser.createStatement(expected, parsingOptions); - assertThat(actualStmt) - .describedAs("%n[actual]%n%s[expect]%n%s", - SqlFormatter.formatSql(actualStmt), SqlFormatter.formatSql(expectedStmt)) - .isEqualTo(expectedStmt); - } -} diff --git a/accio-sqlrewrite/src/test/java/io/accio/sqlrewrite/TestAllRulesRewrite.java b/accio-sqlrewrite/src/test/java/io/accio/sqlrewrite/TestAllRulesRewrite.java index c31f4fd0d..3a2ec5cd2 100644 --- a/accio-sqlrewrite/src/test/java/io/accio/sqlrewrite/TestAllRulesRewrite.java +++ b/accio-sqlrewrite/src/test/java/io/accio/sqlrewrite/TestAllRulesRewrite.java @@ -17,7 +17,6 @@ import io.accio.base.AccioMDL; import io.accio.base.dto.JoinType; import io.accio.testing.AbstractTestFramework; -import io.trino.sql.SqlFormatter; import io.trino.sql.parser.ParsingOptions; import io.trino.sql.tree.Statement; import org.testng.annotations.DataProvider; @@ -36,6 +35,7 @@ import static io.accio.base.dto.Relationship.relationship; import static io.accio.base.dto.View.view; import static io.accio.sqlrewrite.Utils.SQL_PARSER; +import static io.trino.sql.SqlFormatter.formatSql; import static io.trino.sql.parser.ParsingOptions.DecimalLiteralTreatment.AS_DECIMAL; import static org.assertj.core.api.Assertions.assertThat; @@ -59,6 +59,7 @@ public TestAllRulesRewrite() relationshipColumn("band", "Band", "AlbumBand"), column("price", INTEGER, null, true), column("bandId", INTEGER, null, true), + column("bandName", VARCHAR, null, true, "band.name"), column("status", "Inventory", null, true), column("statusA", "InventoryA", null, true), relationshipColumn("orders", "Order", "AlbumOrder")), @@ -87,7 +88,7 @@ public TestAllRulesRewrite() metric( "Collection", "Album", - List.of(column("band", VARCHAR, null, true, "Album.band.name")), + List.of(column("band", VARCHAR, null, true, "bandName")), List.of(column("price", INTEGER, null, true, "sum(Album.price)")), List.of()), metric( @@ -103,7 +104,6 @@ public TestAllRulesRewrite() enumDefinition("InventoryA", List.of(enumValue("IN_STOCK"), enumValue("OUT_OF_STOCK"))))) .setViews(List.of( view("UseModel", "select * from Album"), - view("useRelationship", "select name, band.name as band_name from Album"), view("useMetric", "select band, price from Collection"))) .build()); } @@ -116,23 +116,8 @@ public Object[][] accioUsedCases() "values('Gusare', 2560), ('HisoHiso Banashi', 1500), ('Dakara boku wa ongaku o yameta', 2553)"}, {"SELECT name, price FROM accio.test.Album", "values('Gusare', 2560), ('HisoHiso Banashi', 1500), ('Dakara boku wa ongaku o yameta', 2553)"}, - {"select band.name, count(*) from Album group by band", "values ('ZUTOMAYO', cast(2 as long)), ('Yorushika', cast(1 as long))"}, - {"select band, price from CollectionA order by price", "values (2, cast(2553 as long)), (1, cast(4060 as long))"}, - {"select band from Album", "values (1), (1), (2)"}, - {"select Inventory.IN_STOCK, InventoryA.IN_STOCK", "values ('I', 'IN_STOCK')"}, - {"select band.name as band_name, name from Album where status = Inventory.IN_STOCK", - "values ('ZUTOMAYO', 'Gusare'), ('Yorushika', 'Dakara boku wa ongaku o yameta')"}, - {"select name, band_name from useRelationship", - "values ('Gusare', 'ZUTOMAYO'), ('HisoHiso Banashi', 'ZUTOMAYO'), ('Dakara boku wa ongaku o yameta', 'Yorushika')"}, - {"WITH A as (SELECT b.band.name FROM Album b) SELECT A.name FROM A", "values ('ZUTOMAYO'), ('ZUTOMAYO'), ('Yorushika')"}, - {"select band, price from useMetric", "values ('Yorushika', cast(2553 as long)), ('ZUTOMAYO', cast(4060 as long))"}, - {"select albums[1] from Band", "values (1), (3)"}, - {"select any(albums) from Band", "values (1), (3)"}, + {"select band, cast(price as integer) from useMetric order by band", "values ('Yorushika', 2553), ('ZUTOMAYO', 4060)"}, {"select * from \"Order\"", "values (1, 1), (2, 1), (3, 2), (4, 3)"}, - {"select orders[1].orderkey from Album", "values (1), (3), (4)"} - - // TODO: h2 doesn't support the BigQuery style array element converting. (unnest cross join with an implicit join key) - // {"select any(filter(albums, a -> a.name = 'Gusare')) from Band", "values (1)"}, }; } @@ -143,9 +128,9 @@ public void testAccioRewrite(String original, String expected) assertQuery(actualSql, expected); } - private void assertQuery(String acutal, String expected) + private void assertQuery(String actual, String expected) { - assertThat(query(acutal)).isEqualTo(query(expected)); + assertThat(query(actual)).isEqualTo(query(expected)); } @DataProvider @@ -155,7 +140,7 @@ public Object[][] noRewriteCase() {"select 1, 2, 3"}, {"select id, name from normalTable"}, {"with normalCte as (select id, name from normalTable) select id, name from normalCte"}, - {"SELECT accio.test.Album.id FROM catalog.schema.Album"}, + {"SELECT Album.id FROM catalog.schema.Album"}, }; } @@ -163,7 +148,7 @@ public Object[][] noRewriteCase() public void testAccioNoRewrite(String original) { Statement expectedState = SQL_PARSER.createStatement(original, new ParsingOptions(AS_DECIMAL)); - assertThat(rewrite(original)).isEqualTo(SqlFormatter.formatSql(expectedState)); + assertThat(rewrite(original)).isEqualTo(formatSql(expectedState)); } private String rewrite(String sql) diff --git a/accio-sqlrewrite/src/test/java/io/accio/sqlrewrite/TestEnumRewrite.java b/accio-sqlrewrite/src/test/java/io/accio/sqlrewrite/TestEnumRewrite.java index 6199103b2..1578aa8f0 100644 --- a/accio-sqlrewrite/src/test/java/io/accio/sqlrewrite/TestEnumRewrite.java +++ b/accio-sqlrewrite/src/test/java/io/accio/sqlrewrite/TestEnumRewrite.java @@ -28,6 +28,7 @@ import static io.accio.base.dto.EnumDefinition.enumDefinition; import static io.accio.base.dto.EnumValue.enumValue; import static io.accio.base.dto.Model.model; +import static io.accio.sqlrewrite.EnumRewrite.ENUM_REWRITE; import static org.assertj.core.api.AssertionsForClassTypes.assertThat; import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; @@ -91,7 +92,7 @@ private void assertNoRewrite(String sql) private Statement rewrite(String sql) { - return AccioSqlRewrite.ACCIO_SQL_REWRITE.apply(parse(sql), DEFAULT_SESSION_CONTEXT, accioMDL); + return ENUM_REWRITE.apply(parse(sql), DEFAULT_SESSION_CONTEXT, accioMDL); } private Statement parse(String sql) diff --git a/accio-sqlrewrite/src/test/java/io/accio/sqlrewrite/TestExpressionRelationshipRewriter.java b/accio-sqlrewrite/src/test/java/io/accio/sqlrewrite/TestExpressionRelationshipRewriter.java new file mode 100644 index 000000000..a9dc34fbb --- /dev/null +++ b/accio-sqlrewrite/src/test/java/io/accio/sqlrewrite/TestExpressionRelationshipRewriter.java @@ -0,0 +1,106 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.accio.sqlrewrite; + +import io.accio.base.AccioMDL; +import io.accio.base.dto.Model; +import io.accio.base.dto.Relationship; +import io.accio.sqlrewrite.analyzer.ExpressionRelationshipAnalyzer; +import io.accio.sqlrewrite.analyzer.ExpressionRelationshipInfo; +import io.trino.sql.tree.Expression; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.List; + +import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.accio.sqlrewrite.Utils.parseExpression; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +public class TestExpressionRelationshipRewriter +{ + private AccioMDL mdl; + private Model orders; + private Model nation; + private Relationship ordersCustomer; + private Relationship customerNation; + + @BeforeClass + public void init() + throws IOException + { + mdl = AccioMDL.fromJson(Files.readString(Path.of(getClass().getClassLoader().getResource("tpch_mdl.json").getPath()))); + orders = mdl.getModel("Orders").orElseThrow(); + nation = mdl.getModel("Nation").orElseThrow(); + ordersCustomer = mdl.getRelationship("OrdersCustomer").orElseThrow(); + customerNation = mdl.getRelationship("CustomerNation").orElseThrow(); + } + + @DataProvider + public Object[][] rewriteTests() + { + return new Object[][] { + {"customer.custkey", "\"Customer\".\"custkey\"", List.of(ordersCustomer)}, + {"customer.nation.name", "\"Nation\".\"name\"", List.of(ordersCustomer, customerNation)}, + {"customer.nation.nationkey + 1", "(\"Nation\".\"nationkey\" + 1)", List.of(ordersCustomer, customerNation)}, + {"concat('#', customer.nation.name)", "concat('#', \"Nation\".\"name\")", List.of(ordersCustomer, customerNation)}, + {"concat(customer.name, '#', customer.nation.name)", "concat(\"Customer\".\"name\", '#', \"Nation\".\"name\")", + List.of(ordersCustomer, customerNation, ordersCustomer)}, + }; + } + + @Test(dataProvider = "rewriteTests") + public void testRewrite(String actual, String expected, List relationships) + { + Expression expression = parseExpression(actual); + List expressionRelationshipInfos = ExpressionRelationshipAnalyzer.getRelationships(expression, mdl, orders); + assertThat(expressionRelationshipInfos.stream().map(ExpressionRelationshipInfo::getRelationships).flatMap(List::stream).collect(toImmutableList())) + .containsExactlyInAnyOrderElementsOf(relationships); + assertThat(RelationshipRewriter.rewrite(expressionRelationshipInfos, expression).toString()).isEqualTo(expected); + } + + @Test + public void testToMany() + { + assertThatThrownBy(() -> ExpressionRelationshipAnalyzer.getRelationships(parseExpression("customer.custkey"), mdl, nation)) + .hasMessage("expr in model only accept to-one relation"); + assertThatThrownBy(() -> ExpressionRelationshipAnalyzer.getRelationships(parseExpression("customer.nation.customer.custkey"), mdl, orders)) + .hasMessage("expr in model only accept to-one relation"); + } + + @Test + public void testNoRelationshipFound() + { + // won't collect relationship if direct access relationship column + assertThat(ExpressionRelationshipAnalyzer.getRelationships(parseExpression("customer"), mdl, nation)).isEmpty(); + assertThat(ExpressionRelationshipAnalyzer.getRelationships(parseExpression("customer.nation"), mdl, orders)).isEmpty(); + // won't collect relationship if column not found in model + assertThat(ExpressionRelationshipAnalyzer.getRelationships(parseExpression("foo"), mdl, orders)).isEmpty(); + // won't collect relationship since "Orders" is not a column in orders model + assertThat(ExpressionRelationshipAnalyzer.getRelationships(parseExpression("Orders.customer.custkey"), mdl, orders)).isEmpty(); + } + + @Test + public void testCycle() + { + assertThatThrownBy(() -> ExpressionRelationshipAnalyzer.getRelationships(parseExpression("region.nation"), mdl, nation)) + .hasMessage("found cycle in expression"); + } +} diff --git a/accio-sqlrewrite/src/test/java/io/accio/sqlrewrite/TestLambdaExpressionBodyRewrite.java b/accio-sqlrewrite/src/test/java/io/accio/sqlrewrite/TestLambdaExpressionBodyRewrite.java deleted file mode 100644 index 6cecd6ee1..000000000 --- a/accio-sqlrewrite/src/test/java/io/accio/sqlrewrite/TestLambdaExpressionBodyRewrite.java +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.accio.sqlrewrite; - -import io.trino.sql.parser.ParsingOptions; -import io.trino.sql.parser.SqlParser; -import io.trino.sql.tree.Expression; -import io.trino.sql.tree.Identifier; -import io.trino.sql.tree.Node; -import org.testng.annotations.DataProvider; -import org.testng.annotations.Test; - -import static io.trino.sql.parser.ParsingOptions.DecimalLiteralTreatment.AS_DECIMAL; -import static org.assertj.core.api.AssertionsForClassTypes.assertThat; - -public class TestLambdaExpressionBodyRewrite -{ - private static final SqlParser SQL_PARSER = new SqlParser(); - - @DataProvider - public Object[][] lambdaExpression() - { - return new Object[][] { - {"book.f1.f2.f3", "t.f1.f2.f3"}, - {"book", "'Relationship'"}, - {"book.f1.a1[1].f2", "t.f1.a1[1].f2"}, - {"concat(book.name, '_1')", "concat(t.name, '_1')"}, - {"book.name = 'Lord of the Rings'", "t.name = 'Lord of the Rings'"}, - }; - } - - @Test(dataProvider = "lambdaExpression") - public void testLambdaExpressionRewrite(String actual, String expected) - { - Node node = LambdaExpressionBodyRewrite.rewrite(parse(actual), "Book", new Identifier("book")); - assertThat(node.toString()).isEqualTo(parse(expected).toString()); - } - - private Expression parse(String sql) - { - return SQL_PARSER.createExpression(sql, new ParsingOptions(AS_DECIMAL)); - } -} diff --git a/accio-sqlrewrite/src/test/java/io/accio/sqlrewrite/TestMetricViewSqlRewrite.java b/accio-sqlrewrite/src/test/java/io/accio/sqlrewrite/TestMetricViewSqlRewrite.java index c6e15e52c..0b5b7ecb6 100644 --- a/accio-sqlrewrite/src/test/java/io/accio/sqlrewrite/TestMetricViewSqlRewrite.java +++ b/accio-sqlrewrite/src/test/java/io/accio/sqlrewrite/TestMetricViewSqlRewrite.java @@ -49,27 +49,37 @@ public class TestMetricViewSqlRewrite extends AbstractTestFramework { @Language("sql") - private static final String MODEL_CTES = - "Album AS (\n" + - " SELECT\n" + - " \"id\"\n" + - " , \"name\"\n" + - " , \"author\"\n" + - " , \"price\"\n" + - " , \"publish_date\"\n" + - " , \"release_date\"\n" + - " FROM\n" + - " (\n" + - " SELECT *\n" + - " FROM\n" + - " (\n" + - " VALUES \n" + - " ROW (1, 'Gusare', 'ZUTOMAYO', 2560, DATE '2023-03-29', TIMESTAMP '2023-04-27 06:06:06')\n" + - " , ROW (2, 'HisoHiso Banashi', 'ZUTOMAYO', 1500, DATE '2023-04-29', TIMESTAMP '2023-05-27 07:07:07')\n" + - " , ROW (3, 'Dakara boku wa ongaku o yameta', 'Yorushika', 2553, DATE '2023-05-29', TIMESTAMP '2023-06-27 08:08:08')\n" + - " ) album (id, name, author, price, publish_date, release_date)\n" + - " ) t\n" + - ") \n"; + private static final String MODEL_CTES = "" + + " Album AS (\n" + + " SELECT\n" + + " \"Album\".\"id\" \"id\"\n" + + " , \"Album\".\"name\" \"name\"\n" + + " , \"Album\".\"author\" \"author\"\n" + + " , \"Album\".\"price\" \"price\"\n" + + " , \"Album\".\"publish_date\" \"publish_date\"\n" + + " , \"Album\".\"release_date\" \"release_date\"\n" + + " FROM\n" + + " (\n" + + " SELECT\n" + + " \"Album\".\"id\" \"id\"\n" + + " , \"Album\".\"name\" \"name\"\n" + + " , \"Album\".\"author\" \"author\"\n" + + " , \"Album\".\"price\" \"price\"\n" + + " , \"Album\".\"publish_date\" \"publish_date\"\n" + + " , \"Album\".\"release_date\" \"release_date\"\n" + + " FROM\n" + + " (\n" + + " SELECT *\n" + + " FROM\n" + + " (\n" + + " VALUES \n" + + " ROW (1, 'Gusare', 'ZUTOMAYO', 2560, DATE '2023-03-29', TIMESTAMP '2023-04-27 06:06:06')\n" + + " , ROW (2, 'HisoHiso Banashi', 'ZUTOMAYO', 1500, DATE '2023-04-29', TIMESTAMP '2023-05-27 07:07:07')\n" + + " , ROW (3, 'Dakara boku wa ongaku o yameta', 'Yorushika', 2553, DATE '2023-05-29', TIMESTAMP '2023-06-27 08:08:08')\n" + + " ) album (id, name, author, price, publish_date, release_date)\n" + + " ) \"Album\"\n" + + " ) \"Album\"\n" + + ")\n"; @Language("sql") private static final String METRIC_CTES = @@ -134,8 +144,8 @@ public Object[][] metricCases() "WITH\n" + METRIC_CTES + "SELECT\n" + - " Collection.author\n" + - ", Collection.price\n" + + " author\n" + + ", price\n" + "FROM\n" + " Collection"}, { diff --git a/accio-sqlrewrite/src/test/java/io/accio/sqlrewrite/TestModelSqlRewrite.java b/accio-sqlrewrite/src/test/java/io/accio/sqlrewrite/TestModelSqlRewrite.java new file mode 100644 index 000000000..622258036 --- /dev/null +++ b/accio-sqlrewrite/src/test/java/io/accio/sqlrewrite/TestModelSqlRewrite.java @@ -0,0 +1,290 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.accio.sqlrewrite; + +import io.accio.base.AccioMDL; +import io.accio.testing.AbstractTestFramework; +import io.trino.sql.parser.ParsingOptions; +import io.trino.sql.parser.SqlParser; +import io.trino.sql.tree.Statement; +import org.intellij.lang.annotations.Language; +import org.testng.annotations.Test; + +import java.util.List; + +import static io.accio.base.dto.Column.column; +import static io.accio.base.dto.Column.relationshipColumn; +import static io.accio.base.dto.JoinType.ONE_TO_MANY; +import static io.accio.base.dto.JoinType.ONE_TO_ONE; +import static io.accio.base.dto.Model.model; +import static io.accio.base.dto.Relationship.relationship; +import static io.accio.sqlrewrite.AccioSqlRewrite.ACCIO_SQL_REWRITE; +import static io.trino.sql.SqlFormatter.formatSql; +import static io.trino.sql.parser.ParsingOptions.DecimalLiteralTreatment.AS_DECIMAL; +import static java.lang.String.format; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatNoException; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +public class TestModelSqlRewrite + extends AbstractTestFramework +{ + private static final AccioMDL ACCIOMDL = AccioMDL.fromManifest(withDefaultCatalogSchema() + .setRelationships(List.of( + relationship("WishListPeople", List.of("WishList", "People"), ONE_TO_ONE, "WishList.id = People.id"), + relationship("PeopleBook", List.of("People", "Book"), ONE_TO_MANY, "People.id = Book.authorId"))) + .setModels(List.of( + model( + "People", + "SELECT * FROM table_people", + List.of( + column("id", "STRING", null, false), + column("email", "STRING", null, false), + column("gift", "STRING", null, false, "wishlist.bookId"), + relationshipColumn("book", "Book", "PeopleBook"), + relationshipColumn("wishlist", "WishList", "WishListPeople")), + "id"), + model( + "Book", + "SELECT * FROM table_book", + List.of( + column("bookId", "STRING", null, false), + column("authorId", "STRING", null, false), + column("publish_date", "STRING", null, false), + column("publish_year", "DATE", null, false, "date_trunc('year', publish_date)"), + column("author_gift_id", "STRING", null, false, "people.wishlist.bookId"), + relationshipColumn("people", "People", "PeopleBook")), + "bookId"), + model( + "WishList", + "SELECT * FROM table_wishlist", + List.of( + column("id", "STRING", null, false), + column("bookId", "STRING", null, false)), + "id"))) + .build()); + + @Override + protected void prepareData() + { + exec("CREATE TABLE table_people AS SELECT * FROM\n" + + "(VALUES\n" + + "('P1001', 'foo@foo.org'),\n" + + "('P1002', 'bar@bar.org'))\n" + + "People (id, email)"); + exec("CREATE TABLE table_book AS SELECT * FROM\n" + + "(VALUES\n" + + "('SN1001', 'P1001', CAST('1991-01-01' AS TIMESTAMP)),\n" + + "('SN1002', 'P1002', CAST('1992-02-02' AS TIMESTAMP)),\n" + + "('SN1003', 'P1001', CAST('1993-03-03' AS TIMESTAMP)))\n" + + "Book (bookId, authorId, publish_date)"); + exec("CREATE TABLE table_wishlist AS SELECT * FROM\n" + + "(VALUES\n" + + "('P1001', 'SN1002'),\n" + + "('P1002', 'SN1001'))\n" + + "WishList (id, bookId)"); + } + + @Override + protected void cleanup() + { + exec("DROP TABLE table_people"); + exec("DROP TABLE table_book"); + exec("DROP TABLE table_wishlist"); + } + + @Test + public void testModelRewrite() + { + @Language("SQL") String withPeopleQuery = "" + + " WishList AS (\n" + + " SELECT\n" + + " \"WishList\".\"id\" \"id\"\n" + + " , \"WishList\".\"bookId\" \"bookId\"\n" + + " FROM\n" + + " (\n" + + " SELECT\n" + + " \"WishList\".\"id\" \"id\"\n" + + " , \"WishList\".\"bookId\" \"bookId\"\n" + + " FROM\n" + + " (\n" + + " SELECT *\n" + + " FROM\n" + + " table_wishlist\n" + + " ) \"WishList\"\n" + + " ) \"WishList\"\n" + + ") \n" + + ", People AS (\n" + + " SELECT\n" + + " \"People\".\"id\" \"id\"\n" + + " , \"People\".\"email\" \"email\"\n" + + " , \"gift\".\"gift\" \"gift\"\n" + + " FROM\n" + + " ((\n" + + " SELECT\n" + + " \"People\".\"id\" \"id\"\n" + + " , \"People\".\"email\" \"email\"\n" + + " FROM\n" + + " (\n" + + " SELECT *\n" + + " FROM\n" + + " table_people\n" + + " ) \"People\"\n" + + " ) \"People\"\n" + + " LEFT JOIN (\n" + + " SELECT\n" + + " \"People\".\"id\"\n" + + " , \"WishList\".\"bookId\" \"gift\"\n" + + " FROM\n" + + " ((\n" + + " SELECT\n" + + " id \"id\"\n" + + " , id \"id\"\n" + + " FROM\n" + + " (\n" + + " SELECT *\n" + + " FROM\n" + + " table_people\n" + + " ) \"People\"\n" + + " ) \"People\"\n" + + " LEFT JOIN \"WishList\" ON (WishList.id = People.id))\n" + + " ) \"gift\" ON (\"People\".\"id\" = \"gift\".\"id\"))\n" + + ")\n"; + + @Language("SQL") String withBookQuery = withPeopleQuery + + ", Book AS (\n" + + " SELECT\n" + + " \"Book\".\"bookId\" \"bookId\"\n" + + " , \"Book\".\"authorId\" \"authorId\"\n" + + " , \"Book\".\"publish_date\" \"publish_date\"\n" + + " , \"Book\".\"publish_year\" \"publish_year\"\n" + + " , \"author_gift_id\".\"author_gift_id\" \"author_gift_id\"\n" + + " FROM\n" + + " ((\n" + + " SELECT\n" + + " \"Book\".\"bookId\" \"bookId\"\n" + + " , \"Book\".\"authorId\" \"authorId\"\n" + + " , \"Book\".\"publish_date\" \"publish_date\"\n" + + " , date_trunc('year', publish_date) \"publish_year\"\n" + + " FROM\n" + + " (\n" + + " SELECT *\n" + + " FROM\n" + + " table_book\n" + + " ) \"Book\"\n" + + " ) \"Book\"\n" + + " LEFT JOIN (\n" + + " SELECT\n" + + " \"Book\".\"bookId\"\n" + + " , \"WishList\".\"bookId\" \"author_gift_id\"\n" + + " FROM\n" + + " (((\n" + + " SELECT\n" + + " bookId \"bookId\"\n" + + " , authorId \"authorId\"\n" + + " FROM\n" + + " (\n" + + " SELECT *\n" + + " FROM\n" + + " table_book\n" + + " ) \"Book\"\n" + + " ) \"Book\"\n" + + " LEFT JOIN \"People\" ON (People.id = Book.authorId))\n" + + " LEFT JOIN \"WishList\" ON (WishList.id = People.id))\n" + + " ) \"author_gift_id\" ON (\"Book\".\"bookId\" = \"author_gift_id\".\"bookId\"))\n" + + ")\n"; + + assertSqlEqualsAndValid(rewrite("SELECT * FROM People"), "WITH " + withPeopleQuery + "SELECT * FROM People"); + assertSqlEqualsAndValid(rewrite("SELECT * FROM People WHERE id = 'SN1001'"), "WITH " + withPeopleQuery + "SELECT * FROM People WHERE id = 'SN1001'"); + assertSqlEqualsAndValid(rewrite("SELECT * FROM Book"), "WITH " + withBookQuery + "SELECT * FROM Book"); + assertSqlEqualsAndValid(rewrite("SELECT * FROM People a join Book b ON a.id = b.authorId WHERE a.id = 'SN1001'"), + "WITH " + withBookQuery + "SELECT * FROM People a join Book b ON a.id = b.authorId WHERE a.id = 'SN1001'"); + assertSqlEqualsAndValid(rewrite("SELECT * FROM People a join WishList b ON a.id = b.id WHERE a.id = 'SN1001'"), + "WITH " + withPeopleQuery + "SELECT * FROM People a join WishList b ON a.id = b.id WHERE a.id = 'SN1001'"); + + assertSqlEqualsAndValid(rewrite("WITH a AS (SELECT * FROM WishList) SELECT * FROM a JOIN People ON a.id = People.id"), + "WITH" + withPeopleQuery + ", a AS (SELECT * FROM WishList) SELECT * FROM a JOIN People ON a.id = People.id"); + // rewrite table in with query + assertSqlEqualsAndValid(rewrite("WITH a AS (SELECT * FROM People) SELECT * FROM a"), + "WITH" + withPeopleQuery + ", a AS (SELECT * FROM People) SELECT * FROM a"); + } + + @Test + public void testCycle() + { + AccioMDL cycle = AccioMDL.fromManifest(withDefaultCatalogSchema() + .setRelationships(List.of( + relationship("WishListPeople", List.of("WishList", "People"), ONE_TO_ONE, "WishList.id = People.id"))) + .setModels(List.of( + model( + "People", + "SELECT * FROM People", + List.of( + column("id", "STRING", null, false), + column("email", "STRING", null, false), + column("gift", "STRING", null, false, "wishlist.bookId"), + relationshipColumn("wishlist", "WishList", "WishListPeople")), + "id"), + model( + "WishList", + "SELECT * FROM WishList", + List.of( + column("id", "STRING", null, false), + column("bookId", "STRING", null, false), + column("peopleId", "STRING", null, false, "people.id"), + relationshipColumn("people", "People", "WishListPeople")), + "id"))) + .build()); + + // TODO: This is not allowed since accio lack of the functionality of analyzing select items in model in sql. + // Currently we treat all columns in models are required, and that cause cycles in generating WITH queries when models reference each other. + assertThatThrownBy(() -> rewrite("SELECT * FROM People", cycle), "") + .hasMessage("found cycle in models"); + } + + @Test + public void testNoRewrite() + { + assertSqlEquals(rewrite("SELECT * FROM foo"), "SELECT * FROM foo"); + } + + private String rewrite(String sql) + { + return rewrite(sql, ACCIOMDL); + } + + private String rewrite(String sql, AccioMDL mdl) + { + return AccioPlanner.rewrite(sql, DEFAULT_SESSION_CONTEXT, mdl, List.of(ACCIO_SQL_REWRITE)); + } + + private void assertSqlEqualsAndValid(@Language("SQL") String actual, @Language("SQL") String expected) + { + assertSqlEquals(actual, expected); + assertThatNoException() + .describedAs(format("actual sql: %s is invalid", actual)) + .isThrownBy(() -> query(actual)); + } + + private void assertSqlEquals(String actual, String expected) + { + SqlParser sqlParser = new SqlParser(); + ParsingOptions parsingOptions = new ParsingOptions(AS_DECIMAL); + Statement actualStmt = sqlParser.createStatement(actual, parsingOptions); + Statement expectedStmt = sqlParser.createStatement(expected, parsingOptions); + assertThat(formatSql(actualStmt)) + .isEqualTo(formatSql(expectedStmt)); + } +} diff --git a/accio-sqlrewrite/src/test/java/io/accio/sqlrewrite/TestRelationshipAccessing.java b/accio-sqlrewrite/src/test/java/io/accio/sqlrewrite/TestRelationshipAccessing.java deleted file mode 100644 index c6393a7e8..000000000 --- a/accio-sqlrewrite/src/test/java/io/accio/sqlrewrite/TestRelationshipAccessing.java +++ /dev/null @@ -1,1427 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.accio.sqlrewrite; - -import io.accio.base.AccioMDL; -import io.accio.base.AccioTypes; -import io.accio.base.dto.JoinType; -import io.accio.base.dto.Model; -import io.accio.base.dto.Relationship; -import io.accio.sqlrewrite.analyzer.Analysis; -import io.accio.sqlrewrite.analyzer.StatementAnalyzer; -import io.accio.testing.AbstractTestFramework; -import io.trino.sql.SqlFormatter; -import io.trino.sql.parser.ParsingOptions; -import io.trino.sql.parser.SqlParser; -import io.trino.sql.tree.Node; -import io.trino.sql.tree.Statement; -import org.apache.commons.lang3.text.StrSubstitutor; -import org.intellij.lang.annotations.Language; -import org.testng.annotations.DataProvider; -import org.testng.annotations.Test; - -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -import static io.accio.base.dto.Column.column; -import static io.accio.base.dto.Relationship.SortKey.sortKey; -import static io.accio.base.dto.Relationship.relationship; -import static io.accio.sqlrewrite.AccioPlanner.ALL_RULES; -import static io.accio.sqlrewrite.AccioSqlRewrite.ACCIO_SQL_REWRITE; -import static io.trino.sql.parser.ParsingOptions.DecimalLiteralTreatment.AS_DECIMAL; -import static java.lang.String.format; -import static org.assertj.core.api.Assertions.assertThatNoException; -import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.assertj.core.api.AssertionsForClassTypes.assertThat; - -public class TestRelationshipAccessing - extends AbstractTestFramework -{ - @Language("SQL") - private static final String ONE_TO_ONE_MODEL_CTE = "" + - " Book AS (\n" + - " SELECT\n" + - " \"bookId\",\n" + - " \"name\",\n" + - " 'relationship' as \"author\",\n" + - " \"authorId\"\n" + - " FROM (\n" + - " SELECT *\n" + - " FROM (\n" + - " VALUES\n" + - " (1, 'book1', 1),\n" + - " (2, 'book2', 2),\n" + - " (3, 'book3', 3)\n" + - " ) Book(bookId, name, authorId)\n" + - " ) t\n" + - " ),\n" + - " People AS (\n" + - " SELECT\n" + - " \"userId\",\n" + - " \"name\",\n" + - " 'relationship' AS \"book\"\n" + - " FROM\n" + - " (\n" + - " SELECT *\n" + - " FROM\n" + - " (\n" + - " VALUES\n" + - " (1, 'user1'),\n" + - " (2, 'user2'),\n" + - " (3, 'user3')\n" + - " ) People (userId, name)\n" + - " ) t\n" + - " )\n"; - - @Language("SQL") - private static final String ONE_TO_MANY_MODEL_CTE = "" + - " Book AS (\n" + - " SELECT\n" + - " \"bookId\",\n" + - " \"name\",\n" + - " 'relationship' as \"author\",\n" + - " 'relationship' as \"author_reverse\",\n" + - " \"authorId\"\n" + - " FROM (\n" + - " SELECT *\n" + - " FROM (\n" + - " VALUES\n" + - " (1, 'book1', 1),\n" + - " (2, 'book2', 2),\n" + - " (3, 'book3', 1)\n" + - " ) Book(bookId, name, authorId)\n" + - " ) t\n" + - " ),\n" + - " People AS (\n" + - " SELECT\n" + - " \"userId\",\n" + - " \"name\",\n" + - // TODO: Remove this field. In ONE_TO_MANY relationship, user can access it directly. - " 'relationship' AS \"books\"\n" + - ", 'relationship' \"sorted_books\"\n" + - " FROM\n" + - " (\n" + - " SELECT *\n" + - " FROM\n" + - " (\n" + - " VALUES\n" + - " (1, 'user1'),\n" + - " (2, 'user2')\n" + - " ) People (userId, name)\n" + - " ) t\n" + - " )\n"; - - @Language("SQL") - private static final String EXPECTED_AUTHOR_BOOK_AUTHOR_WITH_QUERIES = "" + - "WITH\n" + ONE_TO_ONE_MODEL_CTE + ",\n" + - " ${Book.author} (userId, name, book, bk) AS (\n" + - " SELECT\n" + - " t.userId\n" + - " , t.name\n" + - " , t.book\n" + - " , s.bookId bk\n" + - " FROM\n" + - " (Book s\n" + - " LEFT JOIN People t ON (s.authorId = t.userId))\n" + - ") \n" + - ", ${Book.author.book} (bookId, name, author, authorId, bk) AS (\n" + - " SELECT\n" + - " t.bookId\n" + - " , t.name\n" + - " , t.author\n" + - " , t.authorId\n" + - " , s.bk bk\n" + - " FROM\n" + - " (${Book.author} s\n" + - " LEFT JOIN Book t ON (s.userId = t.authorId))\n" + - ") \n" + - ", ${Book.author.book.author} (userId, name, book, bk) AS (\n" + - " SELECT\n" + - " t.userId\n" + - " , t.name\n" + - " , t.book\n" + - " , s.bk bk\n" + - " FROM\n" + - " (${Book.author.book} s\n" + - " LEFT JOIN People t ON (s.authorId = t.userId))\n" + - ")"; - private static final SqlParser SQL_PARSER = new SqlParser(); - - private final AccioMDL oneToOneAccioMDL; - private final AccioMDL oneToManyAccioMDL; - - public TestRelationshipAccessing() - { - oneToOneAccioMDL = AccioMDL.fromManifest(withDefaultCatalogSchema() - .setModels(List.of( - Model.model("Book", - "select * from (values (1, 'book1', 1), (2, 'book2', 2), (3, 'book3', 3)) Book(bookId, name, authorId)", - List.of( - column("bookId", AccioTypes.INTEGER, null, true), - column("name", AccioTypes.VARCHAR, null, true), - column("author", "People", "BookPeople", true), - column("authorId", AccioTypes.INTEGER, null, true)), - "bookId"), - Model.model("People", - "select * from (values (1, 'user1'), (2, 'user2'), (3, 'user3')) People(userId, name)", - List.of( - column("userId", AccioTypes.INTEGER, null, true), - column("name", AccioTypes.VARCHAR, null, true), - column("book", "Book", "BookPeople", true)), - "userId"))) - .setRelationships(List.of(relationship("BookPeople", List.of("Book", "People"), JoinType.ONE_TO_ONE, "Book.authorId = People.userId"))) - .build()); - - oneToManyAccioMDL = AccioMDL.fromManifest(withDefaultCatalogSchema() - .setModels(List.of( - Model.model("Book", - "select * from (values (1, 'book1', 1), (2, 'book2', 2), (3, 'book3', 1)) Book(bookId, name, authorId)", - List.of( - column("bookId", AccioTypes.INTEGER, null, true), - column("name", AccioTypes.VARCHAR, null, true), - column("author", "People", "PeopleBook", true), - column("author_reverse", "People", "BookPeople", true), - column("authorId", AccioTypes.INTEGER, null, true)), - "bookId"), - Model.model("People", - "select * from (values (1, 'user1'), (2, 'user2')) People(userId, name)", - List.of( - column("userId", AccioTypes.INTEGER, null, true), - column("name", AccioTypes.VARCHAR, null, true), - column("books", "Book", "PeopleBook", true), - column("sorted_books", "Book", "PeopleBookOrderByName", true)), - "userId"))) - .setRelationships(List.of( - relationship("PeopleBook", List.of("People", "Book"), JoinType.ONE_TO_MANY, "People.userId = Book.authorId"), - relationship("BookPeople", List.of("Book", "People"), JoinType.MANY_TO_ONE, "Book.authorId = People.userId"), - relationship("PeopleBookOrderByName", List.of("People", "Book"), JoinType.ONE_TO_MANY, "People.userId = Book.authorId", - List.of(sortKey("name", Relationship.SortKey.Ordering.ASC), sortKey("bookId", Relationship.SortKey.Ordering.DESC))))) - .build()); - } - - @DataProvider - public Object[][] oneToOneRelationshipAccessCases() - { - return new Object[][] { - {"SELECT a.author.book.author.name\n" + - "FROM Book a", - EXPECTED_AUTHOR_BOOK_AUTHOR_WITH_QUERIES + - "SELECT ${Book.author.book.author}.name\n" + - "FROM\n" + - " (Book a\n" + - "LEFT JOIN ${Book.author.book.author} ON (a.bookId = ${Book.author.book.author}.bk))", - true}, - {"SELECT a.author.book.author.name, a.author.book.name, a.author.name\n" + - "FROM Book a", - EXPECTED_AUTHOR_BOOK_AUTHOR_WITH_QUERIES + - "SELECT\n" + - " ${Book.author.book.author}.name\n" + - ", ${Book.author.book}.name\n" + - ", ${Book.author}.name\n" + - "FROM\n" + - " (((Book a\n" + - "LEFT JOIN ${Book.author} ON (a.bookId = ${Book.author}.bk))\n" + - "LEFT JOIN ${Book.author.book} ON (a.bookId = ${Book.author.book}.bk))\n" + - "LEFT JOIN ${Book.author.book.author} ON (a.bookId = ${Book.author.book.author}.bk))", - true}, - // TODO: support join models - // {"SELECT author.book.author.name, book.name\n" + - // "FROM Book JOIN People on Book.authorId = People.userId", - // "SELECT 1", - // true}, - // {"SELECT a.author.book.author.name, b book.name\n" + - // "FROM Book a JOIN People b on a.authorId = b.userId", - // "SELECT 1", - // true}, - {"SELECT accio.test.Book.author.book.author.name,\n" + - "test.Book.author.book.author.name,\n" + - "Book.author.book.author.name\n" + - "FROM accio.test.Book", - EXPECTED_AUTHOR_BOOK_AUTHOR_WITH_QUERIES + - "SELECT ${Book.author.book.author}.name,\n" + - "${Book.author.book.author}.name,\n" + - "${Book.author.book.author}.name\n" + - "FROM\n" + - " (Book\n" + - "LEFT JOIN ${Book.author.book.author} ON (Book.bookId = ${Book.author.book.author}.bk))", - true}, - {"select author.book.author.name,\n" + - "author.book.name,\n" + - "author.name\n" + - "from Book", - EXPECTED_AUTHOR_BOOK_AUTHOR_WITH_QUERIES + - "SELECT\n" + - " ${Book.author.book.author}.name\n" + - ", ${Book.author.book}.name\n" + - ", ${Book.author}.name\n" + - "FROM\n" + - " (((Book\n" + - "LEFT JOIN ${Book.author} ON (Book.bookId = ${Book.author}.bk))\n" + - "LEFT JOIN ${Book.author.book} ON (Book.bookId = ${Book.author.book}.bk))\n" + - "LEFT JOIN ${Book.author.book.author} ON (Book.bookId = ${Book.author.book.author}.bk))", - true}, - {"select name from Book where author.book.author.name = 'jax'", - EXPECTED_AUTHOR_BOOK_AUTHOR_WITH_QUERIES + - "SELECT name\n" + - "FROM\n" + - " (Book\n" + - "LEFT JOIN ${Book.author.book.author} ON (Book.bookId = ${Book.author.book.author}.bk))\n" + - "WHERE (${Book.author.book.author}.name = 'jax')", - false}, - {"select name, author.book.author.name from Book group by author.book.author.name having author.book.name = 'destiny'", - EXPECTED_AUTHOR_BOOK_AUTHOR_WITH_QUERIES + - "SELECT\n" + - " name\n" + - ", ${Book.author.book.author}.name\n" + - "FROM\n" + - " ((Book\n" + - "LEFT JOIN ${Book.author.book} ON (Book.bookId = ${Book.author.book}.bk))\n" + - "LEFT JOIN ${Book.author.book.author} ON (Book.bookId = ${Book.author.book.author}.bk))\n" + - "GROUP BY ${Book.author.book.author}.name\n" + - "HAVING (${Book.author.book}.name = 'destiny')", - false}, - {"select name, author.book.author.name from Book order by author.book.author.name", - EXPECTED_AUTHOR_BOOK_AUTHOR_WITH_QUERIES + - "SELECT\n" + - " name\n" + - ", ${Book.author.book.author}.name\n" + - "FROM\n" + - " (Book\n" + - "LEFT JOIN ${Book.author.book.author} ON (Book.bookId = ${Book.author.book.author}.bk))\n" + - "ORDER BY ${Book.author.book.author}.name ASC", - false}, - {"select a.* from (select name, author.book.author.name from Book order by author.book.author.name) a", - EXPECTED_AUTHOR_BOOK_AUTHOR_WITH_QUERIES + - "SELECT a.*\n" + - "FROM\n" + - " (\n" + - " SELECT\n" + - " name\n" + - " , ${Book.author.book.author}.name\n" + - " FROM\n" + - " (Book\n" + - " LEFT JOIN ${Book.author.book.author} ON (Book.bookId = ${Book.author.book.author}.bk))\n" + - " ORDER BY ${Book.author.book.author}.name ASC\n" + - ") a", - false}, - {"with a as (select b.* from (select name, author.book.author.name from Book) b)\n" + - "select * from a", - EXPECTED_AUTHOR_BOOK_AUTHOR_WITH_QUERIES + - ", a as (" + - "SELECT b.* from (\n" + - " SELECT " + - " name,\n" + - " ${Book.author.book.author}.name\n" + - " FROM " + - " (Book " + - " LEFT JOIN ${Book.author.book.author} ON (Book.bookId = ${Book.author.book.author}.bk))\n" + - ") b)\n" + - "SELECT * FROM a", - false - }, - // test the reverse relationship accessing - {"select book.author.book.name, book.author.name, book.name from People", - "WITH\n" + ONE_TO_ONE_MODEL_CTE + ",\n" + - " ${People.book} (bookId, name, author, authorId, bk) AS (\n" + - " SELECT\n" + - " t.bookId\n" + - " , t.name\n" + - " , t.author\n" + - " , t.authorId\n" + - " , s.userId bk\n" + - " FROM\n" + - " (People s\n" + - " LEFT JOIN Book t ON (s.userId = t.authorId))\n" + - ") \n" + - ", ${People.book.author} (userId, name, book, bk) AS (\n" + - " SELECT\n" + - " t.userId\n" + - " , t.name\n" + - " , t.book\n" + - " , s.bk bk\n" + - " FROM\n" + - " (${People.book} s\n" + - " LEFT JOIN People t ON (s.authorId = t.userId))\n" + - ") \n" + - ", ${People.book.author.book} (bookId, name, author, authorId, bk) AS (\n" + - " SELECT\n" + - " t.bookId\n" + - " , t.name\n" + - " , t.author\n" + - " , t.authorId\n" + - " , s.bk bk\n" + - " FROM\n" + - " (${People.book.author} s\n" + - " LEFT JOIN Book t ON (s.userId = t.authorId))\n" + - ") \n" + - "SELECT\n" + - " ${People.book.author.book}.name\n" + - ", ${People.book.author}.name\n" + - ", ${People.book}.name\n" + - "FROM\n" + - " (((People\n" + - "LEFT JOIN ${People.book.author} ON (People.userId = ${People.book.author}.bk))\n" + - "LEFT JOIN ${People.book.author.book} ON (People.userId = ${People.book.author.book}.bk))\n" + - "LEFT JOIN ${People.book} ON (People.userId = ${People.book}.bk))", - true}, - {"WITH A as (SELECT b.author.name FROM Book b) SELECT A.name FROM A", - "WITH\n" + ONE_TO_ONE_MODEL_CTE + ",\n" + - " ${Book.author} (userId, name, book, bk) AS (\n" + - " SELECT\n" + - " t.userId\n" + - " , t.name\n" + - " , t.book\n" + - " , s.bookId bk\n" + - " FROM\n" + - " (Book s\n" + - " LEFT JOIN People t ON (s.authorId = t.userId))\n" + - ") \n" + - ", A AS (\n" + - " SELECT ${Book.author}.name\n" + - " FROM\n" + - " (Book b\n" + - " LEFT JOIN ${Book.author} ON (b.bookId = ${Book.author}.bk))\n" + - ") \n" + - "SELECT A.name\n" + - "FROM\n" + - " A", true}, - }; - } - - @Test(dataProvider = "oneToOneRelationshipAccessCases") - public void testOneToOneRelationshipAccessingRewrite(String original, String expected, boolean enableH2Assertion) - { - Statement statement = SQL_PARSER.createStatement(original, new ParsingOptions(AS_DECIMAL)); - RelationshipCteGenerator generator = new RelationshipCteGenerator(oneToOneAccioMDL); - Analysis analysis = StatementAnalyzer.analyze(statement, DEFAULT_SESSION_CONTEXT, oneToOneAccioMDL, generator); - - Map replaceMap = new HashMap<>(); - replaceMap.put("Book.author", generator.getNameMapping().get("Book.author")); - replaceMap.put("Book.author.book", generator.getNameMapping().get("Book.author.book")); - replaceMap.put("Book.author.book.author", generator.getNameMapping().get("Book.author.book.author")); - replaceMap.put("People.book", generator.getNameMapping().get("People.book")); - replaceMap.put("People.book.author", generator.getNameMapping().get("People.book.author")); - replaceMap.put("People.book.author.book", generator.getNameMapping().get("People.book.author.book")); - - Node rewrittenStatement = statement; - for (AccioRule rule : List.of(ACCIO_SQL_REWRITE)) { - rewrittenStatement = rule.apply((Statement) rewrittenStatement, DEFAULT_SESSION_CONTEXT, analysis, oneToOneAccioMDL); - } - - Statement expectedResult = SQL_PARSER.createStatement(new StrSubstitutor(replaceMap).replace(expected), new ParsingOptions(AS_DECIMAL)); - String actualSql = SqlFormatter.formatSql(rewrittenStatement); - assertThat(actualSql).isEqualTo(SqlFormatter.formatSql(expectedResult)); - // TODO: remove this flag, disabled h2 assertion due to ambiguous column name - if (enableH2Assertion) { - assertThatNoException() - .describedAs(format("actual sql: %s is invalid", actualSql)) - .isThrownBy(() -> query(actualSql)); - } - } - - @Test - public void testNotFoundRelationAliased() - { - String original = "select b.book.author.book.name from Book a"; - Statement statement = SQL_PARSER.createStatement(original, new ParsingOptions(AS_DECIMAL)); - RelationshipCteGenerator generator = new RelationshipCteGenerator(oneToOneAccioMDL); - Analysis analysis = StatementAnalyzer.analyze(statement, DEFAULT_SESSION_CONTEXT, oneToOneAccioMDL, generator); - - Node rewrittenStatement = statement; - for (AccioRule rule : List.of(ACCIO_SQL_REWRITE)) { - rewrittenStatement = rule.apply((Statement) rewrittenStatement, DEFAULT_SESSION_CONTEXT, analysis, oneToOneAccioMDL); - } - - String expected = "WITH\n" + - " Book AS (\n" + - " SELECT\n" + - " \"bookId\"\n" + - " , \"name\"\n" + - " , 'relationship' \"author\"\n" + - " , \"authorId\"\n" + - " FROM\n" + - " (\n" + - " SELECT *\n" + - " FROM\n" + - " (\n" + - " VALUES \n" + - " ROW (1, 'book1', 1)\n" + - " , ROW (2, 'book2', 2)\n" + - " , ROW (3, 'book3', 3)\n" + - " ) Book (bookId, name, authorId)\n" + - " ) t\n" + - ") \n" + - "SELECT b.book.author.book.name\n" + - "FROM\n" + - " Book a"; - Statement expectedResult = SQL_PARSER.createStatement(expected, new ParsingOptions(AS_DECIMAL)); - @Language("SQL") String actualSql = SqlFormatter.formatSql(rewrittenStatement); - assertThat(actualSql).isEqualTo(SqlFormatter.formatSql(expectedResult)); - assertThatThrownBy(() -> query(actualSql)) - .hasMessageContaining("Database \"b\" not found;"); - } - - @DataProvider - public Object[][] oneToManyRelationshipAccessCase() - { - return new Object[][] { - {"SELECT books[1].name FROM People", - "WITH\n" + ONE_TO_MANY_MODEL_CTE + ",\n" + - "${People.books} (userId, bk, books) AS (\n" + - " SELECT\n" + - " o.userId userId\n" + - " , o.userId bk\n" + - " , array_agg(m.bookId ORDER BY m.bookId ASC) filter(WHERE m.bookId IS NOT NULL) books\n" + - " FROM\n" + - " (People o\n" + - " LEFT JOIN Book m ON (o.userId = m.authorId))\n" + - " GROUP BY 1, 2\n" + - ") \n" + - ", ${People.books[1]} (bookId, name, author, author_reverse, authorId, bk) AS (\n" + - " SELECT\n" + - " t.bookId bookId\n" + - " , t.name name\n" + - " , t.author author\n" + - " , t.author_reverse author_reverse\n" + - " , t.authorId authorId\n" + - " , s.bk bk\n" + - " FROM\n" + - " (${People.books} s\n" + - " LEFT JOIN Book t ON (s.books[1] = t.bookId))\n" + - ") \n" + - "SELECT ${People.books[1]}.name\n" + - "FROM\n" + - " (People\n" + - "LEFT JOIN ${People.books[1]} ON (People.userId = ${People.books[1]}.bk))", false}, - {"SELECT books[1].author.books[1].name FROM People", - "WITH\n" + ONE_TO_MANY_MODEL_CTE + ",\n" + - "${People.books} (userId, bk, books) AS (\n" + - " SELECT\n" + - " o.userId userId\n" + - " , o.userId bk\n" + - " , array_agg(m.bookId ORDER BY m.bookId ASC) filter(WHERE m.bookId IS NOT NULL) books\n" + - " FROM\n" + - " (People o\n" + - " LEFT JOIN Book m ON (o.userId = m.authorId))\n" + - " GROUP BY 1, 2\n" + - ") \n" + - ", ${People.books[1]} (bookId, name, author, author_reverse, authorId, bk) AS (\n" + - " SELECT\n" + - " t.bookId bookId\n" + - " , t.name name\n" + - " , t.author author\n" + - " , t.author_reverse author_reverse\n" + - " , t.authorId authorId\n" + - " , s.bk bk\n" + - " FROM\n" + - " (${People.books} s\n" + - " LEFT JOIN Book t ON (s.books[1] = t.bookId))\n" + - ") \n" + - ", ${People.books[1].author} (userId, name, books, sorted_books, bk) AS (\n" + - " SELECT DISTINCT\n" + - " t.userId\n" + - " , t.name\n" + - " , t.books\n" + - " , t.sorted_books\n" + - " , s.bk bk\n" + - " FROM\n" + - " (${People.books[1]} s\n" + - " LEFT JOIN People t ON (s.authorId = t.userId))\n" + - ") \n" + - ", ${People.books[1].author.books} (userId, bk, books) AS (\n" + - " SELECT\n" + - " o.userId userId\n" + - " , o.bk bk\n" + - " , array_agg(m.bookId ORDER BY m.bookId ASC) filter(WHERE m.bookId IS NOT NULL) books\n" + - " FROM\n" + - " (${People.books[1].author} o\n" + - " LEFT JOIN Book m ON (o.userId = m.authorId))\n" + - " GROUP BY 1, 2\n" + - ") \n" + - ", ${People.books[1].author.books[1]} (bookId, name, author, author_reverse, authorId, bk) AS (\n" + - " SELECT\n" + - " t.bookId bookId\n" + - " , t.name name\n" + - " , t.author author\n" + - " , t.author_reverse author_reverse\n" + - " , t.authorId authorId\n" + - " , s.bk bk\n" + - " FROM\n" + - " (${People.books[1].author.books} s\n" + - " LEFT JOIN Book t ON (s.books[1] = t.bookId))\n" + - ") \n" + - "SELECT ${People.books[1].author.books[1]}.name\n" + - "FROM\n" + - " (People\n" + - "LEFT JOIN ${People.books[1].author.books[1]} ON (People.userId = ${People.books[1].author.books[1]}.bk))", false}, - {"SELECT cardinality(books) FROM People", - "WITH\n" + ONE_TO_MANY_MODEL_CTE + ",\n" + - "${People.books} (userId, bk, books) AS (\n" + - " SELECT\n" + - " o.userId userId\n" + - " , o.userId bk\n" + - " , array_agg(m.bookId ORDER BY m.bookId ASC) filter(WHERE m.bookId IS NOT NULL) books\n" + - " FROM\n" + - " (People o\n" + - " LEFT JOIN Book m ON (o.userId = m.authorId))\n" + - " GROUP BY 1, 2\n" + - ") \n" + - "SELECT cardinality(${People.books}.books)\n" + - "FROM\n" + - " (People\n" + - "LEFT JOIN ${People.books} ON (People.userId = ${People.books}.bk))", false}, - {"SELECT cardinality(People.books) FROM People", - "WITH\n" + ONE_TO_MANY_MODEL_CTE + ",\n" + - "${People.books} (userId, bk, books) AS (\n" + - " SELECT\n" + - " o.userId userId\n" + - " , o.userId bk\n" + - " , array_agg(m.bookId ORDER BY m.bookId ASC) filter(WHERE m.bookId IS NOT NULL) books\n" + - " FROM\n" + - " (People o\n" + - " LEFT JOIN Book m ON (o.userId = m.authorId))\n" + - " GROUP BY 1, 2\n" + - ") \n" + - "SELECT cardinality(${People.books}.books)\n" + - "FROM\n" + - " (People\n" + - "LEFT JOIN ${People.books} ON (People.userId = ${People.books}.bk))", false}, - {"SELECT cardinality(author.books) FROM Book", - "WITH\n" + ONE_TO_MANY_MODEL_CTE + ",\n" + - "${Book.author} (userId, name, books, sorted_books, bk) AS (\n" + - " SELECT DISTINCT\n" + - " t.userId\n" + - " , t.name\n" + - " , t.books\n" + - " , t.sorted_books\n" + - " , s.bookId bk\n" + - " FROM\n" + - " (Book s\n" + - " LEFT JOIN People t ON (s.authorId = t.userId))\n" + - ") \n" + - ", ${Book.author.books} (userId, bk, books) AS (\n" + - " SELECT\n" + - " o.userId userId\n" + - " , o.bk bk\n" + - " , array_agg(m.bookId ORDER BY m.bookId ASC) filter(WHERE m.bookId IS NOT NULL) books\n" + - " FROM\n" + - " (${Book.author} o\n" + - " LEFT JOIN Book m ON (o.userId = m.authorId))\n" + - " GROUP BY 1, 2\n" + - ") \n" + - "SELECT cardinality(${Book.author.books}.books)\n" + - "FROM\n" + - " (Book\n" + - "LEFT JOIN ${Book.author.books} ON (Book.bookId = ${Book.author.books}.bk))", false}, - {"SELECT author_reverse.name FROM Book", - "WITH\n" + ONE_TO_MANY_MODEL_CTE + ",\n" + - "${Book.author_reverse} (userId, name, books, sorted_books, bk) AS (\n" + - " SELECT DISTINCT\n" + - " t.userId\n" + - " , t.name\n" + - " , t.books\n" + - " , t.sorted_books\n" + - " , s.bookId bk\n" + - " FROM\n" + - " (Book s\n" + - " LEFT JOIN People t ON (s.authorId = t.userId))\n" + - ") \n" + - "SELECT ${Book.author_reverse}.name\n" + - "FROM\n" + - " (Book\n" + - "LEFT JOIN ${Book.author_reverse} ON (Book.bookId = ${Book.author_reverse}.bk))", false}, - {"SELECT author.name FROM Book", - "WITH\n" + ONE_TO_MANY_MODEL_CTE + ",\n" + - "${Book.author} (userId, name, books, sorted_books, bk) AS (\n" + - " SELECT DISTINCT\n" + - " t.userId\n" + - " , t.name\n" + - " , t.books\n" + - " , t.sorted_books\n" + - " , s.bookId bk\n" + - " FROM\n" + - " (Book s\n" + - " LEFT JOIN People t ON (s.authorId = t.userId))\n" + - ") \n" + - "SELECT ${Book.author}.name\n" + - "FROM\n" + - " (Book\n" + - "LEFT JOIN ${Book.author} ON (Book.bookId = ${Book.author}.bk))", false}, - {"SELECT cardinality(sorted_books) FROM People", - "WITH\n" + ONE_TO_MANY_MODEL_CTE + ",\n" + - "${People.sorted_books} (userId, bk, sorted_books) AS (\n" + - " SELECT\n" + - " o.userId userId\n" + - " , o.userId bk\n" + - " , array_agg(m.bookId ORDER BY m.name ASC, m.bookId DESC) filter(WHERE m.bookId IS NOT NULL) sorted_books\n" + - " FROM\n" + - " (People o\n" + - " LEFT JOIN Book m ON (o.userId = m.authorId))\n" + - " GROUP BY 1, 2\n" + - ") \n" + - "SELECT cardinality(${People.sorted_books}.sorted_books)\n" + - "FROM\n" + - " (People\n" + - "LEFT JOIN ${People.sorted_books} ON (People.userId = ${People.sorted_books}.bk))", false}, - {"SELECT sorted_books[1].name FROM People", - "WITH\n" + ONE_TO_MANY_MODEL_CTE + ",\n" + - "${People.sorted_books} (userId, bk, sorted_books) AS (\n" + - " SELECT\n" + - " o.userId userId\n" + - " , o.userId bk\n" + - " , array_agg(m.bookId ORDER BY m.name ASC, m.bookId DESC) filter(WHERE m.bookId IS NOT NULL) sorted_books\n" + - " FROM\n" + - " (People o\n" + - " LEFT JOIN Book m ON (o.userId = m.authorId))\n" + - " GROUP BY 1, 2\n" + - ") \n" + - ", ${People.sorted_books[1]} (bookId, name, author, author_reverse, authorId, bk) AS (\n" + - " SELECT\n" + - " t.bookId bookId\n" + - " , t.name name\n" + - " , t.author author\n" + - " , t.author_reverse author_reverse\n" + - " , t.authorId authorId\n" + - " , s.bk bk\n" + - " FROM\n" + - " (${People.sorted_books} s\n" + - " LEFT JOIN Book t ON (s.sorted_books[1] = t.bookId))\n" + - ") \n" + - "SELECT ${People.sorted_books[1]}.name\n" + - "FROM\n" + - " (People\n" + - "LEFT JOIN ${People.sorted_books[1]} ON (People.userId = ${People.sorted_books[1]}.bk))", false}, - {"SELECT cardinality(books) FROM People p", - "WITH\n" + ONE_TO_MANY_MODEL_CTE + ",\n" + - "${People.books} (userId, bk, books) AS (\n" + - " SELECT\n" + - " o.userId userId\n" + - " , o.userId bk\n" + - " , array_agg(m.bookId ORDER BY m.bookId ASC) filter(WHERE m.bookId IS NOT NULL) books\n" + - " FROM\n" + - " (People o\n" + - " LEFT JOIN Book m ON (o.userId = m.authorId))\n" + - " GROUP BY 1, 2\n" + - ") \n" + - "SELECT cardinality(${People.books}.books)\n" + - "FROM\n" + - " (People p\n" + - "LEFT JOIN ${People.books} ON (p.userId = ${People.books}.bk))", false}, - {"SELECT p.name, cardinality(books) FROM People p", - "WITH\n" + ONE_TO_MANY_MODEL_CTE + ",\n" + - "${People.books} (userId, bk, books) AS (\n" + - " SELECT\n" + - " o.userId userId\n" + - " , o.userId bk\n" + - " , array_agg(m.bookId ORDER BY m.bookId ASC) filter(WHERE m.bookId IS NOT NULL) books\n" + - " FROM\n" + - " (People o\n" + - " LEFT JOIN Book m ON (o.userId = m.authorId))\n" + - " GROUP BY 1, 2\n" + - ") \n" + - "SELECT\n" + - " p.name\n" + - ", cardinality(${People.books}.books)\n" + - "FROM\n" + - " (People p\n" + - "LEFT JOIN ${People.books} ON (p.userId = ${People.books}.bk))", false}, - }; - } - - @Test(dataProvider = "oneToManyRelationshipAccessCase") - public void testOneToManyRelationshipAccessingRewrite(String original, String expected, boolean enableH2Assertion) - { - Statement statement = SQL_PARSER.createStatement(original, new ParsingOptions(AS_DECIMAL)); - RelationshipCteGenerator generator = new RelationshipCteGenerator(oneToManyAccioMDL); - Analysis analysis = StatementAnalyzer.analyze(statement, DEFAULT_SESSION_CONTEXT, oneToManyAccioMDL, generator); - - Node rewrittenStatement = statement; - for (AccioRule rule : List.of(ACCIO_SQL_REWRITE)) { - rewrittenStatement = rule.apply((Statement) rewrittenStatement, DEFAULT_SESSION_CONTEXT, analysis, oneToManyAccioMDL); - } - - Map replaceMap = new HashMap<>(); - replaceMap.put("People.books", generator.getNameMapping().get("People.books")); - replaceMap.put("Book.author", generator.getNameMapping().get("Book.author")); - replaceMap.put("Book.author.books", generator.getNameMapping().get("Book.author.books")); - replaceMap.put("Book.author_reverse", generator.getNameMapping().get("Book.author_reverse")); - replaceMap.put("People.books[1]", generator.getNameMapping().get("People.books[1]")); - replaceMap.put("People.books[1].author", generator.getNameMapping().get("People.books[1].author")); - replaceMap.put("People.books[1].author.books", generator.getNameMapping().get("People.books[1].author.books")); - replaceMap.put("People.books[1].author.books[1]", generator.getNameMapping().get("People.books[1].author.books[1]")); - replaceMap.put("People.sorted_books", generator.getNameMapping().get("People.sorted_books")); - replaceMap.put("People.sorted_books[1]", generator.getNameMapping().get("People.sorted_books[1]")); - - Statement expectedResult = SQL_PARSER.createStatement(new StrSubstitutor(replaceMap).replace(expected), new ParsingOptions(AS_DECIMAL)); - String actualSql = SqlFormatter.formatSql(rewrittenStatement); - assertThat(actualSql).isEqualTo(SqlFormatter.formatSql(expectedResult)); - // TODO: remove this flag, disabled h2 assertion due to ambiguous column name - if (enableH2Assertion) { - assertThatNoException() - .describedAs(format("actual sql: %s is invalid", actualSql)) - .isThrownBy(() -> query(actualSql)); - } - } - - @DataProvider - public Object[][] notRewritten() - { - return new Object[][] { - {"SELECT col_1 FROM foo"}, - {"SELECT foo.col_1 FROM foo"}, - {"SELECT col_1.a FROM foo"}, - {"WITH foo AS (SELECT 1 AS col_1) SELECT col_1 FROM foo"}, - }; - } - - @Test(dataProvider = "notRewritten") - public void testNotRewritten(String sql) - { - String rewrittenSql = AccioPlanner.rewrite(sql, DEFAULT_SESSION_CONTEXT, oneToOneAccioMDL, List.of(ACCIO_SQL_REWRITE)); - Statement expectedResult = SQL_PARSER.createStatement(sql, new ParsingOptions(AS_DECIMAL)); - assertThat(rewrittenSql).isEqualTo(SqlFormatter.formatSql(expectedResult)); - } - - @Test - public void testRelationshipOutsideQuery() - { - // this is invalid since we don't allow access to relationship field outside the sub-query - // hence this sql shouldn't be rewritten - String actualSql = "SELECT a.name, a.author.book.author.name from (SELECT * FROM Book) a"; - String expectedSql = format("WITH Book AS (%s) SELECT a.name, a.author.book.author.name from (SELECT * FROM Book) a", - Utils.getModelSql(oneToOneAccioMDL.getModel("Book").orElseThrow())); - - String rewrittenSql = AccioPlanner.rewrite(actualSql, DEFAULT_SESSION_CONTEXT, oneToOneAccioMDL, List.of(ACCIO_SQL_REWRITE)); - Statement expectedResult = SQL_PARSER.createStatement(expectedSql, new ParsingOptions(AS_DECIMAL)); - assertThat(rewrittenSql).isEqualTo(SqlFormatter.formatSql(expectedResult)); - } - - @DataProvider - public Object[][] transform() - { - return new Object[][] { - {"select p.name, transform(p.books, book -> book.name) as book_names from People p", - "WITH\n" + ONE_TO_MANY_MODEL_CTE + ",\n" + - " ${People.books} (userId, bk, books) AS (\n" + - " SELECT\n" + - " o.userId userId\n" + - " , o.userId bk\n" + - " , array_agg(m.bookId ORDER BY m.bookId ASC) FILTER (WHERE (m.bookId IS NOT NULL)) books\n" + - " FROM\n" + - " (People o\n" + - " LEFT JOIN Book m ON (o.userId = m.authorId))\n" + - " GROUP BY 1, 2\n" + - ") \n" + - ", ${transform(p.books, (book) -> book.name)} (userId, bk, f1) AS (\n" + - " SELECT\n" + - " s.userId userId\n" + - " , s.bk bk\n" + - " , array_agg(t.name ORDER BY t.bookId ASC) FILTER (WHERE (t.name IS NOT NULL)) f1\n" + - " FROM\n" + - " ((${People.books} s\n" + - " CROSS JOIN UNNEST(s.books) u (uc))\n" + - " LEFT JOIN Book t ON (u.uc = t.bookId))\n" + - " GROUP BY 1, 2\n" + - ") \n" + - "SELECT\n" + - " p.name\n" + - ", ${transform(p.books, (book) -> book.name)}.f1 book_names\n" + - "FROM\n" + - " (People p\n" + - "LEFT JOIN ${transform(p.books, (book) -> book.name)} ON (p.userId = ${transform(p.books, (book) -> book.name)}.bk))"}, - {"select p.name, transform(p.books, book -> concat(book.name, '_1')) as book_names from People p", - "WITH\n" + ONE_TO_MANY_MODEL_CTE + ",\n" + - " ${People.books} (userId, bk, books) AS (\n" + - " SELECT\n" + - " o.userId userId\n" + - " , o.userId bk\n" + - " , array_agg(m.bookId ORDER BY m.bookId ASC) FILTER (WHERE (m.bookId IS NOT NULL)) books\n" + - " FROM\n" + - " (People o\n" + - " LEFT JOIN Book m ON (o.userId = m.authorId))\n" + - " GROUP BY 1, 2\n" + - ") \n" + - ", ${transform(p.books, (book) -> concat(book.name, '_1'))} (userId, bk, f1) AS (\n" + - " SELECT\n" + - " s.userId userId\n" + - " , s.bk bk\n" + - " , array_agg(concat(t.name, '_1') ORDER BY t.bookId ASC) FILTER (WHERE (concat(t.name, '_1') IS NOT NULL)) f1\n" + - " FROM\n" + - " ((${People.books} s\n" + - " CROSS JOIN UNNEST(s.books) u (uc))\n" + - " LEFT JOIN Book t ON (u.uc = t.bookId))\n" + - " GROUP BY 1, 2\n" + - ") \n" + - "SELECT\n" + - " p.name\n" + - ", ${transform(p.books, (book) -> concat(book.name, '_1'))}.f1 book_names\n" + - "FROM\n" + - " (People p\n" + - "LEFT JOIN ${transform(p.books, (book) -> concat(book.name, '_1'))} ON (p.userId = ${transform(p.books, (book) -> concat(book.name, '_1'))}.bk))"}, - }; - } - - @Test(dataProvider = "transform") - public void testTransform(String original, String expected) - { - Statement statement = SQL_PARSER.createStatement(original, new ParsingOptions(AS_DECIMAL)); - RelationshipCteGenerator generator = new RelationshipCteGenerator(oneToManyAccioMDL); - Analysis analysis = StatementAnalyzer.analyze(statement, DEFAULT_SESSION_CONTEXT, oneToManyAccioMDL, generator); - - Statement rewrittenStatement = statement; - for (AccioRule rule : List.of(ACCIO_SQL_REWRITE)) { - rewrittenStatement = rule.apply(rewrittenStatement, DEFAULT_SESSION_CONTEXT, analysis, oneToManyAccioMDL); - } - - Map replaceMap = new HashMap<>(); - replaceMap.put("People.books", generator.getNameMapping().get("People.books")); - replaceMap.put("transform(p.books, (book) -> book.name)", generator.getNameMapping().get("transform(p.books, (book) -> book.name)")); - replaceMap.put("transform(p.books, (book) -> concat(book.name, '_1'))", generator.getNameMapping().get("transform(p.books, (book) -> concat(book.name, '_1'))")); - - Statement expectedResult = SQL_PARSER.createStatement(new StrSubstitutor(replaceMap).replace(expected), new ParsingOptions(AS_DECIMAL)); - - String actualSql = SqlFormatter.formatSql(rewrittenStatement); - System.out.println(actualSql); - - assertThat(actualSql).isEqualTo(SqlFormatter.formatSql(expectedResult)); - } - - @DataProvider - public Object[][] filter() - { - return new Object[][] { - {"select p.name, filter(p.books, (book) -> book.name = 'book1' or book.name = 'book2') as filter_books from People p", - "WITH\n" + ONE_TO_MANY_MODEL_CTE + ",\n" + - " ${People.books} (userId, bk, books) AS (\n" + - " SELECT\n" + - " o.userId userId\n" + - " , o.userId bk\n" + - " , array_agg(m.bookId ORDER BY m.bookId ASC) filter(WHERE m.bookId IS NOT NULL) books\n" + - " FROM\n" + - " (People o\n" + - " LEFT JOIN Book m ON (o.userId = m.authorId))\n" + - " GROUP BY 1, 2\n" + - ") \n" + - ", ${filter_cte} (userId, bk, f1) AS (\n" + - " SELECT\n" + - " s.userId userId\n" + - " , s.bk bk\n" + - " , array_agg(t.bookId ORDER BY t.bookId ASC) filter(WHERE t.bookId IS NOT NULL) f1\n" + - " FROM\n" + - " ((${People.books} s\n" + - " CROSS JOIN UNNEST(s.books) u (uc))\n" + - " LEFT JOIN Book t ON (u.uc = t.bookId))\n" + - " WHERE ((t.name = 'book1') OR (t.name = 'book2'))\n" + - " GROUP BY 1, 2\n" + - ") \n" + - "SELECT\n" + - " p.name\n" + - ", ${filter_cte}.f1 filter_books\n" + - "FROM\n" + - " (People p\n" + - "LEFT JOIN ${filter_cte}\n" + - "ON (p.userId = ${filter_cte}.bk))"}, - }; - } - - @Test(dataProvider = "filter") - public void testFilter(String original, String expected) - { - Statement statement = SQL_PARSER.createStatement(original, new ParsingOptions(AS_DECIMAL)); - RelationshipCteGenerator generator = new RelationshipCteGenerator(oneToManyAccioMDL); - Analysis analysis = StatementAnalyzer.analyze(statement, DEFAULT_SESSION_CONTEXT, oneToManyAccioMDL, generator); - - Statement rewrittenStatement = statement; - for (AccioRule rule : List.of(ACCIO_SQL_REWRITE)) { - rewrittenStatement = rule.apply(rewrittenStatement, DEFAULT_SESSION_CONTEXT, analysis, oneToManyAccioMDL); - } - - Map replaceMap = new HashMap<>(); - replaceMap.put("People.books", generator.getNameMapping().get("People.books")); - replaceMap.put("filter_cte", - generator.getNameMapping().get("filter(p.books, (book) -> ((book.name = 'book1') OR (book.name = 'book2')))")); - - Statement expectedResult = SQL_PARSER.createStatement(new StrSubstitutor(replaceMap).replace(expected), new ParsingOptions(AS_DECIMAL)); - String actualSql = SqlFormatter.formatSql(rewrittenStatement); - assertThat(actualSql).isEqualTo(SqlFormatter.formatSql(expectedResult)); - } - - @DataProvider - public Object[][] functionChain() - { - return new Object[][] { - {"select p.name, transform(filter(p.books, (book) -> book.name = 'book1' or book.name = 'book2'), book -> book.name) as book_names from People p", - "WITH\n" + ONE_TO_MANY_MODEL_CTE + ",\n" + - " ${People.books} (userId, bk, books) AS (\n" + - " SELECT\n" + - " o.userId userId\n" + - " , o.userId bk\n" + - " , array_agg(m.bookId ORDER BY m.bookId ASC) FILTER (WHERE (m.bookId IS NOT NULL)) books\n" + - " FROM\n" + - " (People o\n" + - " LEFT JOIN Book m ON (o.userId = m.authorId))\n" + - " GROUP BY 1, 2\n" + - ") \n" + - ", ${filter_cte} (userId, bk, f1) AS (\n" + - " SELECT\n" + - " s.userId userId\n" + - " , s.bk bk\n" + - " , array_agg(t.bookId ORDER BY t.bookId ASC) FILTER (WHERE (t.bookId IS NOT NULL)) f1\n" + - " FROM\n" + - " (( ${People.books} s\n" + - " CROSS JOIN UNNEST(s.books) u (uc))\n" + - " LEFT JOIN Book t ON (u.uc = t.bookId))\n" + - " WHERE ((t.name = 'book1') OR (t.name = 'book2'))\n" + - " GROUP BY 1, 2\n" + - ") \n" + - ", ${transform_cte} (userId, bk, f1) AS (\n" + - " SELECT\n" + - " s.userId userId\n" + - " , s.bk bk\n" + - " , array_agg(t.name ORDER BY t.bookId ASC) FILTER (WHERE (t.name IS NOT NULL)) f1\n" + - " FROM\n" + - " ((${filter_cte} s\n" + - " CROSS JOIN UNNEST(s.f1) u (uc))\n" + - " LEFT JOIN Book t ON (u.uc = t.bookId))\n" + - " GROUP BY 1, 2\n" + - ") \n" + - "SELECT\n" + - " p.name\n" + - ", ${transform_cte}.f1 book_names\n" + - "FROM\n" + - " (People p\n" + - "LEFT JOIN ${transform_cte} ON (p.userId = ${transform_cte}.bk))"}, - }; - } - - @Test(dataProvider = "functionChain") - public void testFunctionChain(String original, String expected) - { - Statement statement = SQL_PARSER.createStatement(original, new ParsingOptions(AS_DECIMAL)); - RelationshipCteGenerator generator = new RelationshipCteGenerator(oneToManyAccioMDL); - Analysis analysis = StatementAnalyzer.analyze(statement, DEFAULT_SESSION_CONTEXT, oneToManyAccioMDL, generator); - - Statement rewrittenStatement = statement; - for (AccioRule rule : List.of(ACCIO_SQL_REWRITE)) { - rewrittenStatement = rule.apply(rewrittenStatement, DEFAULT_SESSION_CONTEXT, analysis, oneToManyAccioMDL); - } - - Map replaceMap = new HashMap<>(); - replaceMap.put("People.books", generator.getNameMapping().get("People.books")); - replaceMap.put("filter_cte", - generator.getNameMapping().get("filter(p.books, (book) -> ((book.name = 'book1') OR (book.name = 'book2')))")); - replaceMap.put("transform_cte", - generator.getNameMapping().get("transform(filter(p.books, (book) -> ((book.name = 'book1') OR (book.name = 'book2'))), (book) -> book.name)")); - - Statement expectedResult = SQL_PARSER.createStatement(new StrSubstitutor(replaceMap).replace(expected), new ParsingOptions(AS_DECIMAL)); - String actualSql = SqlFormatter.formatSql(rewrittenStatement); - assertThat(actualSql).isEqualTo(SqlFormatter.formatSql(expectedResult)); - } - - @DataProvider - public Object[][] functionIndex() - { - return new Object[][] { - {"select filter(p.books, (book) -> book.name = 'book1' or book.name = 'book2')[0].name as filter_books from People p", - "WITH\n" + ONE_TO_MANY_MODEL_CTE + ",\n" + - " ${People.books} (userId, bk, books) AS (\n" + - " SELECT\n" + - " o.userId userId\n" + - " , o.userId bk\n" + - " , array_agg(m.bookId ORDER BY m.bookId ASC) filter(WHERE m.bookId IS NOT NULL) books\n" + - " FROM\n" + - " (People o\n" + - " LEFT JOIN Book m ON (o.userId = m.authorId))\n" + - " GROUP BY 1, 2\n" + - ") \n" + - ", ${filter_cte} (userId, bk, f1) AS (\n" + - " SELECT\n" + - " s.userId userId\n" + - " , s.bk bk\n" + - " , array_agg(t.bookId ORDER BY t.bookId ASC) filter(WHERE t.bookId IS NOT NULL) f1\n" + - " FROM\n" + - " ((${People.books} s\n" + - " CROSS JOIN UNNEST(s.books) u (uc))\n" + - " LEFT JOIN Book t ON (u.uc = t.bookId))\n" + - " WHERE ((t.name = 'book1') OR (t.name = 'book2'))\n" + - " GROUP BY 1, 2\n" + - ") \n" + - ", ${filter_cte_index} (bookId, name, author, author_reverse, authorId, bk) AS (\n" + - " SELECT\n" + - " t.bookId bookId\n" + - " , t.name name\n" + - " , t.author author\n" + - " , t.author_reverse author_reverse\n" + - " , t.authorId authorId\n" + - " , s.bk bk\n" + - " FROM\n" + - " (${filter_cte} s\n" + - " LEFT JOIN Book t ON (s.f1[0] = t.bookId))\n" + - ") \n" + - "SELECT ${filter_cte_index}.name filter_books\n" + - "FROM\n" + - " (People p\n" + - "LEFT JOIN ${filter_cte_index} ON (p.userId = ${filter_cte_index}.bk))"}, - }; - } - - @Test(dataProvider = "functionIndex") - public void testFunctionIndex(String original, String expected) - { - Statement statement = SQL_PARSER.createStatement(original, new ParsingOptions(AS_DECIMAL)); - RelationshipCteGenerator generator = new RelationshipCteGenerator(oneToManyAccioMDL); - Analysis analysis = StatementAnalyzer.analyze(statement, DEFAULT_SESSION_CONTEXT, oneToManyAccioMDL, generator); - - Statement rewrittenStatement = statement; - for (AccioRule rule : List.of(ACCIO_SQL_REWRITE)) { - rewrittenStatement = rule.apply(rewrittenStatement, DEFAULT_SESSION_CONTEXT, analysis, oneToManyAccioMDL); - } - - Map replaceMap = new HashMap<>(); - replaceMap.put("People.books", generator.getNameMapping().get("People.books")); - replaceMap.put("filter_cte", - generator.getNameMapping().get("filter(p.books, (book) -> ((book.name = 'book1') OR (book.name = 'book2')))")); - replaceMap.put("filter_cte_index", - generator.getNameMapping().get("filter(p.books, (book) -> ((book.name = 'book1') OR (book.name = 'book2')))[0]")); - - Statement expectedResult = SQL_PARSER.createStatement(new StrSubstitutor(replaceMap).replace(expected), new ParsingOptions(AS_DECIMAL)); - String actualSql = SqlFormatter.formatSql(rewrittenStatement); - assertThat(actualSql).isEqualTo(SqlFormatter.formatSql(expectedResult)); - } - - @DataProvider - public Object[][] aggregateForArray() - { - return new Object[][] { - {"select array_count(p.books) as arraycount from People p", - "WITH\n" + ONE_TO_MANY_MODEL_CTE + ",\n" + - " ${People.books} (userId, bk, books) AS (\n" + - " SELECT\n" + - " o.userId userId\n" + - " , o.userId bk\n" + - " , array_agg(m.bookId ORDER BY m.bookId ASC) FILTER (WHERE (m.bookId IS NOT NULL)) books\n" + - " FROM\n" + - " (People o\n" + - " LEFT JOIN Book m ON (o.userId = m.authorId))\n" + - " GROUP BY 1, 2\n" + - ") \n" + - ", ${array_count} (userId, bk, f1) AS (\n" + - " SELECT\n" + - " s.userId userId\n" + - " , s.bk bk\n" + - " , count(u.uc) f1\n" + - " FROM\n" + - " ${People.books} s\n" + - " CROSS JOIN UNNEST(s.books) u (uc)\n" + - " GROUP BY 1, 2\n" + - ") \n" + - "SELECT ${array_count}.f1 arraycount\n" + - "FROM\n" + - " (People p\n" + - "LEFT JOIN ${array_count} ON (p.userId = ${array_count}.bk))"}, - {"select array_bool_or(transform(p.books, (book) -> (book.name = 'The Lord of the rings'))) as result from People p", - "WITH\n" + ONE_TO_MANY_MODEL_CTE + ",\n" + - " ${People.books} (userId, bk, books) AS (\n" + - " SELECT\n" + - " o.userId userId\n" + - " , o.userId bk\n" + - " , array_agg(m.bookId ORDER BY m.bookId ASC) FILTER (WHERE (m.bookId IS NOT NULL)) books\n" + - " FROM\n" + - " (People o\n" + - " LEFT JOIN Book m ON (o.userId = m.authorId))\n" + - " GROUP BY 1, 2\n" + - ") \n" + - ", ${transform} (userId, bk, f1) AS (\n" + - " SELECT\n" + - " s.userId userId\n" + - " , s.bk bk\n" + - " , array_agg(t.name = 'The Lord of the rings' ORDER BY t.bookId ASC) FILTER (WHERE ((t.name = 'The Lord of the rings') IS NOT NULL)) f1\n" + - " FROM\n" + - " ((${People.books} s\n" + - " CROSS JOIN UNNEST(s.books) u (uc))\n" + - " LEFT JOIN Book t ON (u.uc = t.bookId))\n" + - " GROUP BY 1, 2\n" + - ") \n" + - ", ${array_bool_or} (userId, bk, f1) AS (\n" + - " SELECT\n" + - " s.userId userId\n" + - " , s.bk bk\n" + - " , bool_or(u.uc) f1\n" + - " FROM\n" + - " ${transform} s\n" + - " CROSS JOIN UNNEST(s.f1) u (uc)\n" + - " GROUP BY 1, 2\n" + - ") \n" + - "SELECT ${array_bool_or}.f1 result\n" + - "FROM\n" + - " (People p\n" + - "LEFT JOIN ${array_bool_or} ON (p.userId = ${array_bool_or}.bk))"}, - {"select array_every(transform(p.books, (book) -> (book.name = 'The Lord of the rings'))) as result from People p", - "WITH\n" + ONE_TO_MANY_MODEL_CTE + ",\n" + - " ${People.books} (userId, bk, books) AS (\n" + - " SELECT\n" + - " o.userId userId\n" + - " , o.userId bk\n" + - " , array_agg(m.bookId ORDER BY m.bookId ASC) FILTER (WHERE (m.bookId IS NOT NULL)) books\n" + - " FROM\n" + - " (People o\n" + - " LEFT JOIN Book m ON (o.userId = m.authorId))\n" + - " GROUP BY 1, 2\n" + - ") \n" + - ", ${transform} (userId, bk, f1) AS (\n" + - " SELECT\n" + - " s.userId userId\n" + - " , s.bk bk\n" + - " , array_agg(t.name = 'The Lord of the rings' ORDER BY t.bookId ASC) FILTER (WHERE ((t.name = 'The Lord of the rings') IS NOT NULL)) f1\n" + - " FROM\n" + - " ((${People.books} s\n" + - " CROSS JOIN UNNEST(s.books) u (uc))\n" + - " LEFT JOIN Book t ON (u.uc = t.bookId))\n" + - " GROUP BY 1, 2\n" + - ") \n" + - ", ${array_every} (userId, bk, f1) AS (\n" + - " SELECT\n" + - " s.userId userId\n" + - " , s.bk bk\n" + - " , every(u.uc) f1\n" + - " FROM\n" + - " ${transform} s\n" + - " CROSS JOIN UNNEST(s.f1) u (uc)\n" + - " GROUP BY 1, 2\n" + - ") \n" + - "SELECT ${array_every}.f1 result\n" + - "FROM\n" + - " (People p\n" + - "LEFT JOIN ${array_every} ON (p.userId = ${array_every}.bk))"}, - }; - } - - @Test(dataProvider = "aggregateForArray") - public void testAggregateForArray(String original, String expected) - { - Statement statement = SQL_PARSER.createStatement(original, new ParsingOptions(AS_DECIMAL)); - RelationshipCteGenerator generator = new RelationshipCteGenerator(oneToManyAccioMDL); - Analysis analysis = StatementAnalyzer.analyze(statement, DEFAULT_SESSION_CONTEXT, oneToManyAccioMDL, generator); - - Statement rewrittenStatement = statement; - for (AccioRule rule : List.of(ACCIO_SQL_REWRITE)) { - rewrittenStatement = rule.apply(rewrittenStatement, DEFAULT_SESSION_CONTEXT, analysis, oneToManyAccioMDL); - } - - Map replaceMap = new HashMap<>(); - replaceMap.put("People.books", generator.getNameMapping().get("People.books")); - replaceMap.put("transform", generator.getNameMapping().get("transform(p.books, (book) -> (book.name = 'The Lord of the rings'))")); - replaceMap.put("array_count", generator.getNameMapping().get("array_count(p.books)")); - replaceMap.put("array_bool_or", generator.getNameMapping().get("array_bool_or(transform(p.books, (book) -> (book.name = 'The Lord of the rings')))")); - replaceMap.put("array_every", generator.getNameMapping().get("array_every(transform(p.books, (book) -> (book.name = 'The Lord of the rings')))")); - - Statement expectedResult = SQL_PARSER.createStatement(new StrSubstitutor(replaceMap).replace(expected), new ParsingOptions(AS_DECIMAL)); - String actualSql = SqlFormatter.formatSql(rewrittenStatement); - assertThat(actualSql).isEqualTo(SqlFormatter.formatSql(expectedResult)); - } - - @DataProvider - public Object[][] arraySort() - { - return new Object[][] { - {"select array_sort(p.books, name, ASC) as sorted_books from People p", - "WITH\n" + ONE_TO_MANY_MODEL_CTE + ",\n" + - " ${People.books} (userId, bk, books) AS (\n" + - " SELECT\n" + - " o.userId userId\n" + - " , o.userId bk\n" + - " , array_agg(m.bookId ORDER BY m.bookId ASC) filter(WHERE m.bookId IS NOT NULL) books\n" + - " FROM\n" + - " (People o\n" + - " LEFT JOIN Book m ON (o.userId = m.authorId))\n" + - " GROUP BY 1, 2\n" + - ") \n" + - ", ${array_sort_cte} (userId, bk, f1) AS (\n" + - " SELECT\n" + - " s.userId userId\n" + - " , s.bk bk\n" + - " , array_agg(t.bookId ORDER BY t.name ASC) filter(WHERE t.bookId IS NOT NULL) f1\n" + - " FROM\n" + - " ((${People.books} s\n" + - " CROSS JOIN UNNEST(s.books) u (uc))\n" + - " LEFT JOIN Book t ON (u.uc = t.bookId))\n" + - " GROUP BY 1, 2\n" + - ") \n" + - "SELECT\n" + - " ${array_sort_cte}.f1 sorted_books\n" + - "FROM\n" + - " (People p\n" + - "LEFT JOIN ${array_sort_cte}\n" + - "ON (p.userId = ${array_sort_cte}.bk))"}, - }; - } - - @Test(dataProvider = "arraySort") - public void testArraySort(String original, String expected) - { - Statement statement = SQL_PARSER.createStatement(original, new ParsingOptions(AS_DECIMAL)); - RelationshipCteGenerator generator = new RelationshipCteGenerator(oneToManyAccioMDL); - Analysis analysis = StatementAnalyzer.analyze(statement, DEFAULT_SESSION_CONTEXT, oneToManyAccioMDL, generator); - - Statement rewrittenStatement = statement; - for (AccioRule rule : List.of(ACCIO_SQL_REWRITE)) { - rewrittenStatement = rule.apply(rewrittenStatement, DEFAULT_SESSION_CONTEXT, analysis, oneToManyAccioMDL); - } - - Map replaceMap = new HashMap<>(); - replaceMap.put("People.books", generator.getNameMapping().get("People.books")); - replaceMap.put("array_sort_cte", - generator.getNameMapping().get("array_sort(p.books, name, ASC)")); - - Statement expectedResult = SQL_PARSER.createStatement(new StrSubstitutor(replaceMap).replace(expected), new ParsingOptions(AS_DECIMAL)); - String actualSql = SqlFormatter.formatSql(rewrittenStatement); - assertThat(actualSql).isEqualTo(SqlFormatter.formatSql(expectedResult)); - } - - @DataProvider - public Object[][] slice() - { - return new Object[][] { - {"select slice(p.books, 1, 5) as sliced_books from People p", - "WITH\n" + ONE_TO_MANY_MODEL_CTE + ",\n" + - " ${People.books} (userId, bk, books) AS (\n" + - " SELECT\n" + - " o.userId userId\n" + - " , o.userId bk\n" + - " , array_agg(m.bookId ORDER BY m.bookId ASC) filter(WHERE m.bookId IS NOT NULL) books\n" + - " FROM\n" + - " (People o\n" + - " LEFT JOIN Book m ON (o.userId = m.authorId))\n" + - " GROUP BY 1, 2\n" + - ") \n" + - ", ${slice_cte} (userId, bk, f1) AS (\n" + - " SELECT\n" + - " s.userId userId\n" + - " , s.bk bk\n" + - " , slice(s.books, 1, 5) f1\n" + - " FROM\n" + - " ${People.books} s\n" + - ") \n" + - "SELECT\n" + - " ${slice_cte}.f1 sliced_books\n" + - "FROM\n" + - " (People p\n" + - "LEFT JOIN ${slice_cte}\n" + - "ON (p.userId = ${slice_cte}.bk))"} - }; - } - - @Test(dataProvider = "slice") - public void testSlice(String original, String expected) - { - Statement statement = SQL_PARSER.createStatement(original, new ParsingOptions(AS_DECIMAL)); - RelationshipCteGenerator generator = new RelationshipCteGenerator(oneToManyAccioMDL); - Analysis analysis = StatementAnalyzer.analyze(statement, DEFAULT_SESSION_CONTEXT, oneToManyAccioMDL, generator); - - Statement rewrittenStatement = statement; - for (AccioRule rule : List.of(ACCIO_SQL_REWRITE)) { - rewrittenStatement = rule.apply(rewrittenStatement, DEFAULT_SESSION_CONTEXT, analysis, oneToManyAccioMDL); - } - - Map replaceMap = new HashMap<>(); - replaceMap.put("People.books", generator.getNameMapping().get("People.books")); - replaceMap.put("slice_cte", generator.getNameMapping().get("slice(p.books, 1, 5)")); - - Statement expectedResult = SQL_PARSER.createStatement(new StrSubstitutor(replaceMap).replace(expected), new ParsingOptions(AS_DECIMAL)); - String actualSql = SqlFormatter.formatSql(rewrittenStatement); - assertThat(actualSql).isEqualTo(SqlFormatter.formatSql(expectedResult)); - } - - @DataProvider - public Object[][] directAccessRelationship() - { - return new Object[][] { - {"SELECT author, count(*) FROM Book GROUP BY author", - "WITH\n" + ONE_TO_MANY_MODEL_CTE + "\n" + - ", ${Book.author} (userId, name, books, sorted_books, bk) AS (\n" + - " SELECT DISTINCT\n" + - " t.userId\n" + - " , t.name\n" + - " , t.books\n" + - " , t.sorted_books\n" + - " , s.bookId bk\n" + - " FROM\n" + - " (Book s\n" + - " LEFT JOIN People t ON (s.authorId = t.userId))\n" + - ") \n" + - "SELECT ${Book.author}.userId AS author , count(*)\n" + - "FROM\n" + - " (Book\n" + - "LEFT JOIN ${Book.author} ON (Book.bookId = ${Book.author}.bk))" + - "GROUP BY ${Book.author}.userId"}, - {"SELECT author, name, count(*) FROM Book GROUP BY (author, name)", - "WITH\n" + ONE_TO_MANY_MODEL_CTE + "\n" + - ", ${Book.author} (userId, name, books, sorted_books, bk) AS (\n" + - " SELECT DISTINCT\n" + - " t.userId\n" + - " , t.name\n" + - " , t.books\n" + - " , t.sorted_books\n" + - " , s.bookId bk\n" + - " FROM\n" + - " (Book s\n" + - " LEFT JOIN People t ON (s.authorId = t.userId))\n" + - ") \n" + - "SELECT ${Book.author}.userId AS author, name name, count(*)\n" + - "FROM\n" + - " (Book\n" + - "LEFT JOIN ${Book.author} ON (Book.bookId = ${Book.author}.bk))" + - "GROUP BY (${Book.author}.userId, name)"}, - {"SELECT books FROM People", - "WITH\n" + ONE_TO_MANY_MODEL_CTE + "\n" + - ", ${People.books} (userId, bk, books) AS (\n" + - " SELECT\n" + - " o.userId userId\n" + - " , o.userId bk\n" + - " , array_agg(m.bookId ORDER BY m.bookId ASC) FILTER (WHERE (m.bookId IS NOT NULL)) books\n" + - " FROM\n" + - " (People o\n" + - " LEFT JOIN Book m ON (o.userId = m.authorId))\n" + - " GROUP BY 1, 2" + - ") \n" + - "SELECT ${People.books}.books AS books\n" + - "FROM\n" + - " (People\n" + - "LEFT JOIN ${People.books} ON (People.userId = ${People.books}.bk))"}, - }; - } - - @Test(dataProvider = "directAccessRelationship") - public void testDirectAccessRelationship(String original, String expected) - { - Statement rewrittenStatement = SQL_PARSER.createStatement(original, new ParsingOptions(AS_DECIMAL)); - Map nameMapping = Map.of(); - for (AccioRule rule : ALL_RULES) { - RelationshipCteGenerator generator = new RelationshipCteGenerator(oneToManyAccioMDL); - Analysis analysis = StatementAnalyzer.analyze(rewrittenStatement, DEFAULT_SESSION_CONTEXT, oneToManyAccioMDL, generator); - rewrittenStatement = rule.apply(rewrittenStatement, DEFAULT_SESSION_CONTEXT, analysis, oneToManyAccioMDL); - nameMapping = generator.getNameMapping(); - } - - Map replaceMap = new HashMap<>(); - replaceMap.put("Book.author", nameMapping.get("Book.author")); - replaceMap.put("People.books", nameMapping.get("People.books")); - - Statement expectedResult = SQL_PARSER.createStatement(new StrSubstitutor(replaceMap).replace(expected), new ParsingOptions(AS_DECIMAL)); - String actualSql = SqlFormatter.formatSql(rewrittenStatement); - assertThat(actualSql).isEqualTo(SqlFormatter.formatSql(expectedResult)); - } -} diff --git a/accio-sqlrewrite/src/test/java/io/accio/sqlrewrite/TestSyntacticSugarRewrite.java b/accio-sqlrewrite/src/test/java/io/accio/sqlrewrite/TestSyntacticSugarRewrite.java deleted file mode 100644 index 814bcf9ce..000000000 --- a/accio-sqlrewrite/src/test/java/io/accio/sqlrewrite/TestSyntacticSugarRewrite.java +++ /dev/null @@ -1,125 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.accio.sqlrewrite; - -import io.accio.base.AccioMDL; -import io.accio.base.AccioTypes; -import io.accio.base.dto.JoinType; -import io.accio.base.dto.Model; -import io.accio.base.dto.Relationship; -import io.accio.testing.AbstractTestFramework; -import io.trino.sql.parser.ParsingOptions; -import io.trino.sql.parser.SqlParser; -import io.trino.sql.tree.Statement; -import org.testng.annotations.DataProvider; -import org.testng.annotations.Test; - -import java.util.List; - -import static io.accio.base.dto.Column.column; -import static io.accio.base.dto.Relationship.SortKey.sortKey; -import static io.accio.base.dto.Relationship.relationship; -import static io.accio.sqlrewrite.SyntacticSugarRewrite.SYNTACTIC_SUGAR_REWRITE; -import static org.assertj.core.api.AssertionsForClassTypes.assertThat; - -public class TestSyntacticSugarRewrite - extends AbstractTestFramework -{ - private final AccioMDL oneToManyAccioMDL; - private static final SqlParser SQL_PARSER = new SqlParser(); - - public TestSyntacticSugarRewrite() - { - oneToManyAccioMDL = AccioMDL.fromManifest(withDefaultCatalogSchema() - .setModels(List.of( - Model.model("Book", - "select * from (values (1, 'book1', 1), (2, 'book2', 2), (3, 'book3', 1)) Book(bookId, name, authorId)", - List.of( - column("bookId", AccioTypes.INTEGER, null, true), - column("name", AccioTypes.VARCHAR, null, true), - column("author", "People", "PeopleBook", true), - column("author_reverse", "People", "BookPeople", true), - column("authorId", AccioTypes.INTEGER, null, true)), - "bookId"), - Model.model("People", - "select * from (values (1, 'user1'), (2, 'user2')) People(userId, name)", - List.of( - column("userId", AccioTypes.INTEGER, null, true), - column("name", AccioTypes.VARCHAR, null, true), - column("books", "Book", "PeopleBook", true), - column("sorted_books", "Book", "PeopleBookOrderByName", true)), - "userId"))) - .setRelationships(List.of( - relationship("PeopleBook", List.of("People", "Book"), JoinType.ONE_TO_MANY, "People.userId = Book.authorId"), - relationship("BookPeople", List.of("Book", "People"), JoinType.MANY_TO_ONE, "Book.authorId = People.userId"), - relationship("PeopleBookOrderByName", List.of("People", "Book"), JoinType.ONE_TO_MANY, "People.userId = Book.authorId", - List.of(sortKey("name", Relationship.SortKey.Ordering.ASC), sortKey("bookId", Relationship.SortKey.Ordering.DESC))))) - .build()); - } - - @DataProvider - public Object[][] testCase() - { - return new Object[][] { - {"SELECT author FROM Book", "SELECT Book.author AS author FROM Book"}, - {"SELECT books FROM People", "SELECT People.books AS books FROM People"}, - }; - } - - @Test(dataProvider = "testCase") - public void testBasic(String actual, String expected) - { - assertThat(rewrite(actual)).isEqualTo(parse(expected)); - } - - @DataProvider - public Object[][] anyFunction() - { - return new Object[][] { - {"SELECT any(filter(author, rs -> rs.name = 'F')) FROM Book", "SELECT filter(Book.author, rs -> rs.name = 'F')[1] FROM Book"}, - {"SELECT any(filter(author, rs -> rs.name = 'F')) IS NOT NULL FROM Book", "SELECT filter(Book.author, rs -> rs.name = 'F')[1] IS NOT NULL FROM Book"}, - // TODO: Fix scope awareness for the arguments of dereferenceExpression with functionCalls - {"SELECT any(filter(author, rs -> rs.name = 'F')).name + 1 FROM Book", "SELECT filter(author, rs -> rs.name = 'F')[1].name + 1 FROM Book"}, - {"SELECT any(filter(author, rs -> rs.name = 'F')) AS a FROM Book", "SELECT filter(Book.author, rs -> rs.name = 'F')[1] AS a FROM Book"}, - // TODO: Fix scope awareness for the arguments of dereferenceExpression with functionCalls - {"SELECT any(filter(author, rs -> rs.name = 'F')).name FROM Book", "SELECT filter(author, rs -> rs.name = 'F')[1].name name FROM Book"}, - // TODO: Fix scope awareness for the arguments of dereferenceExpression with functionCalls - {"SELECT any(filter(author, rs -> rs.name = 'F')).name AS a FROM Book", "SELECT filter(author, rs -> rs.name = 'F')[1].name AS a FROM Book"}, - // TODO: Fix scope awareness for the arguments of dereferenceExpression with functionCalls - {"SELECT concat(any(filter(author, rs -> rs.name = 'F')).name, 'foo') AS a FROM Book", - "SELECT concat(filter(author, rs -> rs.name = 'F')[1].name, 'foo') AS a FROM Book"}, - // TODO: Fix scope awareness for the arguments of dereferenceExpression with functionCalls - {"SELECT first(filter(author, rs -> rs.name = 'F'), name, ASC).name AS a FROM Book", - "SELECT array_sort(filter(author, rs -> rs.name = 'F'), name, ASC)[1].name AS a FROM Book"}, - }; - } - - @Test(dataProvider = "anyFunction") - public void testAnyFunctionRewrite(String actual, String expected) - { - assertThat(rewrite(actual)).isEqualTo(parse(expected)); - } - - private Statement rewrite(String sql) - { - Statement scoped = ScopeAwareRewrite.SCOPE_AWARE_REWRITE.rewrite(parse(sql), oneToManyAccioMDL, DEFAULT_SESSION_CONTEXT); - return SYNTACTIC_SUGAR_REWRITE.apply(scoped, DEFAULT_SESSION_CONTEXT, oneToManyAccioMDL); - } - - private Statement parse(String sql) - { - return SQL_PARSER.createStatement(sql, new ParsingOptions()); - } -} diff --git a/accio-sqlrewrite/src/test/resources/tpch_mdl.json b/accio-sqlrewrite/src/test/resources/tpch_mdl.json new file mode 100644 index 000000000..06bf657ef --- /dev/null +++ b/accio-sqlrewrite/src/test/resources/tpch_mdl.json @@ -0,0 +1,328 @@ +{ + "catalog": "canner-cml", + "schema": "tpch_tiny", + "models": [ + { + "name": "Orders", + "refSql": "select * from \"canner-cml\".tpch_tiny.orders", + "columns": [ + { + "name": "orderkey", + "expression": "o_orderkey", + "type": "int4" + }, + { + "name": "custkey", + "expression": "o_custkey", + "type": "int4" + }, + { + "name": "orderstatus", + "expression": "o_orderstatus", + "type": "OrderStatus" + }, + { + "name": "totalprice", + "expression": "o_totalprice", + "type": "float8" + }, + { + "name": "customer", + "type": "Customer", + "relationship": "OrdersCustomer" + }, + { + "name": "orderdate", + "expression": "o_orderdate", + "type": "date" + }, + { + "name": "lineitems", + "type": "Lineitem", + "relationship": "OrdersLineitem" + } + ], + "primaryKey": "orderkey" + }, + { + "name": "Customer", + "refSql": "select * from \"canner-cml\".tpch_tiny.customer", + "columns": [ + { + "name": "custkey", + "expression": "c_custkey", + "type": "int4" + }, + { + "name": "nationkey", + "expression": "c_nationkey", + "type": "integer" + }, + { + "name": "name", + "expression": "c_name", + "type": "varchar" + }, + { + "name": "orders", + "type": "Orders", + "relationship": "OrdersCustomer" + }, + { + "name": "nation", + "type": "Nation", + "relationship": "CustomerNation" + } + ], + "primaryKey": "custkey" + }, + { + "name": "Lineitem", + "refSql": "select * from \"canner-cml\".tpch_tiny.lineitem", + "columns": [ + { + "name": "orderkey", + "expression": "l_orderkey", + "type": "int4" + }, + { + "name": "partkey", + "expression": "l_partkey", + "type": "int4" + }, + { + "name": "linenumber", + "expression": "l_linenumber", + "type": "int4" + }, + { + "name": "extendedprice", + "expression": "l_extendedprice", + "type": "float8" + }, + { + "name": "discount", + "expression": "l_discount", + "type": "float8" + }, + { + "name": "shipdate", + "expression": "l_shipdate", + "type": "date" + }, + { + "name": "order", + "type": "int4", + "expression": "1" + }, + { + "name": "part", + "type": "Part", + "relationship": "LineitemPart" + }, + { + "name": "orderkey_linenumber", + "type": "varchar", + "expression": "concat(l_orderkey, l_linenumber)" + } + ], + "primaryKey": "orderkey_linenumber" + }, + { + "name": "Part", + "refSql": "select * from \"canner-cml\".tpch_tiny.part", + "columns": [ + { + "name": "partkey", + "expression": "p_partkey", + "type": "int4" + }, + { + "name": "name", + "expression": "p_name", + "type": "varchar" + } + ], + "primaryKey": "partkey" + }, + { + "name": "Nation", + "refSql": "select * from \"canner-cml\".tpch_tiny.nation", + "columns": [ + { + "name": "nationkey", + "expression": "n_nationkey", + "type": "int4" + }, + { + "name": "name", + "expression": "n_name", + "type": "varchar" + }, + { + "name": "regionkey", + "expression": "n_regionkey", + "type": "int4" + }, + { + "name": "comment", + "expression": "n_comment", + "type": "varchar" + }, + { + "name": "region", + "type": "Region", + "relationship": "NationRegion" + }, + { + "name": "customer", + "type": "Customer", + "relationship": "CustomerNation" + }, + { + "name": "supplier", + "type": "Supplier", + "relationship": "NationSupplier" + } + ], + "primaryKey": "nationkey" + }, + { + "name": "Region", + "refSql": "select * from \"canner-cml\".tpch_tiny.region", + "columns": [ + { + "name": "regionkey", + "expression": "r_regionkey", + "type": "integer" + }, + { + "name": "name", + "expression": "r_name", + "type": "varchar" + }, + { + "name": "comment", + "expression": "r_comment", + "type": "varchar" + }, + { + "name": "nation", + "type": "Nation", + "relationship": "NationRegion" + } + ], + "primaryKey": "regionkey" + } + ], + "relationships": [ + { + "name": "OrdersCustomer", + "models": [ + "Orders", + "Customer" + ], + "joinType": "MANY_TO_ONE", + "condition": "Orders.custkey = Customer.custkey" + }, + { + "name": "OrdersLineitem", + "models": [ + "Orders", + "Lineitem" + ], + "joinType": "ONE_TO_MANY", + "condition": "Orders.orderkey = Lineitem.orderkey" + }, + { + "name": "LineitemPart", + "models": [ + "Lineitem", + "Part" + ], + "joinType": "MANY_TO_ONE", + "condition": "Lineitem.partkey = Part.partkey" + }, + { + "name": "CustomerNation", + "models": [ + "Customer", + "Nation" + ], + "joinType": "MANY_TO_ONE", + "condition": "Customer.nationkey = Nation.nationkey" + }, + { + "name": "NationRegion", + "models": [ + "Nation", + "Region" + ], + "joinType": "MANY_TO_ONE", + "condition": "Nation.regionkey = Region.regionkey" + } + ], + "metrics": [ + { + "name": "Revenue", + "baseModel": "Orders", + "dimension": [ + { + "name": "custkey", + "type": "int4" + } + ], + "measure": [ + { + "name": "totalprice", + "type": "int4", + "expression": "sum(totalprice)" + } + ], + "timeGrain": [ + { + "name": "orderdate", + "refColumn": "orderdate", + "dateParts": [ + "YEAR", + "MONTH" + ] + } + ] + } + ], + "enumDefinitions": [ + { + "name": "Status", + "values": [ + { + "name": "F" + }, + { + "name": "O" + }, + { + "name": "P" + } + ] + } + ], + "views": [ + { + "name": "useModel", + "statement": "select * from Orders" + }, + { + "name": "useMetric", + "statement": "select * from Revenue" + }, + { + "name": "useMetricRollUp", + "statement": "select * from roll_up(Revenue, orderdate, YEAR)" + }, + { + "name": "useUseMetric", + "statement": "select * from useMetric" + } + ] +} \ No newline at end of file diff --git a/accio-testing/pom.xml b/accio-testing/pom.xml index 911d0ab5c..b734721f9 100644 --- a/accio-testing/pom.xml +++ b/accio-testing/pom.xml @@ -43,8 +43,8 @@
    - org.jdbi - jdbi3-core + io.accio + trino-parser @@ -61,11 +61,5 @@ org.testng testng - - - com.h2database - h2 - runtime - diff --git a/accio-testing/src/main/java/io/accio/testing/AbstractTestFramework.java b/accio-testing/src/main/java/io/accio/testing/AbstractTestFramework.java index 37faeb381..3dc401fab 100644 --- a/accio-testing/src/main/java/io/accio/testing/AbstractTestFramework.java +++ b/accio-testing/src/main/java/io/accio/testing/AbstractTestFramework.java @@ -14,23 +14,30 @@ package io.accio.testing; +import com.google.common.collect.ImmutableList; import io.accio.base.SessionContext; +import io.accio.base.client.AutoCloseableIterator; +import io.accio.base.client.duckdb.DuckdbClient; import io.accio.base.dto.Manifest; +import io.trino.sql.parser.ParsingOptions; +import io.trino.sql.parser.SqlParser; import org.intellij.lang.annotations.Language; -import org.jdbi.v3.core.Handle; -import org.jdbi.v3.core.Jdbi; import org.testng.annotations.AfterClass; import org.testng.annotations.BeforeClass; -import java.util.ArrayList; +import java.util.Arrays; import java.util.List; -import java.util.concurrent.ThreadLocalRandom; + +import static io.trino.sql.SqlFormatter.Dialect.DUCKDB; +import static io.trino.sql.SqlFormatter.formatSql; +import static io.trino.sql.parser.ParsingOptions.DecimalLiteralTreatment.AS_DECIMAL; public abstract class AbstractTestFramework { + private static final SqlParser SQL_PARSER = new SqlParser(); public static final SessionContext DEFAULT_SESSION_CONTEXT = SessionContext.builder().setCatalog("accio").setSchema("test").build(); - private Handle handle; + private DuckdbClient duckdbClient; public static Manifest.Builder withDefaultCatalogSchema() { @@ -42,40 +49,37 @@ public static Manifest.Builder withDefaultCatalogSchema() @BeforeClass public void init() { - handle = Jdbi.open("jdbc:h2:mem:test" + System.nanoTime() + ThreadLocalRandom.current().nextLong() + ";MODE=PostgreSQL;database_to_upper=false"); + duckdbClient = new DuckdbClient(); prepareData(); } @AfterClass(alwaysRun = true) public final void close() { - try { - handle.close(); - } - finally { - handle = null; - } + cleanup(); } protected void prepareData() {} + protected void cleanup() {} + protected List> query(@Language("SQL") String sql) { - return handle.createQuery(sql) - .map((resultSet, index, context) -> { - int count = resultSet.getMetaData().getColumnCount(); - List row = new ArrayList<>(count); - for (int i = 1; i <= count; i++) { - Object value = resultSet.getObject(i); - row.add(value); - } - return row; - }) - .list(); + sql = formatSql(SQL_PARSER.createStatement(sql, new ParsingOptions(AS_DECIMAL)), DUCKDB); + try (AutoCloseableIterator iterator = duckdbClient.query(sql)) { + ImmutableList.Builder> builder = ImmutableList.builder(); + while (iterator.hasNext()) { + builder.add(Arrays.asList(iterator.next())); + } + return builder.build(); + } + catch (Exception e) { + throw new RuntimeException(e); + } } protected void exec(@Language("SQL") String sql) { - handle.execute(sql); + duckdbClient.executeDDL(sql); } } diff --git a/accio-tests/src/test/java/io/accio/TestBigQuerySqlConverter.java b/accio-tests/src/test/java/io/accio/TestBigQuerySqlConverter.java index 1ad0dd9d2..8890dd9f3 100644 --- a/accio-tests/src/test/java/io/accio/TestBigQuerySqlConverter.java +++ b/accio-tests/src/test/java/io/accio/TestBigQuerySqlConverter.java @@ -271,7 +271,7 @@ public void testRemoveCatalogSchemaColumnPrefix() "ORDER BY test.t1.c2", SessionContext.builder().build())) .isEqualTo("SELECT\n" + " t1.c1\n" + - ", t1.c2\n" + + ", `t1`.c2\n" + ", t1.c3\n" + "FROM\n" + " accio.test.t1\n" + diff --git a/accio-tests/src/test/java/io/accio/testing/RequireAccioServer.java b/accio-tests/src/test/java/io/accio/testing/RequireAccioServer.java index f6125cc13..b011d4720 100644 --- a/accio-tests/src/test/java/io/accio/testing/RequireAccioServer.java +++ b/accio-tests/src/test/java/io/accio/testing/RequireAccioServer.java @@ -27,6 +27,7 @@ import io.airlift.json.JsonCodec; import io.airlift.units.Duration; import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; import javax.ws.rs.WebApplicationException; import javax.ws.rs.core.Response; @@ -44,18 +45,22 @@ public abstract class RequireAccioServer { - private final TestingAccioServer accioServer; - protected final Closer closer = Closer.create(); - protected final HttpClient client; + private TestingAccioServer accioServer; + protected Closer closer = Closer.create(); + protected HttpClient client; public static final JsonCodec TASK_INFO_CODEC = jsonCodec(TaskInfo.class); private static final JsonCodec ERROR_CODEC = jsonCodec(ErrorMessageDto.class); - public RequireAccioServer() + public RequireAccioServer() {} + + @BeforeClass + public void init() { this.accioServer = createAccioServer(); this.client = closer.register(new JettyHttpClient(new HttpClientConfig().setIdleTimeout(new Duration(20, SECONDS)))); closer.register(accioServer); + prepare(); } protected abstract TestingAccioServer createAccioServer(); @@ -65,6 +70,8 @@ protected TestingAccioServer server() return accioServer; } + protected void prepare() {} + public T getInstance(Key key) { return accioServer.getInstance(key); diff --git a/accio-tests/src/test/java/io/accio/testing/bigquery/AbstractPreAggregationTest.java b/accio-tests/src/test/java/io/accio/testing/bigquery/AbstractPreAggregationTest.java index 3b81103f3..a05e52f6e 100644 --- a/accio-tests/src/test/java/io/accio/testing/bigquery/AbstractPreAggregationTest.java +++ b/accio-tests/src/test/java/io/accio/testing/bigquery/AbstractPreAggregationTest.java @@ -29,6 +29,7 @@ import java.time.LocalDate; import java.util.List; +import java.util.function.Supplier; import static io.accio.base.Utils.randomIntString; import static java.lang.String.format; @@ -37,9 +38,9 @@ public abstract class AbstractPreAggregationTest extends AbstractWireProtocolTestWithBigQuery { - protected final PreAggregationManager preAggregationManager = getInstance(Key.get(PreAggregationManager.class)); - protected final PreAggregationTableMapping preAggregationTableMapping = getInstance(Key.get(PreAggregationTableMapping.class)); - protected final DuckdbClient duckdbClient = getInstance(Key.get(DuckdbClient.class)); + protected final Supplier preAggregationManager = () -> getInstance(Key.get(PreAggregationManager.class)); + protected final Supplier preAggregationTableMapping = () -> getInstance(Key.get(PreAggregationTableMapping.class)); + protected final Supplier duckdbClient = () -> getInstance(Key.get(DuckdbClient.class)); @Override protected TestingAccioServer createAccioServer() @@ -65,12 +66,12 @@ protected TestingAccioServer createAccioServer() protected PreAggregationInfoPair getDefaultPreAggregationInfoPair(String name) { - return preAggregationTableMapping.getPreAggregationInfoPair("canner-cml", "tpch_tiny", name); + return preAggregationTableMapping.get().getPreAggregationInfoPair("canner-cml", "tpch_tiny", name); } protected List queryDuckdb(String statement) { - try (AutoCloseableIterator iterator = duckdbClient.query(statement)) { + try (AutoCloseableIterator iterator = duckdbClient.get().query(statement)) { ImmutableList.Builder builder = ImmutableList.builder(); while (iterator.hasNext()) { builder.add(iterator.next()); diff --git a/accio-tests/src/test/java/io/accio/testing/bigquery/TestAccioWithBigquery.java b/accio-tests/src/test/java/io/accio/testing/bigquery/TestAccioWithBigquery.java index 2eb36f261..71d8fde20 100644 --- a/accio-tests/src/test/java/io/accio/testing/bigquery/TestAccioWithBigquery.java +++ b/accio-tests/src/test/java/io/accio/testing/bigquery/TestAccioWithBigquery.java @@ -14,14 +14,12 @@ package io.accio.testing.bigquery; -import org.intellij.lang.annotations.Language; import org.testng.annotations.DataProvider; import org.testng.annotations.Test; import java.sql.Connection; import java.sql.PreparedStatement; import java.sql.ResultSet; -import java.sql.SQLException; import java.util.Optional; import static org.assertj.core.api.AssertionsForClassTypes.assertThat; @@ -41,9 +39,10 @@ protected Optional getAccioMDLPath() public Object[][] queryModel() { return new Object[][] { - {"select * from Orders"}, - {"select * from Orders WHERE orderkey > 100"}, - {"select * from Orders a JOIN Customer b ON a.custkey = b.custkey"}, + {"SELECT * FROM Orders"}, + {"SELECT * FROM Orders WHERE orderkey > 100"}, + {"SELECT * FROM Orders a JOIN Customer b ON a.custkey = b.custkey"}, + {"SELECT * FROM Orders WHERE nation_name IS NOT NULL"} }; } @@ -71,6 +70,7 @@ public void testQueryOnlyModelColumn() assertThatNoException().isThrownBy(() -> resultSet.getInt("custkey")); assertThatNoException().isThrownBy(() -> resultSet.getString("orderstatus")); assertThatNoException().isThrownBy(() -> resultSet.getString("totalprice")); + assertThatNoException().isThrownBy(() -> resultSet.getString("nation_name")); assertThatThrownBy(() -> resultSet.getString("o_orderkey")) .hasMessageMatching(".*The column name o_orderkey was not found in this ResultSet.*"); int count = 1; @@ -82,115 +82,6 @@ public void testQueryOnlyModelColumn() } } - @Test - public void testQueryRelationship() - throws Exception - { - try (Connection connection = createConnection()) { - PreparedStatement stmt = connection.prepareStatement("select orderkey, customer.name as name from Orders limit 100"); - ResultSet resultSet = stmt.executeQuery(); - resultSet.next(); - assertThatNoException().isThrownBy(() -> resultSet.getInt("orderkey")); - assertThatNoException().isThrownBy(() -> resultSet.getString("name")); - int count = 1; - - while (resultSet.next()) { - count++; - } - assertThat(count).isEqualTo(100); - } - - try (Connection connection = createConnection()) { - PreparedStatement stmt = connection.prepareStatement("select c.custkey, array_length(orders) as agg from Customer c limit 100"); - ResultSet resultSet = stmt.executeQuery(); - resultSet.next(); - assertThatNoException().isThrownBy(() -> resultSet.getInt("custkey")); - assertThatNoException().isThrownBy(() -> resultSet.getString("agg")); - int count = 1; - - while (resultSet.next()) { - count++; - } - assertThat(count).isEqualTo(100); - } - - try (Connection connection = createConnection()) { - PreparedStatement stmt = connection.prepareStatement("select custkey, array_length(orders) as agg from Customer limit 100"); - ResultSet resultSet = stmt.executeQuery(); - resultSet.next(); - assertThatNoException().isThrownBy(() -> resultSet.getInt("custkey")); - assertThatNoException().isThrownBy(() -> resultSet.getString("agg")); - int count = 1; - - while (resultSet.next()) { - count++; - } - assertThat(count).isEqualTo(100); - } - - try (Connection connection = createConnection()) { - PreparedStatement stmt = connection.prepareStatement("select array_length(orders) as agg from Customer limit 100"); - ResultSet resultSet = stmt.executeQuery(); - resultSet.next(); - assertThatNoException().isThrownBy(() -> resultSet.getString("agg")); - int count = 1; - - while (resultSet.next()) { - count++; - } - assertThat(count).isEqualTo(100); - } - - try (Connection connection = createConnection()) { - PreparedStatement stmt = connection.prepareStatement("select orders[1].orderstatus as orderstatus from Customer limit 100"); - ResultSet resultSet = stmt.executeQuery(); - resultSet.next(); - assertThatNoException().isThrownBy(() -> resultSet.getString("orderstatus")); - int count = 1; - - while (resultSet.next()) { - count++; - } - assertThat(count).isEqualTo(100); - } - - try (Connection connection = createConnection()) { - PreparedStatement stmt = connection.prepareStatement("select customer from Orders limit 100"); - ResultSet resultSet = stmt.executeQuery(); - resultSet.next(); - assertThatNoException().isThrownBy(() -> resultSet.getString("customer")); - int count = 1; - while (resultSet.next()) { - count++; - } - assertThat(count).isEqualTo(100); - } - - try (Connection connection = createConnection()) { - PreparedStatement stmt = connection.prepareStatement("select customer.nation as nation_key from Orders limit 100"); - ResultSet resultSet = stmt.executeQuery(); - resultSet.next(); - assertThatNoException().isThrownBy(() -> resultSet.getString("nation_key")); - int count = 1; - while (resultSet.next()) { - count++; - } - assertThat(count).isEqualTo(100); - } - - try (Connection connection = createConnection()) { - PreparedStatement stmt = connection.prepareStatement("select orders from Customer limit 100"); - ResultSet resultSet = stmt.executeQuery(); - resultSet.next(); - assertThatNoException().isThrownBy(() -> resultSet.getString("orders")); - int count = 1; - while (resultSet.next()) { - count++; - } - assertThat(count).isEqualTo(100); - } - } - @Test public void testQueryMetric() throws Exception @@ -229,344 +120,6 @@ void testQueryMetricRollup() } } - @Test - public void testTransform() - throws Exception - { - try (Connection connection = createConnection()) { - PreparedStatement stmt = connection.prepareStatement("select transform(Customer.orders, orderItem -> orderItem.orderstatus) as orderstatuses from Customer limit 100"); - ResultSet resultSet = stmt.executeQuery(); - resultSet.next(); - assertThatNoException().isThrownBy(() -> resultSet.getString("orderstatuses")); - int count = 1; - while (resultSet.next()) { - count++; - } - assertThat(count).isEqualTo(100); - } - } - - @DataProvider - public static Object[][] functionIndex() - { - return new Object[][] { - {"select filter(orders, orderItem -> orderItem.orderstatus = 'F')[1].orderstatus as col_1 from Customer limit 100"}, - {"select filter(Customer.orders, orderItem -> orderItem.orderstatus = 'F')[1].orderstatus as col_1 from Customer limit 100"}, - {"select filter(Customer.orders, orderItem -> orderItem.orderstatus = 'F')[1].customer.name as col_1 from Customer limit 100"}, - {"select filter(Customer.orders, orderItem -> orderItem.orderstatus = 'F')[1].customer.orders[2].orderstatus as col_1 from Customer limit 100"}, - {"select filter(Customer.orders[1].lineitems, lineitem -> lineitem.linenumber = 1)[1].linenumber as col_1 from Customer limit 100"}, - {"select filter(filter(Customer.orders[1].lineitems, lineitem -> lineitem.linenumber = 1), lineitem -> lineitem.partkey = 1)[1].linenumber as col_1 from Customer limit 100"}, - }; - } - - @Test(dataProvider = "functionIndex") - public void testFunctionIndex(String sql) - throws Exception - { - try (Connection connection = createConnection()) { - PreparedStatement stmt = connection.prepareStatement(sql); - ResultSet resultSet = stmt.executeQuery(); - resultSet.next(); - assertThatNoException().isThrownBy(() -> resultSet.getString("col_1")); - int count = 1; - while (resultSet.next()) { - count++; - } - assertThat(count).isEqualTo(100); - } - } - - @Test - public void testLambdaFunctionChain() - throws Exception - { - try (Connection connection = createConnection()) { - PreparedStatement stmt = connection.prepareStatement( - "select transform(filter(Customer.orders, orderItem -> orderItem.orderstatus = 'O' or orderItem.orderstatus = 'F'), orderItem -> orderItem.totalprice)\n" + - "as col_1\n" + - "from Customer limit 100"); - ResultSet resultSet = stmt.executeQuery(); - resultSet.next(); - assertThatNoException().isThrownBy(() -> resultSet.getString("col_1")); - int count = 1; - while (resultSet.next()) { - count++; - } - assertThat(count).isEqualTo(100); - } - - try (Connection connection = createConnection()) { - PreparedStatement stmt = connection.prepareStatement( - "select array_concat(\n" + - "filter(Customer.orders, orderItem -> orderItem.orderstatus = 'O'),\n" + - "filter(Customer.orders, orderItem -> orderItem.orderstatus = 'F'))\n" + - "as col_1\n" + - "from Customer limit 100"); - ResultSet resultSet = stmt.executeQuery(); - resultSet.next(); - assertThatNoException().isThrownBy(() -> resultSet.getString("col_1")); - int count = 1; - while (resultSet.next()) { - count++; - } - assertThat(count).isEqualTo(100); - } - - // test failed stmt - try (Connection connection = createConnection()) { - assertThatThrownBy(() -> { - PreparedStatement stmt = connection.prepareStatement( - "select filter(transform(Customer.orders, orderItem -> orderItem.orderstatus), orderItem -> orderItem.orderstatus = 'O' or orderItem.orderstatus = 'F')\n" + - "as col_1\n" + - "from Customer limit 100"); - stmt.executeQuery(); - }).hasMessageStartingWith("ERROR: Invalid statement"); - } - - // test failed stmt - try (Connection connection = createConnection()) { - assertThatThrownBy(() -> { - PreparedStatement stmt = connection.prepareStatement( - "select transform(array_concat(\n" + - "filter(Customer.orders, orderItem -> orderItem.orderstatus = 'O'),\n" + - "filter(Customer.orders, orderItem -> orderItem.orderstatus = 'F'))," + - "orderItem -> orderItem.totalprice)\n" + - "as col_1\n" + - "from Customer limit 100"); - stmt.executeQuery(); - }).hasMessageStartingWith("ERROR: accio function chain contains invalid function array_concat"); - } - - // test failed stmt - try (Connection connection = createConnection()) { - assertThatThrownBy(() -> { - PreparedStatement stmt = connection.prepareStatement( - "select transform(array_reverse(filter(Customer.orders, orderItem -> orderItem.orderstatus = 'O' or orderItem.orderstatus = 'F')), orderItem -> orderItem.totalprice)\n" + - "as col_1\n" + - "from Customer limit 100"); - stmt.executeQuery(); - }).hasMessageStartingWith("ERROR: accio function chain contains invalid function array_reverse"); - } - } - - @Test - public void testAggregateForArray() - { - try (Connection connection = createConnection()) { - PreparedStatement stmt = connection.prepareStatement("select array_sum(transform(orders, a -> a.totalprice)) as col_1 from Customer limit 100"); - ResultSet resultSet = stmt.executeQuery(); - resultSet.next(); - assertThatNoException().isThrownBy(() -> resultSet.getDouble("col_1")); - int count = 1; - while (resultSet.next()) { - count++; - } - assertThat(count).isEqualTo(100); - } - catch (SQLException e) { - throw new RuntimeException(e); - } - - try (Connection connection = createConnection()) { - PreparedStatement stmt = connection.prepareStatement("select array_avg(transform(orders, a -> a.totalprice)) as col_1 from Customer limit 100"); - ResultSet resultSet = stmt.executeQuery(); - resultSet.next(); - assertThatNoException().isThrownBy(() -> resultSet.getDouble("col_1")); - int count = 1; - while (resultSet.next()) { - count++; - } - assertThat(count).isEqualTo(100); - } - catch (SQLException e) { - throw new RuntimeException(e); - } - - try (Connection connection = createConnection()) { - PreparedStatement stmt = connection.prepareStatement("select array_count(transform(orders, a -> a.totalprice)) as col_1 from Customer limit 100"); - ResultSet resultSet = stmt.executeQuery(); - resultSet.next(); - assertThatNoException().isThrownBy(() -> resultSet.getDouble("col_1")); - int count = 1; - while (resultSet.next()) { - count++; - } - assertThat(count).isEqualTo(100); - } - catch (SQLException e) { - throw new RuntimeException(e); - } - - try (Connection connection = createConnection()) { - PreparedStatement stmt = connection.prepareStatement("select array_max(transform(orders, a -> a.totalprice)) as col_1 from Customer limit 100"); - ResultSet resultSet = stmt.executeQuery(); - resultSet.next(); - assertThatNoException().isThrownBy(() -> resultSet.getDouble("col_1")); - int count = 1; - while (resultSet.next()) { - count++; - } - assertThat(count).isEqualTo(100); - } - catch (SQLException e) { - throw new RuntimeException(e); - } - - try (Connection connection = createConnection()) { - PreparedStatement stmt = connection.prepareStatement("select array_min(transform(orders, a -> a.totalprice)) as col_1 from Customer limit 100"); - ResultSet resultSet = stmt.executeQuery(); - resultSet.next(); - assertThatNoException().isThrownBy(() -> resultSet.getDouble("col_1")); - int count = 1; - while (resultSet.next()) { - count++; - } - assertThat(count).isEqualTo(100); - } - catch (SQLException e) { - throw new RuntimeException(e); - } - - try (Connection connection = createConnection()) { - PreparedStatement stmt = connection.prepareStatement("select array_min(transform(filter(orders, a -> a.orderstatus = 'F'), a -> a.totalprice)) as col_1 from Customer limit 100"); - ResultSet resultSet = stmt.executeQuery(); - resultSet.next(); - assertThatNoException().isThrownBy(() -> resultSet.getDouble("col_1")); - int count = 1; - while (resultSet.next()) { - count++; - } - assertThat(count).isEqualTo(100); - } - catch (SQLException e) { - throw new RuntimeException(e); - } - - try (Connection connection = createConnection()) { - PreparedStatement stmt = connection.prepareStatement("select array_bool_or(transform(orders, a -> a.orderstatus = 'F')) as col_1 from Customer limit 100"); - ResultSet resultSet = stmt.executeQuery(); - resultSet.next(); - assertThatNoException().isThrownBy(() -> resultSet.getDouble("col_1")); - int count = 1; - while (resultSet.next()) { - count++; - } - assertThat(count).isEqualTo(100); - } - catch (SQLException e) { - throw new RuntimeException(e); - } - - try (Connection connection = createConnection()) { - PreparedStatement stmt = connection.prepareStatement("select array_every(transform(orders, a -> a.orderstatus = 'F')) as col_1 from Customer limit 100"); - ResultSet resultSet = stmt.executeQuery(); - resultSet.next(); - assertThatNoException().isThrownBy(() -> resultSet.getDouble("col_1")); - int count = 1; - while (resultSet.next()) { - count++; - } - assertThat(count).isEqualTo(100); - } - catch (SQLException e) { - throw new RuntimeException(e); - } - } - - @Test - public void testGroupByRelationship() - throws Exception - { - try (Connection connection = createConnection()) { - PreparedStatement stmt = connection.prepareStatement("select customer, count(*) as totalcount from Orders group by customer"); - ResultSet resultSet = stmt.executeQuery(); - resultSet.next(); - assertThatNoException().isThrownBy(() -> resultSet.getInt("totalcount")); - int count = 1; - while (resultSet.next()) { - count++; - } - assertThat(count).isEqualTo(1000); - } - - try (Connection connection = createConnection()) { - PreparedStatement stmt = connection.prepareStatement("select customer, count(*) as totalcount from Orders group by 1"); - ResultSet resultSet = stmt.executeQuery(); - resultSet.next(); - assertThatNoException().isThrownBy(() -> resultSet.getInt("totalcount")); - assertThatNoException().isThrownBy(() -> resultSet.getString("customer")); - int count = 1; - while (resultSet.next()) { - count++; - } - assertThat(count).isEqualTo(1000); - } - } - - @Test - public void testAccessMultiRelationship() - throws Exception - { - try (Connection connection = createConnection()) { - PreparedStatement stmt = connection.prepareStatement("select linenumber, \"order\".orderstatus from Lineitem limit 100"); - ResultSet resultSet = stmt.executeQuery(); - resultSet.next(); - assertThatNoException().isThrownBy(() -> resultSet.getInt("linenumber")); - assertThatNoException().isThrownBy(() -> resultSet.getString("orderstatus")); - int count = 1; - - while (resultSet.next()) { - count++; - } - assertThat(count).isEqualTo(100); - } - - try (Connection connection = createConnection()) { - PreparedStatement stmt = connection.prepareStatement("select linenumber, \"order\".orderstatus, part.name from Lineitem limit 100"); - ResultSet resultSet = stmt.executeQuery(); - resultSet.next(); - assertThatNoException().isThrownBy(() -> resultSet.getInt("linenumber")); - assertThatNoException().isThrownBy(() -> resultSet.getString("orderstatus")); - assertThatNoException().isThrownBy(() -> resultSet.getString("name")); - int count = 1; - - while (resultSet.next()) { - count++; - } - assertThat(count).isEqualTo(100); - } - - try (Connection connection = createConnection()) { - PreparedStatement stmt = connection.prepareStatement("select linenumber, \"order\".customer.name from Lineitem limit 100"); - ResultSet resultSet = stmt.executeQuery(); - resultSet.next(); - assertThatNoException().isThrownBy(() -> resultSet.getInt("linenumber")); - assertThatNoException().isThrownBy(() -> resultSet.getString("name")); - - int count = 1; - - while (resultSet.next()) { - count++; - } - assertThat(count).isEqualTo(100); - } - - try (Connection connection = createConnection()) { - PreparedStatement stmt = connection.prepareStatement("select name, orders[1].lineitems[2].extendedprice from Customer limit 100"); - ResultSet resultSet = stmt.executeQuery(); - resultSet.next(); - assertThatNoException().isThrownBy(() -> resultSet.getString("name")); - assertThatNoException().isThrownBy(() -> resultSet.getDouble("extendedprice")); - - int count = 1; - - while (resultSet.next()) { - count++; - } - assertThat(count).isEqualTo(100); - } - } - @Test public void testEnum() throws Exception @@ -602,32 +155,6 @@ public void testView() assertThat(count).isEqualTo(100); } - try (Connection connection = createConnection()) { - PreparedStatement stmt = connection.prepareStatement("select * from useRelationship limit 100"); - ResultSet resultSet = stmt.executeQuery(); - resultSet.next(); - assertThatNoException().isThrownBy(() -> resultSet.getString("name")); - int count = 1; - while (resultSet.next()) { - count++; - } - assertThat(count).isEqualTo(100); - } - - try (Connection connection = createConnection()) { - PreparedStatement stmt = connection.prepareStatement("select * from useRelationshipCustomer limit 100"); - ResultSet resultSet = stmt.executeQuery(); - resultSet.next(); - assertThatNoException().isThrownBy(() -> resultSet.getString("name")); - assertThatNoException().isThrownBy(() -> resultSet.getInt("length")); - int count = 1; - - while (resultSet.next()) { - count++; - } - assertThat(count).isEqualTo(100); - } - try (Connection connection = createConnection()) { PreparedStatement stmt = connection.prepareStatement("select * from useMetric limit 100"); ResultSet resultSet = stmt.executeQuery(); @@ -671,152 +198,12 @@ public void testView() } } - @DataProvider - public Object[][] anyFunction() - { - return new Object[][] { - {"SELECT (any(filter(orders, orderItem -> orderItem.orderstatus = 'F')) IS NOT NULL) AS col_1 FROM Customer LIMIT 100"}, - {"SELECT any(filter(orders, orderItem -> orderItem.orderstatus = 'F')).totalprice FROM Customer LIMIT 100"}, - // useAny is a view that invoke any function - {"SELECT * FROM useAny"}, - {"select orders[1] from Customer LIMIT 100"}, - {"select any(orders) from Customer LIMIT 100"}, - {"select any(filter(orders, orderItem -> orderItem.orderstatus = 'F')) FROM Customer LIMIT 100"}, - }; - } - - @Test(dataProvider = "anyFunction") - public void testAnyFunction(@Language("sql") String sql) - throws SQLException - { - try (Connection connection = createConnection()) { - PreparedStatement stmt = connection.prepareStatement(sql); - ResultSet resultSet = stmt.executeQuery(); - resultSet.next(); - assertThatNoException().isThrownBy(() -> resultSet.getObject(1)); - int count = 1; - - while (resultSet.next()) { - count++; - } - assertThat(count).isEqualTo(100); - } - } - - @DataProvider - public Object[][] arraySort() - { - return new Object[][] { - {"SELECT array_sort(orders, totalprice, DESC) FROM Customer LIMIT 100"}, - {"SELECT array_sort(c.orders, totalprice, DESC) FROM Customer c LIMIT 100"}, - {"SELECT array_sort(filter(orders, orderItem -> orderItem.orderstatus = 'F'), totalprice, ASC) FROM Customer LIMIT 100"}, - {"SELECT array_sort(filter(orders, orderItem -> orderItem.orderstatus = 'F'), totalprice, ASC)[1].totalprice FROM Customer LIMIT 100"}, - }; - } - - @Test(dataProvider = "arraySort") - public void testArraySortFunction(String sql) - throws SQLException - { - try (Connection connection = createConnection()) { - PreparedStatement stmt = connection.prepareStatement(sql); - ResultSet resultSet = stmt.executeQuery(); - resultSet.next(); - assertThatNoException().isThrownBy(() -> resultSet.getObject(1)); - int count = 1; - - while (resultSet.next()) { - count++; - } - assertThat(count).isEqualTo(100); - } - } - - @DataProvider - public Object[][] first() - { - return new Object[][] { - {"SELECT first(orders, totalprice, DESC) IS NOT NULL FROM Customer LIMIT 100"}, - {"SELECT first(c.orders, totalprice, desc).totalprice FROM Customer c LIMIT 100"}, - {"SELECT first(c.orders, totalprice, desc).customer.name FROM Customer c LIMIT 100"}, - {"SELECT first(c.orders, totalprice, desc).customer.nation.name FROM Customer c LIMIT 100"}, - {"SELECT first(filter(orders, orderItem -> orderItem.orderstatus = 'F'), totalprice, ASC) IS NOT NULL FROM Customer LIMIT 100"}, - {"SELECT first(filter(orders, orderItem -> orderItem.orderstatus = 'F'), totalprice, asc).totalprice FROM Customer LIMIT 100"}, - }; - } - - @Test(dataProvider = "first") - public void testFirstFunction(String sql) - throws SQLException - { - try (Connection connection = createConnection()) { - PreparedStatement stmt = connection.prepareStatement(sql); - ResultSet resultSet = stmt.executeQuery(); - resultSet.next(); - assertThatNoException().isThrownBy(() -> resultSet.getObject(1)); - int count = 1; - - while (resultSet.next()) { - count++; - } - assertThat(count).isEqualTo(100); - } - } - - @DataProvider - public Object[][] slice() - { - return new Object[][] { - {"SELECT slice(orders, 1, 5) FROM Customer LIMIT 100"}, - {"SELECT slice(c.orders, 1, 5) FROM Customer c LIMIT 100"}, - {"SELECT slice(c.orders, 1, 5)[0] FROM Customer c LIMIT 100"}, - {"SELECT slice(c.orders, 1, 5)[0].totalprice FROM Customer c LIMIT 100"}, - {"SELECT slice(filter(orders, orderItem -> orderItem.orderstatus = 'F'), 1, 5) FROM Customer LIMIT 100"}, - {"SELECT slice(filter(orders, orderItem -> orderItem.orderstatus = 'F'), 1, 5)[1].totalprice FROM Customer LIMIT 100"}, - {"SELECT slice(filter(orders, orderItem -> orderItem.orderstatus = 'F'), 1, 5)[1].totalprice FROM Customer LIMIT 100"}, - {"SELECT transform(slice(filter(orders, orderItem -> orderItem.orderstatus = 'F'), 1, 5), s -> s.totalprice) FROM Customer LIMIT 100"}, - {"SELECT slice(array_sort(filter(orders, orderItem -> orderItem.orderstatus = 'F'), totalprice, DESC), 1, 5) FROM Customer LIMIT 100"}, - {"SELECT array_reverse(slice(array_sort(filter(orders, orderItem -> orderItem.orderstatus = 'F'), totalprice, DESC), 1, 5)) FROM Customer LIMIT 100"}, - }; - } - - @Test(dataProvider = "slice") - public void testSlice(String sql) - throws SQLException - { - try (Connection connection = createConnection()) { - PreparedStatement stmt = connection.prepareStatement(sql); - ResultSet resultSet = stmt.executeQuery(); - resultSet.next(); - assertThatNoException().isThrownBy(() -> resultSet.getObject(1)); - int count = 1; - - while (resultSet.next()) { - count++; - } - assertThat(count).isEqualTo(100); - } - } - @Test public void testQuerySqlReservedWord() throws Exception { try (Connection connection = createConnection()) { - PreparedStatement stmt = connection.prepareStatement("select \"order\".orderkey from Lineitem limit 100"); - ResultSet resultSet = stmt.executeQuery(); - resultSet.next(); - assertThatNoException().isThrownBy(() -> resultSet.getObject(1)); - int count = 1; - - while (resultSet.next()) { - count++; - } - assertThat(count).isEqualTo(100); - } - - try (Connection connection = createConnection()) { - PreparedStatement stmt = connection.prepareStatement("select transform(\"order\".lineitems, l -> l.shipdate)[1] from Lineitem limit 100"); + PreparedStatement stmt = connection.prepareStatement("select \"order\" from Lineitem limit 100"); ResultSet resultSet = stmt.executeQuery(); resultSet.next(); assertThatNoException().isThrownBy(() -> resultSet.getObject(1)); diff --git a/accio-tests/src/test/java/io/accio/testing/bigquery/TestBigQueryPreAggregation.java b/accio-tests/src/test/java/io/accio/testing/bigquery/TestBigQueryPreAggregation.java index da29462e4..1d47824b5 100644 --- a/accio-tests/src/test/java/io/accio/testing/bigquery/TestBigQueryPreAggregation.java +++ b/accio-tests/src/test/java/io/accio/testing/bigquery/TestBigQueryPreAggregation.java @@ -82,7 +82,7 @@ protected Properties getDefaultProperties() public void testType() throws SQLException { - String mappingName = preAggregationTableMapping.getPreAggregationInfoPair("canner-cml", "cml_temp", "PrintBigQueryType").getRequiredTableName(); + String mappingName = preAggregationTableMapping.get().getPreAggregationInfoPair("canner-cml", "cml_temp", "PrintBigQueryType").getRequiredTableName(); List tables = queryDuckdb("show tables"); Set tableNames = tables.stream().map(table -> table[0].toString()).collect(toImmutableSet()); @@ -152,7 +152,7 @@ public static Object[][] typesInPredicateProvider() public void testTypesInPredicate(String columnName, Object value) throws SQLException { - String mappingName = preAggregationTableMapping.getPreAggregationInfoPair("canner-cml", "cml_temp", "PrintBigQueryType").getRequiredTableName(); + String mappingName = preAggregationTableMapping.get().getPreAggregationInfoPair("canner-cml", "cml_temp", "PrintBigQueryType").getRequiredTableName(); List tables = queryDuckdb("show tables"); Set tableNames = tables.stream().map(table -> table[0].toString()).collect(toImmutableSet()); @@ -190,7 +190,7 @@ public static Object[][] typesInPredicateWithPreparedStatementProvider() public void testTypesInPredicateWithPreparedStatement(String columnName, Object value) throws SQLException { - String mappingName = preAggregationTableMapping.getPreAggregationInfoPair("canner-cml", "cml_temp", "PrintBigQueryType").getRequiredTableName(); + String mappingName = preAggregationTableMapping.get().getPreAggregationInfoPair("canner-cml", "cml_temp", "PrintBigQueryType").getRequiredTableName(); List tables = queryDuckdb("show tables"); Set tableNames = tables.stream().map(table -> table[0].toString()).collect(toImmutableSet()); diff --git a/accio-tests/src/test/java/io/accio/testing/bigquery/TestBigQueryType.java b/accio-tests/src/test/java/io/accio/testing/bigquery/TestBigQueryType.java index 5600604fd..f9673ae42 100644 --- a/accio-tests/src/test/java/io/accio/testing/bigquery/TestBigQueryType.java +++ b/accio-tests/src/test/java/io/accio/testing/bigquery/TestBigQueryType.java @@ -24,7 +24,6 @@ import io.airlift.log.Logger; import org.postgresql.util.PGInterval; import org.postgresql.util.PGobject; -import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; import java.math.BigDecimal; @@ -72,13 +71,17 @@ enum DataType STRUCT } - @BeforeClass - public void init() - throws SQLException + @Override + protected void prepare() { testSchemaTableName = new SchemaTableName("cml_temp", "test_bigquery_type_" + currentTimeMillis()); bigQueryClient = getInstance(Key.get(BigQueryClient.class)); - testCases = initTestcases(); + try { + testCases = initTestcases(); + } + catch (SQLException e) { + throw new RuntimeException(e); + } createBigQueryTable(); } diff --git a/accio-tests/src/test/java/io/accio/testing/bigquery/TestPreAggregation.java b/accio-tests/src/test/java/io/accio/testing/bigquery/TestPreAggregation.java index d60570b73..33b872383 100644 --- a/accio-tests/src/test/java/io/accio/testing/bigquery/TestPreAggregation.java +++ b/accio-tests/src/test/java/io/accio/testing/bigquery/TestPreAggregation.java @@ -34,6 +34,7 @@ import java.util.Optional; import java.util.Set; import java.util.function.Function; +import java.util.function.Supplier; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static io.accio.base.type.IntegerType.INTEGER; @@ -47,8 +48,8 @@ public class TestPreAggregation extends AbstractPreAggregationTest { private static final Function dropTableStatement = (tableName) -> format("BEGIN TRANSACTION;DROP TABLE IF EXISTS %s;COMMIT;", tableName); - private final AccioMDL accioMDL = getInstance(Key.get(AccioMetastore.class)).getAccioMDL(); - private final DuckdbClient duckdbClient = getInstance(Key.get(DuckdbClient.class)); + private final Supplier accioMDL = () -> getInstance(Key.get(AccioMetastore.class)).getAccioMDL(); + private final Supplier duckdbClient = () -> getInstance(Key.get(DuckdbClient.class)); private final SessionContext defaultSessionContext = SessionContext.builder() .setCatalog("canner-cml") .setSchema("tpch_tiny") @@ -130,10 +131,10 @@ public void testExecuteRewrittenQuery() PreAggregationRewrite.rewrite( defaultSessionContext, "select custkey, revenue from Revenue limit 100", - preAggregationTableMapping::convertToAggregationTable, - accioMDL) + preAggregationTableMapping.get()::convertToAggregationTable, + accioMDL.get()) .orElseThrow(AssertionError::new); - try (ConnectorRecordIterator connectorRecordIterator = preAggregationManager.query(rewritten, ImmutableList.of())) { + try (ConnectorRecordIterator connectorRecordIterator = preAggregationManager.get().query(rewritten, ImmutableList.of())) { int count = 0; while (connectorRecordIterator.hasNext()) { count++; @@ -146,10 +147,10 @@ public void testExecuteRewrittenQuery() PreAggregationRewrite.rewrite( defaultSessionContext, "select custkey, revenue from Revenue where custkey = ?", - preAggregationTableMapping::convertToAggregationTable, - accioMDL) + preAggregationTableMapping.get()::convertToAggregationTable, + accioMDL.get()) .orElseThrow(AssertionError::new); - try (ConnectorRecordIterator connectorRecordIterator = preAggregationManager.query(withParam, ImmutableList.of(new Parameter(INTEGER, 1202)))) { + try (ConnectorRecordIterator connectorRecordIterator = preAggregationManager.get().query(withParam, ImmutableList.of(new Parameter(INTEGER, 1202)))) { Object[] result = connectorRecordIterator.next(); assertThat(result.length).isEqualTo(2); assertThat(result[0]).isEqualTo(1202L); @@ -164,7 +165,7 @@ public void testQueryMetricWithDroppedPreAggTable() String tableName = getDefaultPreAggregationInfoPair("ForDropTable").getRequiredTableName(); List origin = queryDuckdb(format("select * from %s", tableName)); assertThat(origin.size()).isGreaterThan(0); - duckdbClient.executeDDL(dropTableStatement.apply(tableName)); + duckdbClient.get().executeDDL(dropTableStatement.apply(tableName)); try (Connection connection = createConnection(); PreparedStatement stmt = connection.prepareStatement("select custkey, revenue from ForDropTable limit 100"); @@ -192,7 +193,6 @@ public void testModelPreAggregation() " , o_custkey custkey\n" + " , o_orderstatus orderstatus\n" + " , o_totalprice totalprice\n" + - " , 'relationship' customer\n" + " , o_orderdate orderdate" + " from `canner-cml`.tpch_tiny.orders"); assertThat(duckdbResult.size()).isEqualTo(bqResult.size()); diff --git a/accio-tests/src/test/java/io/accio/testing/bigquery/TestRefreshPreAggregation.java b/accio-tests/src/test/java/io/accio/testing/bigquery/TestRefreshPreAggregation.java index 5022e4702..e347bb9f1 100644 --- a/accio-tests/src/test/java/io/accio/testing/bigquery/TestRefreshPreAggregation.java +++ b/accio-tests/src/test/java/io/accio/testing/bigquery/TestRefreshPreAggregation.java @@ -20,6 +20,7 @@ import org.testng.annotations.Test; import java.util.Optional; +import java.util.function.Supplier; import static java.util.Objects.requireNonNull; import static org.assertj.core.api.Assertions.assertThat; @@ -28,7 +29,7 @@ public class TestRefreshPreAggregation extends AbstractPreAggregationTest { - private final AccioMDL accioMDL = getInstance(Key.get(AccioMetastore.class)).getAccioMDL(); + private final Supplier accioMDL = () -> getInstance(Key.get(AccioMetastore.class)).getAccioMDL(); @Override protected Optional getAccioMDLPath() @@ -41,7 +42,7 @@ public void testRefreshFrequently() throws InterruptedException { // manually reload pre-aggregation - preAggregationManager.createTaskUtilDone(accioMDL); + preAggregationManager.get().createTaskUtilDone(accioMDL.get()); // We have one pre-aggregation table and the most tables existing in duckdb is 2 assertThat(queryDuckdb("show tables").size()).isLessThan(3); for (int i = 0; i < 50; i++) { diff --git a/accio-tests/src/test/java/io/accio/testing/bigquery/TestReloadPreAggregation.java b/accio-tests/src/test/java/io/accio/testing/bigquery/TestReloadPreAggregation.java index 166abd21a..90020938a 100644 --- a/accio-tests/src/test/java/io/accio/testing/bigquery/TestReloadPreAggregation.java +++ b/accio-tests/src/test/java/io/accio/testing/bigquery/TestReloadPreAggregation.java @@ -79,7 +79,7 @@ public void testReloadPreAggregation() List tables = queryDuckdb("show tables"); Set tableNames = tables.stream().map(table -> table[0].toString()).collect(toImmutableSet()); assertThat(tableNames).doesNotContain(beforeMappingName); - assertThat(preAggregationManager.preAggregationScheduledFutureExists(beforeCatalogSchemaTableName)).isFalse(); + assertThat(preAggregationManager.get().preAggregationScheduledFutureExists(beforeCatalogSchemaTableName)).isFalse(); assertThatThrownBy(() -> getDefaultPreAggregationInfoPair(beforeMappingName).getRequiredTableName()).isInstanceOf(NullPointerException.class); rewriteFile("pre_agg/pre_agg_reload_1_mdl.json"); @@ -123,7 +123,7 @@ private void assertPreAggregation(String name) List tables = queryDuckdb("show tables"); Set tableNames = tables.stream().map(table -> table[0].toString()).collect(toImmutableSet()); assertThat(tableNames).contains(mappingName); - assertThat(preAggregationManager.preAggregationScheduledFutureExists(mapping)).isTrue(); + assertThat(preAggregationManager.get().preAggregationScheduledFutureExists(mapping)).isTrue(); } private void rewriteFile(String resourcePath) diff --git a/accio-tests/src/test/resources/tpch_mdl.json b/accio-tests/src/test/resources/tpch_mdl.json index c7ce647f3..8d669bdc4 100644 --- a/accio-tests/src/test/resources/tpch_mdl.json +++ b/accio-tests/src/test/resources/tpch_mdl.json @@ -26,6 +26,11 @@ "expression": "o_totalprice", "type": "float8" }, + { + "name": "nation_name", + "expression": "customer.nation.name", + "type": "varchar" + }, { "name": "customer", "type": "Customer", @@ -112,8 +117,8 @@ }, { "name": "order", - "type": "Orders", - "relationship": "OrdersLineitem" + "type": "int4", + "expression": "1" }, { "name": "part", @@ -276,14 +281,6 @@ "name": "useModel", "statement": "select * from Orders" }, - { - "name": "useRelationship", - "statement": "select orderkey, customer.name from Orders" - }, - { - "name": "useRelationshipCustomer", - "statement": "select name, array_length(orders) as length from Customer" - }, { "name": "useMetric", "statement": "select * from Revenue" @@ -295,10 +292,6 @@ { "name": "useUseMetric", "statement": "select * from useMetric" - }, - { - "name": "useAny", - "statement": "SELECT any(filter(orders, orderItem -> orderItem.orderstatus = 'F')).totalprice FROM Customer LIMIT 100" } ] } \ No newline at end of file diff --git a/pom.xml b/pom.xml index 45800ce93..56eed2371 100644 --- a/pom.xml +++ b/pom.xml @@ -374,6 +374,12 @@ 19.0.0 + + org.jgrapht + jgrapht-core + 1.5.2 + + org.postgresql postgresql diff --git a/trino-parser/src/main/java/io/trino/sql/ExpressionFormatter.java b/trino-parser/src/main/java/io/trino/sql/ExpressionFormatter.java index 96c698e09..7501b5403 100644 --- a/trino-parser/src/main/java/io/trino/sql/ExpressionFormatter.java +++ b/trino-parser/src/main/java/io/trino/sql/ExpressionFormatter.java @@ -167,7 +167,7 @@ protected String visitNode(Node node, Void context) @Override protected String visitRow(Row node, Void context) { - String rowPrefix = dialect != POSTGRES ? "ROW" : ""; + String rowPrefix = (dialect == DEFAULT || dialect == BIGQUERY) ? "ROW" : ""; return rowPrefix + " (" + Joiner.on(", ").join(node.getItems().stream() .map(child -> process(child, context)) .collect(toList())) + ")"; diff --git a/trino-parser/src/main/java/io/trino/sql/tree/DereferenceExpression.java b/trino-parser/src/main/java/io/trino/sql/tree/DereferenceExpression.java index e516f0fc1..9e094d9f3 100644 --- a/trino-parser/src/main/java/io/trino/sql/tree/DereferenceExpression.java +++ b/trino-parser/src/main/java/io/trino/sql/tree/DereferenceExpression.java @@ -115,12 +115,12 @@ public static Expression from(QualifiedName name) { Expression result = null; - for (String part : name.getParts()) { + for (Identifier part : name.getOriginalParts()) { if (result == null) { - result = new Identifier(part); + result = part; } else { - result = new DereferenceExpression(result, new Identifier(part)); + result = new DereferenceExpression(result, part); } } diff --git a/trino-parser/src/main/java/io/trino/sql/tree/QualifiedName.java b/trino-parser/src/main/java/io/trino/sql/tree/QualifiedName.java index 703be6541..cea120ef0 100644 --- a/trino-parser/src/main/java/io/trino/sql/tree/QualifiedName.java +++ b/trino-parser/src/main/java/io/trino/sql/tree/QualifiedName.java @@ -110,6 +110,15 @@ public String getSuffix() return Iterables.getLast(parts); } + public boolean hasPrefix(QualifiedName prefix) + { + if (parts.size() < prefix.getParts().size()) { + return false; + } + + return parts.subList(0, prefix.getParts().size()).equals(prefix.getParts()); + } + @Override public boolean equals(Object o) {