Skip to content

Commit

Permalink
Type annotate pyro.primitives & poutine.block_messenger (#3292)
Browse files Browse the repository at this point in the history
  • Loading branch information
ordabayevy authored Nov 13, 2023
1 parent 71e0f9c commit e274bca
Show file tree
Hide file tree
Showing 4 changed files with 186 additions and 97 deletions.
4 changes: 2 additions & 2 deletions pyro/distributions/torch_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class TorchDistributionMixin(Distribution):
from :class:`TorchDistributionMixin`.
"""

def __call__(self, sample_shape=torch.Size()):
def __call__(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor:
"""
Samples a random value.
Expand All @@ -51,7 +51,7 @@ def __call__(self, sample_shape=torch.Size()):
)

@property
def event_dim(self):
def event_dim(self) -> int:
"""
:return: Number of dimensions of individual events.
:rtype: int
Expand Down
49 changes: 36 additions & 13 deletions pyro/poutine/block_messenger.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,20 @@
# SPDX-License-Identifier: Apache-2.0

from functools import partial
from typing import Callable, List, Optional

from pyro.poutine.messenger import Messenger
from pyro.poutine.runtime import Message


def _block_fn(expose, expose_types, hide, hide_types, hide_all, msg):
def _block_fn(
expose: List[str],
expose_types: List[str],
hide: List[str],
hide_types: List[str],
hide_all: bool,
msg: Message,
) -> bool:
# handle observes
if msg["type"] == "sample" and msg["is_observed"]:
msg_type = "observe"
Expand All @@ -27,7 +36,14 @@ def _block_fn(expose, expose_types, hide, hide_types, hide_all, msg):
return False


def _make_default_hide_fn(hide_all, expose_all, hide, expose, hide_types, expose_types):
def _make_default_hide_fn(
hide_all: bool,
expose_all: bool,
hide: Optional[List[str]],
expose: Optional[List[str]],
hide_types: Optional[List[str]],
expose_types: Optional[List[str]],
) -> Callable[[Message], bool]:
# first, some sanity checks:
# hide_all and expose_all intersect?
assert (hide_all is False and expose_all is False) or (
Expand Down Expand Up @@ -65,6 +81,14 @@ def _make_default_hide_fn(hide_all, expose_all, hide, expose, hide_types, expose
return partial(_block_fn, expose, expose_types, hide, hide_types, hide_all)


def _negate_fn(fn: Callable[[Message], Optional[bool]]) -> Callable[[Message], bool]:
# typed version of lambda msg: not fn(msg)
def negated_fn(msg: Message) -> bool:
return not fn(msg)

return negated_fn


class BlockMessenger(Messenger):
"""
This handler selectively hides Pyro primitive sites from the outside world.
Expand Down Expand Up @@ -116,27 +140,26 @@ class BlockMessenger(Messenger):

def __init__(
self,
hide_fn=None,
expose_fn=None,
hide_all=True,
expose_all=False,
hide=None,
expose=None,
hide_types=None,
expose_types=None,
hide_fn: Optional[Callable[[Message], Optional[bool]]] = None,
expose_fn: Optional[Callable[[Message], Optional[bool]]] = None,
hide_all: bool = True,
expose_all: bool = False,
hide: Optional[List[str]] = None,
expose: Optional[List[str]] = None,
hide_types: Optional[List[str]] = None,
expose_types: Optional[List[str]] = None,
):
super().__init__()
if not (hide_fn is None or expose_fn is None):
raise ValueError("Only specify one of hide_fn or expose_fn")
if hide_fn is not None:
self.hide_fn = hide_fn
elif expose_fn is not None:
self.hide_fn = lambda msg: not expose_fn(msg)
self.hide_fn = _negate_fn(expose_fn)
else:
self.hide_fn = _make_default_hide_fn(
hide_all, expose_all, hide, expose, hide_types, expose_types
)

def _process_message(self, msg):
def _process_message(self, msg: Message) -> None:
msg["stop"] = bool(self.hide_fn(msg))
return None
120 changes: 76 additions & 44 deletions pyro/poutine/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,30 @@
from __future__ import annotations

import functools
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Set, Tuple, Union
from typing import (
TYPE_CHECKING,
Callable,
Dict,
List,
Optional,
Set,
Tuple,
TypeVar,
Union,
overload,
)

import torch
from typing_extensions import TypedDict
from typing_extensions import ParamSpec, TypedDict

from pyro.params.param_store import ( # noqa: F401
_MODULE_NAMESPACE_DIVIDER,
ParamStoreDict,
)

P = ParamSpec("P")
T = TypeVar("T")

if TYPE_CHECKING:
from pyro.poutine.indep_messenger import CondIndepStackFrame
from pyro.poutine.messenger import Messenger
Expand All @@ -26,8 +40,8 @@


class Message(TypedDict, total=False):
type: Optional[str]
name: str
type: str
name: Optional[str]
fn: Callable
is_observed: bool
args: Tuple
Expand Down Expand Up @@ -252,15 +266,31 @@ def apply_stack(initial_msg: Message) -> None:
cont(msg)


def am_i_wrapped():
def am_i_wrapped() -> bool:
"""
Checks whether the current computation is wrapped in a poutine.
:returns: bool
"""
return len(_PYRO_STACK) > 0


def effectful(fn: Optional[Callable] = None, type: Optional[str] = None) -> Callable:
@overload
def effectful(
fn: None = ..., type: Optional[str] = ...
) -> Callable[[Callable[P, T]], Callable[..., Union[T, torch.Tensor, None]]]:
...


@overload
def effectful(
fn: Callable[P, T] = ..., type: Optional[str] = ...
) -> Callable[..., Union[T, torch.Tensor, None]]:
...


def effectful(
fn: Optional[Callable[P, T]] = None, type: Optional[str] = None
) -> Callable:
"""
:param fn: function or callable that performs an effectful computation
:param str type: the type label of the operation, e.g. `"sample"`
Expand All @@ -277,32 +307,34 @@ def effectful(fn: Optional[Callable] = None, type: Optional[str] = None) -> Call
assert type != "message", "cannot use 'message' as keyword"

@functools.wraps(fn)
def _fn(*args, **kwargs):
name = kwargs.pop("name", None)
infer = kwargs.pop("infer", {})

value = kwargs.pop("obs", None)
is_observed = value is not None
def _fn(
*args: P.args,
name: Optional[str] = None,
infer: Optional[Dict] = None,
obs: Optional[torch.Tensor] = None,
**kwargs: P.kwargs,
) -> Union[T, torch.Tensor, None]:
is_observed = obs is not None

if not am_i_wrapped():
return fn(*args, **kwargs)
else:
msg = {
"type": type,
"name": name,
"fn": fn,
"is_observed": is_observed,
"args": args,
"kwargs": kwargs,
"value": value,
"scale": 1.0,
"mask": None,
"cond_indep_stack": (),
"done": False,
"stop": False,
"continuation": None,
"infer": infer,
}
msg = Message(
type=type,
name=name,
fn=fn,
is_observed=is_observed,
args=args,
kwargs=kwargs,
value=obs,
scale=1.0,
mask=None,
cond_indep_stack=(),
done=False,
stop=False,
continuation=None,
infer=infer if infer is not None else {},
)
# apply the stack and return its return value
apply_stack(msg)
return msg["value"]
Expand All @@ -321,22 +353,22 @@ def _inspect() -> Message:
:returns: A message with all effects applied.
:rtype: dict
"""
msg: Message = {
"type": "inspect",
"name": "_pyro_inspect",
"fn": lambda: True,
"is_observed": False,
"args": (),
"kwargs": {},
"value": None,
"infer": {"_do_not_trace": True},
"scale": 1.0,
"mask": None,
"cond_indep_stack": (),
"done": False,
"stop": False,
"continuation": None,
}
msg = Message(
type="inspect",
name="_pyro_inspect",
fn=lambda: True,
is_observed=False,
args=(),
kwargs={},
value=None,
infer={"_do_not_trace": True},
scale=1.0,
mask=None,
cond_indep_stack=(),
done=False,
stop=False,
continuation=None,
)
apply_stack(msg)
return msg

Expand Down
Loading

0 comments on commit e274bca

Please sign in to comment.