diff --git a/tests/contrib/pydantic/activities.py b/tests/contrib/pydantic/activities.py new file mode 100644 index 00000000..78020335 --- /dev/null +++ b/tests/contrib/pydantic/activities.py @@ -0,0 +1,11 @@ +from typing import List + +from temporalio import activity +from tests.contrib.pydantic.models import PydanticModels + + +@activity.defn +async def pydantic_models_activity( + models: List[PydanticModels], +) -> List[PydanticModels]: + return models diff --git a/tests/contrib/test_pydantic.py b/tests/contrib/pydantic/models.py similarity index 70% rename from tests/contrib/test_pydantic.py rename to tests/contrib/pydantic/models.py index e5766d0e..4d9d8ae8 100644 --- a/tests/contrib/test_pydantic.py +++ b/tests/contrib/pydantic/models.py @@ -26,14 +26,9 @@ ) from annotated_types import Len -from pydantic import BaseModel, Field, WithJsonSchema, create_model +from pydantic import BaseModel, Field, WithJsonSchema from typing_extensions import TypedDict -from temporalio import activity, workflow -from temporalio.client import Client -from temporalio.contrib.pydantic import pydantic_data_converter -from temporalio.worker import Worker - SequenceType = TypeVar("SequenceType", bound=Sequence[Any]) ShortSequence = Annotated[SequenceType, Len(max_length=2)] @@ -573,21 +568,6 @@ def make_pydantic_timedelta_object() -> PydanticTimedeltaModel: ) -PydanticModels = Union[ - StandardTypesModel, - ComplexTypesModel, - SpecialTypesModel, - ParentModel, - FieldFeaturesModel, - AnnotatedFieldsModel, - GenericModel[Any], - UnionModel, - PydanticDatetimeModel, - PydanticDateModel, - PydanticTimedeltaModel, -] - - def _assert_datetime_validity(dt: datetime): assert isinstance(dt, datetime) assert issubclass(dt.__class__, datetime) @@ -603,6 +583,21 @@ def _assert_timedelta_validity(td: timedelta): assert issubclass(td.__class__, timedelta) +PydanticModels = Union[ + StandardTypesModel, + ComplexTypesModel, + SpecialTypesModel, + ParentModel, + FieldFeaturesModel, + AnnotatedFieldsModel, + GenericModel[Any], + UnionModel, + PydanticDatetimeModel, + PydanticDateModel, + PydanticTimedeltaModel, +] + + def make_list_of_pydantic_objects() -> List[PydanticModels]: objects = [ make_standard_types_object(), @@ -622,126 +617,6 @@ def make_list_of_pydantic_objects() -> List[PydanticModels]: return objects # type: ignore -@activity.defn -async def pydantic_models_activity( - models: List[PydanticModels], -) -> List[PydanticModels]: - return models - - -@workflow.defn -class InstantiateModelsWorkflow: - @workflow.run - async def run(self) -> None: - make_list_of_pydantic_objects() - - -@workflow.defn -class RoundTripObjectsWorkflow: - @workflow.run - async def run(self, objects: List[PydanticModels]) -> List[PydanticModels]: - return await workflow.execute_activity( - pydantic_models_activity, - objects, - start_to_close_timeout=timedelta(minutes=1), - ) - - -def clone_objects(objects: List[PydanticModels]) -> List[PydanticModels]: - new_objects = [] - for o in objects: - fields = {} - for name, f in o.model_fields.items(): - fields[name] = (f.annotation, f) - model = create_model(o.__class__.__name__, **fields) # type: ignore - new_objects.append(model(**o.model_dump(by_alias=True))) - for old, new in zip(objects, new_objects): - assert old.model_dump() == new.model_dump() - return new_objects - - -@workflow.defn -class CloneObjectsWorkflow: - @workflow.run - async def run(self, objects: List[PydanticModels]) -> List[PydanticModels]: - return clone_objects(objects) - - -async def test_instantiation_outside_sandbox(): - make_list_of_pydantic_objects() - - -async def test_instantiation_inside_sandbox(client: Client): - new_config = client.config() - new_config["data_converter"] = pydantic_data_converter - client = Client(**new_config) - task_queue_name = str(uuid.uuid4()) - - async with Worker( - client, - task_queue=task_queue_name, - workflows=[InstantiateModelsWorkflow], - ): - await client.execute_workflow( - InstantiateModelsWorkflow.run, - id=str(uuid.uuid4()), - task_queue=task_queue_name, - ) - - -async def test_round_trip_pydantic_objects(client: Client): - new_config = client.config() - new_config["data_converter"] = pydantic_data_converter - client = Client(**new_config) - task_queue_name = str(uuid.uuid4()) - - orig_objects = make_list_of_pydantic_objects() - - async with Worker( - client, - task_queue=task_queue_name, - workflows=[RoundTripObjectsWorkflow], - activities=[pydantic_models_activity], - ): - returned_objects = await client.execute_workflow( - RoundTripObjectsWorkflow.run, - orig_objects, - id=str(uuid.uuid4()), - task_queue=task_queue_name, - ) - assert returned_objects == orig_objects - for o in returned_objects: - o._check_instance() - - -async def test_clone_objects_outside_sandbox(): - clone_objects(make_list_of_pydantic_objects()) - - -async def test_clone_objects_in_sandbox(client: Client): - new_config = client.config() - new_config["data_converter"] = pydantic_data_converter - client = Client(**new_config) - task_queue_name = str(uuid.uuid4()) - - orig_objects = make_list_of_pydantic_objects() - - async with Worker( - client, - task_queue=task_queue_name, - workflows=[CloneObjectsWorkflow], - ): - returned_objects = await client.execute_workflow( - CloneObjectsWorkflow.run, - orig_objects, - id=str(uuid.uuid4()), - task_queue=task_queue_name, - ) - assert returned_objects == orig_objects - for o in returned_objects: - o._check_instance() - - @dataclasses.dataclass(order=True) class MyDataClass: # The name int_field also occurs in StandardTypesModel and currently unions can match them up incorrectly. @@ -753,170 +628,4 @@ def make_dataclass_objects() -> List[MyDataClass]: ComplexCustomType = Tuple[List[MyDataClass], List[PydanticModels]] - - -@workflow.defn -class ComplexCustomTypeWorkflow: - @workflow.run - async def run( - self, - input: ComplexCustomType, - ) -> ComplexCustomType: - data_classes, pydantic_objects = input - pydantic_objects = await workflow.execute_activity( - pydantic_models_activity, - pydantic_objects, - start_to_close_timeout=timedelta(minutes=1), - ) - return data_classes, pydantic_objects - - -async def test_complex_custom_type(client: Client): - new_config = client.config() - new_config["data_converter"] = pydantic_data_converter - client = Client(**new_config) - task_queue_name = str(uuid.uuid4()) - - orig_dataclass_objects = make_dataclass_objects() - orig_pydantic_objects = make_list_of_pydantic_objects() - - async with Worker( - client, - task_queue=task_queue_name, - workflows=[ComplexCustomTypeWorkflow], - activities=[pydantic_models_activity], - ): - ( - returned_dataclass_objects, - returned_pydantic_objects, - ) = await client.execute_workflow( - ComplexCustomTypeWorkflow.run, - (orig_dataclass_objects, orig_pydantic_objects), - id=str(uuid.uuid4()), - task_queue=task_queue_name, - ) - assert orig_dataclass_objects == returned_dataclass_objects - assert orig_pydantic_objects == returned_pydantic_objects - for o in returned_pydantic_objects: - o._check_instance() - - ComplexCustomUnionType = List[Union[MyDataClass, PydanticModels]] - - -@workflow.defn -class ComplexCustomUnionTypeWorkflow: - @workflow.run - async def run( - self, - input: ComplexCustomUnionType, - ) -> ComplexCustomUnionType: - data_classes = [] - pydantic_objects: List[PydanticModels] = [] - for o in input: - if dataclasses.is_dataclass(o): - data_classes.append(o) - elif isinstance(o, BaseModel): - pydantic_objects.append(o) - else: - raise TypeError(f"Unexpected type: {type(o)}") - pydantic_objects = await workflow.execute_activity( - pydantic_models_activity, - pydantic_objects, - start_to_close_timeout=timedelta(minutes=1), - ) - return data_classes + pydantic_objects # type: ignore - - -async def test_complex_custom_union_type(client: Client): - new_config = client.config() - new_config["data_converter"] = pydantic_data_converter - client = Client(**new_config) - task_queue_name = str(uuid.uuid4()) - - orig_dataclass_objects = make_dataclass_objects() - orig_pydantic_objects = make_list_of_pydantic_objects() - orig_objects = orig_dataclass_objects + orig_pydantic_objects - import random - - random.shuffle(orig_objects) - - async with Worker( - client, - task_queue=task_queue_name, - workflows=[ComplexCustomUnionTypeWorkflow], - activities=[pydantic_models_activity], - ): - returned_objects = await client.execute_workflow( - ComplexCustomUnionTypeWorkflow.run, - orig_objects, - id=str(uuid.uuid4()), - task_queue=task_queue_name, - ) - returned_dataclass_objects, returned_pydantic_objects = [], [] - for o in returned_objects: - if isinstance(o, MyDataClass): - returned_dataclass_objects.append(o) - elif isinstance(o, BaseModel): - returned_pydantic_objects.append(o) - else: - raise TypeError(f"Unexpected type: {type(o)}") - assert sorted(orig_dataclass_objects) == sorted(returned_dataclass_objects) - assert sorted(orig_pydantic_objects, key=lambda o: o.__class__.__name__) == sorted( - returned_pydantic_objects, key=lambda o: o.__class__.__name__ - ) - for o in returned_pydantic_objects: - o._check_instance() - - -@workflow.defn -class PydanticModelUsageWorkflow: - @workflow.run - async def run(self) -> None: - for o in make_list_of_pydantic_objects(): - o._check_instance() - - -async def test_pydantic_model_usage_in_workflow(client: Client): - new_config = client.config() - new_config["data_converter"] = pydantic_data_converter - client = Client(**new_config) - task_queue_name = str(uuid.uuid4()) - - async with Worker( - client, - task_queue=task_queue_name, - workflows=[PydanticModelUsageWorkflow], - ): - await client.execute_workflow( - PydanticModelUsageWorkflow.run, - id=str(uuid.uuid4()), - task_queue=task_queue_name, - ) - - -@workflow.defn -class DatetimeUsageWorkflow: - @workflow.run - async def run(self) -> None: - dt = workflow.now() - assert isinstance(dt, datetime) - assert issubclass(dt.__class__, datetime) - - -async def test_datetime_usage_in_workflow(client: Client): - new_config = client.config() - new_config["data_converter"] = pydantic_data_converter - client = Client(**new_config) - task_queue_name = str(uuid.uuid4()) - - async with Worker( - client, - task_queue=task_queue_name, - workflows=[DatetimeUsageWorkflow], - ): - await client.execute_workflow( - DatetimeUsageWorkflow.run, - id=str(uuid.uuid4()), - task_queue=task_queue_name, - ) diff --git a/tests/contrib/pydantic/test_pydantic.py b/tests/contrib/pydantic/test_pydantic.py new file mode 100644 index 00000000..d1636707 --- /dev/null +++ b/tests/contrib/pydantic/test_pydantic.py @@ -0,0 +1,205 @@ +import dataclasses +import uuid + +from pydantic import BaseModel + +from temporalio.client import Client +from temporalio.contrib.pydantic import pydantic_data_converter +from temporalio.worker import Worker +from tests.contrib.pydantic.models import ( + make_dataclass_objects, + make_list_of_pydantic_objects, +) +from tests.contrib.pydantic.workflows import ( + CloneObjectsWorkflow, + ComplexCustomTypeWorkflow, + ComplexCustomUnionTypeWorkflow, + DatetimeUsageWorkflow, + InstantiateModelsWorkflow, + PydanticModelUsageWorkflow, + RoundTripObjectsWorkflow, + clone_objects, + pydantic_models_activity, +) + + +async def test_instantiation_outside_sandbox(): + make_list_of_pydantic_objects() + + +async def test_instantiation_inside_sandbox(client: Client): + new_config = client.config() + new_config["data_converter"] = pydantic_data_converter + client = Client(**new_config) + task_queue_name = str(uuid.uuid4()) + + async with Worker( + client, + task_queue=task_queue_name, + workflows=[InstantiateModelsWorkflow], + ): + await client.execute_workflow( + InstantiateModelsWorkflow.run, + id=str(uuid.uuid4()), + task_queue=task_queue_name, + ) + + +async def test_round_trip_pydantic_objects(client: Client): + new_config = client.config() + new_config["data_converter"] = pydantic_data_converter + client = Client(**new_config) + task_queue_name = str(uuid.uuid4()) + + orig_objects = make_list_of_pydantic_objects() + + async with Worker( + client, + task_queue=task_queue_name, + workflows=[RoundTripObjectsWorkflow], + activities=[pydantic_models_activity], + ): + returned_objects = await client.execute_workflow( + RoundTripObjectsWorkflow.run, + orig_objects, + id=str(uuid.uuid4()), + task_queue=task_queue_name, + ) + assert returned_objects == orig_objects + for o in returned_objects: + o._check_instance() + + +async def test_clone_objects_outside_sandbox(): + clone_objects(make_list_of_pydantic_objects()) + + +async def test_clone_objects_in_sandbox(client: Client): + new_config = client.config() + new_config["data_converter"] = pydantic_data_converter + client = Client(**new_config) + task_queue_name = str(uuid.uuid4()) + + orig_objects = make_list_of_pydantic_objects() + + async with Worker( + client, + task_queue=task_queue_name, + workflows=[CloneObjectsWorkflow], + ): + returned_objects = await client.execute_workflow( + CloneObjectsWorkflow.run, + orig_objects, + id=str(uuid.uuid4()), + task_queue=task_queue_name, + ) + assert returned_objects == orig_objects + for o in returned_objects: + o._check_instance() + + +async def test_complex_custom_type(client: Client): + new_config = client.config() + new_config["data_converter"] = pydantic_data_converter + client = Client(**new_config) + task_queue_name = str(uuid.uuid4()) + + orig_dataclass_objects = make_dataclass_objects() + orig_pydantic_objects = make_list_of_pydantic_objects() + + async with Worker( + client, + task_queue=task_queue_name, + workflows=[ComplexCustomTypeWorkflow], + activities=[pydantic_models_activity], + ): + ( + returned_dataclass_objects, + returned_pydantic_objects, + ) = await client.execute_workflow( + ComplexCustomTypeWorkflow.run, + (orig_dataclass_objects, orig_pydantic_objects), + id=str(uuid.uuid4()), + task_queue=task_queue_name, + ) + assert orig_dataclass_objects == returned_dataclass_objects + assert orig_pydantic_objects == returned_pydantic_objects + for o in returned_pydantic_objects: + o._check_instance() + + +async def test_complex_custom_union_type(client: Client): + new_config = client.config() + new_config["data_converter"] = pydantic_data_converter + client = Client(**new_config) + task_queue_name = str(uuid.uuid4()) + + orig_dataclass_objects = make_dataclass_objects() + orig_pydantic_objects = make_list_of_pydantic_objects() + orig_objects = orig_dataclass_objects + orig_pydantic_objects + import random + + random.shuffle(orig_objects) + + async with Worker( + client, + task_queue=task_queue_name, + workflows=[ComplexCustomUnionTypeWorkflow], + activities=[pydantic_models_activity], + ): + returned_objects = await client.execute_workflow( + ComplexCustomUnionTypeWorkflow.run, + orig_objects, + id=str(uuid.uuid4()), + task_queue=task_queue_name, + ) + returned_dataclass_objects, returned_pydantic_objects = [], [] + for o in returned_objects: + if dataclasses.is_dataclass(o): + returned_dataclass_objects.append(o) + elif isinstance(o, BaseModel): + returned_pydantic_objects.append(o) + else: + raise TypeError(f"Unexpected type: {type(o)}") + assert sorted(orig_dataclass_objects) == sorted(returned_dataclass_objects) + assert sorted(orig_pydantic_objects, key=lambda o: o.__class__.__name__) == sorted( + returned_pydantic_objects, key=lambda o: o.__class__.__name__ + ) + for o in returned_pydantic_objects: + o._check_instance() + + +async def test_pydantic_model_usage_in_workflow(client: Client): + new_config = client.config() + new_config["data_converter"] = pydantic_data_converter + client = Client(**new_config) + task_queue_name = str(uuid.uuid4()) + + async with Worker( + client, + task_queue=task_queue_name, + workflows=[PydanticModelUsageWorkflow], + ): + await client.execute_workflow( + PydanticModelUsageWorkflow.run, + id=str(uuid.uuid4()), + task_queue=task_queue_name, + ) + + +async def test_datetime_usage_in_workflow(client: Client): + new_config = client.config() + new_config["data_converter"] = pydantic_data_converter + client = Client(**new_config) + task_queue_name = str(uuid.uuid4()) + + async with Worker( + client, + task_queue=task_queue_name, + workflows=[DatetimeUsageWorkflow], + ): + await client.execute_workflow( + DatetimeUsageWorkflow.run, + id=str(uuid.uuid4()), + task_queue=task_queue_name, + ) diff --git a/tests/contrib/pydantic/workflows.py b/tests/contrib/pydantic/workflows.py new file mode 100644 index 00000000..4a2a99cb --- /dev/null +++ b/tests/contrib/pydantic/workflows.py @@ -0,0 +1,111 @@ +import dataclasses +from datetime import datetime, timedelta +from typing import List + +from pydantic import BaseModel, create_model + +from temporalio import workflow + +with workflow.unsafe.imports_passed_through(): + from tests.contrib.pydantic.activities import pydantic_models_activity + from tests.contrib.pydantic.models import ( + ComplexCustomType, + ComplexCustomUnionType, + PydanticModels, + make_list_of_pydantic_objects, + ) + + +def clone_objects(objects: List[PydanticModels]) -> List[PydanticModels]: + new_objects = [] + for o in objects: + fields = {} + for name, f in o.model_fields.items(): + fields[name] = (f.annotation, f) + model = create_model(o.__class__.__name__, **fields) # type: ignore + new_objects.append(model(**o.model_dump(by_alias=True))) + for old, new in zip(objects, new_objects): + assert old.model_dump() == new.model_dump() + return new_objects + + +@workflow.defn +class InstantiateModelsWorkflow: + @workflow.run + async def run(self) -> None: + make_list_of_pydantic_objects() + + +@workflow.defn +class RoundTripObjectsWorkflow: + @workflow.run + async def run(self, objects: List[PydanticModels]) -> List[PydanticModels]: + return await workflow.execute_activity( + pydantic_models_activity, + objects, + start_to_close_timeout=timedelta(minutes=1), + ) + + +@workflow.defn +class CloneObjectsWorkflow: + @workflow.run + async def run(self, objects: List[PydanticModels]) -> List[PydanticModels]: + return clone_objects(objects) + + +@workflow.defn +class ComplexCustomUnionTypeWorkflow: + @workflow.run + async def run( + self, + input: ComplexCustomUnionType, + ) -> ComplexCustomUnionType: + data_classes = [] + pydantic_objects: List[PydanticModels] = [] + for o in input: + if dataclasses.is_dataclass(o): + data_classes.append(o) + elif isinstance(o, BaseModel): + pydantic_objects.append(o) + else: + raise TypeError(f"Unexpected type: {type(o)}") + pydantic_objects = await workflow.execute_activity( + pydantic_models_activity, + pydantic_objects, + start_to_close_timeout=timedelta(minutes=1), + ) + return data_classes + pydantic_objects # type: ignore + + +@workflow.defn +class ComplexCustomTypeWorkflow: + @workflow.run + async def run( + self, + input: ComplexCustomType, + ) -> ComplexCustomType: + data_classes, pydantic_objects = input + pydantic_objects = await workflow.execute_activity( + pydantic_models_activity, + pydantic_objects, + start_to_close_timeout=timedelta(minutes=1), + ) + return data_classes, pydantic_objects + + +@workflow.defn +class PydanticModelUsageWorkflow: + @workflow.run + async def run(self) -> None: + for o in make_list_of_pydantic_objects(): + o._check_instance() + + +@workflow.defn +class DatetimeUsageWorkflow: + @workflow.run + async def run(self) -> None: + dt = workflow.now() + assert isinstance(dt, datetime) + assert issubclass(dt.__class__, datetime)