Skip to content

Commit 34d1ae2

Browse files
committed
Only use pydantic in sandbox if it can be imported
1 parent d3d9ed3 commit 34d1ae2

File tree

1 file changed

+17
-11
lines changed

1 file changed

+17
-11
lines changed

temporalio/worker/workflow_sandbox/_restrictions.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,11 @@
3232
cast,
3333
)
3434

35-
from pydantic import GetCoreSchemaHandler
36-
from pydantic_core import CoreSchema, core_schema
35+
try:
36+
import pydantic
37+
import pydantic_core
38+
except ImportError:
39+
pydantic = None # type: ignore
3740

3841
import temporalio.workflow
3942

@@ -986,7 +989,7 @@ def __init__(self, *args, **kwargs) -> None:
986989
_trace("__init__ unrecognized with args %s", args)
987990

988991
def __getattribute__(self, __name: str) -> Any:
989-
if __name == "__get_pydantic_core_schema__":
992+
if pydantic and __name == "__get_pydantic_core_schema__":
990993
return object.__getattribute__(self, "__get_pydantic_core_schema__")
991994
state = _RestrictionState.from_proxy(self)
992995
_trace("__getattribute__ %s on %s", __name, state.name)
@@ -1037,14 +1040,17 @@ def __getitem__(self, key: Any) -> Any:
10371040
)
10381041
return ret
10391042

1040-
# Instruct pydantic to use the proxied type when determining the schema
1041-
@classmethod
1042-
def __get_pydantic_core_schema__(
1043-
cls, source_type: Any, handler: GetCoreSchemaHandler
1044-
) -> CoreSchema:
1045-
return core_schema.no_info_after_validator_function(
1046-
cls, handler(RestrictionContext.unwrap_if_proxied(source_type))
1047-
)
1043+
if pydantic:
1044+
# Instruct pydantic to use the proxied type when determining the schema
1045+
@classmethod
1046+
def __get_pydantic_core_schema__(
1047+
cls,
1048+
source_type: Any,
1049+
handler: pydantic.GetCoreSchemaHandler, # type: ignore
1050+
) -> pydantic_core.CoreSchema:
1051+
return pydantic_core.core_schema.no_info_after_validator_function(
1052+
cls, handler(RestrictionContext.unwrap_if_proxied(source_type))
1053+
)
10481054

10491055
__doc__ = _RestrictedProxyLookup( # type: ignore
10501056
class_value=__doc__, fallback_func=lambda self: type(self).__doc__, is_attr=True

0 commit comments

Comments
 (0)