Skip to content

Commit

Permalink
Type annotate messengers (#3309)
Browse files Browse the repository at this point in the history
  • Loading branch information
ordabayevy authored Jan 6, 2024
1 parent f4a6168 commit 670e9cb
Show file tree
Hide file tree
Showing 8 changed files with 125 additions and 80 deletions.
3 changes: 2 additions & 1 deletion pyro/distributions/torch_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import warnings
from collections import OrderedDict
from typing import Callable

import torch
from torch.distributions.kl import kl_divergence, register_kl
Expand All @@ -15,7 +16,7 @@
from .util import broadcast_shape, scale_and_mask


class TorchDistributionMixin(Distribution):
class TorchDistributionMixin(Distribution, Callable):
"""
Mixin to provide Pyro compatibility for PyTorch distributions.
Expand Down
29 changes: 10 additions & 19 deletions pyro/infer/reparam/reparam.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,20 @@
from typing import Callable, Optional

import torch
from typing_extensions import TypedDict

try:
from typing import TypedDict
except ImportError:

def TypedDict(*args, **kwargs):
return dict
class ReparamMessage(TypedDict):
name: str
fn: Callable
value: Optional[torch.Tensor]
is_observed: Optional[bool]


ReparamMessage = TypedDict(
"ReparamMessage",
name=str,
fn=Callable,
value=Optional[torch.Tensor],
is_observed=Optional[bool],
)

ReparamResult = TypedDict(
"ReparamResult",
fn=Callable,
value=Optional[torch.Tensor],
is_observed=Optional[bool],
)
class ReparamResult(TypedDict):
fn: Callable
value: Optional[torch.Tensor]
is_observed: bool


class Reparam(ABC):
Expand Down
25 changes: 17 additions & 8 deletions pyro/poutine/messenger.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,22 @@
from contextlib import contextmanager
from functools import partial
from types import TracebackType
from typing import Any, Callable, Iterator, List, Optional, Type, TypeVar, cast
from typing import (
Any,
Callable,
Iterator,
List,
Optional,
Type,
TypeVar,
)

from typing_extensions import Self
from typing_extensions import ParamSpec, Self

from .runtime import _PYRO_STACK, Message
from pyro.poutine.runtime import _PYRO_STACK, Message

_F = TypeVar("_F", bound=Callable)
_P = ParamSpec("_P")
_T = TypeVar("_T")


def _context_wrap(
Expand Down Expand Up @@ -76,13 +85,13 @@ class Messenger:
Most inference operations are implemented in subclasses of this.
"""

def __call__(self, fn: _F) -> _F:
def __call__(self, fn: Callable[_P, _T]) -> Callable[_P, _T]:
if not callable(fn):
raise ValueError(
f"{fn!r} is not callable, did you mean to pass it as a keyword arg?"
)
wraps = _bound_partial(partial(_context_wrap, self, fn))
return cast(_F, wraps)
return wraps

def __enter__(self) -> Self:
"""
Expand Down Expand Up @@ -118,8 +127,8 @@ def __enter__(self) -> Self:

def __exit__(
self,
exc_type: Optional[Type[Exception]],
exc_value: Optional[Exception],
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
"""
Expand Down
25 changes: 17 additions & 8 deletions pyro/poutine/plate_messenger.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,16 @@
# SPDX-License-Identifier: Apache-2.0

from contextlib import contextmanager
from typing import TYPE_CHECKING, Iterator, Optional

from .broadcast_messenger import BroadcastMessenger
from .messenger import block_messengers
from .subsample_messenger import SubsampleMessenger
from pyro.poutine.broadcast_messenger import BroadcastMessenger
from pyro.poutine.messenger import Messenger, block_messengers
from pyro.poutine.subsample_messenger import SubsampleMessenger

if TYPE_CHECKING:
import torch

from pyro.poutine.runtime import Message


class PlateMessenger(SubsampleMessenger):
Expand All @@ -14,19 +20,21 @@ class PlateMessenger(SubsampleMessenger):
combines shape inference, independence annotation, and subsampling
"""

def _process_message(self, msg):
def _process_message(self, msg: "Message") -> None:
super()._process_message(msg)
return BroadcastMessenger._pyro_sample(msg)
BroadcastMessenger._pyro_sample(msg)

def __enter__(self):
def __enter__(self) -> Optional["torch.Tensor"]: # type: ignore[override]
super().__enter__()
if self._vectorized and self._indices is not None:
return self.indices
return None


@contextmanager
def block_plate(name=None, dim=None, *, strict=True):
def block_plate(
name: Optional[str] = None, dim: Optional[int] = None, *, strict: bool = True
) -> Iterator[None]:
"""
EXPERIMENTAL Context manager to temporarily block a single enclosing plate.
Expand Down Expand Up @@ -63,13 +71,14 @@ def model_2(data):
assert isinstance(dim, int)
assert dim < 0

def predicate(messenger):
def predicate(messenger: Messenger) -> bool:
if not isinstance(messenger, PlateMessenger):
return False
if name is not None:
return messenger.name == name
if dim is not None:
return messenger.dim == dim
raise ValueError("Unreachable")

with block_messengers(predicate) as matches:
if strict and len(matches) != 1:
Expand Down
22 changes: 17 additions & 5 deletions pyro/poutine/reentrant_messenger.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,37 @@
# SPDX-License-Identifier: Apache-2.0

import functools
from types import TracebackType
from typing import Callable, Optional, Type, TypeVar

from .messenger import Messenger
from typing_extensions import ParamSpec, Self

from pyro.poutine.messenger import Messenger

_P = ParamSpec("_P")
_T = TypeVar("_T")


class ReentrantMessenger(Messenger):
def __init__(self):
def __init__(self) -> None:
self._ref_count = 0
super().__init__()

def __call__(self, fn):
def __call__(self, fn: Callable[_P, _T]) -> Callable[_P, _T]:
return functools.wraps(fn)(super().__call__(fn))

def __enter__(self):
def __enter__(self) -> Self:
self._ref_count += 1
if self._ref_count == 1:
super().__enter__()
return self

def __exit__(self, exc_type, exc_value, traceback):
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
self._ref_count -= 1
if self._ref_count == 0:
super().__exit__(exc_type, exc_value, traceback)
54 changes: 39 additions & 15 deletions pyro/poutine/reparam_messenger.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,33 @@
# SPDX-License-Identifier: Apache-2.0

import warnings
from typing import Callable, Dict, Union
from typing import (
TYPE_CHECKING,
Callable,
Dict,
Generic,
List,
Optional,
TypeVar,
Union,
)

import torch
from typing_extensions import ParamSpec

from .messenger import Messenger
from .runtime import effectful
from pyro.distributions.torch_distribution import TorchDistributionMixin
from pyro.poutine.messenger import Messenger
from pyro.poutine.runtime import Message, effectful

if TYPE_CHECKING:
from pyro.infer.reparam.reparam import Reparam

_P = ParamSpec("_P")
_T = TypeVar("_T")


@effectful(type="get_init_messengers")
def _get_init_messengers():
def _get_init_messengers() -> List[Messenger]:
return []


Expand All @@ -34,24 +51,29 @@ class ReparamMessenger(Messenger):
:param config: Configuration, either a dict mapping site name to
:class:`~pyro.infer.reparam.reparam.Reparameterizer` , or a function
mapping site to :class:`~pyro.infer.reparam.reparam.Reparameterizer` or
mapping site to :class:`~pyro.infer.reparam.reparam.Reparam` or
None. See :mod:`pyro.infer.reparam.strategies` for built-in
configuration strategies.
:type config: dict or callable
"""

def __init__(self, config: Union[Dict[str, object], Callable]):
def __init__(
self,
config: Union[Dict[str, "Reparam"], Callable[[Message], Optional["Reparam"]]],
) -> None:
super().__init__()
assert isinstance(config, dict) or callable(config)
self.config = config
self._args_kwargs = None

def __call__(self, fn):
def __call__(self, fn: Callable[_P, _T]) -> Callable[_P, _T]:
return ReparamHandler(self, fn)

def _pyro_sample(self, msg):
def _pyro_sample(self, msg: Message) -> None:
if type(msg["fn"]).__name__ == "_Subsample":
return
assert msg["name"] is not None
assert isinstance(msg["fn"], TorchDistributionMixin)
if isinstance(self.config, dict):
reparam = self.config.get(msg["name"])
else:
Expand Down Expand Up @@ -79,11 +101,13 @@ def _pyro_sample(self, msg):
# ReplayMessenger we would need to ensure those messengers can
# similarly be safely applied twice, with the second application
# avoiding overwriting the original application.
for m in _get_init_messengers():
m._pyro_sample(msg)
_get_init_messengers_iter = _get_init_messengers()
assert _get_init_messengers_iter is not None
for m in _get_init_messengers_iter:
m._process_message(msg)

# Pass args_kwargs to the reparam via a side channel.
reparam.args_kwargs = self._args_kwargs
reparam.args_kwargs = self._args_kwargs # type: ignore[attr-defined]
try:
new_msg = reparam.apply(
{
Expand All @@ -94,7 +118,7 @@ def _pyro_sample(self, msg):
}
)
finally:
reparam.args_kwargs = None
reparam.args_kwargs = None # type: ignore[attr-defined]

if new_msg["value"] is not None:
# Validate while the original msg["fn"] is known.
Expand All @@ -121,17 +145,17 @@ def _pyro_sample(self, msg):
msg["is_observed"] = new_msg["is_observed"]


class ReparamHandler(object):
class ReparamHandler(Generic[_P, _T]):
"""
Reparameterization poutine.
"""

def __init__(self, msngr, fn):
def __init__(self, msngr, fn: Callable[_P, _T]) -> None:
self.msngr = msngr
self.fn = fn
super().__init__()

def __call__(self, *args, **kwargs):
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T:
# This saves args,kwargs for optional use by reparameterizers.
self.msngr._args_kwargs = args, kwargs
try:
Expand Down
Loading

0 comments on commit 670e9cb

Please sign in to comment.