Skip to content

Commit

Permalink
simplify jinja_static
Browse files Browse the repository at this point in the history
  • Loading branch information
MichelleArk committed Jul 22, 2024
1 parent eca8c3e commit dc10493
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 83 deletions.
69 changes: 19 additions & 50 deletions core/dbt/clients/jinja_static.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,68 +157,37 @@ def statically_parse_adapter_dispatch(func_call, ctx, db_wrapper):
return possible_macro_calls


def statically_parse_ref(input: str) -> RefArgs:
def statically_parse_ref_or_source(expression: str) -> Union[RefArgs, List[str]]:
"""
Returns a RefArgs object corresponding to an input jinja expression.
Returns a RefArgs or List[str] object, corresponding to ref or source respectively, given an input jinja expression.
input: str representing how input node is referenced in tested model sql
* examples:
- "ref('my_model_a')"
- "ref('my_model_a', version=3)"
- "ref('package', 'my_model_a', version=3)"
If input is not a well-formed jinja ref expression, a ParsingError is raised.
"""
try:
statically_parsed = py_extract_from_source(f"{{{{ {input} }}}}")
except ExtractionError:
raise ParsingError(f"Invalid jinja expression: {input}")

if not statically_parsed.get("refs"):
raise ParsingError(f"Invalid ref expression: {input}")

ref = list(statically_parsed["refs"])[0]
return RefArgs(package=ref.get("package"), name=ref.get("name"), version=ref.get("version"))


def statically_parse_source(input: str) -> List[str]:
"""
Returns a RefArgs object corresponding to an input jinja expression.
input: str representing how input node is referenced in tested model sql
* examples:
- "source('my_source_schema', 'my_source_name')"
If input is not a well-formed jinja source expression, ParsingError is raised.
If input is not a well-formed jinja ref or source expression, a ParsingError is raised.
"""
try:
statically_parsed = py_extract_from_source(f"{{{{ {input} }}}}")
except ExtractionError:
raise ParsingError(f"Invalid jinja expression: {input}")

if not statically_parsed.get("sources"):
raise ParsingError(f"Invalid source expression: {input}")

source = list(statically_parsed["sources"])[0]
source_name, source_table_name = source
return [source_name, source_table_name]


def statically_parse_ref_or_source(expression: str) -> Union[RefArgs, List[str]]:
ref_or_source: Union[RefArgs, List[str]]
valid_ref = True
valid_source = True

try:
ref_or_source = statically_parse_ref(expression)
except ParsingError:
valid_ref = False
try:
ref_or_source = statically_parse_source(expression)
except ParsingError:
valid_source = False

if not valid_ref and not valid_source:
raise ParsingError(f"Invalid ref or source syntax: {expression}.")
statically_parsed = py_extract_from_source(f"{{{{ {expression} }}}}")
except ExtractionError:
raise ParsingError(f"Invalid jinja expression: {expression}")

if statically_parsed.get("refs"):
raw_ref = list(statically_parsed["refs"])[0]
ref_or_source = RefArgs(
package=raw_ref.get("package"),
name=raw_ref.get("name"),
version=raw_ref.get("version"),
)
elif statically_parsed.get("sources"):
source_name, source_table_name = list(statically_parsed["sources"])[0]
ref_or_source = [source_name, source_table_name]
else:
raise ParsingError(f"Invalid ref or source expression: {expression}")

Check warning on line 191 in core/dbt/clients/jinja_static.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/clients/jinja_static.py#L191

Added line #L191 was not covered by tests

return ref_or_source
33 changes: 0 additions & 33 deletions tests/unit/clients/test_jinja_static.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@
from dbt.artifacts.resources import RefArgs
from dbt.clients.jinja_static import (
statically_extract_macro_calls,
statically_parse_ref,
statically_parse_ref_or_source,
statically_parse_source,
)
from dbt.context.base import generate_base_context
from dbt.exceptions import ParsingError
Expand Down Expand Up @@ -61,37 +59,6 @@ def test_extract_macro_calls(macro_string, expected_possible_macro_calls):
assert possible_macro_calls == expected_possible_macro_calls


class TestStaticallyParseRef:
@pytest.mark.parametrize("invalid_expression", ["invalid", "source('schema', 'table')"])
def test_invalid_expression(self, invalid_expression):
with pytest.raises(ParsingError):
statically_parse_ref(invalid_expression)

@pytest.mark.parametrize(
"ref_expression,expected_ref_args",
[
("ref('model')", RefArgs(name="model")),
("ref('package','model')", RefArgs(name="model", package="package")),
("ref('model',v=3)", RefArgs(name="model", version=3)),
("ref('package','model',v=3)", RefArgs(name="model", package="package", version=3)),
],
)
def test_valid_ref_expression(self, ref_expression, expected_ref_args):
ref_args = statically_parse_ref(ref_expression)
assert ref_args == expected_ref_args


class TestStaticallyParseSource:
@pytest.mark.parametrize("invalid_expression", ["invalid", "ref('package', 'model')"])
def test_invalid_expression(self, invalid_expression):
with pytest.raises(ParsingError):
statically_parse_source(invalid_expression)

def test_valid_ref_expression(self):
parsed_source = statically_parse_source("source('schema', 'table')")
assert parsed_source == ["schema", "table"]


class TestStaticallyParseRefOrSource:
def test_invalid_expression(self):
with pytest.raises(ParsingError):
Expand Down

0 comments on commit dc10493

Please sign in to comment.