Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use sqlglot to parse queries to insert select literals #52

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 43 additions & 16 deletions dbt_dry_run/literals.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import re
from typing import Callable, Dict, cast
from typing import Callable, Dict, List, Optional
from uuid import uuid4

import sqlglot

from dbt_dry_run.exception import UpstreamFailedException
from dbt_dry_run.models import BigQueryFieldMode, BigQueryFieldType, Table, TableField
from dbt_dry_run.models.manifest import Node
from dbt_dry_run.results import DryRunStatus, Results
from dbt_dry_run.results import DryRunResult, DryRunStatus, Results

_EXAMPLE_VALUES: Dict[BigQueryFieldType, Callable[[], str]] = {
BigQueryFieldType.STRING: lambda: f"'{uuid4()}'",
Expand Down Expand Up @@ -85,15 +86,45 @@ def get_sql_literal_from_table(table: Table) -> str:
return select_literal


def replace_upstream_sql(node_sql: str, node: Node, table: Table) -> str:
upstream_table_ref = node.to_table_ref_literal()
regex = re.compile(
rf"((?:from|join)(?:\s--.*)?[\r\n\s]*)({upstream_table_ref})",
flags=re.IGNORECASE | re.MULTILINE,
def convert_ast_to_sql(trees: List[sqlglot.Expression]) -> str:
return ";\n".join(tree.sql(sqlglot.dialects.BigQuery) for tree in trees)


def _table_from_node(node: Node) -> sqlglot.Expression:
return sqlglot.exp.table_(
catalog=node.database, db=node.db_schema, table=node.alias, quoted=True
)
select_literal = get_sql_literal_from_table(table)
new_node_sql = regex.sub(r"\1" + select_literal, node_sql)
return new_node_sql


def _remove_alias_from_table(exp: sqlglot.Expression) -> Optional[sqlglot.Expression]:
if isinstance(exp, sqlglot.exp.TableAlias):
return None
return exp


def replace_upstream_sql(node_sql: str, upstream_results: List[DryRunResult]) -> str:
parsed_statements = sqlglot.parse(node_sql, dialect=sqlglot.dialects.BigQuery)
upstream_literals = {
_table_from_node(upstream.node): get_sql_literal_from_table(upstream.table)
for upstream in upstream_results
if upstream.table
}

def transformer(exp: sqlglot.Expression) -> sqlglot.Expression:
if isinstance(exp, sqlglot.exp.Table):
table_without_alias = exp.transform(_remove_alias_from_table)
literal = upstream_literals.get(table_without_alias)
if literal:
new_alias = exp.alias or exp.name
return sqlglot.parse_one(
literal, dialect=sqlglot.dialects.BigQuery
).as_(new_alias)
return exp

transformed_trees = [
parsed.transform(transformer) for parsed in parsed_statements if parsed
]
return convert_ast_to_sql(transformed_trees)


def insert_dependant_sql_literals(node: Node, results: Results) -> str:
Expand All @@ -114,9 +145,5 @@ def insert_dependant_sql_literals(node: Node, results: Results) -> str:
raise UpstreamFailedException(msg)
completed_upstreams = [r for r in upstream_results if r.table]

node_new_sql = node.compiled_code
for upstream in completed_upstreams:
node_new_sql = replace_upstream_sql(
node_new_sql, upstream.node, cast(Table, upstream.table)
)
node_new_sql = replace_upstream_sql(node.compiled_code, completed_upstreams)
return node_new_sql
4 changes: 4 additions & 0 deletions dbt_dry_run/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ def with_linting_errors(self, linting_errors: List[LintingError]) -> "DryRunResu
linting_status=linting_status,
)

@classmethod
def successful(cls, node: Node, table: Table) -> "DryRunResult":
return cls(node=node, table=table, status=DryRunStatus.SUCCESS, exception=None, total_bytes_processed=0)


class Results:
def __init__(self) -> None:
Expand Down
10 changes: 8 additions & 2 deletions dbt_dry_run/test/node_runner/test_incremental_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from unittest.mock import MagicMock, call

import pytest
import sqlglot

from dbt_dry_run import flags
from dbt_dry_run.exception import SchemaChangeException
Expand All @@ -17,7 +18,12 @@
)
from dbt_dry_run.results import DryRunResult, DryRunStatus, Results
from dbt_dry_run.scheduler import ManifestScheduler
from dbt_dry_run.test.utils import SimpleNode, assert_result_has_table, get_executed_sql
from dbt_dry_run.test.utils import (
SimpleNode,
assert_ast_equivalent,
assert_result_has_table,
get_executed_sql,
)

enable_test_example_values(True)

Expand Down Expand Up @@ -103,7 +109,7 @@ def test_partitioned_incremental_model_declares_dbt_max_partition_variable() ->

executed_sql = get_executed_sql(mock_sql_runner)
assert executed_sql.startswith(dbt_max_partition_declaration)
assert node.compiled_code in executed_sql
assert_ast_equivalent(node.compiled_code, executed_sql.split(";")[1])


def test_incremental_model_that_does_not_exist_returns_dry_run_schema() -> None:
Expand Down
2 changes: 1 addition & 1 deletion dbt_dry_run/test/node_runner/test_table_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def test_model_with_dependency_inserts_sql_literal() -> None:

executed_sql = get_executed_sql(mock_sql_runner)
assert result.status == DryRunStatus.SUCCESS
assert executed_sql == "SELECT * FROM (SELECT 'foo' as `a`)"
assert executed_sql == "SELECT * FROM (SELECT 'foo' AS `a`) AS upstream"


def test_model_with_sql_header_executes_header_first() -> None:
Expand Down
2 changes: 1 addition & 1 deletion dbt_dry_run/test/node_runner/test_view_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def test_model_with_dependency_inserts_sql_literal() -> None:
assert result.status == DryRunStatus.SUCCESS
assert result.total_bytes_processed == A_TOTAL_BYTES_PROCESSED
assert executed_sql == sql_with_view_creation(
node, "SELECT * FROM (SELECT 'foo' as `a`)"
node, "SELECT * FROM (SELECT 'foo' AS `a`) AS upstream"
)


Expand Down
Loading