diff --git a/accio-base/src/main/java/io/accio/base/dto/Model.java b/accio-base/src/main/java/io/accio/base/dto/Model.java index d870ec604..2bad55e15 100644 --- a/accio-base/src/main/java/io/accio/base/dto/Model.java +++ b/accio-base/src/main/java/io/accio/base/dto/Model.java @@ -57,7 +57,7 @@ public static Model model(String name, String refSql, List columns, Stri return new Model(name, refSql, null, columns, primaryKey, false, null, description); } - public static Model onModel(String name, String baseObject, List columns, String primaryKey) + public static Model onBaseObject(String name, String baseObject, List columns, String primaryKey) { return new Model(name, null, baseObject, columns, primaryKey, false, null, null); } diff --git a/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/AccioSqlRewrite.java b/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/AccioSqlRewrite.java index 5c5888324..b90b17c09 100644 --- a/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/AccioSqlRewrite.java +++ b/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/AccioSqlRewrite.java @@ -67,13 +67,7 @@ public Statement apply(Statement root, SessionContext sessionContext, Analysis a .addAll(viewDescriptors) .addAll(cumulativeMetricDescriptors) .build(); - - // initDescriptors gathers queries that need to be placed at the beginning of a Common Table Expression (CTE). - Set initDescriptors = new HashSet<>(); - if (cumulativeMetricDescriptors.size() > 0) { - initDescriptors.add(DateSpineInfo.get(accioMDL.getDateSpine())); - } - return apply(root, sessionContext, analysis, accioMDL, allDescriptors, initDescriptors); + return apply(root, sessionContext, analysis, accioMDL, allDescriptors); } private Statement apply( @@ -81,8 +75,7 @@ private Statement apply( SessionContext sessionContext, Analysis analysis, AccioMDL accioMDL, - Set allDescriptors, - Set initDescriptors) + Set allDescriptors) { DirectedAcyclicGraph graph = new DirectedAcyclicGraph<>(Object.class); Set requiredQueryDescriptors = new HashSet<>(); @@ -94,7 +87,6 @@ private Statement apply( requiredQueryDescriptors.forEach(queryDescriptor -> descriptorMap.put(queryDescriptor.getName(), queryDescriptor)); List withQueries = new ArrayList<>(); - initDescriptors.forEach(queryDescriptor -> withQueries.add(getWithQuery(queryDescriptor))); graph.iterator().forEachRemaining(objectName -> { QueryDescriptor queryDescriptor = descriptorMap.get(objectName); checkArgument(queryDescriptor != null, objectName + " not found in query descriptors"); diff --git a/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/MetricInfo.java b/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/MetricInfo.java index 53d3d12d7..e19bfc9a1 100644 --- a/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/MetricInfo.java +++ b/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/MetricInfo.java @@ -33,7 +33,7 @@ public static MetricInfo get(CumulativeMetric metric, AccioMDL mdl) { return new MetricInfo( metric.getName(), - Set.of(metric.getBaseObject()), + Set.of(metric.getBaseObject(), DateSpineInfo.NAME), Utils.parseCumulativeMetricSql(metric, mdl)); } diff --git a/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/MetricSqlRender.java b/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/MetricSqlRender.java index 529a9d201..172a89070 100644 --- a/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/MetricSqlRender.java +++ b/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/MetricSqlRender.java @@ -130,7 +130,7 @@ protected void collectRelationship(Column column, Model baseModel) requiredExpressions, tableJoins)); // collect all required models in relationships - requiredModels.addAll( + requiredObjects.addAll( relationshipInfos.stream() .map(ExpressionRelationshipInfo::getRelationships) .flatMap(List::stream) diff --git a/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/ModelSqlRender.java b/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/ModelSqlRender.java index 034371f14..61cade420 100644 --- a/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/ModelSqlRender.java +++ b/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/ModelSqlRender.java @@ -115,7 +115,7 @@ protected void collectRelationship(Column column, Model baseModel) column.getName(), tableJoins)); // collect all required models in relationships - requiredModels.addAll( + requiredObjects.addAll( relationshipInfos.stream() .map(ExpressionRelationshipInfo::getRelationships) .flatMap(List::stream) diff --git a/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/QueryDescriptor.java b/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/QueryDescriptor.java index 39671d8dd..5bca0b6f7 100644 --- a/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/QueryDescriptor.java +++ b/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/QueryDescriptor.java @@ -51,6 +51,9 @@ static QueryDescriptor of(String name, AccioMDL mdl, SessionContext sessionConte if (view.isPresent()) { return ViewInfo.get(view.get(), mdl, sessionContext); } + if (name.equals(DateSpineInfo.NAME)) { + return DateSpineInfo.get(mdl.getDateSpine()); + } throw new IllegalArgumentException(name + " not found in accio mdl"); } } diff --git a/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/RelationableSqlRender.java b/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/RelationableSqlRender.java index 6f188c06d..05f5245f8 100644 --- a/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/RelationableSqlRender.java +++ b/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/RelationableSqlRender.java @@ -46,7 +46,7 @@ public abstract class RelationableSqlRender protected final AccioMDL mdl; private final String refSql; // collect dependent models - protected final Set requiredModels; + protected final Set requiredObjects; // key is alias_name.column_name, value is column name, this map is used to compose select items in model sql protected final List selectItems = new ArrayList<>(); @@ -60,9 +60,9 @@ public RelationableSqlRender(Relationable relationable, AccioMDL mdl) this.relationable = requireNonNull(relationable); this.mdl = requireNonNull(mdl); this.refSql = initRefSql(relationable); - this.requiredModels = new HashSet<>(); + this.requiredObjects = new HashSet<>(); if (relationable.getBaseObject() != null) { - requiredModels.add(relationable.getBaseObject()); + requiredObjects.add(relationable.getBaseObject()); } } @@ -72,7 +72,7 @@ public RelationInfo render() { requireNonNull(relationable, "model is null"); if (relationable.getColumns().isEmpty() && relationable instanceof Model) { - return new RelationInfo((Model) relationable, Set.of(), parseQuery(refSql)); + return new RelationInfo(relationable, Set.of(), parseQuery(refSql)); } Model baseModel; @@ -108,7 +108,7 @@ public RelationInfo render() return new RelationInfo( relationable, - requiredModels, + requiredObjects, parseQuery(getQuerySql(relationable, join(", ", selectItems), tableJoinsSql))); } diff --git a/accio-sqlrewrite/src/test/java/io/accio/sqlrewrite/TestCumulativeMetric.java b/accio-sqlrewrite/src/test/java/io/accio/sqlrewrite/TestCumulativeMetric.java index 9ed5ba071..79bada2ae 100644 --- a/accio-sqlrewrite/src/test/java/io/accio/sqlrewrite/TestCumulativeMetric.java +++ b/accio-sqlrewrite/src/test/java/io/accio/sqlrewrite/TestCumulativeMetric.java @@ -14,8 +14,11 @@ package io.accio.sqlrewrite; +import com.google.common.collect.ImmutableList; import io.accio.base.AccioMDL; import io.accio.base.dto.DateSpine; +import io.accio.base.dto.Manifest; +import io.accio.base.dto.Model; import io.accio.base.dto.TimeUnit; import io.accio.testing.AbstractTestFramework; import org.testng.annotations.Test; @@ -29,6 +32,7 @@ import static io.accio.base.dto.CumulativeMetric.cumulativeMetric; import static io.accio.base.dto.Measure.measure; import static io.accio.base.dto.Model.model; +import static io.accio.base.dto.Model.onBaseObject; import static io.accio.base.dto.Window.window; import static io.accio.sqlrewrite.AccioSqlRewrite.ACCIO_SQL_REWRITE; import static org.assertj.core.api.Assertions.assertThat; @@ -36,11 +40,12 @@ public class TestCumulativeMetric extends AbstractTestFramework { - private static AccioMDL accioMDL; + private final Manifest manifest; + private final AccioMDL accioMDL; public TestCumulativeMetric() { - accioMDL = AccioMDL.fromManifest(withDefaultCatalogSchema() + manifest = withDefaultCatalogSchema() .setModels(List.of( model("Orders", "select * from main.orders", @@ -71,7 +76,8 @@ public TestCumulativeMetric() "Orders", measure("totalprice", INTEGER, "sum", "totalprice"), window("orderdate", "orderdate", TimeUnit.YEAR, "1994-01-01", "1998-12-31")))) .setDateSpine(new DateSpine(TimeUnit.DAY, "1970-01-01", "2077-12-31")) - .build()); + .build(); + accioMDL = AccioMDL.fromManifest(manifest); } @Override @@ -91,6 +97,29 @@ public void testCumulativeMetric() assertThat(query(rewrite("select * from YearlyRevenue")).size()).isEqualTo(5); } + @Test + public void testModelOnCumulativeMetric() + { + List models = ImmutableList.builder() + .addAll(manifest.getModels()) + .add(onBaseObject( + "testModelOnCumulativeMetric", + "WeeklyRevenue", + List.of( + column("totalprice", INTEGER, null, false), + column("orderdate", "DATE", null, false)), + "orderdate")) + .build(); + AccioMDL mdl = AccioMDL.fromManifest( + copyOf(manifest) + .setModels(models) + .build()); + + List> result = query(rewrite("select * from testModelOnCumulativeMetric", mdl)); + assertThat(result.get(0).size()).isEqualTo(2); + assertThat(result.size()).isEqualTo(53); + } + private String rewrite(String sql) { return rewrite(sql, accioMDL); diff --git a/accio-sqlrewrite/src/test/java/io/accio/sqlrewrite/TestMetric.java b/accio-sqlrewrite/src/test/java/io/accio/sqlrewrite/TestMetric.java index 5f3bff5cf..82386cd20 100644 --- a/accio-sqlrewrite/src/test/java/io/accio/sqlrewrite/TestMetric.java +++ b/accio-sqlrewrite/src/test/java/io/accio/sqlrewrite/TestMetric.java @@ -14,7 +14,10 @@ package io.accio.sqlrewrite; +import com.google.common.collect.ImmutableList; import io.accio.base.AccioMDL; +import io.accio.base.dto.Manifest; +import io.accio.base.dto.Model; import io.accio.testing.AbstractTestFramework; import org.testng.annotations.Test; @@ -27,6 +30,7 @@ import static io.accio.base.dto.JoinType.MANY_TO_ONE; import static io.accio.base.dto.Metric.metric; import static io.accio.base.dto.Model.model; +import static io.accio.base.dto.Model.onBaseObject; import static io.accio.base.dto.Relationship.relationship; import static io.accio.base.dto.TimeGrain.timeGrain; import static io.accio.base.dto.TimeUnit.DAY; @@ -39,11 +43,12 @@ public class TestMetric extends AbstractTestFramework { - private static AccioMDL accioMDL; + private final Manifest manifest; + private final AccioMDL accioMDL; public TestMetric() { - accioMDL = AccioMDL.fromManifest(withDefaultCatalogSchema() + manifest = withDefaultCatalogSchema() .setModels(List.of( model("Orders", "select * from main.orders", @@ -97,7 +102,8 @@ public TestMetric() List.of(column("orderdate", DATE, null, true)), List.of(column("count_of_customer", INTEGER, null, true, "count(distinct customer.name)")), List.of()))) - .build()); + .build(); + accioMDL = AccioMDL.fromManifest(manifest); } @Override @@ -124,6 +130,29 @@ public void testMetricUseToOneRelationship() assertThat(measureRelationship.size()).isEqualTo(2401); } + @Test + public void testModelOnMetric() + { + List models = ImmutableList.builder() + .addAll(manifest.getModels()) + .add(onBaseObject( + "testModelOnMetric", + "RevenueByCustomerBaseOrders", + List.of( + column("name", VARCHAR, null, true), + column("revenue", INTEGER, null, true, "totalprice")), + "name")) + .build(); + AccioMDL mdl = AccioMDL.fromManifest( + copyOf(manifest) + .setModels(models) + .build()); + + List> result = query(rewrite("select * from testModelOnMetric", mdl)); + assertThat(result.get(0).size()).isEqualTo(2); + assertThat(result.size()).isEqualTo(14958); + } + private String rewrite(String sql) { return rewrite(sql, accioMDL); diff --git a/accio-sqlrewrite/src/test/java/io/accio/sqlrewrite/TestModelSqlRewrite.java b/accio-sqlrewrite/src/test/java/io/accio/sqlrewrite/TestModelSqlRewrite.java index 1c28452a0..1e3c31fcf 100644 --- a/accio-sqlrewrite/src/test/java/io/accio/sqlrewrite/TestModelSqlRewrite.java +++ b/accio-sqlrewrite/src/test/java/io/accio/sqlrewrite/TestModelSqlRewrite.java @@ -33,7 +33,7 @@ import static io.accio.base.dto.JoinType.ONE_TO_MANY; import static io.accio.base.dto.JoinType.ONE_TO_ONE; import static io.accio.base.dto.Model.model; -import static io.accio.base.dto.Model.onModel; +import static io.accio.base.dto.Model.onBaseObject; import static io.accio.base.dto.Relationship.relationship; import static io.accio.sqlrewrite.AccioSqlRewrite.ACCIO_SQL_REWRITE; import static io.trino.sql.SqlFormatter.formatSql; @@ -305,7 +305,7 @@ public void testModelOnModel() List models = ImmutableList.builder() .addAll(DEFAULT_MANIFEST.getModels()) .add( - onModel( + onBaseObject( "BookReplica", "Book", List.of( @@ -394,18 +394,6 @@ private static void assertSqlEquals(String actual, String expected) .isEqualTo(formatSql(expectedStmt)); } - private static Manifest.Builder copyOf(Manifest manifest) - { - return Manifest.builder() - .setCatalog(manifest.getCatalog()) - .setSchema(manifest.getSchema()) - .setModels(manifest.getModels()) - .setRelationships(manifest.getRelationships()) - .setMetrics(manifest.getMetrics()) - .setViews(manifest.getViews()) - .setEnumDefinitions(manifest.getEnumDefinitions()); - } - private void assertSqlEqualsAndValid(@Language("SQL") String actual, @Language("SQL") String expected) { assertSqlEquals(actual, expected); diff --git a/accio-testing/src/main/java/io/accio/testing/AbstractTestFramework.java b/accio-testing/src/main/java/io/accio/testing/AbstractTestFramework.java index 3dc401fab..504968e9d 100644 --- a/accio-testing/src/main/java/io/accio/testing/AbstractTestFramework.java +++ b/accio-testing/src/main/java/io/accio/testing/AbstractTestFramework.java @@ -82,4 +82,17 @@ protected void exec(@Language("SQL") String sql) { duckdbClient.executeDDL(sql); } + + protected static Manifest.Builder copyOf(Manifest manifest) + { + return Manifest.builder() + .setCatalog(manifest.getCatalog()) + .setSchema(manifest.getSchema()) + .setModels(manifest.getModels()) + .setRelationships(manifest.getRelationships()) + .setMetrics(manifest.getMetrics()) + .setCumulativeMetrics(manifest.getCumulativeMetrics()) + .setViews(manifest.getViews()) + .setEnumDefinitions(manifest.getEnumDefinitions()); + } }