From 6ce705339924348b9e062c295b29d15582dcb95c Mon Sep 17 00:00:00 2001 From: Benoit Perigaud <8754100+b-per@users.noreply.github.com> Date: Wed, 11 Sep 2024 12:22:16 +0200 Subject: [PATCH] Add ability to import other YML files --- core/dbt/clients/yaml_helper.py | 109 +++++++++++++++++++++++++++++++- 1 file changed, 106 insertions(+), 3 deletions(-) diff --git a/core/dbt/clients/yaml_helper.py b/core/dbt/clients/yaml_helper.py index a0a51099331..bcff21fa21a 100644 --- a/core/dbt/clients/yaml_helper.py +++ b/core/dbt/clients/yaml_helper.py @@ -1,9 +1,10 @@ -from typing import Any, Dict, Optional +import os +from functools import cached_property +from typing import Any, Dict, List, Optional, Union, overload import yaml import dbt_common.exceptions -import dbt_common.exceptions.base # the C version is faster, but it doesn't always exist try: @@ -52,8 +53,33 @@ def contextualized_yaml_error(raw_contents, error): ) +class LoaderWithInclude(Loader): + """Loader with a name being set.""" + + def __init__(self, stream: Any) -> None: + """Initialize a safe line loader.""" + self.stream = stream + + # Set name in same way as the Python loader does in yaml.reader.__init__ + if isinstance(stream, str): + self.name = "" + elif isinstance(stream, bytes): + self.name = "" + else: + self.name = getattr(stream, "name", "") + + super().__init__(stream) + + @cached_property + def get_name(self) -> str: + """Get the name of the loader.""" + return self.name + + def safe_load(contents) -> Optional[Dict[str, Any]]: - return yaml.load(contents, Loader=SafeLoader) + loader = LoaderWithInclude + loader.add_constructor("!include", _include_yaml) + return yaml.load(contents, Loader=loader) def load_yaml_text(contents, path=None): @@ -66,3 +92,80 @@ def load_yaml_text(contents, path=None): error = str(e) raise dbt_common.exceptions.base.DbtValidationError(error) + + +JSON_TYPE = Union[List, Dict, str] + + +def parse_yaml(content: Any, secrets=None) -> JSON_TYPE: + """Parse YAML with the fastest available loader.""" + return _parse_yaml(LoaderWithInclude, content, secrets) + + +def _parse_yaml( + loader: LoaderWithInclude, + content: Any, + secrets: Optional[str] = None, +) -> JSON_TYPE: + """Load a YAML file.""" + return yaml.load(content, LoaderWithInclude) # type: ignore[arg-type] + + +def load_yaml(fname: Any) -> Optional[JSON_TYPE]: + """Load a YAML file.""" + try: + with open(fname, encoding="utf-8") as conf_file: + return parse_yaml(conf_file, None) + except UnicodeDecodeError as exc: + raise dbt_common.exceptions.base.DbtValidationError(str(exc)) + + +@overload +def _add_reference( + obj: list, + loader: LoaderWithInclude, + node: yaml.nodes.Node, +) -> list: ... + + +@overload +def _add_reference( + obj: str, + loader: LoaderWithInclude, + node: yaml.nodes.Node, +) -> str: ... + + +@overload +def _add_reference(obj: dict, loader: LoaderWithInclude, node: yaml.nodes.Node) -> dict: ... + + +def _add_reference(obj, loader: LoaderWithInclude, node: yaml.nodes.Node): # type: ignore[no-untyped-def] + """Add file reference information to an object.""" + if isinstance(obj, list): + obj = obj + if isinstance(obj, str): + obj = obj + try: # noqa: SIM105 suppress is much slower + setattr(obj, "__config_file__", loader.get_name) + setattr(obj, "__line__", node.start_mark.line + 1) + except AttributeError: + pass + return obj + + +def _include_yaml(loader: LoaderWithInclude, node: yaml.nodes.Node) -> JSON_TYPE: + """Load another YAML file and embed it using the !include tag. + + Example: + +schema: !include schema_config.yml + + """ + fname = os.path.join(os.path.dirname(loader.get_name), node.value) + try: + loaded_yaml = load_yaml(fname) + if loaded_yaml is None: + loaded_yaml = {} + return _add_reference(loaded_yaml, loader, node) + except FileNotFoundError as exc: + raise dbt_common.exceptions.base.DbtValidationError(str(exc))