diff --git a/temporalio/contrib/pydantic/converter.py b/temporalio/contrib/pydantic/converter.py index 81997e81..ee4037ca 100644 --- a/temporalio/contrib/pydantic/converter.py +++ b/temporalio/contrib/pydantic/converter.py @@ -1,7 +1,11 @@ import json -from typing import Any, Optional +import typing +from typing import Any, Optional, Type +from pydantic import BaseModel, create_model from pydantic.json import pydantic_encoder + +import temporalio.workflow from temporalio.api.common.v1 import Payload from temporalio.converter import ( CompositePayloadConverter, @@ -9,6 +13,10 @@ DefaultPayloadConverter, JSONPlainPayloadConverter, ) +from temporalio.worker.workflow_sandbox._restrictions import ( + _RestrictedProxy, + _unwrap_restricted_proxy, +) class PydanticJSONPayloadConverter(JSONPlainPayloadConverter): @@ -33,6 +41,26 @@ def to_payload(self, value: Any) -> Optional[Payload]: ).encode(), ) + def from_payload(self, payload: Payload, type_hint: Optional[Type] = None) -> Any: + data = json.loads(payload.data.decode()) + if type_hint and typing.get_origin(type_hint) is list: + assert isinstance(data, list), "Expected list" + [type_hint] = typing.get_args(type_hint) + assert type_hint is not None, "Expected type hint" + assert issubclass(type_hint, BaseModel), "Expected BaseModel" + if temporalio.workflow.unsafe.in_sandbox(): + type_hint = _unwrap_restricted_fields(type_hint) + + return [self._from_dict(d, type_hint) for d in data] + return self._from_dict(data, type_hint) + + def _from_dict(self, data: dict, type_hint: Optional[Type]) -> Any: + assert isinstance(data, dict), "Expected dict" + if type_hint and hasattr(type_hint, "validate"): + return type_hint.validate(data) + + return data + class PydanticPayloadConverter(CompositePayloadConverter): """Payload converter that replaces Temporal JSON conversion with Pydantic @@ -56,3 +84,13 @@ def __init__(self) -> None: payload_converter_class=PydanticPayloadConverter ) """Data converter using Pydantic JSON conversion.""" + + +def _unwrap_restricted_fields( + model: Type[BaseModel], +) -> Type[BaseModel]: + fields = { + name: (_unwrap_restricted_proxy(f.annotation), f) + for name, f in model.model_fields.items() + } + return create_model(model.__name__, **fields) # type: ignore