Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support passing args as keyword argument for run-operation in programmatic invocations #10473

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20240722-133729.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Support passing `args` as keyword argument for `run-operation` in programmatic invocations
time: 2024-07-22T13:37:29.285621-06:00
custom:
Author: dbeatty10
Issue: "10473"
4 changes: 2 additions & 2 deletions core/dbt/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ def __init__(
callbacks = []
self.callbacks = callbacks

def invoke(self, args: List[str], **kwargs) -> dbtRunnerResult:
def invoke(self, invocation_args: List[str], /, **kwargs) -> dbtRunnerResult:
try:
dbt_ctx = cli.make_context(cli.name, args.copy())
dbt_ctx = cli.make_context(cli.name, invocation_args.copy())
dbt_ctx.obj = {
"manifest": self.manifest,
"callbacks": self.callbacks,
Expand Down
26 changes: 11 additions & 15 deletions core/dbt/tests/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,28 +70,25 @@
# run_dbt(["run", "--vars", "seed_name: base"])
# If the command is expected to fail, pass in "expect_pass=False"):
# run_dbt(["test"], expect_pass=False)
def run_dbt(
args: Optional[List[str]] = None,
expect_pass: bool = True,
):
def run_dbt(invocation_args: Optional[List[str]] = None, /, expect_pass: bool = True, **kwargs):
# reset global vars
reset_metadata_vars()

if args is None:
args = ["run"]
if invocation_args is None:
invocation_args = ["run"]

Check warning on line 78 in core/dbt/tests/util.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/tests/util.py#L77-L78

Added lines #L77 - L78 were not covered by tests

print("\n\nInvoking dbt with {}".format(args))
print("\n\nInvoking dbt with {}".format(invocation_args))

Check warning on line 80 in core/dbt/tests/util.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/tests/util.py#L80

Added line #L80 was not covered by tests
from dbt.flags import get_flags

flags = get_flags()
project_dir = getattr(flags, "PROJECT_DIR", None)
profiles_dir = getattr(flags, "PROFILES_DIR", None)
if project_dir and "--project-dir" not in args:
args.extend(["--project-dir", project_dir])
if profiles_dir and "--profiles-dir" not in args:
args.extend(["--profiles-dir", profiles_dir])
if project_dir and "--project-dir" not in invocation_args:
invocation_args.extend(["--project-dir", project_dir])
if profiles_dir and "--profiles-dir" not in invocation_args:
invocation_args.extend(["--profiles-dir", profiles_dir])

Check warning on line 89 in core/dbt/tests/util.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/tests/util.py#L86-L89

Added lines #L86 - L89 were not covered by tests
dbt = dbtRunner()
res = dbt.invoke(args)
res = dbt.invoke(invocation_args, **kwargs)

Check warning on line 91 in core/dbt/tests/util.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/tests/util.py#L91

Added line #L91 was not covered by tests

# the exception is immediately raised to be caught in tests
# using a pattern like `with pytest.raises(SomeException):`
Expand All @@ -109,13 +106,12 @@
# start with the "--debug" flag. The structured schema log CI test
# will turn the logs into json, so you have to be prepared for that.
def run_dbt_and_capture(
args: Optional[List[str]] = None,
expect_pass: bool = True,
invocation_args: Optional[List[str]] = None, /, expect_pass: bool = True, **kwargs
):
try:
stringbuf = StringIO()
capture_stdout_logs(stringbuf)
res = run_dbt(args, expect_pass=expect_pass)
res = run_dbt(invocation_args, expect_pass=expect_pass, **kwargs)

Check warning on line 114 in core/dbt/tests/util.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/tests/util.py#L114

Added line #L114 was not covered by tests
stdout = stringbuf.getvalue()

finally:
Expand Down
2 changes: 1 addition & 1 deletion tests/functional/colors/test_colors.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def test_no_use_colors(self, project):
)

def assert_colors_used(self, flag, expect_colors):
_, stdout = run_dbt_and_capture(args=[flag, "run"], expect_pass=False)
_, stdout = run_dbt_and_capture([flag, "run"], expect_pass=False)
# pattern to match formatted log output
pattern = re.compile(r"\[31m.*|\[33m.*")
stdout_contains_formatting_characters = bool(pattern.search(stdout))
Expand Down
2 changes: 1 addition & 1 deletion tests/functional/list/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def run_dbt_ls(self, args=None, expect_pass=True):
full_args = ["ls"]
if args is not None:
full_args += args
result = run_dbt(args=full_args, expect_pass=expect_pass)
result = run_dbt(full_args, expect_pass=expect_pass)

return result

Expand Down
4 changes: 2 additions & 2 deletions tests/functional/run_operations/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@
{% endmacro %}


{% macro print_something() %}
{{ print("You're doing awesome!") }}
{% macro print_something(message="You're doing awesome!") %}
{{ print(message) }}
{% endmacro %}
"""

Expand Down
6 changes: 6 additions & 0 deletions tests/functional/run_operations/test_run_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,12 @@ def test_macro_args(self, project):
self.run_operation("table_name_args", table_name="my_fancy_table")
check_table_does_exist(project.adapter, "my_fancy_table")

def test_args_as_keyword(self, project):
results, log_output = run_dbt_and_capture(
["run-operation", "print_something"], args={"message": "Morning coffee"}
)
assert "Morning coffee" in log_output

def test_macro_exception(self, project):
self.run_operation("syntax_error", False)

Expand Down
2 changes: 1 addition & 1 deletion tests/functional/threading/test_thread_count.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,5 @@ def profiles_config_update(self):
return {"threads": 2}

def test_threading_8x(self, project):
results = run_dbt(args=["run", "--threads", "16"])
results = run_dbt(["run", "--threads", "16"])
assert len(results), 20
Loading