Skip to content

Commit

Permalink
Support metric on metric/cumulative metric (#388)
Browse files Browse the repository at this point in the history
  • Loading branch information
brandboat committed Nov 15, 2023
1 parent e37e44e commit 67713f2
Show file tree
Hide file tree
Showing 6 changed files with 105 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@

package io.accio.sqlrewrite;

import com.google.common.base.Joiner;
import io.accio.base.AccioMDL;
import io.accio.base.dto.Column;
import io.accio.base.dto.CumulativeMetric;
import io.accio.base.dto.Metric;
import io.accio.base.dto.Model;
import io.accio.base.dto.Relationable;
Expand All @@ -26,11 +28,14 @@

import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import static io.accio.base.Utils.checkArgument;
import static io.accio.sqlrewrite.Utils.parseExpression;
import static io.accio.sqlrewrite.Utils.parseQuery;
import static java.lang.String.format;
import static java.util.stream.Collectors.joining;
import static java.util.stream.Collectors.toList;
Expand All @@ -50,6 +55,43 @@ protected String initRefSql(Relationable relationable)
return "SELECT * FROM " + relationable.getBaseObject();
}

@Override
public RelationInfo render()
{
Optional<Model> metricBaseModel = mdl.getModel(relationable.getBaseObject());
// metric on model
if (metricBaseModel.isPresent()) {
return render(metricBaseModel.orElseThrow(() -> new IllegalArgumentException("model not found")));
}
// metric on metric
Optional<Metric> metricBaseMetric = mdl.getMetric(relationable.getBaseObject());
if (metricBaseMetric.isPresent()) {
return renderBasedOnMetric(metricBaseMetric.orElseThrow(() -> new IllegalArgumentException("metric not found")).getName());
}
// metric on cumulative metric
Optional<CumulativeMetric> metricBaseCumulativeMetric = mdl.getCumulativeMetric(relationable.getBaseObject());
if (metricBaseCumulativeMetric.isPresent()) {
return renderBasedOnMetric(metricBaseCumulativeMetric.orElseThrow(() -> new IllegalArgumentException("metric not found")).getName());
}
throw new IllegalArgumentException("invalid metric, cannot render metric sql");
}

// TODO: Refactor this out of MetricSqlRender since Metric currently can't be used in relationship in MDL
// thus no need to take care relations in column expression.
private RelationInfo renderBasedOnMetric(String metricName)
{
Metric metric = (Metric) relationable;
List<String> selectItems = metric.getColumns().stream()
.filter(column -> column.getRelationship().isEmpty())
.map(column -> format("%s AS %s", column.getExpression().orElse(column.getName()), column.getName()))
.collect(toList());
String sql = getQuerySql(metric, Joiner.on(", ").join(selectItems), metricName);
return new RelationInfo(
relationable,
Set.of(metricName),
parseQuery(sql));
}

@Override
protected String getQuerySql(Relationable relationable, String selectItemsSql, String tableJoinsSql)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,14 @@

import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

import static io.accio.base.Utils.checkArgument;
import static io.accio.sqlrewrite.Utils.parseExpression;
import static io.accio.sqlrewrite.Utils.parseQuery;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;
import static java.util.stream.Collectors.joining;
import static java.util.stream.Collectors.toList;
import static java.util.stream.Collectors.toSet;
Expand Down Expand Up @@ -58,6 +61,17 @@ else if (model.getBaseObject() != null) {
}
}

@Override
public RelationInfo render()
{
requireNonNull(relationable, "model is null");
if (relationable.getColumns().isEmpty()) {
return new RelationInfo(relationable, Set.of(), parseQuery(refSql));
}

return render((Model) relationable);
}

@Override
protected String getQuerySql(Relationable relationable, String selectItemsSql, String tableJoinsSql)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public abstract class RelationableSqlRender
{
protected final Relationable relationable;
protected final AccioMDL mdl;
private final String refSql;
protected final String refSql;
// collect dependent models
protected final Set<String> requiredObjects;
// key is alias_name.column_name, value is column name, this map is used to compose select items in model sql
Expand All @@ -68,21 +68,11 @@ public RelationableSqlRender(Relationable relationable, AccioMDL mdl)

protected abstract String initRefSql(Relationable relationable);

public RelationInfo render()
{
requireNonNull(relationable, "model is null");
if (relationable.getColumns().isEmpty() && relationable instanceof Model) {
return new RelationInfo(relationable, Set.of(), parseQuery(refSql));
}

Model baseModel;
if (relationable instanceof Model) {
baseModel = (Model) relationable;
}
else {
baseModel = mdl.getModel(relationable.getBaseObject()).orElseThrow(() -> new IllegalArgumentException("model not found"));
}
public abstract RelationInfo render();

protected RelationInfo render(Model baseModel)
{
requireNonNull(baseModel, "baseModel is null");
relationable.getColumns().stream()
.filter(column -> column.getRelationship().isEmpty() && column.getExpression().isEmpty())
.forEach(column -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import io.accio.base.dto.CumulativeMetric;
import io.accio.base.dto.Metric;
import io.accio.base.dto.Model;
import io.accio.base.dto.Relationship;
import io.accio.base.dto.View;
import io.trino.sql.tree.AliasedRelation;
import io.trino.sql.tree.AstVisitor;
Expand All @@ -42,11 +41,9 @@
import io.trino.sql.tree.With;
import io.trino.sql.tree.WithQuery;

import java.util.Collection;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Stream;

import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.accio.base.Utils.checkArgument;
Expand Down Expand Up @@ -78,17 +75,6 @@ public static Analysis analyze(Statement statement, SessionContext sessionContex
.anyMatch(table -> table.getSchemaTableName().getTableName().equals(model.getName())))
.collect(toUnmodifiableSet()));

// add models required for relationships
analysis.addModels(
analysis.getRelationships().stream()
.map(Relationship::getModels)
.flatMap(List::stream)
.distinct()
.map(modelName ->
accioMDL.getModel(modelName)
.orElseThrow(() -> new IllegalArgumentException(format("relationship model %s not exists", modelName))))
.collect(toUnmodifiableSet()));

Set<Metric> metrics = analysis.getTables().stream()
.map(accioMDL::getMetric)
.filter(Optional::isPresent)
Expand All @@ -101,30 +87,13 @@ public static Analysis analyze(Statement statement, SessionContext sessionContex

// TODO: remove this check
checkArgument(metrics.stream().noneMatch(metricInMetricRollups::contains), "duplicate metrics in metrics and metric rollups");

// add models required for metrics
analysis.addModels(
Stream.of(metrics, metricInMetricRollups)
.flatMap(Collection::stream)
.map(Metric::getBaseObject)
.distinct()
.map(model -> accioMDL.getModel(model).orElseThrow(() -> new IllegalArgumentException(format("metric model %s not exists", model))))
.collect(toUnmodifiableSet()));

analysis.addMetrics(metrics);

Set<CumulativeMetric> cumulativeMetrics = analysis.getTables().stream()
.map(accioMDL::getCumulativeMetric)
.filter(Optional::isPresent)
.map(Optional::get)
.collect(toUnmodifiableSet());

analysis.addModels(
cumulativeMetrics.stream()
.map(CumulativeMetric::getBaseObject)
.distinct()
.map(model -> accioMDL.getModel(model).orElseThrow(() -> new IllegalArgumentException(format("cumulative metric model %s not exists", model))))
.collect(toUnmodifiableSet()));
analysis.addCumulativeMetrics(cumulativeMetrics);

Set<View> views = analysis.getTables().stream()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import io.accio.base.AccioMDL;
import io.accio.base.dto.DateSpine;
import io.accio.base.dto.Manifest;
import io.accio.base.dto.Metric;
import io.accio.base.dto.Model;
import io.accio.base.dto.TimeUnit;
import io.accio.testing.AbstractTestFramework;
Expand All @@ -31,6 +32,7 @@
import static io.accio.base.dto.Column.column;
import static io.accio.base.dto.CumulativeMetric.cumulativeMetric;
import static io.accio.base.dto.Measure.measure;
import static io.accio.base.dto.Metric.metric;
import static io.accio.base.dto.Model.model;
import static io.accio.base.dto.Model.onBaseObject;
import static io.accio.base.dto.Window.window;
Expand Down Expand Up @@ -120,6 +122,28 @@ public void testModelOnCumulativeMetric()
assertThat(result.size()).isEqualTo(53);
}

@Test
public void testMetricOnCumulativeMetric()
{
List<Metric> metrics = ImmutableList.<Metric>builder()
.addAll(manifest.getMetrics())
.add(metric(
"testMetricOnCumulativeMetric",
"DailyRevenue",
List.of(column("ordermonth", "DATE", null, false, "date_trunc('month', orderdate)")),
List.of(column("totalprice", INTEGER, null, false, "sum(totalprice)")),
List.of()))
.build();
AccioMDL mdl = AccioMDL.fromManifest(
copyOf(manifest)
.setMetrics(metrics)
.build());

List<List<Object>> result = query(rewrite("SELECT * FROM testMetricOnCumulativeMetric ORDER BY ordermonth", mdl));
assertThat(result.get(0).size()).isEqualTo(2);
assertThat(result.size()).isEqualTo(12);
}

private String rewrite(String sql)
{
return rewrite(sql, accioMDL);
Expand Down
20 changes: 20 additions & 0 deletions accio-sqlrewrite/src/test/java/io/accio/sqlrewrite/TestMetric.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import com.google.common.collect.ImmutableList;
import io.accio.base.AccioMDL;
import io.accio.base.dto.Manifest;
import io.accio.base.dto.Metric;
import io.accio.base.dto.Model;
import io.accio.testing.AbstractTestFramework;
import org.testng.annotations.Test;
Expand Down Expand Up @@ -153,6 +154,25 @@ public void testModelOnMetric()
assertThat(result.size()).isEqualTo(14958);
}

@Test
public void testMetricOnMetric()
{
List<Metric> metrics = ImmutableList.<Metric>builder()
.addAll(manifest.getMetrics())
.add(metric(
"testMetricOnMetric",
"RevenueByCustomerBaseOrders",
List.of(column("orderyear", VARCHAR, null, true, "DATE_TRUNC('YEAR', orderdate)")),
List.of(column("revenue", INTEGER, null, true, "sum(totalprice)")),
List.of()))
.build();
AccioMDL mdl = AccioMDL.fromManifest(copyOf(manifest).setMetrics(metrics).build());

List<List<Object>> result = query(rewrite("SELECT * FROM testMetricOnMetric ORDER BY orderyear", mdl));
assertThat(result.get(0).size()).isEqualTo(2);
assertThat(result.size()).isEqualTo(7);
}

private String rewrite(String sql)
{
return rewrite(sql, accioMDL);
Expand Down

0 comments on commit 67713f2

Please sign in to comment.