Skip to content

Commit

Permalink
Ignore parent tests added edges for build selection (#7431)
Browse files Browse the repository at this point in the history
  • Loading branch information
b-luu authored Apr 28, 2023
1 parent a7eb89d commit f1dddaa
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 9 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Fixes-20230421-172428.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Fixes
body: dbt build selection of tests' descendants
time: 2023-04-21T17:24:28.335866975+02:00
custom:
Author: b-luu
Issue: "7289"
2 changes: 1 addition & 1 deletion core/dbt/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ def add_test_edges(self, linker: Linker, manifest: Manifest) -> None:
# is a subset of all upstream nodes of the current node,
# add an edge from the upstream test to the current node.
if test_depends_on.issubset(upstream_nodes):
linker.graph.add_edge(upstream_test, node_id)
linker.graph.add_edge(upstream_test, node_id, edge_type="parent_test")

def compile(self, manifest: Manifest, write=True, add_test_edges=False) -> Graph:
self.initialize()
Expand Down
17 changes: 15 additions & 2 deletions core/dbt/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,29 @@ def ancestors(self, node: UniqueId, max_depth: Optional[int] = None) -> Set[Uniq
"""Returns all nodes having a path to `node` in `graph`"""
if not self.graph.has_node(node):
raise DbtInternalError(f"Node {node} not found in the graph!")
filtered_graph = self.exclude_edge_type("parent_test")
return {
child
for _, child in nx.bfs_edges(self.graph, node, reverse=True, depth_limit=max_depth)
for _, child in nx.bfs_edges(filtered_graph, node, reverse=True, depth_limit=max_depth)
}

def descendants(self, node: UniqueId, max_depth: Optional[int] = None) -> Set[UniqueId]:
"""Returns all nodes reachable from `node` in `graph`"""
if not self.graph.has_node(node):
raise DbtInternalError(f"Node {node} not found in the graph!")
return {child for _, child in nx.bfs_edges(self.graph, node, depth_limit=max_depth)}
filtered_graph = self.exclude_edge_type("parent_test")
return {child for _, child in nx.bfs_edges(filtered_graph, node, depth_limit=max_depth)}

def exclude_edge_type(self, edge_type_to_exclude):
return nx.restricted_view(
self.graph,
nodes=[],
edges=(
(a, b)
for a, b in self.graph.edges
if self.graph[a][b].get("edge_type") == edge_type_to_exclude
),
)

def select_childrens_parents(self, selected: Set[UniqueId]) -> Set[UniqueId]:
ancestors_for = self.select_children(selected) | selected
Expand Down
21 changes: 21 additions & 0 deletions tests/functional/build/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,27 @@
- not_null
"""

models_triple_blocking__test_yml = """
version: 2
models:
- name: model_a
columns:
- name: id
tests:
- not_null
- name: model_b
columns:
- name: id
tests:
- not_null
- name: model_c
columns:
- name: id
tests:
- not_null
"""

models_interdependent__model_a_sql = """
select 1 as id
"""
Expand Down
32 changes: 32 additions & 0 deletions tests/functional/build/test_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
models_simple_blocking__model_a_sql,
models_simple_blocking__model_b_sql,
models_simple_blocking__test_yml,
models_triple_blocking__test_yml,
models_interdependent__test_yml,
models_interdependent__model_a_sql,
models_interdependent__model_b_sql,
Expand Down Expand Up @@ -196,3 +197,34 @@ def test_interdependent_models_fail(self, project):
actual = [str(r.status) for r in results]
expected = ["error"] * 4 + ["skipped"] * 7 + ["pass"] * 2 + ["success"] * 3
assert sorted(actual) == sorted(expected)


class TestDownstreamSelection:
@pytest.fixture(scope="class")
def models(self):
return {
"model_a.sql": models_simple_blocking__model_a_sql,
"model_b.sql": models_simple_blocking__model_b_sql,
"test.yml": models_simple_blocking__test_yml,
}

def test_downstream_selection(self, project):
"""Ensure that selecting test+ does not select model_a's other children"""
results = run_dbt(["build", "--select", "model_a not_null_model_a_id+"], expect_pass=True)
assert len(results) == 2


class TestLimitedUpstreamSelection:
@pytest.fixture(scope="class")
def models(self):
return {
"model_a.sql": models_interdependent__model_a_sql,
"model_b.sql": models_interdependent__model_b_sql,
"model_c.sql": models_interdependent__model_c_sql,
"test.yml": models_triple_blocking__test_yml,
}

def test_limited_upstream_selection(self, project):
"""Ensure that selecting 1+model_c only selects up to model_b (+ tests of both)"""
results = run_dbt(["build", "--select", "1+model_c"], expect_pass=True)
assert len(results) == 4
11 changes: 5 additions & 6 deletions tests/functional/defer_state/test_run_results_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,9 +216,9 @@ def test_build_run_results_state(self, project):
results = run_dbt(
["build", "--select", "result:fail+", "--state", "./state"], expect_pass=False
)
assert len(results) == 2
assert len(results) == 1
nodes = set([elem.node.name for elem in results])
assert nodes == {"table_model", "unique_view_model_id"}
assert nodes == {"unique_view_model_id"}

results = run_dbt(["ls", "--select", "result:fail+", "--state", "./state"])
assert len(results) == 1
Expand All @@ -240,9 +240,9 @@ def test_build_run_results_state(self, project):
results = run_dbt(
["build", "--select", "result:warn+", "--state", "./state"], expect_pass=True
)
assert len(results) == 2 # includes table_model to be run
assert len(results) == 1
nodes = set([elem.node.name for elem in results])
assert nodes == {"table_model", "unique_view_model_id"}
assert nodes == {"unique_view_model_id"}

results = run_dbt(["ls", "--select", "result:warn+", "--state", "./state"])
assert len(results) == 1
Expand Down Expand Up @@ -483,12 +483,11 @@ def test_concurrent_selectors_build_run_results_state(self, project):
],
expect_pass=False,
)
assert len(results) == 5
assert len(results) == 4
nodes = set([elem.node.name for elem in results])
assert nodes == {
"error_model",
"downstream_of_error_model",
"table_model_modified_example",
"table_model",
"unique_view_model_id",
}

0 comments on commit f1dddaa

Please sign in to comment.