diff --git a/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/MetricSqlRender.java b/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/MetricSqlRender.java index 172a89070..a318c7ef7 100644 --- a/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/MetricSqlRender.java +++ b/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/MetricSqlRender.java @@ -14,8 +14,10 @@ package io.accio.sqlrewrite; +import com.google.common.base.Joiner; import io.accio.base.AccioMDL; import io.accio.base.dto.Column; +import io.accio.base.dto.CumulativeMetric; import io.accio.base.dto.Metric; import io.accio.base.dto.Model; import io.accio.base.dto.Relationable; @@ -26,11 +28,14 @@ import java.util.List; import java.util.Map; +import java.util.Optional; +import java.util.Set; import java.util.stream.Collectors; import java.util.stream.IntStream; import static io.accio.base.Utils.checkArgument; import static io.accio.sqlrewrite.Utils.parseExpression; +import static io.accio.sqlrewrite.Utils.parseQuery; import static java.lang.String.format; import static java.util.stream.Collectors.joining; import static java.util.stream.Collectors.toList; @@ -50,6 +55,43 @@ protected String initRefSql(Relationable relationable) return "SELECT * FROM " + relationable.getBaseObject(); } + @Override + public RelationInfo render() + { + Optional metricBaseModel = mdl.getModel(relationable.getBaseObject()); + // metric on model + if (metricBaseModel.isPresent()) { + return render(metricBaseModel.orElseThrow(() -> new IllegalArgumentException("model not found"))); + } + // metric on metric + Optional metricBaseMetric = mdl.getMetric(relationable.getBaseObject()); + if (metricBaseMetric.isPresent()) { + return renderBasedOnMetric(metricBaseMetric.orElseThrow(() -> new IllegalArgumentException("metric not found")).getName()); + } + // metric on cumulative metric + Optional metricBaseCumulativeMetric = mdl.getCumulativeMetric(relationable.getBaseObject()); + if (metricBaseCumulativeMetric.isPresent()) { + return renderBasedOnMetric(metricBaseCumulativeMetric.orElseThrow(() -> new IllegalArgumentException("metric not found")).getName()); + } + throw new IllegalArgumentException("invalid metric, cannot render metric sql"); + } + + // TODO: Refactor this out of MetricSqlRender since Metric currently can't be used in relationship in MDL + // thus no need to take care relations in column expression. + private RelationInfo renderBasedOnMetric(String metricName) + { + Metric metric = (Metric) relationable; + List selectItems = metric.getColumns().stream() + .filter(column -> column.getRelationship().isEmpty()) + .map(column -> format("%s AS %s", column.getExpression().orElse(column.getName()), column.getName())) + .collect(toList()); + String sql = getQuerySql(metric, Joiner.on(", ").join(selectItems), metricName); + return new RelationInfo( + relationable, + Set.of(metricName), + parseQuery(sql)); + } + @Override protected String getQuerySql(Relationable relationable, String selectItemsSql, String tableJoinsSql) { diff --git a/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/ModelSqlRender.java b/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/ModelSqlRender.java index 61cade420..7e6414fdb 100644 --- a/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/ModelSqlRender.java +++ b/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/ModelSqlRender.java @@ -25,11 +25,14 @@ import java.util.List; import java.util.Map; +import java.util.Set; import java.util.stream.Collectors; import static io.accio.base.Utils.checkArgument; import static io.accio.sqlrewrite.Utils.parseExpression; +import static io.accio.sqlrewrite.Utils.parseQuery; import static java.lang.String.format; +import static java.util.Objects.requireNonNull; import static java.util.stream.Collectors.joining; import static java.util.stream.Collectors.toList; import static java.util.stream.Collectors.toSet; @@ -58,6 +61,17 @@ else if (model.getBaseObject() != null) { } } + @Override + public RelationInfo render() + { + requireNonNull(relationable, "model is null"); + if (relationable.getColumns().isEmpty()) { + return new RelationInfo(relationable, Set.of(), parseQuery(refSql)); + } + + return render((Model) relationable); + } + @Override protected String getQuerySql(Relationable relationable, String selectItemsSql, String tableJoinsSql) { diff --git a/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/RelationableSqlRender.java b/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/RelationableSqlRender.java index 05f5245f8..b3027c4fe 100644 --- a/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/RelationableSqlRender.java +++ b/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/RelationableSqlRender.java @@ -44,7 +44,7 @@ public abstract class RelationableSqlRender { protected final Relationable relationable; protected final AccioMDL mdl; - private final String refSql; + protected final String refSql; // collect dependent models protected final Set requiredObjects; // key is alias_name.column_name, value is column name, this map is used to compose select items in model sql @@ -68,21 +68,11 @@ public RelationableSqlRender(Relationable relationable, AccioMDL mdl) protected abstract String initRefSql(Relationable relationable); - public RelationInfo render() - { - requireNonNull(relationable, "model is null"); - if (relationable.getColumns().isEmpty() && relationable instanceof Model) { - return new RelationInfo(relationable, Set.of(), parseQuery(refSql)); - } - - Model baseModel; - if (relationable instanceof Model) { - baseModel = (Model) relationable; - } - else { - baseModel = mdl.getModel(relationable.getBaseObject()).orElseThrow(() -> new IllegalArgumentException("model not found")); - } + public abstract RelationInfo render(); + protected RelationInfo render(Model baseModel) + { + requireNonNull(baseModel, "baseModel is null"); relationable.getColumns().stream() .filter(column -> column.getRelationship().isEmpty() && column.getExpression().isEmpty()) .forEach(column -> { diff --git a/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/analyzer/StatementAnalyzer.java b/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/analyzer/StatementAnalyzer.java index a6d274ce7..81b5fb21e 100644 --- a/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/analyzer/StatementAnalyzer.java +++ b/accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/analyzer/StatementAnalyzer.java @@ -20,7 +20,6 @@ import io.accio.base.dto.CumulativeMetric; import io.accio.base.dto.Metric; import io.accio.base.dto.Model; -import io.accio.base.dto.Relationship; import io.accio.base.dto.View; import io.trino.sql.tree.AliasedRelation; import io.trino.sql.tree.AstVisitor; @@ -42,11 +41,9 @@ import io.trino.sql.tree.With; import io.trino.sql.tree.WithQuery; -import java.util.Collection; import java.util.List; import java.util.Optional; import java.util.Set; -import java.util.stream.Stream; import static com.google.common.collect.ImmutableList.toImmutableList; import static io.accio.base.Utils.checkArgument; @@ -78,17 +75,6 @@ public static Analysis analyze(Statement statement, SessionContext sessionContex .anyMatch(table -> table.getSchemaTableName().getTableName().equals(model.getName()))) .collect(toUnmodifiableSet())); - // add models required for relationships - analysis.addModels( - analysis.getRelationships().stream() - .map(Relationship::getModels) - .flatMap(List::stream) - .distinct() - .map(modelName -> - accioMDL.getModel(modelName) - .orElseThrow(() -> new IllegalArgumentException(format("relationship model %s not exists", modelName)))) - .collect(toUnmodifiableSet())); - Set metrics = analysis.getTables().stream() .map(accioMDL::getMetric) .filter(Optional::isPresent) @@ -101,16 +87,6 @@ public static Analysis analyze(Statement statement, SessionContext sessionContex // TODO: remove this check checkArgument(metrics.stream().noneMatch(metricInMetricRollups::contains), "duplicate metrics in metrics and metric rollups"); - - // add models required for metrics - analysis.addModels( - Stream.of(metrics, metricInMetricRollups) - .flatMap(Collection::stream) - .map(Metric::getBaseObject) - .distinct() - .map(model -> accioMDL.getModel(model).orElseThrow(() -> new IllegalArgumentException(format("metric model %s not exists", model)))) - .collect(toUnmodifiableSet())); - analysis.addMetrics(metrics); Set cumulativeMetrics = analysis.getTables().stream() @@ -118,13 +94,6 @@ public static Analysis analyze(Statement statement, SessionContext sessionContex .filter(Optional::isPresent) .map(Optional::get) .collect(toUnmodifiableSet()); - - analysis.addModels( - cumulativeMetrics.stream() - .map(CumulativeMetric::getBaseObject) - .distinct() - .map(model -> accioMDL.getModel(model).orElseThrow(() -> new IllegalArgumentException(format("cumulative metric model %s not exists", model)))) - .collect(toUnmodifiableSet())); analysis.addCumulativeMetrics(cumulativeMetrics); Set views = analysis.getTables().stream() diff --git a/accio-sqlrewrite/src/test/java/io/accio/sqlrewrite/TestCumulativeMetric.java b/accio-sqlrewrite/src/test/java/io/accio/sqlrewrite/TestCumulativeMetric.java index 79bada2ae..0f93b2c49 100644 --- a/accio-sqlrewrite/src/test/java/io/accio/sqlrewrite/TestCumulativeMetric.java +++ b/accio-sqlrewrite/src/test/java/io/accio/sqlrewrite/TestCumulativeMetric.java @@ -18,6 +18,7 @@ import io.accio.base.AccioMDL; import io.accio.base.dto.DateSpine; import io.accio.base.dto.Manifest; +import io.accio.base.dto.Metric; import io.accio.base.dto.Model; import io.accio.base.dto.TimeUnit; import io.accio.testing.AbstractTestFramework; @@ -31,6 +32,7 @@ import static io.accio.base.dto.Column.column; import static io.accio.base.dto.CumulativeMetric.cumulativeMetric; import static io.accio.base.dto.Measure.measure; +import static io.accio.base.dto.Metric.metric; import static io.accio.base.dto.Model.model; import static io.accio.base.dto.Model.onBaseObject; import static io.accio.base.dto.Window.window; @@ -120,6 +122,28 @@ public void testModelOnCumulativeMetric() assertThat(result.size()).isEqualTo(53); } + @Test + public void testMetricOnCumulativeMetric() + { + List metrics = ImmutableList.builder() + .addAll(manifest.getMetrics()) + .add(metric( + "testMetricOnCumulativeMetric", + "DailyRevenue", + List.of(column("ordermonth", "DATE", null, false, "date_trunc('month', orderdate)")), + List.of(column("totalprice", INTEGER, null, false, "sum(totalprice)")), + List.of())) + .build(); + AccioMDL mdl = AccioMDL.fromManifest( + copyOf(manifest) + .setMetrics(metrics) + .build()); + + List> result = query(rewrite("SELECT * FROM testMetricOnCumulativeMetric ORDER BY ordermonth", mdl)); + assertThat(result.get(0).size()).isEqualTo(2); + assertThat(result.size()).isEqualTo(12); + } + private String rewrite(String sql) { return rewrite(sql, accioMDL); diff --git a/accio-sqlrewrite/src/test/java/io/accio/sqlrewrite/TestMetric.java b/accio-sqlrewrite/src/test/java/io/accio/sqlrewrite/TestMetric.java index 82386cd20..864907484 100644 --- a/accio-sqlrewrite/src/test/java/io/accio/sqlrewrite/TestMetric.java +++ b/accio-sqlrewrite/src/test/java/io/accio/sqlrewrite/TestMetric.java @@ -17,6 +17,7 @@ import com.google.common.collect.ImmutableList; import io.accio.base.AccioMDL; import io.accio.base.dto.Manifest; +import io.accio.base.dto.Metric; import io.accio.base.dto.Model; import io.accio.testing.AbstractTestFramework; import org.testng.annotations.Test; @@ -153,6 +154,25 @@ public void testModelOnMetric() assertThat(result.size()).isEqualTo(14958); } + @Test + public void testMetricOnMetric() + { + List metrics = ImmutableList.builder() + .addAll(manifest.getMetrics()) + .add(metric( + "testMetricOnMetric", + "RevenueByCustomerBaseOrders", + List.of(column("orderyear", VARCHAR, null, true, "DATE_TRUNC('YEAR', orderdate)")), + List.of(column("revenue", INTEGER, null, true, "sum(totalprice)")), + List.of())) + .build(); + AccioMDL mdl = AccioMDL.fromManifest(copyOf(manifest).setMetrics(metrics).build()); + + List> result = query(rewrite("SELECT * FROM testMetricOnMetric ORDER BY orderyear", mdl)); + assertThat(result.get(0).size()).isEqualTo(2); + assertThat(result.size()).isEqualTo(7); + } + private String rewrite(String sql) { return rewrite(sql, accioMDL);