Skip to content

Commit

Permalink
Create publication.py, various Publication classes, Dependency class
Browse files Browse the repository at this point in the history
  • Loading branch information
gshank committed Apr 12, 2023
1 parent f9b28bc commit 0daea9f
Show file tree
Hide file tree
Showing 8 changed files with 343 additions and 120 deletions.
195 changes: 100 additions & 95 deletions core/dbt/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,94 @@ def write_graph(self, outfile: str, manifest: Manifest):
with open(outfile, "wb") as outfh:
pickle.dump(out_graph, outfh, protocol=pickle.HIGHEST_PROTOCOL)

def link_node(self, node: GraphMemberNode, manifest: Manifest):
self.add_node(node.unique_id)

for dependency in node.depends_on_nodes:
if dependency in manifest.nodes:
self.dependency(node.unique_id, (manifest.nodes[dependency].unique_id))
elif dependency in manifest.sources:
self.dependency(node.unique_id, (manifest.sources[dependency].unique_id))
elif dependency in manifest.metrics:
self.dependency(node.unique_id, (manifest.metrics[dependency].unique_id))
else:
raise GraphDependencyNotFoundError(node, dependency)

def link_graph(self, manifest: Manifest, add_test_edges: bool = False):
for source in manifest.sources.values():
self.add_node(source.unique_id)
for node in manifest.nodes.values():
self.link_node(node, manifest)
for exposure in manifest.exposures.values():
self.link_node(exposure, manifest)
for metric in manifest.metrics.values():
self.link_node(metric, manifest)

cycle = self.find_cycles()

if cycle:
raise RuntimeError("Found a cycle: {}".format(cycle))

if add_test_edges:
manifest.build_parent_and_child_maps()
self.add_test_edges(manifest)

def add_test_edges(self, manifest: Manifest) -> None:
"""This method adds additional edges to the DAG. For a given non-test
executable node, add an edge from an upstream test to the given node if
the set of nodes the test depends on is a subset of the upstream nodes
for the given node."""

# Given a graph:
# model1 --> model2 --> model3
# | |
# | \/
# \/ test 2
# test1
#
# Produce the following graph:
# model1 --> model2 --> model3
# | /\ | /\ /\
# | | \/ | |
# \/ | test2 ----| |
# test1 ----|---------------|

for node_id in self.graph:
# If node is executable (in manifest.nodes) and does _not_
# represent a test, continue.
if (
node_id in manifest.nodes
and manifest.nodes[node_id].resource_type != NodeType.Test
):
# Get *everything* upstream of the node
all_upstream_nodes = nx.traversal.bfs_tree(self.graph, node_id, reverse=True)
# Get the set of upstream nodes not including the current node.
upstream_nodes = set([n for n in all_upstream_nodes if n != node_id])

# Get all tests that depend on any upstream nodes.
upstream_tests = []
for upstream_node in upstream_nodes:
upstream_tests += _get_tests_for_node(manifest, upstream_node)

for upstream_test in upstream_tests:
# Get the set of all nodes that the test depends on
# including the upstream_node itself. This is necessary
# because tests can depend on multiple nodes (ex:
# relationship tests). Test nodes do not distinguish
# between what node the test is "testing" and what
# node(s) it depends on.
test_depends_on = set(manifest.nodes[upstream_test].depends_on_nodes)

# If the set of nodes that an upstream test depends on
# is a subset of all upstream nodes of the current node,
# add an edge from the upstream test to the current node.
if test_depends_on.issubset(upstream_nodes):
self.graph.add_edge(upstream_test, node_id)

def get_graph(self, manifest: Manifest) -> Graph:
self.link_graph(manifest)
return Graph(self.graph)


class Compiler:
def __init__(self, config):
Expand Down Expand Up @@ -385,104 +473,13 @@ def _compile_code(

return node

def write_graph_file(self, linker: Linker, manifest: Manifest):
filename = graph_file_name
graph_path = os.path.join(self.config.target_path, filename)
flags = get_flags()
if flags.WRITE_JSON:
linker.write_graph(graph_path, manifest)

def link_node(self, linker: Linker, node: GraphMemberNode, manifest: Manifest):
linker.add_node(node.unique_id)

for dependency in node.depends_on_nodes:
if dependency in manifest.nodes:
linker.dependency(node.unique_id, (manifest.nodes[dependency].unique_id))
elif dependency in manifest.sources:
linker.dependency(node.unique_id, (manifest.sources[dependency].unique_id))
elif dependency in manifest.metrics:
linker.dependency(node.unique_id, (manifest.metrics[dependency].unique_id))
else:
raise GraphDependencyNotFoundError(node, dependency)

def link_graph(self, linker: Linker, manifest: Manifest, add_test_edges: bool = False):
for source in manifest.sources.values():
linker.add_node(source.unique_id)
for node in manifest.nodes.values():
self.link_node(linker, node, manifest)
for exposure in manifest.exposures.values():
self.link_node(linker, exposure, manifest)
for metric in manifest.metrics.values():
self.link_node(linker, metric, manifest)

cycle = linker.find_cycles()

if cycle:
raise RuntimeError("Found a cycle: {}".format(cycle))

if add_test_edges:
manifest.build_parent_and_child_maps()
self.add_test_edges(linker, manifest)

def add_test_edges(self, linker: Linker, manifest: Manifest) -> None:
"""This method adds additional edges to the DAG. For a given non-test
executable node, add an edge from an upstream test to the given node if
the set of nodes the test depends on is a subset of the upstream nodes
for the given node."""

# Given a graph:
# model1 --> model2 --> model3
# | |
# | \/
# \/ test 2
# test1
#
# Produce the following graph:
# model1 --> model2 --> model3
# | /\ | /\ /\
# | | \/ | |
# \/ | test2 ----| |
# test1 ----|---------------|

for node_id in linker.graph:
# If node is executable (in manifest.nodes) and does _not_
# represent a test, continue.
if (
node_id in manifest.nodes
and manifest.nodes[node_id].resource_type != NodeType.Test
):
# Get *everything* upstream of the node
all_upstream_nodes = nx.traversal.bfs_tree(linker.graph, node_id, reverse=True)
# Get the set of upstream nodes not including the current node.
upstream_nodes = set([n for n in all_upstream_nodes if n != node_id])

# Get all tests that depend on any upstream nodes.
upstream_tests = []
for upstream_node in upstream_nodes:
upstream_tests += _get_tests_for_node(manifest, upstream_node)

for upstream_test in upstream_tests:
# Get the set of all nodes that the test depends on
# including the upstream_node itself. This is necessary
# because tests can depend on multiple nodes (ex:
# relationship tests). Test nodes do not distinguish
# between what node the test is "testing" and what
# node(s) it depends on.
test_depends_on = set(manifest.nodes[upstream_test].depends_on_nodes)

# If the set of nodes that an upstream test depends on
# is a subset of all upstream nodes of the current node,
# add an edge from the upstream test to the current node.
if test_depends_on.issubset(upstream_nodes):
linker.graph.add_edge(upstream_test, node_id)

# This method doesn't actually "compile" any of the nodes. That is done by the
# "compile_node" method. This creates a Linker and builds the networkx graph,
# writes out the graph.gpickle file, and prints the stats, returning a Graph object.
def compile(self, manifest: Manifest, write=True, add_test_edges=False) -> Graph:
self.initialize()
linker = Linker()

self.link_graph(linker, manifest, add_test_edges)

stats = _generate_stats(manifest)
linker.link_graph(manifest, add_test_edges)

if write:
self.write_graph_file(linker, manifest)
Expand All @@ -492,10 +489,18 @@ def compile(self, manifest: Manifest, write=True, add_test_edges=False) -> Graph
self.config.args.__class__ == argparse.Namespace
and self.config.args.cls == list_task.ListTask
):
stats = _generate_stats(manifest)
print_compile_stats(stats)

return Graph(linker.graph)

def write_graph_file(self, linker: Linker, manifest: Manifest):
filename = graph_file_name
graph_path = os.path.join(self.config.target_path, filename)
flags = get_flags()
if flags.WRITE_JSON:
linker.write_graph(graph_path, manifest)

# writes the "compiled_code" into the target/compiled directory
def _write_node(self, node: ManifestSQLNode) -> ManifestSQLNode:
if not node.extra_ctes_injected or node.resource_type in (
Expand Down
3 changes: 3 additions & 0 deletions core/dbt/contracts/graph/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from typing_extensions import Protocol
from uuid import UUID

from dbt.contracts.publication import Dependencies

from dbt.contracts.graph.nodes import (
Macro,
Documentation,
Expand Down Expand Up @@ -633,6 +635,7 @@ class Manifest(MacroMethods, DataClassMessagePackMixin, dbtClassMixin):
source_patches: MutableMapping[SourceKey, SourcePatch] = field(default_factory=dict)
disabled: MutableMapping[str, List[GraphMemberNode]] = field(default_factory=dict)
env_vars: MutableMapping[str, str] = field(default_factory=dict)
dependencies: Optional[Dependencies] = None

_doc_lookup: Optional[DocLookup] = field(
default=None, metadata={"serialize": lambda x: None, "deserialize": lambda x: None}
Expand Down
46 changes: 46 additions & 0 deletions core/dbt/contracts/publication.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from typing import Optional, List, Dict, Any
from dbt.dataclass_schema import dbtClassMixin

from dataclasses import dataclass, field

from dbt.contracts.util import BaseArtifactMetadata, ArtifactMixin, schema_version


@dataclass
class DependentProjects(dbtClassMixin):
name: str
environment: str


@dataclass
class Dependencies(dbtClassMixin):
projects: list[DependentProjects] = field(default_factory=list)


@dataclass
class PublicationMetadata(BaseArtifactMetadata):
dbt_schema_version: str = field(default_factory=lambda: str(Publication.dbt_schema_version))
adapter_type: Optional[str] = None
quoting: Dict[str, Any] = field(default_factory=dict)


@dataclass
class PublicModel(dbtClassMixin):
relation_name: str
latest: bool = False # not implemented yet
# list of model unique_ids
public_dependencies: List[str] = field(default_factory=list)


@dataclass
class PublicationMandatory:
project_name: str


@dataclass
@schema_version("publication", 1)
class Publication(ArtifactMixin, PublicationMandatory):
public_models: Dict[str, PublicModel] = field(default_factory=dict)
metadata: PublicationMetadata = field(default_factory=PublicationMetadata)
# list of project name strings
dependencies: List[str] = field(default_factory=list)
Loading

0 comments on commit 0daea9f

Please sign in to comment.