Skip to content

Commit

Permalink
Implement pydantic from_payload
Browse files Browse the repository at this point in the history
  • Loading branch information
dandavison committed Feb 2, 2025
1 parent 1e1c2c8 commit 1e09d52
Showing 1 changed file with 39 additions and 1 deletion.
40 changes: 39 additions & 1 deletion temporalio/contrib/pydantic/converter.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
import json
from typing import Any, Optional
import typing
from typing import Any, Optional, Type

from pydantic import BaseModel, create_model
from pydantic.json import pydantic_encoder

import temporalio.workflow
from temporalio.api.common.v1 import Payload
from temporalio.converter import (
CompositePayloadConverter,
DataConverter,
DefaultPayloadConverter,
JSONPlainPayloadConverter,
)
from temporalio.worker.workflow_sandbox._restrictions import (
_RestrictedProxy,
_unwrap_restricted_proxy,
)


class PydanticJSONPayloadConverter(JSONPlainPayloadConverter):
Expand All @@ -33,6 +41,26 @@ def to_payload(self, value: Any) -> Optional[Payload]:
).encode(),
)

def from_payload(self, payload: Payload, type_hint: Optional[Type] = None) -> Any:
data = json.loads(payload.data.decode())
if type_hint and typing.get_origin(type_hint) is list:
assert isinstance(data, list), "Expected list"
[type_hint] = typing.get_args(type_hint)
assert type_hint is not None, "Expected type hint"
assert issubclass(type_hint, BaseModel), "Expected BaseModel"
if temporalio.workflow.unsafe.in_sandbox():
type_hint = _unwrap_restricted_fields(type_hint)

return [self._from_dict(d, type_hint) for d in data]
return self._from_dict(data, type_hint)

def _from_dict(self, data: dict, type_hint: Optional[Type]) -> Any:
assert isinstance(data, dict), "Expected dict"
if type_hint and hasattr(type_hint, "validate"):
return type_hint.validate(data)

return data


class PydanticPayloadConverter(CompositePayloadConverter):
"""Payload converter that replaces Temporal JSON conversion with Pydantic
Expand All @@ -56,3 +84,13 @@ def __init__(self) -> None:
payload_converter_class=PydanticPayloadConverter
)
"""Data converter using Pydantic JSON conversion."""


def _unwrap_restricted_fields(
model: Type[BaseModel],
) -> Type[BaseModel]:
fields = {
name: (_unwrap_restricted_proxy(f.annotation), f)
for name, f in model.model_fields.items()
}
return create_model(model.__name__, **fields) # type: ignore

0 comments on commit 1e09d52

Please sign in to comment.