Skip to content

Commit

Permalink
Add RunnerInput/Output Pydantic V2 classes (#938)
Browse files Browse the repository at this point in the history
Part of #858, follow up to #920  
* Enables users to use Pydantic V2 objects in their scripts while
maintaining V1 usage internally for Hera
* RunnerInput/Output classes are created in hera.workflows.io depending
on the value of _PYDANTIC_VERSION - users will automatically get classes
using their (possibly pinned) version of Pydantic
* I have not yet managed to get an explicit test that uses the automatic
import of V1 classes - we may just have to rely on the Pydantic V1 CI
check.

---------

Signed-off-by: Elliot Gunton <[email protected]>
  • Loading branch information
elliotgunton authored Feb 12, 2024
1 parent c124309 commit b146956
Show file tree
Hide file tree
Showing 16 changed files with 464 additions and 63 deletions.
2 changes: 1 addition & 1 deletion src/hera/shared/_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class Config:
"""

allow_population_by_field_name = True
"""support populating Hera object fields via keyed dictionaries"""
"""support populating Hera object fields by their Field alias"""

allow_mutation = True
"""supports mutating Hera objects post instantiation"""
Expand Down
3 changes: 3 additions & 0 deletions src/hera/workflows/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from hera.workflows.env_from import ConfigMapEnvFrom, SecretEnvFrom
from hera.workflows.exceptions import InvalidDispatchType, InvalidTemplateCall, InvalidType
from hera.workflows.http_template import HTTP
from hera.workflows.io import RunnerInput, RunnerOutput
from hera.workflows.metrics import Counter, Gauge, Histogram, Label, Metric, Metrics
from hera.workflows.operator import Operator
from hera.workflows.parameter import Parameter
Expand Down Expand Up @@ -148,6 +149,8 @@
"Resources",
"RetryPolicy",
"RetryStrategy",
"RunnerInput",
"RunnerOutput",
"RunnerScriptConstructor",
"S3Artifact",
"ScaleIOVolume",
Expand Down
12 changes: 12 additions & 0 deletions src/hera/workflows/io/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""Hera IO models."""
from importlib.util import find_spec

if find_spec("pydantic.v1"):
from hera.workflows.io.v2 import RunnerInput, RunnerOutput
else:
from hera.workflows.io.v1 import RunnerInput, RunnerOutput # type: ignore

__all__ = [
"RunnerInput",
"RunnerOutput",
]
4 changes: 2 additions & 2 deletions src/hera/workflows/io.py → src/hera/workflows/io/v1.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Input/output models for the Hera runner."""
"""Pydantic V1 input/output models for the Hera runner."""
from collections import ChainMap
from typing import Any, List, Optional, Union

Expand Down Expand Up @@ -81,7 +81,7 @@ class RunnerOutput(BaseModel):
"""

exit_code: int = 0
result: Any
result: Any = None

@classmethod
def _get_outputs(cls) -> List[Union[Artifact, Parameter]]:
Expand Down
110 changes: 110 additions & 0 deletions src/hera/workflows/io/v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
"""Pydantic V2 input/output models for the Hera runner.
RunnerInput/Output are only defined in this file if Pydantic v2 is installed.
"""
from collections import ChainMap
from typing import Any, List, Optional, Union

from hera.shared.serialization import serialize
from hera.workflows.artifact import Artifact
from hera.workflows.parameter import Parameter

try:
from inspect import get_annotations # type: ignore
except ImportError:
from hera.workflows._inspect import get_annotations # type: ignore

try:
from typing import Annotated, get_args, get_origin # type: ignore
except ImportError:
from typing_extensions import Annotated, get_args, get_origin # type: ignore

from importlib.util import find_spec

if find_spec("pydantic.v1"):
from pydantic import BaseModel

class RunnerInput(BaseModel):
"""Input model usable by the Hera Runner.
RunnerInput is a Pydantic model which users can create a subclass of. When a subclass
of RunnerInput is used as a function parameter type, the Hera Runner will take the fields
of the user's subclass to create template input parameters and artifacts. See the example
for the script_pydantic_io experimental feature.
"""

@classmethod
def _get_parameters(cls, object_override: "Optional[RunnerInput]" = None) -> List[Parameter]:
parameters = []
annotations = {k: v for k, v in ChainMap(*(get_annotations(c) for c in cls.__mro__)).items()}

for field in cls.model_fields: # type: ignore
if get_origin(annotations[field]) is Annotated:
if isinstance(get_args(annotations[field])[1], Parameter):
param = get_args(annotations[field])[1]
if object_override:
param.default = serialize(getattr(object_override, field))
elif cls.model_fields[field].default: # type: ignore
# Serialize the value (usually done in Parameter's validator)
param.default = serialize(cls.model_fields[field].default) # type: ignore
parameters.append(param)
else:
# Create a Parameter from basic type annotations
if object_override:
parameters.append(Parameter(name=field, default=serialize(getattr(object_override, field))))
else:
parameters.append(Parameter(name=field, default=cls.model_fields[field].default)) # type: ignore
return parameters

@classmethod
def _get_artifacts(cls) -> List[Artifact]:
artifacts = []
annotations = {k: v for k, v in ChainMap(*(get_annotations(c) for c in cls.__mro__)).items()}

for field in cls.model_fields: # type: ignore
if get_origin(annotations[field]) is Annotated:
if isinstance(get_args(annotations[field])[1], Artifact):
artifact = get_args(annotations[field])[1]
if artifact.path is None:
artifact.path = artifact._get_default_inputs_path()
artifacts.append(artifact)
return artifacts

class RunnerOutput(BaseModel):
"""Output model usable by the Hera Runner.
RunnerOutput is a Pydantic model which users can create a subclass of. When a subclass
of RunnerOutput is used as a function return type, the Hera Runner will take the fields
of the user's subclass to create template output parameters and artifacts. See the example
for the script_pydantic_io experimental feature.
"""

exit_code: int = 0
result: Any = None

@classmethod
def _get_outputs(cls) -> List[Union[Artifact, Parameter]]:
outputs = []
annotations = {k: v for k, v in ChainMap(*(get_annotations(c) for c in cls.__mro__)).items()}

for field in cls.model_fields: # type: ignore
if field in {"exit_code", "result"}:
continue
if get_origin(annotations[field]) is Annotated:
if isinstance(get_args(annotations[field])[1], (Parameter, Artifact)):
outputs.append(get_args(annotations[field])[1])
else:
# Create a Parameter from basic type annotations
outputs.append(Parameter(name=field, default=cls.model_fields[field].default)) # type: ignore
return outputs

@classmethod
def _get_output(cls, field_name: str) -> Union[Artifact, Parameter]:
annotations = {k: v for k, v in ChainMap(*(get_annotations(c) for c in cls.__mro__)).items()}
annotation = annotations[field_name]
if get_origin(annotation) is Annotated:
if isinstance(get_args(annotation)[1], (Parameter, Artifact)):
return get_args(annotation)[1]

# Create a Parameter from basic type annotations
return Parameter(name=field_name, default=cls.model_fields[field_name].default) # type: ignore
89 changes: 65 additions & 24 deletions src/hera/workflows/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,34 @@
from hera.shared.serialization import serialize
from hera.workflows import Artifact, Parameter
from hera.workflows.artifact import ArtifactLoader
from hera.workflows.io import RunnerInput, RunnerOutput
from hera.workflows.io.v1 import (
RunnerInput as RunnerInputV1,
RunnerOutput as RunnerOutputV1,
)

try:
from hera.workflows.io.v2 import ( # type: ignore
RunnerInput as RunnerInputV2,
RunnerOutput as RunnerOutputV2,
)
except ImportError:
from hera.workflows.io.v1 import ( # type: ignore
RunnerInput as RunnerInputV2,
RunnerOutput as RunnerOutputV2,
)
from hera.workflows.script import _extract_return_annotation_output

try:
from typing import Annotated, get_args, get_origin # type: ignore
except ImportError:
from typing_extensions import Annotated, get_args, get_origin # type: ignore

try:
from pydantic.type_adapter import TypeAdapter # type: ignore
from pydantic.v1 import parse_obj_as # type: ignore
except ImportError:
from pydantic import parse_obj_as


def _ignore_unmatched_kwargs(f):
"""Make function ignore unmatched kwargs.
Expand All @@ -33,18 +53,7 @@ def _ignore_unmatched_kwargs(f):
def inner(**kwargs):
# filter out kwargs that are not part of the function signature
# and transform them to the correct type
if os.environ.get("hera__script_pydantic_io", None) is None:
filtered_kwargs = {key: _parse(value, key, f) for key, value in kwargs.items() if _is_kwarg_of(key, f)}
return f(**filtered_kwargs)

# filter out kwargs that are not part of the function signature
# and transform them to the correct type. If any kwarg values are
# of RunnerType, pass them through without parsing.
filtered_kwargs = {}
for key, value in kwargs.items():
if _is_kwarg_of(key, f):
type_ = _get_type(key, f)
filtered_kwargs[key] = value if type_ and issubclass(type_, RunnerInput) else _parse(value, key, f)
filtered_kwargs = {key: _parse(value, key, f) for key, value in kwargs.items() if _is_kwarg_of(key, f)}
return f(**filtered_kwargs)

return inner
Expand Down Expand Up @@ -78,8 +87,21 @@ def _parse(value, key, f):
if _is_str_kwarg_of(key, f) or _is_artifact_loaded(key, f) or _is_output_kwarg(key, f):
return value
try:
return json.loads(value)
except json.JSONDecodeError:
if os.environ.get("hera__script_annotations", None) is None:
return json.loads(value)

type_ = _get_unannotated_type(key, f)
loaded_json_value = json.loads(value)

if not type_:
return loaded_json_value

_pydantic_mode = int(os.environ.get("hera__pydantic_mode", _PYDANTIC_VERSION))
if _pydantic_mode == 1:
return parse_obj_as(type_, loaded_json_value)
else:
return TypeAdapter(type_).validate_python(loaded_json_value)
except (json.JSONDecodeError, TypeError):
return value


Expand All @@ -95,6 +117,22 @@ def _get_type(key: str, f: Callable) -> Optional[type]:
return origin_type


def _get_unannotated_type(key: str, f: Callable) -> Optional[type]:
"""Get the type of function param without the 'Annotated' outer type."""
type_ = inspect.signature(f).parameters[key].annotation
if type_ is inspect.Parameter.empty:
return None
if get_origin(type_) is None:
return type_

origin_type = cast(type, get_origin(type_))
if origin_type is Annotated:
return get_args(type_)[0]

# Type could be a dict/list with subscript type
return type_


def _is_str_kwarg_of(key: str, f: Callable) -> bool:
"""Check if param `key` of function `f` has a type annotation of a subclass of str."""
type_ = _get_type(key, f)
Expand Down Expand Up @@ -174,7 +212,7 @@ def map_annotated_artifact(param_name: str, artifact_annotation: Artifact) -> No
elif artifact_annotation.loader is None:
mapped_kwargs[param_name] = artifact_annotation.path

T = TypeVar("T", bound=RunnerInput)
T = TypeVar("T", bound=Union[RunnerInputV1, RunnerInputV2])

def map_runner_input(param_name: str, runner_input_class: T):
"""Map argo input kwargs to the fields of the given RunnerInput.
Expand Down Expand Up @@ -215,18 +253,21 @@ def map_field(field: str) -> Optional[str]:
map_annotated_artifact(param_name, func_param_annotation)
else:
mapped_kwargs[param_name] = kwargs[param_name]
elif get_origin(func_param.annotation) is None and issubclass(func_param.annotation, RunnerInput):
elif get_origin(func_param.annotation) is None and issubclass(
func_param.annotation, (RunnerInputV1, RunnerInputV2)
):
map_runner_input(param_name, func_param.annotation)
else:
mapped_kwargs[param_name] = kwargs[param_name]

return mapped_kwargs


def _save_annotated_return_outputs(
function_outputs: Union[Tuple[Any], Any],
output_annotations: List[Union[Tuple[type, Union[Parameter, Artifact]], Type[RunnerOutput]]],
) -> Optional[RunnerOutput]:
output_annotations: List[
Union[Tuple[type, Union[Parameter, Artifact]], Union[Type[RunnerOutputV1], Type[RunnerOutputV2]]]
],
) -> Optional[Union[RunnerOutputV1, RunnerOutputV2]]:
"""Save the outputs of the function to the specified output destinations.
The output values are matched with the output annotations and saved using the schema:
Expand All @@ -244,7 +285,7 @@ def _save_annotated_return_outputs(
return_obj = None

for output_value, dest in zip(function_outputs, output_annotations):
if isinstance(output_value, RunnerOutput):
if isinstance(output_value, (RunnerOutputV1, RunnerOutputV2)):
if os.environ.get("hera__script_pydantic_io", None) is None:
raise ValueError("hera__script_pydantic_io environment variable is not set")

Expand Down Expand Up @@ -346,13 +387,13 @@ def _runner(entrypoint: str, kwargs_list: List) -> Any:
if _pydantic_mode == 2:
from pydantic import validate_call # type: ignore

function = validate_call(function) # TODO: v2 function blocks pydantic IO
function = validate_call(function)
else:
if _PYDANTIC_VERSION == 1:
from pydantic import validate_arguments
else:
from pydantic.v1 import validate_arguments # type: ignore
function = validate_arguments(function, config=dict(smart_union=True)) # type: ignore
function = validate_arguments(function, config=dict(smart_union=True, arbitrary_types_allowed=True)) # type: ignore

function = _ignore_unmatched_kwargs(function)

Expand Down Expand Up @@ -395,7 +436,7 @@ def _run():
if not result:
return

if isinstance(result, RunnerOutput):
if isinstance(result, (RunnerOutputV1, RunnerOutputV2)):
print(serialize(result.result))
exit(result.exit_code)

Expand Down
Loading

0 comments on commit b146956

Please sign in to comment.