Skip to content

Commit

Permalink
Fix #9608: Unit test fixture compilation
Browse files Browse the repository at this point in the history
  • Loading branch information
aranke committed Mar 26, 2024
1 parent c6c0c79 commit da9b1be
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 22 deletions.
10 changes: 0 additions & 10 deletions core/dbt/contracts/graph/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -954,16 +954,6 @@ class UnitTestDefinition(NodeInfoMixin, GraphNode, UnitTestDefinitionResource):
def resource_class(cls) -> Type[UnitTestDefinitionResource]:
return UnitTestDefinitionResource

@property
def build_path(self):
# TODO: is this actually necessary?
return self.original_file_path

@property
def compiled_path(self):
# TODO: is this actually necessary?
return self.original_file_path

@property
def depends_on_nodes(self):
return self.depends_on.nodes
Expand Down
7 changes: 4 additions & 3 deletions core/dbt/parser/unit_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ def parse_unit_test_case(self, test_case: UnitTestDefinition):

common_fields = {
"resource_type": NodeType.Model,
"original_file_path": original_input_node.original_file_path,
# root directory for input and output fixtures
"original_file_path": unit_test_node.original_file_path,
"config": ModelConfig(materialized="ephemeral"),
"database": original_input_node.database,
"alias": original_input_node.identifier,
Expand All @@ -144,7 +145,7 @@ def parse_unit_test_case(self, test_case: UnitTestDefinition):
package_name=original_input_node.package_name,
unique_id=f"model.{original_input_node.package_name}.{input_name}",
name=input_name,
path=original_input_node.path or f"{input_name}.sql",
path=f"{input_name}.sql",
)
if (
original_input_node.resource_type == NodeType.Model
Expand All @@ -162,7 +163,7 @@ def parse_unit_test_case(self, test_case: UnitTestDefinition):
package_name=original_input_node.package_name,
unique_id=f"model.{original_input_node.package_name}.{input_name}",
name=original_input_node.name, # must be the same name for source lookup to work
path=input_name + ".sql", # for writing out compiled_code
path=f"{input_name}.sql", # for writing out compiled_code
source_name=original_input_node.source_name, # needed for source lookup
)
# Sources need to go in the sources dictionary in order to create the right lookup
Expand Down
2 changes: 1 addition & 1 deletion core/dbt/task/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def print_run_result_error(result, newline: bool = True, is_warning: bool = Fals
else:
fire_event(RunResultErrorNoMessage(status=result.status))

if result.node.build_path is not None:
if result.node.compiled_path is not None:
with TextOnly():
fire_event(Formatting(""))
fire_event(SQLCompiledPath(path=result.node.compiled_path))
Expand Down
26 changes: 18 additions & 8 deletions core/dbt/task/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,13 @@
from .compile import CompileRunner
from .run import RunTask

from dbt.contracts.graph.nodes import TestNode, UnitTestDefinition, UnitTestNode
from dbt.contracts.graph.nodes import (
TestNode,
UnitTestDefinition,
UnitTestNode,
GenericTestNode,
SingularTestNode,
)
from dbt.contracts.graph.manifest import Manifest
from dbt.artifacts.schemas.results import TestStatus
from dbt.artifacts.schemas.run import RunResult
Expand Down Expand Up @@ -180,7 +186,7 @@ def build_unit_test_manifest_from_test(

def execute_unit_test(
self, unit_test_def: UnitTestDefinition, manifest: Manifest
) -> UnitTestResultData:
) -> tuple[UnitTestNode, UnitTestResultData]:

unit_test_manifest = self.build_unit_test_manifest_from_test(unit_test_def, manifest)

Expand All @@ -190,6 +196,7 @@ def execute_unit_test(

# Compile the node
unit_test_node = self.compiler.compile_node(unit_test_node, unit_test_manifest, {})
assert isinstance(unit_test_node, UnitTestNode)

# generate_runtime_unit_test_context not strictly needed - this is to run the 'unit'
# materialization, not compile the node.compiled_code
Expand Down Expand Up @@ -243,18 +250,21 @@ def execute_unit_test(
rendered=rendered,
)

return UnitTestResultData(
unit_test_result_data = UnitTestResultData(
diff=diff,
should_error=should_error,
adapter_response=adapter_response,
)

def execute(self, test: Union[TestNode, UnitTestDefinition], manifest: Manifest):
return unit_test_node, unit_test_result_data

def execute(self, test: Union[TestNode, UnitTestNode], manifest: Manifest):
if isinstance(test, UnitTestDefinition):
unit_test_result = self.execute_unit_test(test, manifest)
return self.build_unit_test_run_result(test, unit_test_result)
unit_test_node, unit_test_result = self.execute_unit_test(test, manifest)
return self.build_unit_test_run_result(unit_test_node, unit_test_result)
else:
# Note: manifest here is a normal manifest
assert isinstance(test, (SingularTestNode, GenericTestNode))
test_result = self.execute_data_test(test, manifest)
return self.build_test_run_result(test, test_result)

Expand Down Expand Up @@ -293,7 +303,7 @@ def build_test_run_result(self, test: TestNode, result: TestResultData) -> RunRe
return run_result

def build_unit_test_run_result(
self, test: UnitTestDefinition, result: UnitTestResultData
self, test: UnitTestNode, result: UnitTestResultData
) -> RunResult:
thread_id = threading.current_thread().name

Expand All @@ -306,7 +316,7 @@ def build_unit_test_run_result(
failures = 1

return RunResult(
node=test, # type: ignore
node=test,
status=status,
timing=[],
thread_id=thread_id,
Expand Down
59 changes: 59 additions & 0 deletions tests/functional/unit_testing/test_unit_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
run_dbt,
write_file,
get_manifest,
run_dbt_and_capture,
read_file,
file_exists,
)
from dbt.contracts.results import NodeStatus
from dbt.exceptions import DuplicateResourceNameError, ParsingError
Expand Down Expand Up @@ -442,3 +445,59 @@ def test_unit_test_ext_nodes(
run_dbt(["run"], expect_pass=True)
results = run_dbt(["test", "--select", "valid_emails"], expect_pass=True)
assert len(results) == 1


subfolder_model_a_sql = """
select
1 as id, 'blue' as color
"""

subfolder_model_b_sql = """
select
id,
color
from {{ ref('model_a') }}
"""

subfolder_my_model_yml = """
unit_tests:
- name: my_unit_test
model: model_b
given:
- input: ref('model_a')
rows:
- { id: 1, color: 'blue' }
expect:
rows:
- { id: 1, color: 'red' }
"""


class TestUnitTestSubfolderPath:
@pytest.fixture(scope="class")
def models(self):
return {
"subfolder": {
"model_a.sql": subfolder_model_a_sql,
"model_b.sql": subfolder_model_b_sql,
"my_model.yml": subfolder_my_model_yml,
}
}

def test_subfolder_unit_test(self, project):
results, output = run_dbt_and_capture(["build"], expect_pass=False)

# Test that input fixture doesn't overwrite the original model
assert (
read_file("target/compiled/test/models/subfolder/model_a.sql").strip()
== subfolder_model_a_sql.strip()
)

# Test that correct path is written in logs
assert (
"target/compiled/test/models/subfolder/my_model.yml/models/subfolder/my_unit_test.sql"
in output
)
assert file_exists(
"target/compiled/test/models/subfolder/my_model.yml/models/subfolder/my_unit_test.sql"
)

0 comments on commit da9b1be

Please sign in to comment.