diff --git a/pyro/distributions/score_parts.py b/pyro/distributions/score_parts.py index 15d39156d7..7816c45012 100644 --- a/pyro/distributions/score_parts.py +++ b/pyro/distributions/score_parts.py @@ -1,20 +1,28 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -from collections import namedtuple +from typing import NamedTuple, Optional, Union + +import torch from pyro.distributions.util import scale_and_mask -class ScoreParts( - namedtuple("ScoreParts", ["log_prob", "score_function", "entropy_term"]) -): +class ScoreParts(NamedTuple): """ This data structure stores terms used in stochastic gradient estimators that combine the pathwise estimator and the score function estimator. """ - def scale_and_mask(self, scale=1.0, mask=None): + log_prob: torch.Tensor + score_function: torch.Tensor + entropy_term: torch.Tensor + + def scale_and_mask( + self, + scale: Union[float, torch.Tensor] = 1.0, + mask: Optional[torch.BoolTensor] = None, + ) -> "ScoreParts": """ Scale and mask appropriate terms of a gradient estimator by a data multiplicity factor. Note that the `score_function` term should not be scaled or masked. diff --git a/pyro/distributions/torch_distribution.py b/pyro/distributions/torch_distribution.py index 8820b002ee..379dd8f01a 100644 --- a/pyro/distributions/torch_distribution.py +++ b/pyro/distributions/torch_distribution.py @@ -118,7 +118,7 @@ def infer_shapes(cls, **arg_shapes): event_shape = torch.Size() return batch_shape, event_shape - def expand(self, batch_shape, _instance=None): + def expand(self, batch_shape, _instance=None) -> "ExpandedDistribution": """ Returns a new :class:`ExpandedDistribution` instance with batch dimensions expanded to `batch_shape`. diff --git a/pyro/poutine/guide.py b/pyro/poutine/guide.py index 22889caa3f..c50d51bb77 100644 --- a/pyro/poutine/guide.py +++ b/pyro/poutine/guide.py @@ -2,16 +2,16 @@ # SPDX-License-Identifier: Apache-2.0 from abc import ABC, abstractmethod -from typing import Callable, Dict, Tuple, Union +from typing import Callable, Dict, Optional, Tuple, Union import torch import pyro.distributions as dist -from pyro.distributions.distribution import Distribution - -from .trace_messenger import TraceMessenger -from .trace_struct import Trace -from .util import prune_subsample_sites, site_is_subsample +from pyro.distributions.torch_distribution import TorchDistributionMixin +from pyro.poutine.runtime import Message +from pyro.poutine.trace_messenger import TraceMessenger +from pyro.poutine.trace_struct import Trace +from pyro.poutine.util import prune_subsample_sites, site_is_subsample class GuideMessenger(TraceMessenger, ABC): @@ -21,19 +21,19 @@ class GuideMessenger(TraceMessenger, ABC): Derived classes must implement the :meth:`get_posterior` method. """ - def __init__(self, model: Callable): + def __init__(self, model: Callable) -> None: super().__init__() # Do not register model as submodule self._model = (model,) @property - def model(self): + def model(self) -> Callable: return self._model[0] - def __getstate__(self): + def __getstate__(self) -> Dict[str, object]: # Avoid pickling the trace. - state = super().__getstate__() - state.pop("trace") + state = self.__dict__.copy() + del state["trace"] return state def __call__(self, *args, **kwargs) -> Dict[str, torch.Tensor]: # type: ignore[override] @@ -53,16 +53,19 @@ def __call__(self, *args, **kwargs) -> Dict[str, torch.Tensor]: # type: ignore[ del self.args_kwargs model_trace, guide_trace = self.get_traces() - samples = { - name: site["value"] - for name, site in model_trace.nodes.items() - if site["type"] == "sample" - } + samples = {} + for name, site in model_trace.nodes.items(): + if site["type"] == "sample": + assert isinstance(site["value"], torch.Tensor) + samples[name] = site["value"] return samples - def _pyro_sample(self, msg): + def _pyro_sample(self, msg: Message) -> None: if msg["is_observed"] or site_is_subsample(msg): return + assert isinstance(msg["name"], str) + assert isinstance(msg["fn"], TorchDistributionMixin) + assert msg["infer"] is not None prior = msg["fn"] msg["infer"]["prior"] = prior posterior = self.get_posterior(msg["name"], prior) @@ -72,17 +75,20 @@ def _pyro_sample(self, msg): posterior = posterior.expand(prior.batch_shape) msg["fn"] = posterior - def _pyro_post_sample(self, msg): + def _pyro_post_sample(self, msg: Message) -> None: # Manually apply outer plates. + assert msg["infer"] is not None prior = msg["infer"].get("prior") - if prior is not None and prior.batch_shape != msg["fn"].batch_shape: - msg["infer"]["prior"] = prior.expand(msg["fn"].batch_shape) + if prior is not None: + assert isinstance(msg["fn"], TorchDistributionMixin) + if prior.batch_shape != msg["fn"].batch_shape: + msg["infer"]["prior"] = prior.expand(msg["fn"].batch_shape) return super()._pyro_post_sample(msg) @abstractmethod def get_posterior( - self, name: str, prior: Distribution - ) -> Union[Distribution, torch.Tensor]: + self, name: str, prior: TorchDistributionMixin + ) -> Union[TorchDistributionMixin, torch.Tensor]: """ Abstract method to compute a posterior distribution or sample a posterior value given a prior distribution conditioned on upstream @@ -112,7 +118,7 @@ def get_posterior( """ raise NotImplementedError - def upstream_value(self, name: str): + def upstream_value(self, name: str) -> Optional[torch.Tensor]: """ For use in :meth:`get_posterior` . diff --git a/pyro/poutine/runtime.py b/pyro/poutine/runtime.py index 77e10a06c0..f3dc3146f6 100644 --- a/pyro/poutine/runtime.py +++ b/pyro/poutine/runtime.py @@ -29,6 +29,7 @@ T = TypeVar("T") if TYPE_CHECKING: + from pyro.distributions.score_parts import ScoreParts from pyro.distributions.torch_distribution import TorchDistributionMixin from pyro.poutine.indep_messenger import CondIndepStackFrame from pyro.poutine.messenger import Messenger @@ -49,9 +50,12 @@ class InferDict(TypedDict, total=False): is_auxiliary: bool is_observed: bool num_samples: int + prior: TorchDistributionMixin tmc: Literal["diagonal", "mixture"] _deterministic: bool + _dim_to_symbol: Dict[int, str] _do_not_trace: bool + _enumerate_symbol: str _markov_scope: Optional[Dict[str, int]] _enumerate_dim: int _dim_to_id: Dict[int, int] @@ -74,6 +78,11 @@ class Message(TypedDict, total=False): continuation: Optional[Callable[[Message], None]] infer: Optional[InferDict] obs: Optional[torch.Tensor] + log_prob: torch.Tensor + log_prob_sum: torch.Tensor + unscaled_log_prob: torch.Tensor + score_parts: ScoreParts + packed: "Message" _intervener_id: Optional[str] diff --git a/pyro/poutine/trace_messenger.py b/pyro/poutine/trace_messenger.py index 2e2790db8d..2b7609a3b5 100644 --- a/pyro/poutine/trace_messenger.py +++ b/pyro/poutine/trace_messenger.py @@ -2,13 +2,17 @@ # SPDX-License-Identifier: Apache-2.0 import sys +from typing import Any, Callable, Literal, Optional -from .messenger import Messenger -from .trace_struct import Trace -from .util import site_is_subsample +from typing_extensions import Self +from pyro.poutine.messenger import Messenger +from pyro.poutine.runtime import Message +from pyro.poutine.trace_struct import Trace +from pyro.poutine.util import site_is_subsample -def identify_dense_edges(trace): + +def identify_dense_edges(trace: Trace) -> None: """ Modifies a trace in-place by adding all edges based on the `cond_indep_stack` information stored at each site. @@ -63,7 +67,11 @@ class TraceMessenger(Messenger): :returns: stochastic function decorated with a :class:`~pyro.poutine.trace_messenger.TraceMessenger` """ - def __init__(self, graph_type=None, param_only=None): + def __init__( + self, + graph_type: Optional[Literal["flat", "dense"]] = None, + param_only: Optional[bool] = None, + ) -> None: """ :param string graph_type: string that specifies the type of graph to construct (currently only "flat" or "dense" supported) @@ -79,11 +87,11 @@ def __init__(self, graph_type=None, param_only=None): self.param_only = param_only self.trace = Trace(graph_type=self.graph_type) - def __enter__(self): + def __enter__(self) -> Self: self.trace = Trace(graph_type=self.graph_type) return super().__enter__() - def __exit__(self, *args, **kwargs): + def __exit__(self, *args, **kwargs) -> None: """ Adds appropriate edges based on cond_indep_stack information upon exiting the context. @@ -91,18 +99,19 @@ def __exit__(self, *args, **kwargs): if self.param_only: for node in list(self.trace.nodes.values()): if node["type"] != "param": + assert node["name"] is not None self.trace.remove_node(node["name"]) if self.graph_type == "dense": identify_dense_edges(self.trace) return super().__exit__(*args, **kwargs) - def __call__(self, fn): + def __call__(self, fn: Callable) -> "TraceHandler": # type: ignore[override] """ TODO docs """ return TraceHandler(self, fn) - def get_trace(self): + def get_trace(self) -> Trace: """ :returns: data structure :rtype: pyro.poutine.Trace @@ -112,7 +121,7 @@ def get_trace(self): """ return self.trace.copy() - def _reset(self): + def _reset(self) -> None: tr = Trace(graph_type=self.graph_type) if "_INPUT" in self.trace.nodes: tr.add_node( @@ -125,16 +134,19 @@ def _reset(self): self.trace = tr super()._reset() - def _pyro_post_sample(self, msg): + def _pyro_post_sample(self, msg: Message) -> None: if self.param_only: return + assert msg["name"] is not None + assert msg["infer"] is not None if msg["infer"].get("_do_not_trace"): assert msg["infer"].get("is_auxiliary") assert not msg["is_observed"] return self.trace.add_node(msg["name"], **msg.copy()) - def _pyro_post_param(self, msg): + def _pyro_post_param(self, msg: Message) -> None: + assert msg["name"] is not None self.trace.add_node(msg["name"], **msg.copy()) @@ -150,11 +162,11 @@ class TraceHandler: We can also use this for visualization. """ - def __init__(self, msngr, fn): + def __init__(self, msngr: TraceMessenger, fn: Callable): self.fn = fn self.msngr = msngr - def __call__(self, *args, **kwargs): + def __call__(self, *args, **kwargs) -> Any: """ Runs the stochastic function stored in this poutine, with additional side effects. @@ -175,6 +187,7 @@ def __call__(self, *args, **kwargs): except (ValueError, RuntimeError) as e: exc_type, exc_value, traceback = sys.exc_info() shapes = self.msngr.trace.format_shapes() + assert exc_type is not None exc = exc_type("{}\n{}".format(exc_value, shapes)) exc = exc.with_traceback(traceback) raise exc from e @@ -184,10 +197,10 @@ def __call__(self, *args, **kwargs): return ret @property - def trace(self): + def trace(self) -> Trace: return self.msngr.trace - def get_trace(self, *args, **kwargs): + def get_trace(self, *args, **kwargs) -> Trace: """ :returns: data structure :rtype: pyro.poutine.Trace diff --git a/pyro/poutine/trace_struct.py b/pyro/poutine/trace_struct.py index f8604cf985..029483afbd 100644 --- a/pyro/poutine/trace_struct.py +++ b/pyro/poutine/trace_struct.py @@ -3,16 +3,36 @@ import sys from collections import OrderedDict +from typing import ( + Any, + Callable, + Dict, + Iterable, + Iterator, + List, + Literal, + Optional, + Set, + Tuple, + Union, +) import opt_einsum +import torch +from pyro.distributions.distribution import Distribution from pyro.distributions.score_parts import ScoreParts from pyro.distributions.util import scale_and_mask from pyro.ops.packed import pack +from pyro.poutine.runtime import Message from pyro.poutine.util import is_validation_enabled from pyro.util import warn_if_inf, warn_if_nan +def allow_all_sites(name: str, site: Message) -> bool: + return True + + class Trace: """ Graph data structure denoting the relationships amongst different pyro primitives @@ -70,31 +90,31 @@ class Trace: :param string graph_type: string specifying the kind of trace graph to construct """ - def __init__(self, graph_type="flat"): + def __init__(self, graph_type: Literal["flat", "dense"] = "flat") -> None: assert graph_type in ("flat", "dense"), "{} not a valid graph type".format( graph_type ) self.graph_type = graph_type - self.nodes = OrderedDict() - self._succ = OrderedDict() - self._pred = OrderedDict() + self.nodes: OrderedDict[str, Message] = OrderedDict() + self._succ: OrderedDict[str, Set[str]] = OrderedDict() + self._pred: OrderedDict[str, Set[str]] = OrderedDict() - def __contains__(self, name): + def __contains__(self, name: str) -> bool: return name in self.nodes - def __iter__(self): + def __iter__(self) -> Iterable[str]: return iter(self.nodes.keys()) - def __len__(self): + def __len__(self) -> int: return len(self.nodes) @property - def edges(self): + def edges(self) -> Iterable[Tuple[str, str]]: for site, adj_nodes in self._succ.items(): for adj_node in adj_nodes: yield site, adj_node - def add_node(self, site_name, **kwargs): + def add_node(self, site_name: str, **kwargs: Any) -> None: """ :param string site_name: the name of the site to be added @@ -117,18 +137,18 @@ def add_node(self, site_name, **kwargs): ) # XXX should copy in case site gets mutated, or dont bother? - self.nodes[site_name] = kwargs + self.nodes[site_name] = kwargs # type: ignore[assignment] self._pred[site_name] = set() self._succ[site_name] = set() - def add_edge(self, site1, site2): + def add_edge(self, site1: str, site2: str) -> None: for site in (site1, site2): if site not in self.nodes: self.add_node(site) self._succ[site1].add(site2) self._pred[site2].add(site1) - def remove_node(self, site_name): + def remove_node(self, site_name: str) -> None: self.nodes.pop(site_name) for p in self._pred[site_name]: self._succ[p].remove(site_name) @@ -137,13 +157,13 @@ def remove_node(self, site_name): self._pred.pop(site_name) self._succ.pop(site_name) - def predecessors(self, site_name): + def predecessors(self, site_name: str) -> Set[str]: return self._pred[site_name] - def successors(self, site_name): + def successors(self, site_name: str) -> Set[str]: return self._succ[site_name] - def copy(self): + def copy(self) -> "Trace": """ Makes a shallow copy of self with nodes and edges preserved. """ @@ -153,7 +173,7 @@ def copy(self): new_tr._pred.update(self._pred) return new_tr - def _dfs(self, site, visited): + def _dfs(self, site: str, visited: Set[str]) -> Iterable[str]: if site in visited: return for s in self._succ[site]: @@ -162,21 +182,24 @@ def _dfs(self, site, visited): visited.add(site) yield site - def topological_sort(self, reverse=False): + def topological_sort(self, reverse: bool = False) -> List[str]: """ Return a list of nodes (site names) in topologically sorted order. :param bool reverse: Return the list in reverse order. :return: list of topologically sorted nodes (site names). """ - visited = set() + visited: Set[str] = set() top_sorted = [] for s in self._succ: for node in self._dfs(s, visited): top_sorted.append(node) return top_sorted if reverse else list(reversed(top_sorted)) - def log_prob_sum(self, site_filter=lambda name, site: True): + def log_prob_sum( + self, + site_filter: Callable[[str, Message], bool] = allow_all_sites, + ) -> Union[torch.Tensor, float]: """ Compute the site-wise log probabilities of the trace. Each ``log_prob`` has shape equal to the corresponding ``batch_shape``. @@ -189,6 +212,7 @@ def log_prob_sum(self, site_filter=lambda name, site: True): result = 0.0 for name, site in self.nodes.items(): if site["type"] == "sample" and site_filter(name, site): + assert isinstance(site["fn"], Distribution) if "log_prob_sum" in site: log_p = site["log_prob_sum"] else: @@ -213,10 +237,13 @@ def log_prob_sum(self, site_filter=lambda name, site: True): "log_prob_sum at site '{}'".format(name), allow_neginf=True, ) - result = result + log_p + result = result + log_p # type: ignore[assignment] return result - def compute_log_prob(self, site_filter=lambda name, site: True): + def compute_log_prob( + self, + site_filter: Callable[[str, Message], bool] = allow_all_sites, + ) -> None: """ Compute the site-wise log probabilities of the trace. Each ``log_prob`` has shape equal to the corresponding ``batch_shape``. @@ -225,6 +252,7 @@ def compute_log_prob(self, site_filter=lambda name, site: True): """ for name, site in self.nodes.items(): if site["type"] == "sample" and site_filter(name, site): + assert isinstance(site["fn"], Distribution) if "log_prob" not in site: try: log_p = site["fn"].log_prob( @@ -253,7 +281,7 @@ def compute_log_prob(self, site_filter=lambda name, site: True): allow_neginf=True, ) - def compute_score_parts(self): + def compute_score_parts(self) -> None: """ Compute the batched local score parts at each site of the trace. Each ``log_prob`` has shape equal to the corresponding ``batch_shape``. @@ -262,6 +290,7 @@ def compute_score_parts(self): """ for name, site in self.nodes.items(): if site["type"] == "sample" and "score_parts" not in site: + assert isinstance(site["fn"], Distribution) # Note that ScoreParts overloads the multiplication operator # to correctly scale each of its three parts. try: @@ -291,16 +320,17 @@ def compute_score_parts(self): allow_neginf=True, ) - def detach_(self): + def detach_(self) -> None: """ Detach values (in-place) at each sample site of the trace. """ for _, site in self.nodes.items(): if site["type"] == "sample": + assert site["value"] is not None site["value"] = site["value"].detach() @property - def observation_nodes(self): + def observation_nodes(self) -> List[str]: """ :return: a list of names of observe sites """ @@ -311,14 +341,14 @@ def observation_nodes(self): ] @property - def param_nodes(self): + def param_nodes(self) -> List[str]: """ :return: a list of names of param sites """ return [name for name, node in self.nodes.items() if node["type"] == "param"] @property - def stochastic_nodes(self): + def stochastic_nodes(self) -> List[str]: """ :return: a list of names of sample sites """ @@ -329,7 +359,7 @@ def stochastic_nodes(self): ] @property - def reparameterized_nodes(self): + def reparameterized_nodes(self) -> List[str]: """ :return: a list of names of sample sites whose stochastic functions are reparameterizable primitive distributions @@ -343,14 +373,14 @@ def reparameterized_nodes(self): ] @property - def nonreparam_stochastic_nodes(self): + def nonreparam_stochastic_nodes(self) -> List[str]: """ :return: a list of names of sample sites whose stochastic functions are not reparameterizable primitive distributions """ return list(set(self.stochastic_nodes) - set(self.reparameterized_nodes)) - def iter_stochastic_nodes(self): + def iter_stochastic_nodes(self) -> Iterator[Tuple[str, Message]]: """ :return: an iterator over stochastic nodes in the trace. """ @@ -358,7 +388,7 @@ def iter_stochastic_nodes(self): if node["type"] == "sample" and not node["is_observed"]: yield name, node - def symbolize_dims(self, plate_to_symbol=None): + def symbolize_dims(self, plate_to_symbol: Optional[Dict[int, str]] = None) -> None: """ Assign unique symbols to all tensor dimensions. """ @@ -369,7 +399,7 @@ def symbolize_dims(self, plate_to_symbol=None): continue # allocate even symbols for plate dims - dim_to_symbol = {} + dim_to_symbol: Dict[int, str] = {} for frame in site["cond_indep_stack"]: if frame.vectorized: if frame.name in plate_to_symbol: @@ -381,6 +411,7 @@ def symbolize_dims(self, plate_to_symbol=None): dim_to_symbol[frame.dim] = symbol # allocate odd symbols for enum dims + assert site["infer"] is not None for dim, id_ in site["infer"].get("_dim_to_id", {}).items(): symbol = opt_einsum.get_symbol(1 + 2 * id_) symbol_to_dim[symbol] = dim @@ -393,7 +424,7 @@ def symbolize_dims(self, plate_to_symbol=None): self.plate_to_symbol = plate_to_symbol self.symbol_to_dim = symbol_to_dim - def pack_tensors(self, plate_to_symbol=None): + def pack_tensors(self, plate_to_symbol: Optional[Dict[int, str]] = None) -> None: """ Computes packed representations of tensors in the trace. This should be called after :meth:`compute_log_prob` or :meth:`compute_score_parts`. @@ -402,6 +433,7 @@ def pack_tensors(self, plate_to_symbol=None): for site in self.nodes.values(): if site["type"] != "sample": continue + assert site["infer"] is not None dim_to_symbol = site["infer"]["_dim_to_symbol"] packed = site.setdefault("packed", {}) try: