diff --git a/src/hera/workflows/script.py b/src/hera/workflows/script.py index 52dbfb0ee..03ef6280c 100644 --- a/src/hera/workflows/script.py +++ b/src/hera/workflows/script.py @@ -629,12 +629,24 @@ class _ScriptDecoratedFunction(Generic[FuncIns, FuncRCov], Protocol): @overload def __call__( # type: ignore [overload-overlap] + self, *args: FuncIns.args, **kwargs: FuncIns.kwargs + ) -> FuncRCov: + # Note: this overload is for calling the decorated function. + # No docstring is provided, so VS Code will use the docstring of the decorated function. + ... + + @overload + def __call__( # type: ignore [overload-overlap, misc] self, ) -> Optional[Union[Step, Task]]: - """@script-decorated function invoked within a workflow, step or task context. + """Create a Step or Task or add the script as a template to the workflow, depending on the context. - May return None, a Step, or a Task, depending on the context. Use `assert isinstance(result, Step)` - or `assert isinstance(result, Task)` to select the correct type if using a type-checker. + * Under a DAG context, creates and returns a Task. + * Under a Steps or Parallel context, creates and returns a Step. + * Under a Workflow context, adds the script as a template to the Workflow and returns None. + + Use `assert isinstance(result, Step)` or `assert isinstance(result, Task)` to select + the correct type if using a type-checker. """ @overload @@ -654,13 +666,16 @@ def __call__( # type: ignore [overload-overlap] with_param: Optional[Any] = ..., with_items: Optional[OneOrMany[Any]] = ..., ) -> Union[Step, Task]: - """@script-decorated function invoked within a step or task context. + """Create a Step or Task, depending on context. + + * Under a DAG context, creates and returns a Task. + * Under a Steps or Parallel context, creates and returns a Step. - May return a Step or a Task, depending on the context. Use `assert isinstance(result, Step)` - or `assert isinstance(result, Task)` to select the correct type if using a type-checker. + Use `assert isinstance(result, Step)` or `assert isinstance(result, Task)` to select + the correct type if using a type-checker. """ # Note: signature must match the Step constructor, except that while name is required for Step, - # it is automatically inferred from the name of the decorated function for @script. + # it is automatically inferred from the name of the decorated function when invoked. @overload def __call__( # type: ignore [overload-overlap] @@ -681,16 +696,12 @@ def __call__( # type: ignore [overload-overlap] dependencies: Optional[List[str]] = ..., depends: Optional[str] = ..., ) -> Task: - """@script-decorated function invoked within a task context.""" - # Note: signature must match the Task constructor, except that while name is required for Task, - # it is automatically inferred from the name of the decorated function for @script. - - @overload - def __call__(self, *args: FuncIns.args, **kwargs: FuncIns.kwargs) -> FuncRCov: - """@script-decorated function invoked outside of a step or task context. + """Create and return a Task. - Will call the decorated function. + Must be invoked under a DAG context. """ + # Note: signature must match the Task constructor, except that while name is required for Task, + # it is automatically inferred from the name of the decorated function when invoked. # Pass actual class of Script to bind inputs to the ParamSpec above diff --git a/tests/typehints/test_script_decorator.py b/tests/typehints/test_script_decorator.py index 8a9a687fc..6f23aa881 100644 --- a/tests/typehints/test_script_decorator.py +++ b/tests/typehints/test_script_decorator.py @@ -65,7 +65,7 @@ def test_invocation_with_no_arguments(): def test_parameter_named_name(): - """If a script has a 'name' argument, the Step/Task overload will take precedence.""" + """If a script has a 'name' argument, the user function will take precedence over other overloads.""" INVOCATION = """ from hera.workflows import script @@ -77,7 +77,7 @@ def simple_script(name: str) -> int: reveal_type(result) """ result = run_mypy(dedent(INVOCATION)) - assert 'Revealed type is "Union[hera.workflows.steps.Step, hera.workflows.task.Task]"' in result + assert 'Revealed type is "builtins.int"' in result def step_and_task_parameters() -> Iterator[Tuple[str, str, str]]: