Skip to content

Commit

Permalink
Use overloads in script decorator types (#1180)
Browse files Browse the repository at this point in the history
**Pull Request Checklist**
- [X] Fixes #1174
- [X] Tests added
- [ ] Documentation/examples added
- [X] [Good commit messages](https://cbea.ms/git-commit/) and/or PR
title

**Description of PR**
Currently, @script-decorated user functions return a union of callables,
making it impossible to type-check without suppressing mypy's 'call-arg'
error (or equivalent), which completely disables parameter
type-checking.

This PR changes the type annotations to use an overloaded callable
Protocol to select the signature most likely to be what the user was
intending to pick.

Unfortunately, the call signatures to Step and Task have to be
duplicated, as:
- mypy and pyright do not support having multiple ParamSpec-decorated
overloads
- the `name` parameter is required in Step and Task but optional when
calling a @script-decorated user function

## Screenshots

### VSCode tooltips
The help text here comes from the docstrings in the new Protocol; to
obtain this, I had to suppress [ruff's D418
rule](https://docs.astral.sh/ruff/rules/overload-with-docstring/).

![First VSCode tooltip, showing a no-argument overload that may return a
step, a task, or
None](https://github.com/user-attachments/assets/d60d03e6-02e3-4bb3-953e-9f1a089918cb)
![Second VSCode tooltip, showing an overload that may return a step or
task](https://github.com/user-attachments/assets/09fdea13-206f-47ed-ac3f-1a6415689ab7)
![Third VSCode tooltip, showing an overload that will return a
task](https://github.com/user-attachments/assets/0f4e35fc-65e0-4627-b1e1-156fe1448f8b)
![Final VSCode tooltip, showing an overload that will call the user
code](https://github.com/user-attachments/assets/e70f7205-f337-4ed0-9593-cb735695978d)

---------

Signed-off-by: Alice Purcell <[email protected]>
  • Loading branch information
alicederyn authored Aug 29, 2024
1 parent c4dd936 commit a50ba23
Show file tree
Hide file tree
Showing 4 changed files with 209 additions and 11 deletions.
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,10 @@ src = ["src"]

[tool.ruff.lint]
select = ["E", "F", "D"]
ignore = ["E501"]
ignore = [
"D418", # Bans docstrings on overloads, but VS Code displays these to users in tooltips.
"E501", # Line too long.
]
extend-select = ["I"]

[tool.ruff.lint.per-file-ignores]
Expand Down
97 changes: 87 additions & 10 deletions src/hera/workflows/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,19 @@
from typing import (
Any,
Callable,
Dict,
Generic,
List,
Literal,
Optional,
Protocol,
Sequence,
Tuple,
Type,
TypeVar,
Union,
cast,
overload,
)

if sys.version_info >= (3, 9):
Expand All @@ -45,8 +49,10 @@
from hera.workflows._context import _context
from hera.workflows._meta_mixins import CallableTemplateMixin
from hera.workflows._mixins import (
ArgumentsT,
ContainerMixin,
EnvIOMixin,
OneOrMany,
ResourceMixin,
TemplateMixin,
VolumeMountMixin,
Expand All @@ -71,16 +77,21 @@
Output as OutputV2,
)
from hera.workflows.models import (
ContinueOn,
EnvVar,
Inputs as ModelInputs,
Lifecycle,
LifecycleHook,
Outputs as ModelOutputs,
ScriptTemplate as _ModelScriptTemplate,
SecurityContext,
Sequence as _ModelSequence,
Template as _ModelTemplate,
TemplateRef,
ValueFrom,
)
from hera.workflows.parameter import MISSING, Parameter
from hera.workflows.protocol import Templatable
from hera.workflows.steps import Step
from hera.workflows.task import Task
from hera.workflows.volume import _BaseVolume
Expand Down Expand Up @@ -606,27 +617,93 @@ def _output_annotations_used(source: Callable) -> bool:

FuncIns = ParamSpec("FuncIns") # For input types of given func to script decorator
FuncR = TypeVar("FuncR") # For return type of given func to script decorator
FuncRCov = TypeVar("FuncRCov", covariant=True)

ScriptIns = ParamSpec("ScriptIns") # For input attributes of Script class
StepIns = ParamSpec("StepIns") # # For input attributes of Step class
TaskIns = ParamSpec("TaskIns") # # For input attributes of Task class


# Pass actual classes of Script, Step and Task to bind inputs to the ParamSpecs above
class _ScriptDecoratedFunction(Generic[FuncIns, FuncRCov], Protocol):
"""Type assigned to functions decorated with @script."""

# Note: For more details about overload-overlap, see https://github.com/python/typeshed/issues/12178

@overload
def __call__( # type: ignore [overload-overlap]
self,
) -> Optional[Union[Step, Task]]:
"""@script-decorated function invoked within a workflow, step or task context.
May return None, a Step, or a Task, depending on the context. Use `assert isinstance(result, Step)`
or `assert isinstance(result, Task)` to select the correct type if using a type-checker.
"""

@overload
def __call__( # type: ignore [overload-overlap]
self,
*,
name: str = ...,
continue_on: Optional[ContinueOn] = ...,
hooks: Optional[Dict[str, LifecycleHook]] = ...,
on_exit: Optional[Union[str, Templatable]] = ...,
template: Optional[Union[str, _ModelTemplate, TemplateMixin, CallableTemplateMixin]] = ...,
template_ref: Optional[TemplateRef] = ...,
inline: Optional[Union[_ModelTemplate, TemplateMixin]] = ...,
when: Optional[str] = ...,
with_sequence: Optional[_ModelSequence] = ...,
arguments: ArgumentsT = ...,
with_param: Optional[Any] = ...,
with_items: Optional[OneOrMany[Any]] = ...,
) -> Union[Step, Task]:
"""@script-decorated function invoked within a step or task context.
May return a Step or a Task, depending on the context. Use `assert isinstance(result, Step)`
or `assert isinstance(result, Task)` to select the correct type if using a type-checker.
"""
# Note: signature must match the Step constructor, except that while name is required for Step,
# it is automatically inferred from the name of the decorated function for @script.

@overload
def __call__( # type: ignore [overload-overlap]
self,
*,
name: str = ...,
continue_on: Optional[ContinueOn] = ...,
hooks: Optional[Dict[str, LifecycleHook]] = ...,
on_exit: Optional[Union[str, Templatable]] = ...,
template: Optional[Union[str, _ModelTemplate, TemplateMixin, CallableTemplateMixin]] = ...,
template_ref: Optional[TemplateRef] = ...,
inline: Optional[Union[_ModelTemplate, TemplateMixin]] = ...,
when: Optional[str] = ...,
with_sequence: Optional[_ModelSequence] = ...,
arguments: ArgumentsT = ...,
with_param: Optional[Any] = ...,
with_items: Optional[OneOrMany[Any]] = ...,
dependencies: Optional[List[str]] = ...,
depends: Optional[str] = ...,
) -> Task:
"""@script-decorated function invoked within a task context."""
# Note: signature must match the Task constructor, except that while name is required for Task,
# it is automatically inferred from the name of the decorated function for @script.

@overload
def __call__(self, *args: FuncIns.args, **kwargs: FuncIns.kwargs) -> FuncRCov:
"""@script-decorated function invoked outside of a step or task context.
Will call the decorated function.
"""


# Pass actual class of Script to bind inputs to the ParamSpec above
def _add_type_hints(
_script: Callable[ScriptIns, Script],
_step: Callable[StepIns, Step],
_task: Callable[TaskIns, Task],
) -> Callable[
...,
Callable[
ScriptIns, # this adds Script type hints to the underlying *library* function kwargs, i.e. `script`
Callable[ # we will return a function that is a decorator
[Callable[FuncIns, FuncR]], # taking underlying *user* function
Union[ # able to return FuncR | Step | Task | None
Callable[FuncIns, FuncR],
Callable[StepIns, Optional[Step]],
Callable[TaskIns, Optional[Task]],
_ScriptDecoratedFunction[ # and returning an overloaded method that can additionally return Task or Step
FuncIns, FuncR
],
],
],
Expand All @@ -635,7 +712,7 @@ def _add_type_hints(
return lambda func: func


@_add_type_hints(Script, Step, Task) # type: ignore
@_add_type_hints(Script)
def script(**script_kwargs) -> Callable:
"""A decorator that wraps a function into a Script object.
Expand Down
9 changes: 9 additions & 0 deletions tests/typehints/test-mypy.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
[mypy]
strict = true
python_version = "3.8"
mypy_path = "src"

plugins = [
"pydantic.mypy"
]

109 changes: 109 additions & 0 deletions tests/typehints/test_script_decorator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from pathlib import Path
from subprocess import run
from tempfile import TemporaryDirectory
from textwrap import dedent
from typing import Iterator, Tuple

import pytest

from hera.workflows import Step, Task

SIMPLE_SCRIPT = """
from hera.workflows import script
@script()
def simple_script(input: str) -> int:
print(input)
return len(input)
""".strip()


def run_mypy(python_code: str):
with TemporaryDirectory() as d:
python_file = Path(d) / "example.py"
python_file.write_text(python_code)
mypy_cmd = ["mypy", "--config-file", "tests/typehints/test-mypy.toml", str(python_file)]
result = run(mypy_cmd, check=False, capture_output=True, encoding="utf-8")
if result.returncode != 0:
msg = f"Error calling {' '.join(mypy_cmd)}:\n{result.stderr}{result.stdout}"
raise AssertionError(msg)
return result.stdout.replace(d, "")


def test_underlying_function_args_invocation():
"""Verify the underlying implementation of a script can be invoked with positional arguments."""
STEP_INVOCATION = """
result = simple_script("Hello World!")
reveal_type(result)
"""
result = run_mypy(SIMPLE_SCRIPT + dedent(STEP_INVOCATION))
assert 'Revealed type is "builtins.int"' in result


def test_underlying_function_kwargs_invocation():
"""Verify the underlying implementation of a script can be invoked with named arguments."""
STEP_INVOCATION = """
result = simple_script(input="Hello World!")
reveal_type(result)
"""
result = run_mypy(SIMPLE_SCRIPT + dedent(STEP_INVOCATION))
assert 'Revealed type is "builtins.int"' in result


def test_invocation_with_no_arguments():
"""Verify a script can be invoked with no arguments.
Without knowing the invocation context, which the type system does not have access to,
the return type could be None, Step or Task.
"""
STEP_INVOCATION = """
result = simple_script()
reveal_type(result)
"""
result = run_mypy(SIMPLE_SCRIPT + dedent(STEP_INVOCATION))
assert 'Revealed type is "Union[hera.workflows.steps.Step, hera.workflows.task.Task, None]"' in result


def test_parameter_named_name():
"""If a script has a 'name' argument, the Step/Task overload will take precedence."""
INVOCATION = """
from hera.workflows import script
@script()
def simple_script(name: str) -> int:
return len(name)
result = simple_script(name="some_step")
reveal_type(result)
"""
result = run_mypy(dedent(INVOCATION))
assert 'Revealed type is "Union[hera.workflows.steps.Step, hera.workflows.task.Task]"' in result


def step_and_task_parameters() -> Iterator[Tuple[str, str, str]]:
"""Return all parameters on Step, Task or both"""
# pydantic-v1 syntax:
step_fields = {field.name for field in Step.__fields__.values()}
task_fields = {field.name for field in Task.__fields__.values()}

for field_name in step_fields | task_fields:
if field_name not in task_fields:
yield (field_name, "Step", "hera.workflows.steps.Step")
elif field_name not in step_fields:
yield (field_name, "Task", "hera.workflows.task.Task")
else:
yield (field_name, "Step", "Union[hera.workflows.steps.Step, hera.workflows.task.Task]")
yield (field_name, "Task", "Union[hera.workflows.steps.Step, hera.workflows.task.Task]")


@pytest.mark.parametrize(("parameter_name", "input_type", "revealed_type"), tuple(step_and_task_parameters()))
def test_optional_step_and_task_arguments(parameter_name: str, input_type: str, revealed_type: str) -> None:
"""Verify a script can be invoked with any argument that Step and/or Task accept."""
STEP_INVOCATION = f"""
from hera.workflows import Step, Task
def some_function(param: {input_type}) -> None:
result = simple_script({parameter_name}=param.{parameter_name})
reveal_type(result)
"""
result = run_mypy(SIMPLE_SCRIPT + dedent(STEP_INVOCATION))
assert f'Revealed type is "{revealed_type}"' in result

0 comments on commit a50ba23

Please sign in to comment.