Skip to content

Commit

Permalink
Reorganize
Browse files Browse the repository at this point in the history
  • Loading branch information
dandavison committed Feb 10, 2025
1 parent 1770a75 commit 40914be
Show file tree
Hide file tree
Showing 4 changed files with 343 additions and 307 deletions.
11 changes: 11 additions & 0 deletions tests/contrib/pydantic/activities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from typing import List

from temporalio import activity
from tests.contrib.pydantic.models import PydanticModels


@activity.defn
async def pydantic_models_activity(
models: List[PydanticModels],
) -> List[PydanticModels]:
return models
323 changes: 16 additions & 307 deletions tests/contrib/test_pydantic.py → tests/contrib/pydantic/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,9 @@
)

from annotated_types import Len
from pydantic import BaseModel, Field, WithJsonSchema, create_model
from pydantic import BaseModel, Field, WithJsonSchema
from typing_extensions import TypedDict

from temporalio import activity, workflow
from temporalio.client import Client
from temporalio.contrib.pydantic import pydantic_data_converter
from temporalio.worker import Worker

SequenceType = TypeVar("SequenceType", bound=Sequence[Any])
ShortSequence = Annotated[SequenceType, Len(max_length=2)]

Expand Down Expand Up @@ -573,21 +568,6 @@ def make_pydantic_timedelta_object() -> PydanticTimedeltaModel:
)


PydanticModels = Union[
StandardTypesModel,
ComplexTypesModel,
SpecialTypesModel,
ParentModel,
FieldFeaturesModel,
AnnotatedFieldsModel,
GenericModel[Any],
UnionModel,
PydanticDatetimeModel,
PydanticDateModel,
PydanticTimedeltaModel,
]


def _assert_datetime_validity(dt: datetime):
assert isinstance(dt, datetime)
assert issubclass(dt.__class__, datetime)
Expand All @@ -603,6 +583,21 @@ def _assert_timedelta_validity(td: timedelta):
assert issubclass(td.__class__, timedelta)


PydanticModels = Union[
StandardTypesModel,
ComplexTypesModel,
SpecialTypesModel,
ParentModel,
FieldFeaturesModel,
AnnotatedFieldsModel,
GenericModel[Any],
UnionModel,
PydanticDatetimeModel,
PydanticDateModel,
PydanticTimedeltaModel,
]


def make_list_of_pydantic_objects() -> List[PydanticModels]:
objects = [
make_standard_types_object(),
Expand All @@ -622,126 +617,6 @@ def make_list_of_pydantic_objects() -> List[PydanticModels]:
return objects # type: ignore


@activity.defn
async def pydantic_models_activity(
models: List[PydanticModels],
) -> List[PydanticModels]:
return models


@workflow.defn
class InstantiateModelsWorkflow:
@workflow.run
async def run(self) -> None:
make_list_of_pydantic_objects()


@workflow.defn
class RoundTripObjectsWorkflow:
@workflow.run
async def run(self, objects: List[PydanticModels]) -> List[PydanticModels]:
return await workflow.execute_activity(
pydantic_models_activity,
objects,
start_to_close_timeout=timedelta(minutes=1),
)


def clone_objects(objects: List[PydanticModels]) -> List[PydanticModels]:
new_objects = []
for o in objects:
fields = {}
for name, f in o.model_fields.items():
fields[name] = (f.annotation, f)
model = create_model(o.__class__.__name__, **fields) # type: ignore
new_objects.append(model(**o.model_dump(by_alias=True)))
for old, new in zip(objects, new_objects):
assert old.model_dump() == new.model_dump()
return new_objects


@workflow.defn
class CloneObjectsWorkflow:
@workflow.run
async def run(self, objects: List[PydanticModels]) -> List[PydanticModels]:
return clone_objects(objects)


async def test_instantiation_outside_sandbox():
make_list_of_pydantic_objects()


async def test_instantiation_inside_sandbox(client: Client):
new_config = client.config()
new_config["data_converter"] = pydantic_data_converter
client = Client(**new_config)
task_queue_name = str(uuid.uuid4())

async with Worker(
client,
task_queue=task_queue_name,
workflows=[InstantiateModelsWorkflow],
):
await client.execute_workflow(
InstantiateModelsWorkflow.run,
id=str(uuid.uuid4()),
task_queue=task_queue_name,
)


async def test_round_trip_pydantic_objects(client: Client):
new_config = client.config()
new_config["data_converter"] = pydantic_data_converter
client = Client(**new_config)
task_queue_name = str(uuid.uuid4())

orig_objects = make_list_of_pydantic_objects()

async with Worker(
client,
task_queue=task_queue_name,
workflows=[RoundTripObjectsWorkflow],
activities=[pydantic_models_activity],
):
returned_objects = await client.execute_workflow(
RoundTripObjectsWorkflow.run,
orig_objects,
id=str(uuid.uuid4()),
task_queue=task_queue_name,
)
assert returned_objects == orig_objects
for o in returned_objects:
o._check_instance()


async def test_clone_objects_outside_sandbox():
clone_objects(make_list_of_pydantic_objects())


async def test_clone_objects_in_sandbox(client: Client):
new_config = client.config()
new_config["data_converter"] = pydantic_data_converter
client = Client(**new_config)
task_queue_name = str(uuid.uuid4())

orig_objects = make_list_of_pydantic_objects()

async with Worker(
client,
task_queue=task_queue_name,
workflows=[CloneObjectsWorkflow],
):
returned_objects = await client.execute_workflow(
CloneObjectsWorkflow.run,
orig_objects,
id=str(uuid.uuid4()),
task_queue=task_queue_name,
)
assert returned_objects == orig_objects
for o in returned_objects:
o._check_instance()


@dataclasses.dataclass(order=True)
class MyDataClass:
# The name int_field also occurs in StandardTypesModel and currently unions can match them up incorrectly.
Expand All @@ -753,170 +628,4 @@ def make_dataclass_objects() -> List[MyDataClass]:


ComplexCustomType = Tuple[List[MyDataClass], List[PydanticModels]]


@workflow.defn
class ComplexCustomTypeWorkflow:
@workflow.run
async def run(
self,
input: ComplexCustomType,
) -> ComplexCustomType:
data_classes, pydantic_objects = input
pydantic_objects = await workflow.execute_activity(
pydantic_models_activity,
pydantic_objects,
start_to_close_timeout=timedelta(minutes=1),
)
return data_classes, pydantic_objects


async def test_complex_custom_type(client: Client):
new_config = client.config()
new_config["data_converter"] = pydantic_data_converter
client = Client(**new_config)
task_queue_name = str(uuid.uuid4())

orig_dataclass_objects = make_dataclass_objects()
orig_pydantic_objects = make_list_of_pydantic_objects()

async with Worker(
client,
task_queue=task_queue_name,
workflows=[ComplexCustomTypeWorkflow],
activities=[pydantic_models_activity],
):
(
returned_dataclass_objects,
returned_pydantic_objects,
) = await client.execute_workflow(
ComplexCustomTypeWorkflow.run,
(orig_dataclass_objects, orig_pydantic_objects),
id=str(uuid.uuid4()),
task_queue=task_queue_name,
)
assert orig_dataclass_objects == returned_dataclass_objects
assert orig_pydantic_objects == returned_pydantic_objects
for o in returned_pydantic_objects:
o._check_instance()


ComplexCustomUnionType = List[Union[MyDataClass, PydanticModels]]


@workflow.defn
class ComplexCustomUnionTypeWorkflow:
@workflow.run
async def run(
self,
input: ComplexCustomUnionType,
) -> ComplexCustomUnionType:
data_classes = []
pydantic_objects: List[PydanticModels] = []
for o in input:
if dataclasses.is_dataclass(o):
data_classes.append(o)
elif isinstance(o, BaseModel):
pydantic_objects.append(o)
else:
raise TypeError(f"Unexpected type: {type(o)}")
pydantic_objects = await workflow.execute_activity(
pydantic_models_activity,
pydantic_objects,
start_to_close_timeout=timedelta(minutes=1),
)
return data_classes + pydantic_objects # type: ignore


async def test_complex_custom_union_type(client: Client):
new_config = client.config()
new_config["data_converter"] = pydantic_data_converter
client = Client(**new_config)
task_queue_name = str(uuid.uuid4())

orig_dataclass_objects = make_dataclass_objects()
orig_pydantic_objects = make_list_of_pydantic_objects()
orig_objects = orig_dataclass_objects + orig_pydantic_objects
import random

random.shuffle(orig_objects)

async with Worker(
client,
task_queue=task_queue_name,
workflows=[ComplexCustomUnionTypeWorkflow],
activities=[pydantic_models_activity],
):
returned_objects = await client.execute_workflow(
ComplexCustomUnionTypeWorkflow.run,
orig_objects,
id=str(uuid.uuid4()),
task_queue=task_queue_name,
)
returned_dataclass_objects, returned_pydantic_objects = [], []
for o in returned_objects:
if isinstance(o, MyDataClass):
returned_dataclass_objects.append(o)
elif isinstance(o, BaseModel):
returned_pydantic_objects.append(o)
else:
raise TypeError(f"Unexpected type: {type(o)}")
assert sorted(orig_dataclass_objects) == sorted(returned_dataclass_objects)
assert sorted(orig_pydantic_objects, key=lambda o: o.__class__.__name__) == sorted(
returned_pydantic_objects, key=lambda o: o.__class__.__name__
)
for o in returned_pydantic_objects:
o._check_instance()


@workflow.defn
class PydanticModelUsageWorkflow:
@workflow.run
async def run(self) -> None:
for o in make_list_of_pydantic_objects():
o._check_instance()


async def test_pydantic_model_usage_in_workflow(client: Client):
new_config = client.config()
new_config["data_converter"] = pydantic_data_converter
client = Client(**new_config)
task_queue_name = str(uuid.uuid4())

async with Worker(
client,
task_queue=task_queue_name,
workflows=[PydanticModelUsageWorkflow],
):
await client.execute_workflow(
PydanticModelUsageWorkflow.run,
id=str(uuid.uuid4()),
task_queue=task_queue_name,
)


@workflow.defn
class DatetimeUsageWorkflow:
@workflow.run
async def run(self) -> None:
dt = workflow.now()
assert isinstance(dt, datetime)
assert issubclass(dt.__class__, datetime)


async def test_datetime_usage_in_workflow(client: Client):
new_config = client.config()
new_config["data_converter"] = pydantic_data_converter
client = Client(**new_config)
task_queue_name = str(uuid.uuid4())

async with Worker(
client,
task_queue=task_queue_name,
workflows=[DatetimeUsageWorkflow],
):
await client.execute_workflow(
DatetimeUsageWorkflow.run,
id=str(uuid.uuid4()),
task_queue=task_queue_name,
)
Loading

0 comments on commit 40914be

Please sign in to comment.