Skip to content

Commit

Permalink
Fix #9005: Allow singular tests to be documented in properties.yml (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
aranke authored Sep 24, 2024
1 parent aa23af9 commit 3ac20ce
Show file tree
Hide file tree
Showing 11 changed files with 220 additions and 15 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Fixes-20240923-190758.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Fixes
body: Allow singular tests to be documented in properties.yml
time: 2024-09-23T19:07:58.151069+01:00
custom:
Author: aranke
Issue: "9005"
13 changes: 4 additions & 9 deletions core/dbt/config/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,14 +158,8 @@ def _parse_versions(versions: Union[List[str], str]) -> List[VersionSpecifier]:
return [VersionSpecifier.from_version_string(v) for v in versions]


def _all_source_paths(
model_paths: List[str],
seed_paths: List[str],
snapshot_paths: List[str],
analysis_paths: List[str],
macro_paths: List[str],
) -> List[str]:
paths = chain(model_paths, seed_paths, snapshot_paths, analysis_paths, macro_paths)
def _all_source_paths(*args: List[str]) -> List[str]:
paths = chain(*args)
# Strip trailing slashes since the path is the same even though the name is not
stripped_paths = map(lambda s: s.rstrip("/"), paths)
return list(set(stripped_paths))
Expand Down Expand Up @@ -409,7 +403,7 @@ def create_project(self, rendered: RenderComponents) -> "Project":
snapshot_paths: List[str] = value_or(cfg.snapshot_paths, ["snapshots"])

all_source_paths: List[str] = _all_source_paths(
model_paths, seed_paths, snapshot_paths, analysis_paths, macro_paths
model_paths, seed_paths, snapshot_paths, analysis_paths, macro_paths, test_paths
)

docs_paths: List[str] = value_or(cfg.docs_paths, all_source_paths)
Expand Down Expand Up @@ -652,6 +646,7 @@ def all_source_paths(self) -> List[str]:
self.snapshot_paths,
self.analysis_paths,
self.macro_paths,
self.test_paths,
)

@property
Expand Down
50 changes: 49 additions & 1 deletion core/dbt/contracts/graph/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
SavedQuery,
SeedNode,
SemanticModel,
SingularTestNode,
SourceDefinition,
UnitTestDefinition,
UnitTestFileFixture,
Expand Down Expand Up @@ -89,7 +90,7 @@
RefName = str


def find_unique_id_for_package(storage, key, package: Optional[PackageName]):
def find_unique_id_for_package(storage, key, package: Optional[PackageName]) -> Optional[UniqueID]:
if key not in storage:
return None

Expand Down Expand Up @@ -470,6 +471,43 @@ class AnalysisLookup(RefableLookup):
_versioned_types: ClassVar[set] = set()


class SingularTestLookup(dbtClassMixin):
def __init__(self, manifest: "Manifest") -> None:
self.storage: Dict[str, Dict[PackageName, UniqueID]] = {}
self.populate(manifest)

def get_unique_id(self, search_name, package: Optional[PackageName]) -> Optional[UniqueID]:
return find_unique_id_for_package(self.storage, search_name, package)

def find(
self, search_name, package: Optional[PackageName], manifest: "Manifest"
) -> Optional[SingularTestNode]:
unique_id = self.get_unique_id(search_name, package)
if unique_id is not None:
return self.perform_lookup(unique_id, manifest)
return None

def add_singular_test(self, source: SingularTestNode) -> None:
if source.search_name not in self.storage:
self.storage[source.search_name] = {}

self.storage[source.search_name][source.package_name] = source.unique_id

def populate(self, manifest: "Manifest") -> None:
for node in manifest.nodes.values():
if isinstance(node, SingularTestNode):
self.add_singular_test(node)

def perform_lookup(self, unique_id: UniqueID, manifest: "Manifest") -> SingularTestNode:
if unique_id not in manifest.nodes:
raise dbt_common.exceptions.DbtInternalError(
f"Singular test {unique_id} found in cache but not found in manifest"
)
node = manifest.nodes[unique_id]
assert isinstance(node, SingularTestNode)
return node


def _packages_to_search(
current_project: str,
node_package: str,
Expand Down Expand Up @@ -869,6 +907,9 @@ class Manifest(MacroMethods, dbtClassMixin):
_analysis_lookup: Optional[AnalysisLookup] = field(
default=None, metadata={"serialize": lambda x: None, "deserialize": lambda x: None}
)
_singular_test_lookup: Optional[SingularTestLookup] = field(
default=None, metadata={"serialize": lambda x: None, "deserialize": lambda x: None}
)
_parsing_info: ParsingInfo = field(
default_factory=ParsingInfo,
metadata={"serialize": lambda x: None, "deserialize": lambda x: None},
Expand Down Expand Up @@ -1264,6 +1305,12 @@ def analysis_lookup(self) -> AnalysisLookup:
self._analysis_lookup = AnalysisLookup(self)
return self._analysis_lookup

@property
def singular_test_lookup(self) -> SingularTestLookup:
if self._singular_test_lookup is None:
self._singular_test_lookup = SingularTestLookup(self)
return self._singular_test_lookup

@property
def external_node_unique_ids(self):
return [node.unique_id for node in self.nodes.values() if node.is_external_node]
Expand Down Expand Up @@ -1708,6 +1755,7 @@ def __reduce_ex__(self, protocol):
self._semantic_model_by_measure_lookup,
self._disabled_lookup,
self._analysis_lookup,
self._singular_test_lookup,
)
return self.__class__, args

Expand Down
5 changes: 5 additions & 0 deletions core/dbt/contracts/graph/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1627,6 +1627,11 @@ class ParsedMacroPatch(ParsedPatch):
arguments: List[MacroArgument] = field(default_factory=list)


@dataclass
class ParsedSingularTestPatch(ParsedPatch):
pass


# ====================================
# Node unions/categories
# ====================================
Expand Down
5 changes: 5 additions & 0 deletions core/dbt/contracts/graph/unparsed.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,11 @@ class UnparsedAnalysisUpdate(HasConfig, HasColumnDocs, HasColumnProps, HasYamlMe
access: Optional[str] = None


@dataclass
class UnparsedSingularTestUpdate(HasConfig, HasColumnProps, HasYamlMetadata):
pass


@dataclass
class UnparsedNodeUpdate(HasConfig, HasColumnTests, HasColumnAndTestProps, HasYamlMetadata):
quote_columns: Optional[bool] = None
Expand Down
2 changes: 2 additions & 0 deletions core/dbt/parser/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
UnparsedMacroUpdate,
UnparsedModelUpdate,
UnparsedNodeUpdate,
UnparsedSingularTestUpdate,
)
from dbt.exceptions import ParsingError
from dbt.parser.search import FileBlock
Expand All @@ -38,6 +39,7 @@ def trimmed(inp: str) -> str:
UnpatchedSourceDefinition,
UnparsedExposure,
UnparsedModelUpdate,
UnparsedSingularTestUpdate,
)


Expand Down
60 changes: 59 additions & 1 deletion core/dbt/parser/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
ModelNode,
ParsedMacroPatch,
ParsedNodePatch,
ParsedSingularTestPatch,
UnpatchedSourceDefinition,
)
from dbt.contracts.graph.unparsed import (
Expand All @@ -27,6 +28,7 @@
UnparsedMacroUpdate,
UnparsedModelUpdate,
UnparsedNodeUpdate,
UnparsedSingularTestUpdate,
UnparsedSourceDefinition,
)
from dbt.events.types import (
Expand Down Expand Up @@ -222,6 +224,10 @@ def parse_file(self, block: FileBlock, dct: Optional[Dict] = None) -> None:
parser = MacroPatchParser(self, yaml_block, "macros")
parser.parse()

if "data_tests" in dct:
parser = SingularTestPatchParser(self, yaml_block, "data_tests")
parser.parse()

# PatchParser.parse() (but never test_blocks)
if "analyses" in dct:
parser = AnalysisPatchParser(self, yaml_block, "analyses")
Expand Down Expand Up @@ -316,14 +322,17 @@ def _add_yaml_snapshot_nodes_to_manifest(
self.manifest.rebuild_ref_lookup()


Parsed = TypeVar("Parsed", UnpatchedSourceDefinition, ParsedNodePatch, ParsedMacroPatch)
Parsed = TypeVar(
"Parsed", UnpatchedSourceDefinition, ParsedNodePatch, ParsedMacroPatch, ParsedSingularTestPatch
)
NodeTarget = TypeVar("NodeTarget", UnparsedNodeUpdate, UnparsedAnalysisUpdate, UnparsedModelUpdate)
NonSourceTarget = TypeVar(
"NonSourceTarget",
UnparsedNodeUpdate,
UnparsedAnalysisUpdate,
UnparsedMacroUpdate,
UnparsedModelUpdate,
UnparsedSingularTestUpdate,
)


Expand Down Expand Up @@ -1105,6 +1114,55 @@ def _target_type(self) -> Type[UnparsedAnalysisUpdate]:
return UnparsedAnalysisUpdate


class SingularTestPatchParser(PatchParser[UnparsedSingularTestUpdate, ParsedSingularTestPatch]):
def get_block(self, node: UnparsedSingularTestUpdate) -> TargetBlock:
return TargetBlock.from_yaml_block(self.yaml, node)

def _target_type(self) -> Type[UnparsedSingularTestUpdate]:
return UnparsedSingularTestUpdate

def parse_patch(self, block: TargetBlock[UnparsedSingularTestUpdate], refs: ParserRef) -> None:
patch = ParsedSingularTestPatch(
name=block.target.name,
description=block.target.description,
meta=block.target.meta,
docs=block.target.docs,
config=block.target.config,
original_file_path=block.target.original_file_path,
yaml_key=block.target.yaml_key,
package_name=block.target.package_name,
)

assert isinstance(self.yaml.file, SchemaSourceFile)
source_file: SchemaSourceFile = self.yaml.file

unique_id = self.manifest.singular_test_lookup.get_unique_id(
block.name, block.target.package_name
)
if not unique_id:
warn_or_error(
NoNodeForYamlKey(
patch_name=patch.name,
yaml_key=patch.yaml_key,
file_path=source_file.path.original_file_path,
)
)
return

node = self.manifest.nodes.get(unique_id)
assert node is not None

source_file.append_patch(patch.yaml_key, unique_id)
if patch.config:
self.patch_node_config(node, patch)

node.patch_path = patch.file_id
node.description = patch.description
node.created_at = time.time()
node.meta = patch.meta
node.docs = patch.docs


class MacroPatchParser(PatchParser[UnparsedMacroUpdate, ParsedMacroPatch]):
def get_block(self, node: UnparsedMacroUpdate) -> TargetBlock:
return TargetBlock.from_yaml_block(self.yaml, node)
Expand Down
32 changes: 32 additions & 0 deletions tests/functional/data_test_patch/fixtures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
tests__my_singular_test_sql = """
with my_cte as (
select 1 as id, 'foo' as name
union all
select 2 as id, 'bar' as name
)
select * from my_cte
"""

tests__schema_yml = """
data_tests:
- name: my_singular_test
description: "{{ doc('my_singular_test_documentation') }}"
config:
error_if: ">10"
meta:
some_key: some_val
"""

tests__doc_block_md = """
{% docs my_singular_test_documentation %}
Some docs from a doc block
{% enddocs %}
"""

tests__invalid_name_schema_yml = """
data_tests:
- name: my_double_test
description: documentation, but make it double
"""
53 changes: 53 additions & 0 deletions tests/functional/data_test_patch/test_singular_test_patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import os

import pytest

from dbt.tests.util import get_artifact, run_dbt, run_dbt_and_capture
from tests.functional.data_test_patch.fixtures import (
tests__doc_block_md,
tests__invalid_name_schema_yml,
tests__my_singular_test_sql,
tests__schema_yml,
)


class TestPatchSingularTest:
@pytest.fixture(scope="class")
def tests(self):
return {
"my_singular_test.sql": tests__my_singular_test_sql,
"schema.yml": tests__schema_yml,
"doc_block.md": tests__doc_block_md,
}

def test_compile(self, project):
run_dbt(["compile"])
manifest = get_artifact(project.project_root, "target", "manifest.json")
assert len(manifest["nodes"]) == 1

my_singular_test_node = manifest["nodes"]["test.test.my_singular_test"]
assert my_singular_test_node["description"] == "Some docs from a doc block"
assert my_singular_test_node["config"]["error_if"] == ">10"
assert my_singular_test_node["config"]["meta"] == {"some_key": "some_val"}


class TestPatchSingularTestInvalidName:
@pytest.fixture(scope="class")
def tests(self):
return {
"my_singular_test.sql": tests__my_singular_test_sql,
"schema_with_invalid_name.yml": tests__invalid_name_schema_yml,
}

def test_compile(self, project):
_, log_output = run_dbt_and_capture(["compile"])

file_path = (
"tests\\schema_with_invalid_name.yml"
if os.name == "nt"
else "tests/schema_with_invalid_name.yml"
)
assert (
f"Did not find matching node for patch with name 'my_double_test' in the 'data_tests' section of file '{file_path}'"
in log_output
)
7 changes: 4 additions & 3 deletions tests/unit/config/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class TestProjectMethods:
def test_all_source_paths(self, project: Project):
assert (
project.all_source_paths.sort()
== ["models", "seeds", "snapshots", "analyses", "macros"].sort()
== ["models", "seeds", "snapshots", "analyses", "macros", "tests"].sort()
)

def test_generic_test_paths(self, project: Project):
Expand Down Expand Up @@ -99,7 +99,8 @@ def test_defaults(self):
self.assertEqual(project.test_paths, ["tests"])
self.assertEqual(project.analysis_paths, ["analyses"])
self.assertEqual(
set(project.docs_paths), set(["models", "seeds", "snapshots", "analyses", "macros"])
set(project.docs_paths),
{"models", "seeds", "snapshots", "analyses", "macros", "tests"},
)
self.assertEqual(project.asset_paths, [])
self.assertEqual(project.target_path, "target")
Expand Down Expand Up @@ -128,7 +129,7 @@ def test_implicit_overrides(self):
)
self.assertEqual(
set(project.docs_paths),
set(["other-models", "seeds", "snapshots", "analyses", "macros"]),
{"other-models", "seeds", "snapshots", "analyses", "macros", "tests"},
)

def test_all_overrides(self):
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/config/test_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def test_from_args(self):
self.assertEqual(config.test_paths, ["tests"])
self.assertEqual(config.analysis_paths, ["analyses"])
self.assertEqual(
set(config.docs_paths), set(["models", "seeds", "snapshots", "analyses", "macros"])
set(config.docs_paths), {"models", "seeds", "snapshots", "analyses", "macros", "tests"}
)
self.assertEqual(config.asset_paths, [])
self.assertEqual(config.target_path, "target")
Expand Down

0 comments on commit 3ac20ce

Please sign in to comment.