Skip to content

Commit

Permalink
Move with_param examples into script_annotations tests
Browse files Browse the repository at this point in the history
  • Loading branch information
alicederyn committed Oct 14, 2024
1 parent e5708fe commit 6f32dce
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 140 deletions.

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,9 +1,3 @@
"""
This example showcases how clients can use Hera to dynamically generate tasks that process outputs from one task in
parallel. This is useful for batch jobs and instances where clients do not know ahead of time how many tasks/entities
they may need to process.
"""

from typing import Annotated, List

from hera.shared import global_config
Expand All @@ -30,9 +24,8 @@ def consume(input: ConsumeInput) -> None:
print("Received value: {value}!".format(value=input.some_value))


# assumes you used `hera.set_global_token` and `hera.set_global_host` so that the workflow can be submitted
with Workflow(generate_name="dynamic-fanout-", entrypoint="d") as w:
with DAG(name="d"):
with DAG(name="dag"):
g = generate(arguments={})
c = consume(with_param=g.get_parameter("some-values"))
g >> c
Original file line number Diff line number Diff line change
@@ -1,9 +1,3 @@
"""
This example showcases how clients can use Hera to dynamically generate tasks that process outputs from one task in
parallel. This is useful for batch jobs and instances where clients do not know ahead of time how many tasks/entities
they may need to process.
"""

from typing import Annotated, List

from hera.shared import global_config
Expand All @@ -22,9 +16,8 @@ def consume(some_value: Annotated[int, Parameter(name="some-value", description=
print("Received value: {value}!".format(value=some_value))


# assumes you used `hera.set_global_token` and `hera.set_global_host` so that the workflow can be submitted
with Workflow(generate_name="dynamic-fanout-", entrypoint="d") as w:
with DAG(name="d"):
with DAG(name="dag"):
g = generate(arguments={})
c = consume(with_param=g.get_parameter("some-values"))
g >> c
64 changes: 64 additions & 0 deletions tests/test_script_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,3 +438,67 @@ def test_script_pydantic_without_experimental_flag(global_config_fixture):
"Unable to instantiate <class 'tests.script_annotations.pydantic_io_v1.ParamOnlyInput'> since it is an experimental feature."
in str(e.value)
)


def test_script_annotated_with_param(global_config_fixture):
"""Test that with_param works correctly with annotated types."""
# GIVEN
global_config_fixture.experimental_features["script_annotations"] = True
global_config_fixture.experimental_features["script_pydantic_io"] = True
# Force a reload of the test module, as the runner performs "importlib.import_module", which
# may fetch a cached version
module_name = "tests.script_annotations.with_param"

module = importlib.import_module(module_name)
importlib.reload(module)
workflow = importlib.import_module(module.__name__).w

# WHEN
workflow_dict = workflow.to_dict()
assert workflow == Workflow.from_dict(workflow_dict)
assert workflow == Workflow.from_yaml(workflow.to_yaml())

# THEN
(dag,) = (t for t in workflow_dict["spec"]["templates"] if t["name"] == "dag")
(consume_task,) = (t for t in dag["dag"]["tasks"] if t["name"] == "consume")

assert consume_task["arguments"]["parameters"] == [
{
"name": "some-value",
"value": "{{item}}",
"description": "this is some value",
}
]
assert consume_task["withParam"] == "{{tasks.generate.outputs.parameters.some-values}}"


def test_script_pydantic_io_with_param(global_config_fixture):
"""Test that with_param works correctly with Pydantic IO types."""
# GIVEN
global_config_fixture.experimental_features["script_annotations"] = True
global_config_fixture.experimental_features["script_pydantic_io"] = True
# Force a reload of the test module, as the runner performs "importlib.import_module", which
# may fetch a cached version
module_name = "tests.script_annotations.pydantic_io_with_param"

module = importlib.import_module(module_name)
importlib.reload(module)
workflow = importlib.import_module(module.__name__).w

# WHEN
workflow_dict = workflow.to_dict()
assert workflow == Workflow.from_dict(workflow_dict)
assert workflow == Workflow.from_yaml(workflow.to_yaml())

# THEN
(dag,) = (t for t in workflow_dict["spec"]["templates"] if t["name"] == "dag")
(consume_task,) = (t for t in dag["dag"]["tasks"] if t["name"] == "consume")

assert consume_task["arguments"]["parameters"] == [
{
"name": "some-value",
"value": "{{item}}",
"description": "this is some value",
}
]
assert consume_task["withParam"] == "{{tasks.generate.outputs.parameters.some-values}}"

0 comments on commit 6f32dce

Please sign in to comment.