Skip to content

Commit

Permalink
Add metric view to DAG and move metric rollup to subquery (#381)
Browse files Browse the repository at this point in the history
* Add metric view to DAG

* Move metric rollup to subquery instead of CTE
  • Loading branch information
brandboat committed Nov 2, 2023
1 parent 9f33916 commit 031894b
Show file tree
Hide file tree
Showing 13 changed files with 443 additions and 256 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@

import static io.accio.sqlrewrite.AccioSqlRewrite.ACCIO_SQL_REWRITE;
import static io.accio.sqlrewrite.EnumRewrite.ENUM_REWRITE;
import static io.accio.sqlrewrite.MetricViewSqlRewrite.METRIC_VIEW_SQL_REWRITE;
import static io.accio.sqlrewrite.MetricRollupRewrite.METRIC_ROLLUP_REWRITE;
import static io.trino.sql.parser.ParsingOptions.DecimalLiteralTreatment.AS_DECIMAL;

public class AccioPlanner
{
public static final List<AccioRule> ALL_RULES = List.of(
METRIC_VIEW_SQL_REWRITE,
METRIC_ROLLUP_REWRITE,
ACCIO_SQL_REWRITE,
ENUM_REWRITE);
private static final SqlParser SQL_PARSER = new SqlParser();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import io.accio.sqlrewrite.analyzer.Analysis;
import io.accio.sqlrewrite.analyzer.StatementAnalyzer;
import io.trino.sql.tree.DereferenceExpression;
import io.trino.sql.tree.FunctionRelation;
import io.trino.sql.tree.Identifier;
import io.trino.sql.tree.Node;
import io.trino.sql.tree.NodeRef;
Expand All @@ -34,15 +33,16 @@
import org.jgrapht.graph.GraphCycleProhibitedException;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Stream;

import static com.google.common.base.Strings.nullToEmpty;
import static io.accio.sqlrewrite.Utils.parseQuery;
import static java.lang.String.format;
import static io.accio.base.Utils.checkArgument;
import static java.util.Objects.requireNonNull;
import static java.util.stream.Collectors.toSet;
import static java.util.stream.Collectors.toUnmodifiableList;
Expand All @@ -56,46 +56,80 @@ private AccioSqlRewrite() {}

@Override
public Statement apply(Statement root, SessionContext sessionContext, Analysis analysis, AccioMDL accioMDL)
{
Set<QueryDescriptor> modelDescriptors = analysis.getModels().stream().map(model -> ModelInfo.get(model, accioMDL)).collect(toSet());
Set<QueryDescriptor> metricDescriptors = analysis.getMetrics().stream().map(MetricInfo::get).collect(toSet());
Set<QueryDescriptor> cumulativeMetricDescriptors = analysis.getCumulativeMetrics().stream().map(metric -> MetricInfo.get(metric, accioMDL)).collect(toSet());
Set<QueryDescriptor> viewDescriptors = analysis.getViews().stream().map(view -> ViewInfo.get(view, accioMDL, sessionContext)).collect(toSet());
Set<QueryDescriptor> allDescriptors = ImmutableSet.<QueryDescriptor>builder()
.addAll(modelDescriptors)
.addAll(metricDescriptors)
.addAll(viewDescriptors)
.addAll(cumulativeMetricDescriptors)
.build();

// initDescriptors gathers queries that need to be placed at the beginning of a Common Table Expression (CTE).
Set<QueryDescriptor> initDescriptors = new HashSet<>();
if (cumulativeMetricDescriptors.size() > 0) {
initDescriptors.add(DateSpineInfo.get(accioMDL.getDateSpine()));
}
return apply(root, sessionContext, analysis, accioMDL, allDescriptors, initDescriptors);
}

private Statement apply(
Statement root,
SessionContext sessionContext,
Analysis analysis,
AccioMDL accioMDL,
Set<QueryDescriptor> allDescriptors,
Set<QueryDescriptor> initDescriptors)
{
DirectedAcyclicGraph<String, Object> graph = new DirectedAcyclicGraph<>(Object.class);
Set<ModelInfo> modelInfos = analysis.getModels().stream().map(model -> ModelInfo.get(model, accioMDL)).collect(toSet());
Set<ModelInfo> requiredModelInfos = new HashSet<>();
modelInfos.forEach(modelInfo -> addModelToGraph(modelInfo, graph, accioMDL, requiredModelInfos));
Set<ModelInfo> allModelInfos = ImmutableSet.<ModelInfo>builder().addAll(modelInfos).addAll(requiredModelInfos).build();
Set<QueryDescriptor> requiredQueryDescriptors = new HashSet<>();
// add to graph
allDescriptors.forEach(queryDescriptor -> addSqlDescriptorToGraph(queryDescriptor, graph, accioMDL, requiredQueryDescriptors, sessionContext));

Map<String, QueryDescriptor> descriptorMap = new HashMap<>();
allDescriptors.forEach(queryDescriptor -> descriptorMap.put(queryDescriptor.getName(), queryDescriptor));
requiredQueryDescriptors.forEach(queryDescriptor -> descriptorMap.put(queryDescriptor.getName(), queryDescriptor));

List<WithQuery> withQueries = new ArrayList<>();
graph.iterator().forEachRemaining(modelName -> {
ModelInfo modelInfo = allModelInfos.stream()
.filter(info -> info.getModel().getName().equals(modelName))
.findAny()
.orElseThrow(() -> new IllegalArgumentException(format("Missing model name %s in graph", modelName)));
withQueries.add(new WithQuery(new Identifier(modelInfo.getModel().getName()), parseQuery(modelInfo.getSql()), Optional.empty()));
initDescriptors.forEach(queryDescriptor -> withQueries.add(getWithQuery(queryDescriptor)));
graph.iterator().forEachRemaining(objectName -> {
QueryDescriptor queryDescriptor = descriptorMap.get(objectName);
checkArgument(queryDescriptor != null, objectName + " not found in query descriptors");
withQueries.add(getWithQuery(queryDescriptor));
});

Node rewriteWith = new WithRewriter(withQueries).process(root);
return (Statement) new Rewriter(accioMDL, analysis).process(rewriteWith);
}

private static void addModelToGraph(ModelInfo modelInfo, DirectedAcyclicGraph<String, Object> graph, AccioMDL mdl, Set<ModelInfo> modelInfos)
private static void addSqlDescriptorToGraph(
QueryDescriptor queryDescriptor,
DirectedAcyclicGraph<String, Object> graph,
AccioMDL mdl,
Set<QueryDescriptor> queryDescriptors,
SessionContext sessionContext)
{
// add vertex
graph.addVertex(modelInfo.getModel().getName());
modelInfo.getRequiredModels().forEach(graph::addVertex);
graph.addVertex(queryDescriptor.getName());
queryDescriptor.getRequiredObjects().forEach(graph::addVertex);

//add edge
try {
modelInfo.getRequiredModels().forEach(modelName ->
graph.addEdge(modelName, modelInfo.getModel().getName()));
queryDescriptor.getRequiredObjects().forEach(modelName ->
graph.addEdge(modelName, queryDescriptor.getName()));
}
catch (GraphCycleProhibitedException ex) {
throw new IllegalArgumentException("found cycle in models", ex);
}

// add required models to graph
for (String modelName : modelInfo.getRequiredModels()) {
ModelInfo info = ModelInfo.get(mdl.getModel(modelName).orElseThrow(), mdl);
modelInfos.add(info);
addModelToGraph(info, graph, mdl, modelInfos);
for (String objectName : queryDescriptor.getRequiredObjects()) {
QueryDescriptor descriptor = QueryDescriptor.of(objectName, mdl, sessionContext);
queryDescriptors.add(descriptor);
addSqlDescriptorToGraph(descriptor, graph, mdl, queryDescriptors, sessionContext);
}
}

Expand Down Expand Up @@ -175,20 +209,15 @@ protected Node visitDereferenceExpression(DereferenceExpression dereferenceExpre
return dereferenceExpression;
}

@Override
protected Node visitFunctionRelation(FunctionRelation node, Void context)
{
if (analysis.getMetricRollups().containsKey(NodeRef.of(node))) {
return new Table(QualifiedName.of(analysis.getMetricRollups().get(NodeRef.of(node)).getMetric().getName()));
}
// this should not happen, every MetricRollup node should be captured and syntax checked in StatementAnalyzer
throw new IllegalArgumentException("MetricRollup node is not replaced");
}

// the model is added in with query, and the catalog and schema should be removed
private Node applyModelRule(Table table)
{
return new Table(QualifiedName.of(table.getName().getSuffix()));
}
}

private static WithQuery getWithQuery(QueryDescriptor queryDescriptor)
{
return new WithQuery(new Identifier(queryDescriptor.getName()), queryDescriptor.getQuery(), Optional.empty());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.accio.sqlrewrite;

import io.accio.base.dto.DateSpine;
import io.trino.sql.tree.Query;

import java.util.Set;

import static io.accio.sqlrewrite.Utils.createDateSpineQuery;
import static java.util.Objects.requireNonNull;

public class DateSpineInfo
implements QueryDescriptor
{
public static final String NAME = "date_spine";

private final Query query;

public static DateSpineInfo get(DateSpine dateSpine)
{
return new DateSpineInfo(dateSpine);
}

private DateSpineInfo(DateSpine dateSpine)
{
this.query = createDateSpineQuery(requireNonNull(dateSpine));
}

@Override
public String getName()
{
return NAME;
}

@Override
public Set<String> getRequiredObjects()
{
return Set.of();
}

@Override
public Query getQuery()
{
return query;
}
}
73 changes: 73 additions & 0 deletions accio-sqlrewrite/src/main/java/io/accio/sqlrewrite/MetricInfo.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.accio.sqlrewrite;

import io.accio.base.AccioMDL;
import io.accio.base.dto.CumulativeMetric;
import io.accio.base.dto.Metric;
import io.trino.sql.tree.Query;

import java.util.Set;

import static java.util.Objects.requireNonNull;

public class MetricInfo
implements QueryDescriptor
{
private final String name;
private final Set<String> requiredObjects;
private final Query query;

public static MetricInfo get(Metric metric)
{
return new MetricInfo(
metric.getName(),
Set.of(metric.getBaseModel()),
Utils.parseMetricSql(metric));
}

public static MetricInfo get(CumulativeMetric metric, AccioMDL mdl)
{
return new MetricInfo(
metric.getName(),
Set.of(metric.getBaseModel()),
Utils.parseCumulativeMetricSql(metric, mdl));
}

private MetricInfo(String name, Set<String> requiredObjects, Query query)
{
this.name = requireNonNull(name);
this.requiredObjects = requireNonNull(requiredObjects);
this.query = requireNonNull(query);
}

@Override
public String getName()
{
return name;
}

@Override
public Set<String> getRequiredObjects()
{
return requiredObjects;
}

@Override
public Query getQuery()
{
return query;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.accio.sqlrewrite;

import io.accio.base.AccioMDL;
import io.accio.base.SessionContext;
import io.accio.sqlrewrite.analyzer.Analysis;
import io.accio.sqlrewrite.analyzer.MetricRollupInfo;
import io.accio.sqlrewrite.analyzer.StatementAnalyzer;
import io.trino.sql.tree.AliasedRelation;
import io.trino.sql.tree.FunctionRelation;
import io.trino.sql.tree.Identifier;
import io.trino.sql.tree.Node;
import io.trino.sql.tree.NodeRef;
import io.trino.sql.tree.Query;
import io.trino.sql.tree.Statement;
import io.trino.sql.tree.TableSubquery;

import java.util.List;

import static io.accio.sqlrewrite.Utils.parseMetricRollupSql;

public class MetricRollupRewrite
implements AccioRule
{
public static final MetricRollupRewrite METRIC_ROLLUP_REWRITE = new MetricRollupRewrite();

@Override
public Statement apply(Statement root, SessionContext sessionContext, AccioMDL accioMDL)
{
return apply(root, sessionContext, StatementAnalyzer.analyze(root, sessionContext, accioMDL), accioMDL);
}

@Override
public Statement apply(Statement root, SessionContext sessionContext, Analysis analysis, AccioMDL accioMDL)
{
return (Statement) new Rewriter(analysis).process(root);
}

private static class Rewriter
extends BaseRewriter<Void>
{
private final Analysis analysis;

Rewriter(Analysis analysis)
{
this.analysis = analysis;
}

@Override
protected Node visitFunctionRelation(FunctionRelation node, Void context)
{
if (analysis.getMetricRollups().containsKey(NodeRef.of(node))) {
MetricRollupInfo info = analysis.getMetricRollups().get(NodeRef.of(node));
Query query = parseMetricRollupSql(info);
return new AliasedRelation(new TableSubquery(query), new Identifier(info.getMetric().getName()), List.of());
}
// this should not happen, every MetricRollup node should be captured and syntax checked in StatementAnalyzer
throw new IllegalArgumentException("MetricRollup node is not replaced");
}
}
}
Loading

0 comments on commit 031894b

Please sign in to comment.