Skip to content

Commit

Permalink
rm types.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ordabayevy committed Oct 30, 2023
1 parent 13690ce commit 7b958f1
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 7 deletions.
6 changes: 5 additions & 1 deletion pyro/params/param_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,12 @@
import torch
from torch.distributions import constraints, transform_to
from torch.serialization import MAP_LOCATION
from typing_extensions import TypedDict

from pyro.types import StateDict

class StateDict(TypedDict):
params: Dict[str, torch.Tensor]
constraints: Dict[str, constraints.Constraint]


class ParamStoreDict:
Expand Down
21 changes: 20 additions & 1 deletion pyro/poutine/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
)

if TYPE_CHECKING:
from typing_extensions import TypedDict

from pyro.poutine.indep_messenger import CondIndepStackFrame
from pyro.poutine.messenger import Messenger
from pyro.types import Message

# the global pyro stack
_PYRO_STACK: List[Messenger] = []
Expand All @@ -25,6 +26,24 @@
_PYRO_PARAM_STORE = ParamStoreDict()


class Message(TypedDict, total=False):
type: Optional[str]
name: str
fn: Callable
is_observed: bool
args: Tuple
kwargs: Dict
value: Optional[torch.Tensor]
scale: float
mask: Union[bool, torch.Tensor, None]
cond_indep_stack: Tuple[CondIndepStackFrame, ...]
done: bool
stop: bool
continuation: Optional[Callable[[Message], None]]
infer: Optional[Dict[str, Union[str, bool]]]
obs: Optional[torch.Tensor]


class _DimAllocator:
"""
Dimension allocator for internal use by :class:`plate`.
Expand Down
11 changes: 6 additions & 5 deletions pyro/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@

from __future__ import annotations

from typing import Callable, Dict, Optional, Tuple, Union
from typing import TYPE_CHECKING, Callable, Dict, Optional, Tuple, Union

import torch
from torch.distributions import constraints
from typing_extensions import TypedDict
if TYPE_CHECKING:
import torch
from torch.distributions import constraints
from typing_extensions import TypedDict

from pyro.poutine.indep_messenger import CondIndepStackFrame
from pyro.poutine.indep_messenger import CondIndepStackFrame


class Message(TypedDict, total=False):
Expand Down

0 comments on commit 7b958f1

Please sign in to comment.