Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mypy warn_unreachable=True #3312

Merged
merged 10 commits into from
Jan 17, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pyro/ops/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

import math
from typing import Optional, Tuple
from typing import Optional, Tuple, Union

import torch
from torch.distributions.utils import lazy_property
Expand Down Expand Up @@ -111,7 +111,7 @@ def event_permute(self, perm) -> "Gaussian":
precision = self.precision[..., perm][..., perm, :]
return Gaussian(self.log_normalizer, info_vec, precision)

def __add__(self, other: "Gaussian") -> "Gaussian":
def __add__(self, other: Union["Gaussian", int, float, torch.Tensor]) -> "Gaussian":
"""
Adds two Gaussians in log-density space.
"""
Expand All @@ -126,7 +126,7 @@ def __add__(self, other: "Gaussian") -> "Gaussian":
return Gaussian(self.log_normalizer + other, self.info_vec, self.precision)
raise ValueError("Unsupported type: {}".format(type(other)))

def __sub__(self, other: "Gaussian") -> "Gaussian":
def __sub__(self, other: Union["Gaussian", int, float, torch.Tensor]) -> "Gaussian":
if isinstance(other, (int, float, torch.Tensor)):
return Gaussian(self.log_normalizer - other, self.info_vec, self.precision)
raise ValueError("Unsupported type: {}".format(type(other)))
Expand Down
24 changes: 14 additions & 10 deletions pyro/poutine/block_messenger.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
# SPDX-License-Identifier: Apache-2.0

from functools import partial
from typing import Callable, List, Optional
from typing import TYPE_CHECKING, Callable, List, Optional

from pyro.poutine.messenger import Messenger
from pyro.poutine.runtime import Message

if TYPE_CHECKING:
from pyro.poutine.runtime import Message


def _block_fn(
Expand All @@ -14,7 +16,7 @@ def _block_fn(
hide: List[str],
hide_types: List[str],
hide_all: bool,
msg: Message,
msg: "Message",
) -> bool:
# handle observes
if msg["type"] == "sample" and msg["is_observed"]:
Expand Down Expand Up @@ -43,7 +45,7 @@ def _make_default_hide_fn(
expose: Optional[List[str]],
hide_types: Optional[List[str]],
expose_types: Optional[List[str]],
) -> Callable[[Message], bool]:
) -> Callable[["Message"], bool]:
# first, some sanity checks:
# hide_all and expose_all intersect?
assert (hide_all is False and expose_all is False) or (
Expand Down Expand Up @@ -81,9 +83,11 @@ def _make_default_hide_fn(
return partial(_block_fn, expose, expose_types, hide, hide_types, hide_all)


def _negate_fn(fn: Callable[[Message], Optional[bool]]) -> Callable[[Message], bool]:
def _negate_fn(
fn: Callable[["Message"], Optional[bool]]
) -> Callable[["Message"], bool]:
# typed version of lambda msg: not fn(msg)
def negated_fn(msg: Message) -> bool:
def negated_fn(msg: "Message") -> bool:
return not fn(msg)

return negated_fn
Expand Down Expand Up @@ -140,15 +144,15 @@ class BlockMessenger(Messenger):

def __init__(
self,
hide_fn: Optional[Callable[[Message], Optional[bool]]] = None,
expose_fn: Optional[Callable[[Message], Optional[bool]]] = None,
hide_fn: Optional[Callable[["Message"], Optional[bool]]] = None,
expose_fn: Optional[Callable[["Message"], Optional[bool]]] = None,
hide_all: bool = True,
expose_all: bool = False,
hide: Optional[List[str]] = None,
expose: Optional[List[str]] = None,
hide_types: Optional[List[str]] = None,
expose_types: Optional[List[str]] = None,
):
) -> None:
super().__init__()
if not (hide_fn is None or expose_fn is None):
raise ValueError("Only specify one of hide_fn or expose_fn")
Expand All @@ -161,5 +165,5 @@ def __init__(
hide_all, expose_all, hide, expose, hide_types, expose_types
)

def _process_message(self, msg: Message) -> None:
def _process_message(self, msg: "Message") -> None:
msg["stop"] = bool(self.hide_fn(msg))
8 changes: 5 additions & 3 deletions pyro/poutine/broadcast_messenger.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

from typing import List, Optional
from typing import TYPE_CHECKING, List, Optional

from pyro.distributions.torch_distribution import TorchDistributionMixin
from pyro.poutine.messenger import Messenger
from pyro.poutine.runtime import Message
from pyro.util import ignore_jit_warnings

if TYPE_CHECKING:
from pyro.poutine.runtime import Message


class BroadcastMessenger(Messenger):
"""
Expand Down Expand Up @@ -41,7 +43,7 @@ class BroadcastMessenger(Messenger):

@staticmethod
@ignore_jit_warnings(["Converting a tensor to a Python boolean"])
def _pyro_sample(msg: Message) -> None:
def _pyro_sample(msg: "Message") -> None:
"""
:param msg: current message at a trace site.
"""
Expand Down
53 changes: 24 additions & 29 deletions pyro/poutine/collapse_messenger.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0


from functools import reduce, singledispatch
from typing import TYPE_CHECKING, FrozenSet, Tuple

from typing_extensions import Self

import pyro
from pyro.distributions.distribution import COERCIONS
from pyro.ops.linalg import ignore_torch_deprecation_warnings
from pyro.poutine.runtime import _PYRO_STACK
from pyro.poutine.trace_messenger import TraceMessenger
from pyro.poutine.util import site_is_subsample

from .runtime import _PYRO_STACK
from .trace_messenger import TraceMessenger

# TODO Remove import guard once funsor is a required dependency.
try:
import funsor
Expand All @@ -24,20 +27,10 @@
Funsor = type("Funsor", (), {})
Variable = type("Variable", (), {})

if TYPE_CHECKING:
from funsor.distribution import Distribution

@singledispatch
def _get_free_vars(x):
return x


@_get_free_vars.register(Variable)
def _(x):
return frozenset((x.name,))


@_get_free_vars.register(tuple)
def _(x, subs):
return frozenset().union(*map(_get_free_vars, x))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These seem to be obsolete

from pyro.poutine.runtime import Message


@singledispatch
Expand Down Expand Up @@ -92,7 +85,7 @@ class CollapseMessenger(TraceMessenger):

_coerce = None

def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
if CollapseMessenger._coerce is None:
import funsor
from funsor.distribution import CoerceDistributionToFunsor
Expand All @@ -102,18 +95,20 @@ def __init__(self, *args, **kwargs):
self._block = False
super().__init__(*args, **kwargs)

def _process_message(self, msg):
def _process_message(self, msg: "Message") -> None:
if self._block:
return
if site_is_subsample(msg):
return
super()._process_message(msg)

def _pyro_sample(self, msg):
def _pyro_sample(self, msg: "Message") -> None:
# Eagerly convert fn and value to Funsor.
dim_to_name = {f.dim: f.name for f in msg["cond_indep_stack"]}
dim_to_name.update(self.preserved_plates)
msg["fn"] = funsor.to_funsor(msg["fn"], funsor.Real, dim_to_name)
if TYPE_CHECKING:
assert isinstance(msg["fn"], Distribution)
domain = msg["fn"].inputs["value"]
if msg["value"] is None:
msg["value"] = funsor.Variable(msg["name"], domain)
Expand All @@ -123,14 +118,14 @@ def _pyro_sample(self, msg):
msg["done"] = True
msg["stop"] = True

def _pyro_post_sample(self, msg):
def _pyro_post_sample(self, msg: "Message") -> None:
if self._block:
return
if site_is_subsample(msg):
return
super()._pyro_post_sample(msg)

def _pyro_barrier(self, msg):
def _pyro_barrier(self, msg: "Message") -> None:
# Get log_prob and record factor.
name, log_prob, log_joint, sampled_vars = self._get_log_prob()
self._block = True
Expand All @@ -151,14 +146,14 @@ def _pyro_barrier(self, msg):
value = _substitute(value, samples)
msg["value"] = value

def __enter__(self):
def __enter__(self) -> Self:
self.preserved_plates = {
h.dim: h.name for h in _PYRO_STACK if isinstance(h, pyro.plate)
}
COERCIONS.append(self._coerce)
return super().__enter__()

def __exit__(self, *args):
def __exit__(self, *args) -> None:
_coerce = COERCIONS.pop()
assert _coerce is self._coerce
super().__exit__(*args)
Expand All @@ -168,20 +163,20 @@ def __exit__(self, *args):
pyro.factor(name, log_prob.data)

@ignore_torch_deprecation_warnings()
def _get_log_prob(self):
def _get_log_prob(self) -> Tuple[str, Funsor, Funsor, FrozenSet[str]]:
# Convert delayed statements to pyro.factor()
reduced_vars = []
reduced_vars_list = []
log_prob_terms = []
plates = frozenset()
plates: FrozenSet[str] = frozenset()
for name, site in self.trace.nodes.items():
if not site["is_observed"]:
reduced_vars.append(name)
reduced_vars_list.append(name)
log_prob_terms.append(site["fn"](value=site["value"]))
plates |= frozenset(
f.name for f in site["cond_indep_stack"] if f.vectorized
)
name = reduced_vars[0]
reduced_vars = frozenset(reduced_vars)
name = reduced_vars_list[0]
reduced_vars = frozenset(reduced_vars_list)
assert log_prob_terms, "nothing to collapse"
self.trace.nodes.clear()
reduced_plates = plates - frozenset(self.preserved_plates.values())
Expand Down
8 changes: 5 additions & 3 deletions pyro/poutine/condition_messenger.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

from typing import Dict, Union
from typing import TYPE_CHECKING, Dict, Union

import torch

from pyro.poutine.messenger import Messenger
from pyro.poutine.runtime import Message
from pyro.poutine.trace_struct import Trace

if TYPE_CHECKING:
from pyro.poutine.runtime import Message


class ConditionMessenger(Messenger):
"""
Expand Down Expand Up @@ -46,7 +48,7 @@ def __init__(self, data: Union[Dict[str, torch.Tensor], Trace]) -> None:
super().__init__()
self.data = data

def _pyro_sample(self, msg: Message) -> None:
def _pyro_sample(self, msg: "Message") -> None:
"""
:param msg: current message at a trace site.
:returns: a sample from the stochastic function at the site.
Expand Down
2 changes: 1 addition & 1 deletion pyro/poutine/do_messenger.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class DoMessenger(Messenger):
:returns: stochastic function decorated with a :class:`~pyro.poutine.do_messenger.DoMessenger`
"""

def __init__(self, data: Dict[str, Union[torch.Tensor, numbers.Number]]):
def __init__(self, data: Dict[str, Union[torch.Tensor, numbers.Number]]) -> None:
super().__init__()
self.data = data
self._intervener_id = str(id(self))
Expand Down
27 changes: 16 additions & 11 deletions pyro/poutine/guide.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,19 @@
# SPDX-License-Identifier: Apache-2.0

from abc import ABC, abstractmethod
from typing import Callable, Dict, Optional, Tuple, Union
from typing import TYPE_CHECKING, Callable, Dict, Optional, Tuple, Union

import torch

import pyro.distributions as dist
from pyro.distributions.torch_distribution import TorchDistributionMixin
from pyro.poutine.runtime import Message
from pyro.poutine.trace_messenger import TraceMessenger
from pyro.poutine.trace_struct import Trace
from pyro.poutine.util import prune_subsample_sites, site_is_subsample

if TYPE_CHECKING:
from pyro.distributions.torch_distribution import TorchDistributionMixin
from pyro.poutine.runtime import Message


class GuideMessenger(TraceMessenger, ABC):
"""
Expand Down Expand Up @@ -60,12 +62,13 @@ def __call__(self, *args, **kwargs) -> Dict[str, torch.Tensor]: # type: ignore[
samples[name] = site["value"]
return samples

def _pyro_sample(self, msg: Message) -> None:
def _pyro_sample(self, msg: "Message") -> None:
if msg["is_observed"] or site_is_subsample(msg):
return
assert isinstance(msg["name"], str)
assert isinstance(msg["fn"], TorchDistributionMixin)
assert msg["infer"] is not None
if TYPE_CHECKING:
assert isinstance(msg["name"], str)
assert isinstance(msg["fn"], TorchDistributionMixin)
assert msg["infer"] is not None
prior = msg["fn"]
msg["infer"]["prior"] = prior
posterior = self.get_posterior(msg["name"], prior)
Expand All @@ -75,20 +78,21 @@ def _pyro_sample(self, msg: Message) -> None:
posterior = posterior.expand(prior.batch_shape)
msg["fn"] = posterior

def _pyro_post_sample(self, msg: Message) -> None:
def _pyro_post_sample(self, msg: "Message") -> None:
# Manually apply outer plates.
assert msg["infer"] is not None
prior = msg["infer"].get("prior")
if prior is not None:
assert isinstance(msg["fn"], TorchDistributionMixin)
if TYPE_CHECKING:
assert isinstance(msg["fn"], TorchDistributionMixin)
if prior.batch_shape != msg["fn"].batch_shape:
msg["infer"]["prior"] = prior.expand(msg["fn"].batch_shape)
return super()._pyro_post_sample(msg)

@abstractmethod
def get_posterior(
self, name: str, prior: TorchDistributionMixin
) -> Union[TorchDistributionMixin, torch.Tensor]:
self, name: str, prior: "TorchDistributionMixin"
) -> Union["TorchDistributionMixin", torch.Tensor]:
"""
Abstract method to compute a posterior distribution or sample a
posterior value given a prior distribution conditioned on upstream
Expand Down Expand Up @@ -148,6 +152,7 @@ def get_traces(self) -> Tuple[Trace, Trace]:
del guide_trace.nodes[name]
continue
model_site = model_trace.nodes[name].copy()
assert guide_site["infer"] is not None
model_site["fn"] = guide_site["infer"]["prior"]
model_trace.nodes[name] = model_site
return model_trace, guide_trace
Loading
Loading