-
Notifications
You must be signed in to change notification settings - Fork 106
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Use overloads in script decorator types (#1180)
**Pull Request Checklist** - [X] Fixes #1174 - [X] Tests added - [ ] Documentation/examples added - [X] [Good commit messages](https://cbea.ms/git-commit/) and/or PR title **Description of PR** Currently, @script-decorated user functions return a union of callables, making it impossible to type-check without suppressing mypy's 'call-arg' error (or equivalent), which completely disables parameter type-checking. This PR changes the type annotations to use an overloaded callable Protocol to select the signature most likely to be what the user was intending to pick. Unfortunately, the call signatures to Step and Task have to be duplicated, as: - mypy and pyright do not support having multiple ParamSpec-decorated overloads - the `name` parameter is required in Step and Task but optional when calling a @script-decorated user function ## Screenshots ### VSCode tooltips The help text here comes from the docstrings in the new Protocol; to obtain this, I had to suppress [ruff's D418 rule](https://docs.astral.sh/ruff/rules/overload-with-docstring/). ![First VSCode tooltip, showing a no-argument overload that may return a step, a task, or None](https://github.com/user-attachments/assets/d60d03e6-02e3-4bb3-953e-9f1a089918cb) ![Second VSCode tooltip, showing an overload that may return a step or task](https://github.com/user-attachments/assets/09fdea13-206f-47ed-ac3f-1a6415689ab7) ![Third VSCode tooltip, showing an overload that will return a task](https://github.com/user-attachments/assets/0f4e35fc-65e0-4627-b1e1-156fe1448f8b) ![Final VSCode tooltip, showing an overload that will call the user code](https://github.com/user-attachments/assets/e70f7205-f337-4ed0-9593-cb735695978d) --------- Signed-off-by: Alice Purcell <[email protected]>
- Loading branch information
1 parent
c4dd936
commit a50ba23
Showing
4 changed files
with
209 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
[mypy] | ||
strict = true | ||
python_version = "3.8" | ||
mypy_path = "src" | ||
|
||
plugins = [ | ||
"pydantic.mypy" | ||
] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
from pathlib import Path | ||
from subprocess import run | ||
from tempfile import TemporaryDirectory | ||
from textwrap import dedent | ||
from typing import Iterator, Tuple | ||
|
||
import pytest | ||
|
||
from hera.workflows import Step, Task | ||
|
||
SIMPLE_SCRIPT = """ | ||
from hera.workflows import script | ||
@script() | ||
def simple_script(input: str) -> int: | ||
print(input) | ||
return len(input) | ||
""".strip() | ||
|
||
|
||
def run_mypy(python_code: str): | ||
with TemporaryDirectory() as d: | ||
python_file = Path(d) / "example.py" | ||
python_file.write_text(python_code) | ||
mypy_cmd = ["mypy", "--config-file", "tests/typehints/test-mypy.toml", str(python_file)] | ||
result = run(mypy_cmd, check=False, capture_output=True, encoding="utf-8") | ||
if result.returncode != 0: | ||
msg = f"Error calling {' '.join(mypy_cmd)}:\n{result.stderr}{result.stdout}" | ||
raise AssertionError(msg) | ||
return result.stdout.replace(d, "") | ||
|
||
|
||
def test_underlying_function_args_invocation(): | ||
"""Verify the underlying implementation of a script can be invoked with positional arguments.""" | ||
STEP_INVOCATION = """ | ||
result = simple_script("Hello World!") | ||
reveal_type(result) | ||
""" | ||
result = run_mypy(SIMPLE_SCRIPT + dedent(STEP_INVOCATION)) | ||
assert 'Revealed type is "builtins.int"' in result | ||
|
||
|
||
def test_underlying_function_kwargs_invocation(): | ||
"""Verify the underlying implementation of a script can be invoked with named arguments.""" | ||
STEP_INVOCATION = """ | ||
result = simple_script(input="Hello World!") | ||
reveal_type(result) | ||
""" | ||
result = run_mypy(SIMPLE_SCRIPT + dedent(STEP_INVOCATION)) | ||
assert 'Revealed type is "builtins.int"' in result | ||
|
||
|
||
def test_invocation_with_no_arguments(): | ||
"""Verify a script can be invoked with no arguments. | ||
Without knowing the invocation context, which the type system does not have access to, | ||
the return type could be None, Step or Task. | ||
""" | ||
STEP_INVOCATION = """ | ||
result = simple_script() | ||
reveal_type(result) | ||
""" | ||
result = run_mypy(SIMPLE_SCRIPT + dedent(STEP_INVOCATION)) | ||
assert 'Revealed type is "Union[hera.workflows.steps.Step, hera.workflows.task.Task, None]"' in result | ||
|
||
|
||
def test_parameter_named_name(): | ||
"""If a script has a 'name' argument, the Step/Task overload will take precedence.""" | ||
INVOCATION = """ | ||
from hera.workflows import script | ||
@script() | ||
def simple_script(name: str) -> int: | ||
return len(name) | ||
result = simple_script(name="some_step") | ||
reveal_type(result) | ||
""" | ||
result = run_mypy(dedent(INVOCATION)) | ||
assert 'Revealed type is "Union[hera.workflows.steps.Step, hera.workflows.task.Task]"' in result | ||
|
||
|
||
def step_and_task_parameters() -> Iterator[Tuple[str, str, str]]: | ||
"""Return all parameters on Step, Task or both""" | ||
# pydantic-v1 syntax: | ||
step_fields = {field.name for field in Step.__fields__.values()} | ||
task_fields = {field.name for field in Task.__fields__.values()} | ||
|
||
for field_name in step_fields | task_fields: | ||
if field_name not in task_fields: | ||
yield (field_name, "Step", "hera.workflows.steps.Step") | ||
elif field_name not in step_fields: | ||
yield (field_name, "Task", "hera.workflows.task.Task") | ||
else: | ||
yield (field_name, "Step", "Union[hera.workflows.steps.Step, hera.workflows.task.Task]") | ||
yield (field_name, "Task", "Union[hera.workflows.steps.Step, hera.workflows.task.Task]") | ||
|
||
|
||
@pytest.mark.parametrize(("parameter_name", "input_type", "revealed_type"), tuple(step_and_task_parameters())) | ||
def test_optional_step_and_task_arguments(parameter_name: str, input_type: str, revealed_type: str) -> None: | ||
"""Verify a script can be invoked with any argument that Step and/or Task accept.""" | ||
STEP_INVOCATION = f""" | ||
from hera.workflows import Step, Task | ||
def some_function(param: {input_type}) -> None: | ||
result = simple_script({parameter_name}=param.{parameter_name}) | ||
reveal_type(result) | ||
""" | ||
result = run_mypy(SIMPLE_SCRIPT + dedent(STEP_INVOCATION)) | ||
assert f'Revealed type is "{revealed_type}"' in result |