Skip to content

Commit

Permalink
Ele 3295 api dbt runner (#1577)
Browse files Browse the repository at this point in the history
* add basic dbt runner tests

* dbt runner - refactor in preparation for api runner

* add dbt api runner!

* api_dbt_runner: bugfix

* get_dbt_runner -> create_dbt_runner

* api_dbt_runner: support env vars

* python 3.8 compat

* more compat changes

* add decorator requirement

* unittests fix

* remove unnecessary decorator dependency

* Delete slim dbt runner
  • Loading branch information
haritamar authored Jul 4, 2024
1 parent 1a09bad commit e1f2138
Show file tree
Hide file tree
Showing 51 changed files with 1,029 additions and 438 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,4 @@ venv/

# elementary outputs
edr_target/
tests/tests_with_db/dbt_project/dbt_packages/
1 change: 1 addition & 0 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ types-PyYAML
types-setuptools
pandas-stubs
types-retry
types-decorator
60 changes: 60 additions & 0 deletions elementary/clients/dbt/api_dbt_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import json
from dataclasses import dataclass
from typing import List, Optional, cast

from dbt.cli.main import dbtRunner, dbtRunnerResult
from google.protobuf.json_format import MessageToDict

from elementary.clients.dbt.command_line_dbt_runner import (
CommandLineDbtRunner,
DbtCommandResult,
)
from elementary.exceptions.exceptions import DbtCommandError
from elementary.utils.env_vars_context import env_vars_context
from elementary.utils.log import get_logger

logger = get_logger(__name__)


@dataclass
class APIDbtCommandResult(DbtCommandResult):
result_obj: dbtRunnerResult


class APIDbtRunner(CommandLineDbtRunner):
def _inner_run_command(
self,
dbt_command_args: List[str],
capture_output: bool,
quiet: bool,
log_output: bool,
log_format: str,
) -> DbtCommandResult:
# The dbt python API always prints the output and we collect the logs using a programmatic callback so no
# need to capture the output anymore here.
dbt_command_args = list(dbt_command_args)
if "-q" not in dbt_command_args and "--quiet" not in dbt_command_args:
dbt_command_args.extend(["--quiet"])

dbt_logs = []

def collect_dbt_command_logs(event):
event_dump = json.dumps(MessageToDict(event)) # type: ignore[arg-type]
logger.debug(f"dbt event msg: {event_dump}")
if event.info.name == "JinjaLogInfo":
dbt_logs.append(event_dump)

with env_vars_context(self.env_vars):
dbt = dbtRunner(callbacks=[collect_dbt_command_logs])
res: dbtRunnerResult = dbt.invoke(dbt_command_args)
output = "\n".join(dbt_logs) or None
if self.raise_on_failure and not res.success:
raise DbtCommandError(base_command_args=dbt_command_args, err_msg=output)

return APIDbtCommandResult(success=res.success, output=output, result_obj=res)

def _parse_ls_command_result(
self, select: Optional[str], result: DbtCommandResult
) -> List[str]:
ls_result = cast(APIDbtCommandResult, result).result_obj.result
return cast(List[str], ls_result)
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import os
import subprocess
from typing import Any, Dict, List, Optional, Tuple
from dataclasses import dataclass
from typing import Any, Dict, List, Optional

import yaml

Expand All @@ -10,13 +10,18 @@
from elementary.exceptions.exceptions import DbtCommandError, DbtLsCommandError
from elementary.monitor.dbt_project_utils import is_dbt_package_up_to_date
from elementary.utils.env_vars import is_debug
from elementary.utils.json_utils import try_load_json
from elementary.utils.log import get_logger

logger = get_logger(__name__)


class DbtRunner(BaseDbtRunner):
@dataclass
class DbtCommandResult:
success: bool
output: Optional[str]


class CommandLineDbtRunner(BaseDbtRunner):
ELEMENTARY_LOG_PREFIX = "Elementary: "

def __init__(
Expand Down Expand Up @@ -47,6 +52,21 @@ def __init__(
elif run_deps_if_needed:
self._run_deps_if_needed()

def _inner_run_command(
self,
dbt_command_args: List[str],
capture_output: bool,
quiet: bool,
log_output: bool,
log_format: str,
) -> DbtCommandResult:
raise NotImplementedError

def _parse_ls_command_result(
self, select: Optional[str], result: DbtCommandResult
) -> List[str]:
raise NotImplementedError

def _run_command(
self,
command_args: List[str],
Expand All @@ -55,82 +75,72 @@ def _run_command(
vars: Optional[dict] = None,
quiet: bool = False,
log_output: bool = True,
) -> Tuple[bool, Optional[str]]:
dbt_command = ["dbt"]
) -> DbtCommandResult:
dbt_command_args = []
if capture_output:
dbt_command.extend(["--log-format", log_format])
dbt_command.extend(command_args)
dbt_command.extend(["--project-dir", self.project_dir])
dbt_command_args.extend(["--log-format", log_format])
dbt_command_args.extend(command_args)
dbt_command_args.extend(["--project-dir", self.project_dir])
if self.profiles_dir:
dbt_command.extend(["--profiles-dir", self.profiles_dir])
dbt_command_args.extend(["--profiles-dir", self.profiles_dir])
if self.target:
dbt_command.extend(["--target", self.target])
dbt_command_args.extend(["--target", self.target])

all_vars = self._get_all_vars(vars)
if all_vars:
log_command = dbt_command.copy()
log_command.extend(
log_command_args = dbt_command_args.copy()
log_command_args.extend(
[
"--vars",
json.dumps(self._get_secret_masked_vars(all_vars)),
]
)
dbt_command.extend(["--vars", json.dumps(all_vars)])
dbt_command_args.extend(["--vars", json.dumps(all_vars)])
else:
log_command = dbt_command
log_command_args = dbt_command_args

log_msg = f"Running {' '.join(log_command)}"
log_msg = f"Running dbt command {' '.join(log_command_args)}"
if not quiet:
logger.info(log_msg)
else:
logger.debug(log_msg)
try:
result = subprocess.run(
dbt_command,
check=self.raise_on_failure,
capture_output=(capture_output or quiet),
env=self._get_command_env(),
cwd=self.project_dir,
)
except subprocess.CalledProcessError as err:
logs = list(parse_dbt_output(err.output.decode())) if err.output else []
if capture_output and (log_output or is_debug()):
for log in logs:
logger.info(log.msg)
raise DbtCommandError(err, command_args, logs=logs)

output = None
if capture_output:
output = result.stdout.decode("utf-8")
result = self._inner_run_command(
dbt_command_args,
capture_output=capture_output,
quiet=quiet,
log_output=log_output,
log_format=log_format,
)

if capture_output and result.output:
logger.debug(
f"Result bytes size for command '{log_command}' is {len(result.stdout)}"
f"Result bytes size for command '{log_command_args}' is {len(result.output)}"
)
if log_output or is_debug():
for log in parse_dbt_output(output):
for log in parse_dbt_output(result.output, log_format):
logger.info(log.msg)

if result.returncode != 0:
return False, output
return True, output
return result

def deps(self, quiet: bool = False, capture_output: bool = True) -> bool:
success, _ = self._run_command(
result = self._run_command(
command_args=["deps"], quiet=quiet, capture_output=capture_output
)
return success
return result.success

def seed(self, select: Optional[str] = None, full_refresh: bool = False) -> bool:
command_args = ["seed"]
if full_refresh:
command_args.append("--full-refresh")
if select:
command_args.extend(["-s", select])
success, _ = self._run_command(command_args)
return success
result = self._run_command(command_args)
return result.success

def snapshot(self) -> bool:
success, _ = self._run_command(["snapshot"])
return success
result = self._run_command(["snapshot"])
return result.success

def run_operation(
self,
Expand Down Expand Up @@ -158,20 +168,20 @@ def run_operation(
command_args = ["run-operation", macro_to_run]
json_args = json.dumps(macro_to_run_args, ensure_ascii=False)
command_args.extend(["--args", json_args])
success, command_output = self._run_command(
result = self._run_command(
command_args=command_args,
capture_output=capture_output,
vars=vars,
quiet=quiet,
log_output=log_output,
)
if log_errors and not success:
if log_errors and not result.success:
logger.error(
f'Failed to run macro: "{macro_name}"\nRun output: {command_output}'
f'Failed to run macro: "{macro_name}"\nRun output: {result.output}'
)
run_operation_results = []
if capture_output and command_output is not None:
for log in parse_dbt_output(command_output):
if capture_output and result.output is not None:
for log in parse_dbt_output(result.output):
if log_errors and log.level == "error":
logger.error(log.msg)
continue
Expand Down Expand Up @@ -200,13 +210,13 @@ def run(
command_args.extend(["-s", select])
if selector:
command_args.extend(["--selector", selector])
success, _ = self._run_command(
result = self._run_command(
command_args=command_args,
vars=vars,
quiet=quiet,
capture_output=capture_output,
)
return success
return result.success

def test(
self,
Expand All @@ -218,58 +228,37 @@ def test(
command_args = ["test"]
if select:
command_args.extend(["-s", select])
success, _ = self._run_command(
result = self._run_command(
command_args=command_args,
vars=vars,
quiet=quiet,
capture_output=capture_output,
)
return success

def _get_command_env(self):
env = os.environ.copy()
if self.env_vars is not None:
env.update(self.env_vars)
return env
return result.success

def debug(self, quiet: bool = False) -> bool:
success, _ = self._run_command(command_args=["debug"], quiet=quiet)
return success
result = self._run_command(command_args=["debug"], quiet=quiet)
return result.success

def retry(self, quiet: bool = False) -> bool:
success, _ = self._run_command(command_args=["retry"], quiet=quiet)
return success
result = self._run_command(command_args=["retry"], quiet=quiet)
return result.success

def ls(self, select: Optional[str] = None) -> list:
command_args = ["-q", "ls"]
if select:
command_args.extend(["-s", select])
try:
success, command_output_string = self._run_command(
result = self._run_command(
command_args=command_args, capture_output=True, log_format="text"
)
command_outputs = (
command_output_string.splitlines() if command_output_string else []
)
# ls command didn't match nodes.
# When no node is matched, ls command returns 2 dicts with warning message that there are no matches.
if (
len(command_outputs) == 2
and try_load_json(command_outputs[0])
and try_load_json(command_outputs[1])
):
logger.warning(
f"The selection criterion '{select}' does not match any nodes"
)
return []
# When nodes are matched, ls command returns strings of the node names.
else:
return command_outputs
return self._parse_ls_command_result(select, result)
except DbtCommandError:
raise DbtLsCommandError(select)

def source_freshness(self):
self._run_command(command_args=["source", "freshness"])
def source_freshness(self) -> bool:
result = self._run_command(command_args=["source", "freshness"])
return result.success

def _get_installed_packages_names(self):
packages_dir = os.path.join(
Expand Down
7 changes: 5 additions & 2 deletions elementary/clients/dbt/dbt_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,12 @@ def __str__(self) -> str:
return as_string


def parse_dbt_output(output: str) -> Iterator[DbtLog]:
def parse_dbt_output(output: str, log_format: str = "json") -> Iterator[DbtLog]:
for log_line in output.strip().splitlines():
try:
yield DbtLog.from_log_line(log_line)
if log_format == "json":
yield DbtLog.from_log_line(log_line)
elif log_format == "text":
yield DbtLog(msg=log_line, level="info", exception=None)
except json.JSONDecodeError:
logger.debug(f"Unable to parse dbt log message: {log_line}", exc_info=True)
Loading

0 comments on commit e1f2138

Please sign in to comment.