Skip to content

Commit

Permalink
Support model on metric and cumulative metric (#385)
Browse files Browse the repository at this point in the history
* Rename onModel to onBaseObject

* Implement model on metric and cumulative metric
  • Loading branch information
brandboat committed Nov 10, 2023
1 parent abef955 commit a3b9518
Show file tree
Hide file tree
Showing 11 changed files with 93 additions and 39 deletions.
2 changes: 1 addition & 1 deletion accio-base/src/main/java/io/accio/base/dto/Model.java
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ public static Model model(String name, String refSql, List<Column> columns, Stri
return new Model(name, refSql, null, columns, primaryKey, false, null, description);
}

public static Model onModel(String name, String baseObject, List<Column> columns, String primaryKey)
public static Model onBaseObject(String name, String baseObject, List<Column> columns, String primaryKey)
{
return new Model(name, null, baseObject, columns, primaryKey, false, null, null);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,22 +67,15 @@ public Statement apply(Statement root, SessionContext sessionContext, Analysis a
.addAll(viewDescriptors)
.addAll(cumulativeMetricDescriptors)
.build();

// initDescriptors gathers queries that need to be placed at the beginning of a Common Table Expression (CTE).
Set<QueryDescriptor> initDescriptors = new HashSet<>();
if (cumulativeMetricDescriptors.size() > 0) {
initDescriptors.add(DateSpineInfo.get(accioMDL.getDateSpine()));
}
return apply(root, sessionContext, analysis, accioMDL, allDescriptors, initDescriptors);
return apply(root, sessionContext, analysis, accioMDL, allDescriptors);
}

private Statement apply(
Statement root,
SessionContext sessionContext,
Analysis analysis,
AccioMDL accioMDL,
Set<QueryDescriptor> allDescriptors,
Set<QueryDescriptor> initDescriptors)
Set<QueryDescriptor> allDescriptors)
{
DirectedAcyclicGraph<String, Object> graph = new DirectedAcyclicGraph<>(Object.class);
Set<QueryDescriptor> requiredQueryDescriptors = new HashSet<>();
Expand All @@ -94,7 +87,6 @@ private Statement apply(
requiredQueryDescriptors.forEach(queryDescriptor -> descriptorMap.put(queryDescriptor.getName(), queryDescriptor));

List<WithQuery> withQueries = new ArrayList<>();
initDescriptors.forEach(queryDescriptor -> withQueries.add(getWithQuery(queryDescriptor)));
graph.iterator().forEachRemaining(objectName -> {
QueryDescriptor queryDescriptor = descriptorMap.get(objectName);
checkArgument(queryDescriptor != null, objectName + " not found in query descriptors");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public static MetricInfo get(CumulativeMetric metric, AccioMDL mdl)
{
return new MetricInfo(
metric.getName(),
Set.of(metric.getBaseObject()),
Set.of(metric.getBaseObject(), DateSpineInfo.NAME),
Utils.parseCumulativeMetricSql(metric, mdl));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ protected void collectRelationship(Column column, Model baseModel)
requiredExpressions,
tableJoins));
// collect all required models in relationships
requiredModels.addAll(
requiredObjects.addAll(
relationshipInfos.stream()
.map(ExpressionRelationshipInfo::getRelationships)
.flatMap(List::stream)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ protected void collectRelationship(Column column, Model baseModel)
column.getName(),
tableJoins));
// collect all required models in relationships
requiredModels.addAll(
requiredObjects.addAll(
relationshipInfos.stream()
.map(ExpressionRelationshipInfo::getRelationships)
.flatMap(List::stream)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ static QueryDescriptor of(String name, AccioMDL mdl, SessionContext sessionConte
if (view.isPresent()) {
return ViewInfo.get(view.get(), mdl, sessionContext);
}
if (name.equals(DateSpineInfo.NAME)) {
return DateSpineInfo.get(mdl.getDateSpine());
}
throw new IllegalArgumentException(name + " not found in accio mdl");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public abstract class RelationableSqlRender
protected final AccioMDL mdl;
private final String refSql;
// collect dependent models
protected final Set<String> requiredModels;
protected final Set<String> requiredObjects;
// key is alias_name.column_name, value is column name, this map is used to compose select items in model sql
protected final List<String> selectItems = new ArrayList<>();

Expand All @@ -60,9 +60,9 @@ public RelationableSqlRender(Relationable relationable, AccioMDL mdl)
this.relationable = requireNonNull(relationable);
this.mdl = requireNonNull(mdl);
this.refSql = initRefSql(relationable);
this.requiredModels = new HashSet<>();
this.requiredObjects = new HashSet<>();
if (relationable.getBaseObject() != null) {
requiredModels.add(relationable.getBaseObject());
requiredObjects.add(relationable.getBaseObject());
}
}

Expand All @@ -72,7 +72,7 @@ public RelationInfo render()
{
requireNonNull(relationable, "model is null");
if (relationable.getColumns().isEmpty() && relationable instanceof Model) {
return new RelationInfo((Model) relationable, Set.of(), parseQuery(refSql));
return new RelationInfo(relationable, Set.of(), parseQuery(refSql));
}

Model baseModel;
Expand Down Expand Up @@ -108,7 +108,7 @@ public RelationInfo render()

return new RelationInfo(
relationable,
requiredModels,
requiredObjects,
parseQuery(getQuerySql(relationable, join(", ", selectItems), tableJoinsSql)));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@

package io.accio.sqlrewrite;

import com.google.common.collect.ImmutableList;
import io.accio.base.AccioMDL;
import io.accio.base.dto.DateSpine;
import io.accio.base.dto.Manifest;
import io.accio.base.dto.Model;
import io.accio.base.dto.TimeUnit;
import io.accio.testing.AbstractTestFramework;
import org.testng.annotations.Test;
Expand All @@ -29,18 +32,20 @@
import static io.accio.base.dto.CumulativeMetric.cumulativeMetric;
import static io.accio.base.dto.Measure.measure;
import static io.accio.base.dto.Model.model;
import static io.accio.base.dto.Model.onBaseObject;
import static io.accio.base.dto.Window.window;
import static io.accio.sqlrewrite.AccioSqlRewrite.ACCIO_SQL_REWRITE;
import static org.assertj.core.api.Assertions.assertThat;

public class TestCumulativeMetric
extends AbstractTestFramework
{
private static AccioMDL accioMDL;
private final Manifest manifest;
private final AccioMDL accioMDL;

public TestCumulativeMetric()
{
accioMDL = AccioMDL.fromManifest(withDefaultCatalogSchema()
manifest = withDefaultCatalogSchema()
.setModels(List.of(
model("Orders",
"select * from main.orders",
Expand Down Expand Up @@ -71,7 +76,8 @@ public TestCumulativeMetric()
"Orders", measure("totalprice", INTEGER, "sum", "totalprice"),
window("orderdate", "orderdate", TimeUnit.YEAR, "1994-01-01", "1998-12-31"))))
.setDateSpine(new DateSpine(TimeUnit.DAY, "1970-01-01", "2077-12-31"))
.build());
.build();
accioMDL = AccioMDL.fromManifest(manifest);
}

@Override
Expand All @@ -91,6 +97,29 @@ public void testCumulativeMetric()
assertThat(query(rewrite("select * from YearlyRevenue")).size()).isEqualTo(5);
}

@Test
public void testModelOnCumulativeMetric()
{
List<Model> models = ImmutableList.<Model>builder()
.addAll(manifest.getModels())
.add(onBaseObject(
"testModelOnCumulativeMetric",
"WeeklyRevenue",
List.of(
column("totalprice", INTEGER, null, false),
column("orderdate", "DATE", null, false)),
"orderdate"))
.build();
AccioMDL mdl = AccioMDL.fromManifest(
copyOf(manifest)
.setModels(models)
.build());

List<List<Object>> result = query(rewrite("select * from testModelOnCumulativeMetric", mdl));
assertThat(result.get(0).size()).isEqualTo(2);
assertThat(result.size()).isEqualTo(53);
}

private String rewrite(String sql)
{
return rewrite(sql, accioMDL);
Expand Down
35 changes: 32 additions & 3 deletions accio-sqlrewrite/src/test/java/io/accio/sqlrewrite/TestMetric.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@

package io.accio.sqlrewrite;

import com.google.common.collect.ImmutableList;
import io.accio.base.AccioMDL;
import io.accio.base.dto.Manifest;
import io.accio.base.dto.Model;
import io.accio.testing.AbstractTestFramework;
import org.testng.annotations.Test;

Expand All @@ -27,6 +30,7 @@
import static io.accio.base.dto.JoinType.MANY_TO_ONE;
import static io.accio.base.dto.Metric.metric;
import static io.accio.base.dto.Model.model;
import static io.accio.base.dto.Model.onBaseObject;
import static io.accio.base.dto.Relationship.relationship;
import static io.accio.base.dto.TimeGrain.timeGrain;
import static io.accio.base.dto.TimeUnit.DAY;
Expand All @@ -39,11 +43,12 @@
public class TestMetric
extends AbstractTestFramework
{
private static AccioMDL accioMDL;
private final Manifest manifest;
private final AccioMDL accioMDL;

public TestMetric()
{
accioMDL = AccioMDL.fromManifest(withDefaultCatalogSchema()
manifest = withDefaultCatalogSchema()
.setModels(List.of(
model("Orders",
"select * from main.orders",
Expand Down Expand Up @@ -97,7 +102,8 @@ public TestMetric()
List.of(column("orderdate", DATE, null, true)),
List.of(column("count_of_customer", INTEGER, null, true, "count(distinct customer.name)")),
List.of())))
.build());
.build();
accioMDL = AccioMDL.fromManifest(manifest);
}

@Override
Expand All @@ -124,6 +130,29 @@ public void testMetricUseToOneRelationship()
assertThat(measureRelationship.size()).isEqualTo(2401);
}

@Test
public void testModelOnMetric()
{
List<Model> models = ImmutableList.<Model>builder()
.addAll(manifest.getModels())
.add(onBaseObject(
"testModelOnMetric",
"RevenueByCustomerBaseOrders",
List.of(
column("name", VARCHAR, null, true),
column("revenue", INTEGER, null, true, "totalprice")),
"name"))
.build();
AccioMDL mdl = AccioMDL.fromManifest(
copyOf(manifest)
.setModels(models)
.build());

List<List<Object>> result = query(rewrite("select * from testModelOnMetric", mdl));
assertThat(result.get(0).size()).isEqualTo(2);
assertThat(result.size()).isEqualTo(14958);
}

private String rewrite(String sql)
{
return rewrite(sql, accioMDL);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
import static io.accio.base.dto.JoinType.ONE_TO_MANY;
import static io.accio.base.dto.JoinType.ONE_TO_ONE;
import static io.accio.base.dto.Model.model;
import static io.accio.base.dto.Model.onModel;
import static io.accio.base.dto.Model.onBaseObject;
import static io.accio.base.dto.Relationship.relationship;
import static io.accio.sqlrewrite.AccioSqlRewrite.ACCIO_SQL_REWRITE;
import static io.trino.sql.SqlFormatter.formatSql;
Expand Down Expand Up @@ -305,7 +305,7 @@ public void testModelOnModel()
List<Model> models = ImmutableList.<Model>builder()
.addAll(DEFAULT_MANIFEST.getModels())
.add(
onModel(
onBaseObject(
"BookReplica",
"Book",
List.of(
Expand Down Expand Up @@ -394,18 +394,6 @@ private static void assertSqlEquals(String actual, String expected)
.isEqualTo(formatSql(expectedStmt));
}

private static Manifest.Builder copyOf(Manifest manifest)
{
return Manifest.builder()
.setCatalog(manifest.getCatalog())
.setSchema(manifest.getSchema())
.setModels(manifest.getModels())
.setRelationships(manifest.getRelationships())
.setMetrics(manifest.getMetrics())
.setViews(manifest.getViews())
.setEnumDefinitions(manifest.getEnumDefinitions());
}

private void assertSqlEqualsAndValid(@Language("SQL") String actual, @Language("SQL") String expected)
{
assertSqlEquals(actual, expected);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,4 +82,17 @@ protected void exec(@Language("SQL") String sql)
{
duckdbClient.executeDDL(sql);
}

protected static Manifest.Builder copyOf(Manifest manifest)
{
return Manifest.builder()
.setCatalog(manifest.getCatalog())
.setSchema(manifest.getSchema())
.setModels(manifest.getModels())
.setRelationships(manifest.getRelationships())
.setMetrics(manifest.getMetrics())
.setCumulativeMetrics(manifest.getCumulativeMetrics())
.setViews(manifest.getViews())
.setEnumDefinitions(manifest.getEnumDefinitions());
}
}

0 comments on commit a3b9518

Please sign in to comment.