diff --git a/pyro/params/param_store.py b/pyro/params/param_store.py index c191227f67..4e2ba9e74d 100644 --- a/pyro/params/param_store.py +++ b/pyro/params/param_store.py @@ -19,8 +19,12 @@ import torch from torch.distributions import constraints, transform_to from torch.serialization import MAP_LOCATION +from typing_extensions import TypedDict -from pyro.types import StateDict + +class StateDict(TypedDict): + params: Dict[str, torch.Tensor] + constraints: Dict[str, constraints.Constraint] class ParamStoreDict: diff --git a/pyro/poutine/runtime.py b/pyro/poutine/runtime.py index 64640f106a..2c51f8189f 100644 --- a/pyro/poutine/runtime.py +++ b/pyro/poutine/runtime.py @@ -14,9 +14,10 @@ ) if TYPE_CHECKING: + from typing_extensions import TypedDict + from pyro.poutine.indep_messenger import CondIndepStackFrame from pyro.poutine.messenger import Messenger - from pyro.types import Message # the global pyro stack _PYRO_STACK: List[Messenger] = [] @@ -25,6 +26,24 @@ _PYRO_PARAM_STORE = ParamStoreDict() +class Message(TypedDict, total=False): + type: Optional[str] + name: str + fn: Callable + is_observed: bool + args: Tuple + kwargs: Dict + value: Optional[torch.Tensor] + scale: float + mask: Union[bool, torch.Tensor, None] + cond_indep_stack: Tuple[CondIndepStackFrame, ...] + done: bool + stop: bool + continuation: Optional[Callable[[Message], None]] + infer: Optional[Dict[str, Union[str, bool]]] + obs: Optional[torch.Tensor] + + class _DimAllocator: """ Dimension allocator for internal use by :class:`plate`. diff --git a/pyro/types.py b/pyro/types.py index 3a62caf8e3..2ca6791f10 100644 --- a/pyro/types.py +++ b/pyro/types.py @@ -3,13 +3,14 @@ from __future__ import annotations -from typing import Callable, Dict, Optional, Tuple, Union +from typing import TYPE_CHECKING, Callable, Dict, Optional, Tuple, Union -import torch -from torch.distributions import constraints -from typing_extensions import TypedDict +if TYPE_CHECKING: + import torch + from torch.distributions import constraints + from typing_extensions import TypedDict -from pyro.poutine.indep_messenger import CondIndepStackFrame + from pyro.poutine.indep_messenger import CondIndepStackFrame class Message(TypedDict, total=False):