From ba187d40ff7869ae5f616b76a2ca7d53417fdd4c Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Sun, 7 Jan 2024 15:43:28 +0000 Subject: [PATCH 1/9] warn_uncreachable=True --- pyro/ops/gaussian.py | 10 +++++----- pyro/poutine/block_messenger.py | 2 +- pyro/poutine/indep_messenger.py | 17 +++++------------ pyro/poutine/lift_messenger.py | 2 +- pyro/poutine/subsample_messenger.py | 7 ------- setup.cfg | 1 + 6 files changed, 13 insertions(+), 26 deletions(-) diff --git a/pyro/ops/gaussian.py b/pyro/ops/gaussian.py index 12f17e973c..d1e2adec57 100644 --- a/pyro/ops/gaussian.py +++ b/pyro/ops/gaussian.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import math -from typing import Optional, Tuple +from typing import Optional, Tuple, Union import torch from torch.distributions.utils import lazy_property @@ -111,7 +111,7 @@ def event_permute(self, perm) -> "Gaussian": precision = self.precision[..., perm][..., perm, :] return Gaussian(self.log_normalizer, info_vec, precision) - def __add__(self, other: "Gaussian") -> "Gaussian": + def __add__(self, other: Union["Gaussian", float, torch.Tensor]) -> "Gaussian": """ Adds two Gaussians in log-density space. """ @@ -122,12 +122,12 @@ def __add__(self, other: "Gaussian") -> "Gaussian": self.info_vec + other.info_vec, self.precision + other.precision, ) - if isinstance(other, (int, float, torch.Tensor)): + if isinstance(other, (float, torch.Tensor)): return Gaussian(self.log_normalizer + other, self.info_vec, self.precision) raise ValueError("Unsupported type: {}".format(type(other))) - def __sub__(self, other: "Gaussian") -> "Gaussian": - if isinstance(other, (int, float, torch.Tensor)): + def __sub__(self, other: Union["Gaussian", float, torch.Tensor]) -> "Gaussian": + if isinstance(other, (float, torch.Tensor)): return Gaussian(self.log_normalizer - other, self.info_vec, self.precision) raise ValueError("Unsupported type: {}".format(type(other))) diff --git a/pyro/poutine/block_messenger.py b/pyro/poutine/block_messenger.py index 79b594a29e..072f915e17 100644 --- a/pyro/poutine/block_messenger.py +++ b/pyro/poutine/block_messenger.py @@ -148,7 +148,7 @@ def __init__( expose: Optional[List[str]] = None, hide_types: Optional[List[str]] = None, expose_types: Optional[List[str]] = None, - ): + ) -> None: super().__init__() if not (hide_fn is None or expose_fn is None): raise ValueError("Only specify one of hide_fn or expose_fn") diff --git a/pyro/poutine/indep_messenger.py b/pyro/poutine/indep_messenger.py index bfb3f3d1ae..b3e547a9e2 100644 --- a/pyro/poutine/indep_messenger.py +++ b/pyro/poutine/indep_messenger.py @@ -1,7 +1,6 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -import numbers from typing import Iterator, NamedTuple, Optional, Tuple import torch @@ -9,7 +8,6 @@ from pyro.poutine.messenger import Messenger from pyro.poutine.runtime import _DIM_ALLOCATOR, Message -from pyro.util import ignore_jit_warnings class CondIndepStackFrame(NamedTuple): @@ -24,11 +22,7 @@ def vectorized(self) -> bool: return self.dim is not None def _key(self) -> Tuple[str, Optional[int], int, int]: - with ignore_jit_warnings(["Converting a tensor to a Python number"]): - size = ( - self.size.item() if isinstance(self.size, torch.Tensor) else self.size # type: ignore[attr-defined] - ) - return self.name, self.dim, size, self.counter + return self.name, self.dim, self.size, self.counter def __eq__(self, other: object) -> bool: if not isinstance(other, CondIndepStackFrame): @@ -117,11 +111,10 @@ def __iter__(self) -> Iterator[int]: self._vectorized = False self.dim = None - with ignore_jit_warnings([("Iterating over a tensor", RuntimeWarning)]): - for i in self.indices: - self.next_context() - with self: - yield i if isinstance(i, numbers.Number) else i.item() + for i in self.indices: + self.next_context() + with self: + yield i def _reset(self) -> None: if self._vectorized: diff --git a/pyro/poutine/lift_messenger.py b/pyro/poutine/lift_messenger.py index 09703e6825..a86df5bd3b 100644 --- a/pyro/poutine/lift_messenger.py +++ b/pyro/poutine/lift_messenger.py @@ -119,7 +119,7 @@ def _pyro_param(self, msg: Message) -> None: msg["args"] = msg["args"][1:] else: # otherwise leave as is - return None + return None # type: ignore[unreachable] msg["type"] = "sample" if name in self._samples_cache: # Multiple pyro.param statements with the same diff --git a/pyro/poutine/subsample_messenger.py b/pyro/poutine/subsample_messenger.py index 6a13a5a528..a6717feebe 100644 --- a/pyro/poutine/subsample_messenger.py +++ b/pyro/poutine/subsample_messenger.py @@ -47,7 +47,6 @@ def __init__( with ignore_jit_warnings(["torch.Tensor results are registered as constants"]): self.device = device or torch.Tensor().device - @ignore_jit_warnings(["Converting a tensor to a Python boolean"]) def sample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor: """ :returns: a random subsample of `range(size)` @@ -165,12 +164,6 @@ def _process_message(self, msg: Message) -> None: full_size=self.size, # used for param initialization ) msg["cond_indep_stack"] = (frame,) + msg["cond_indep_stack"] - if isinstance(self.size, torch.Tensor) or isinstance( - self.subsample_size, torch.Tensor - ): - if not isinstance(msg["scale"], torch.Tensor): - with ignore_jit_warnings(): - msg["scale"] = torch.tensor(msg["scale"]) msg["scale"] = msg["scale"] * self.size / self.subsample_size def _postprocess_message(self, msg: Message) -> None: diff --git a/setup.cfg b/setup.cfg index e3ca753137..20d7fc3dd9 100644 --- a/setup.cfg +++ b/setup.cfg @@ -41,6 +41,7 @@ warn_return_any = True warn_unused_configs = True warn_incomplete_stub = True ignore_missing_imports = True +warn_unreachable = True # Per-module options: From 74c3cf9e3641bce1d60524c370822875bb4de4e8 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Sun, 7 Jan 2024 16:37:20 +0000 Subject: [PATCH 2/9] Remove unused test function --- tests/infer/test_jit.py | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/tests/infer/test_jit.py b/tests/infer/test_jit.py index fa013339c6..c1e49f78ee 100644 --- a/tests/infer/test_jit.py +++ b/tests/infer/test_jit.py @@ -28,7 +28,6 @@ infer_discrete, ) from pyro.optim import Adam -from pyro.poutine.indep_messenger import CondIndepStackFrame from pyro.util import ignore_jit_warnings from tests.common import assert_close, assert_equal @@ -585,25 +584,6 @@ def hmm(transition, means, data): assert_equal(state, jit_state) -@pytest.mark.parametrize( - "x,y", - [ - ( - CondIndepStackFrame("a", -1, torch.tensor(2000), 2), - CondIndepStackFrame("a", -1, 2000, 2), - ), - ( - CondIndepStackFrame("a", -1, 1, 2), - CondIndepStackFrame("a", -1, torch.tensor(1), 2), - ), - ], -) -def test_cond_indep_equality(x, y): - assert x == y - assert not x != y - assert hash(x) == hash(y) - - def test_jit_arange_workaround(): def fn(x): y = torch.ones(x.shape[0], dtype=torch.long, device=x.device) From 7d4bd4ab6bc502c4b6babc167113da7820bf7d5c Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Tue, 16 Jan 2024 00:36:56 +0000 Subject: [PATCH 3/9] annotate messengers --- pyro/poutine/do_messenger.py | 2 +- pyro/poutine/enum_messenger.py | 17 +++++---- pyro/poutine/guide.py | 3 +- pyro/poutine/indep_messenger.py | 19 +++++++--- pyro/poutine/infer_config_messenger.py | 2 +- pyro/poutine/lift_messenger.py | 3 +- pyro/poutine/markov_messenger.py | 33 ++++++++++++----- pyro/poutine/replay_messenger.py | 21 ++++++++--- pyro/poutine/runtime.py | 3 +- pyro/poutine/subsample_messenger.py | 7 ++++ pyro/poutine/trace_messenger.py | 2 +- pyro/poutine/trace_struct.py | 26 +++++++------ pyro/poutine/uncondition_messenger.py | 2 +- pyro/poutine/util.py | 51 +++++++++++++++++--------- tests/infer/test_jit.py | 20 ++++++++++ 15 files changed, 146 insertions(+), 65 deletions(-) diff --git a/pyro/poutine/do_messenger.py b/pyro/poutine/do_messenger.py index bfd1bd0c73..ac686061dc 100644 --- a/pyro/poutine/do_messenger.py +++ b/pyro/poutine/do_messenger.py @@ -49,7 +49,7 @@ class DoMessenger(Messenger): :returns: stochastic function decorated with a :class:`~pyro.poutine.do_messenger.DoMessenger` """ - def __init__(self, data: Dict[str, Union[torch.Tensor, numbers.Number]]): + def __init__(self, data: Dict[str, Union[torch.Tensor, numbers.Number]]) -> None: super().__init__() self.data = data self._intervener_id = str(id(self)) diff --git a/pyro/poutine/enum_messenger.py b/pyro/poutine/enum_messenger.py index fe87f76d7f..962cc89ebc 100644 --- a/pyro/poutine/enum_messenger.py +++ b/pyro/poutine/enum_messenger.py @@ -1,6 +1,7 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 +from collections import Counter from typing import Any, Dict, List, Optional import torch @@ -177,14 +178,16 @@ def _pyro_sample(self, msg: Message) -> None: assert isinstance(msg["name"], str) assert msg["infer"] is not None # Compute upstream dims in scope; these are unsafe to use for this site's target_dim. - scope = msg["infer"].get("_markov_scope") # site name -> markov depth + scope = msg["infer"].setdefault( + "_markov_scope", Counter() + ) # site name -> markov depth param_dims = _ENUM_ALLOCATOR.dim_to_id.copy() # enum dim -> unique id - if scope is not None: - for name, depth in scope.items(): - if ( - self._markov_depths[name] == depth - ): # hide sites whose markov context has exited - param_dims.update(self._value_dims[name]) + for name, depth in scope.items(): + if ( + self._markov_depths[name] == depth + ): # hide sites whose markov context has exited + param_dims.update(self._value_dims[name]) + if scope: self._markov_depths[msg["name"]] = msg["infer"]["_markov_depth"] self._param_dims[msg["name"]] = param_dims if msg["is_observed"] or msg["infer"].get("enumerate") != "parallel": diff --git a/pyro/poutine/guide.py b/pyro/poutine/guide.py index c50d51bb77..0de8640477 100644 --- a/pyro/poutine/guide.py +++ b/pyro/poutine/guide.py @@ -143,11 +143,12 @@ def get_traces(self) -> Tuple[Trace, Trace]: """ guide_trace = prune_subsample_sites(self.trace) model_trace = model_trace = guide_trace.copy() - for name, guide_site in list(guide_trace.nodes.items()): + for name, guide_site in guide_trace.nodes.items(): if guide_site["type"] != "sample" or guide_site["is_observed"]: del guide_trace.nodes[name] continue model_site = model_trace.nodes[name].copy() + assert guide_site["infer"] is not None model_site["fn"] = guide_site["infer"]["prior"] model_trace.nodes[name] = model_site return model_trace, guide_trace diff --git a/pyro/poutine/indep_messenger.py b/pyro/poutine/indep_messenger.py index b3e547a9e2..13a175403d 100644 --- a/pyro/poutine/indep_messenger.py +++ b/pyro/poutine/indep_messenger.py @@ -1,6 +1,7 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 +import numbers from typing import Iterator, NamedTuple, Optional, Tuple import torch @@ -8,6 +9,7 @@ from pyro.poutine.messenger import Messenger from pyro.poutine.runtime import _DIM_ALLOCATOR, Message +from pyro.util import ignore_jit_warnings class CondIndepStackFrame(NamedTuple): @@ -22,7 +24,11 @@ def vectorized(self) -> bool: return self.dim is not None def _key(self) -> Tuple[str, Optional[int], int, int]: - return self.name, self.dim, self.size, self.counter + with ignore_jit_warnings(["Converting a tensor to a Python number"]): + size = ( + self.size.item() if isinstance(self.size, torch.Tensor) else self.size # type: ignore[attr-defined, unreachable] + ) + return self.name, self.dim, size, self.counter def __eq__(self, other: object) -> bool: if not isinstance(other, CondIndepStackFrame): @@ -65,7 +71,7 @@ def __init__( size: int, dim: Optional[int] = None, device: Optional[str] = None, - ): + ) -> None: if not torch._C._get_tracing_state() and size == 0: raise ZeroDivisionError("size cannot be zero") @@ -111,10 +117,11 @@ def __iter__(self) -> Iterator[int]: self._vectorized = False self.dim = None - for i in self.indices: - self.next_context() - with self: - yield i + with ignore_jit_warnings([("Iterating over a tensor", RuntimeWarning)]): + for i in self.indices: + self.next_context() + with self: + yield i if isinstance(i, numbers.Number) else i.item() def _reset(self) -> None: if self._vectorized: diff --git a/pyro/poutine/infer_config_messenger.py b/pyro/poutine/infer_config_messenger.py index f70e8cbfb8..1b2a57a727 100644 --- a/pyro/poutine/infer_config_messenger.py +++ b/pyro/poutine/infer_config_messenger.py @@ -18,7 +18,7 @@ class InferConfigMessenger(Messenger): :returns: stochastic function decorated with :class:`~pyro.poutine.infer_config_messenger.InferConfigMessenger` """ - def __init__(self, config_fn: Callable[[Message], InferDict]): + def __init__(self, config_fn: Callable[[Message], InferDict]) -> None: """ :param config_fn: a callable taking a site and returning an infer dict diff --git a/pyro/poutine/lift_messenger.py b/pyro/poutine/lift_messenger.py index a86df5bd3b..2d071c9ac4 100644 --- a/pyro/poutine/lift_messenger.py +++ b/pyro/poutine/lift_messenger.py @@ -118,8 +118,7 @@ def _pyro_param(self, msg: Message) -> None: msg["fn"] = self.prior msg["args"] = msg["args"][1:] else: - # otherwise leave as is - return None # type: ignore[unreachable] + raise TypeError("unreachable") msg["type"] = "sample" if name in self._samples_cache: # Multiple pyro.param statements with the same diff --git a/pyro/poutine/markov_messenger.py b/pyro/poutine/markov_messenger.py index 1d68c9e06a..26a7f0d7a7 100644 --- a/pyro/poutine/markov_messenger.py +++ b/pyro/poutine/markov_messenger.py @@ -2,9 +2,13 @@ # SPDX-License-Identifier: Apache-2.0 from collections import Counter -from contextlib import ExitStack # python 3 +from contextlib import ExitStack +from typing import Iterable, Iterator, List, Optional, Set -from .reentrant_messenger import ReentrantMessenger +from typing_extensions import Self + +from pyro.poutine.reentrant_messenger import ReentrantMessenger +from pyro.poutine.runtime import Message class MarkovMessenger(ReentrantMessenger): @@ -27,7 +31,13 @@ class MarkovMessenger(ReentrantMessenger): Interface stub, behavior not yet implemented. """ - def __init__(self, history=1, keep=False, dim=None, name=None): + def __init__( + self, + history: int = 1, + keep: bool = False, + dim: Optional[int] = None, + name: Optional[str] = None, + ) -> None: assert history >= 0 self.history = history self.keep = keep @@ -41,34 +51,35 @@ def __init__(self, history=1, keep=False, dim=None, name=None): raise NotImplementedError( "vectorized markov not yet implemented, try setting name to None" ) - self._iterable = None + self._iterable: Optional[Iterable] = None self._pos = -1 - self._stack = [] + self._stack: List[Set[str]] = [] super().__init__() - def generator(self, iterable): + def generator(self, iterable: Iterable) -> Self: self._iterable = iterable return self - def __iter__(self): + def __iter__(self) -> Iterator: with ExitStack() as stack: + assert self._iterable is not None for value in self._iterable: stack.enter_context(self) yield value - def __enter__(self): + def __enter__(self) -> Self: self._pos += 1 if len(self._stack) <= self._pos: self._stack.append(set()) return super().__enter__() - def __exit__(self, *args, **kwargs): + def __exit__(self, *args, **kwargs) -> None: if not self.keep: self._stack.pop() self._pos -= 1 return super().__exit__(*args, **kwargs) - def _pyro_sample(self, msg): + def _pyro_sample(self, msg: Message) -> None: if msg["done"] or type(msg["fn"]).__name__ == "_Subsample": return @@ -76,6 +87,8 @@ def _pyro_sample(self, msg): # go out of scope when any one of their markov contexts exits. # This accounting can be done by users of these fields, # e.g. EnumMessenger. + assert msg["name"] is not None + assert msg["infer"] is not None infer = msg["infer"] scope = infer.setdefault( "_markov_scope", Counter() diff --git a/pyro/poutine/replay_messenger.py b/pyro/poutine/replay_messenger.py index 548c971473..7e2ea27c3c 100644 --- a/pyro/poutine/replay_messenger.py +++ b/pyro/poutine/replay_messenger.py @@ -1,7 +1,13 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -from .messenger import Messenger +from typing import Dict, Optional + +import torch + +from pyro.poutine.messenger import Messenger +from pyro.poutine.runtime import Message +from pyro.poutine.trace_struct import Trace class ReplayMessenger(Messenger): @@ -32,7 +38,11 @@ class ReplayMessenger(Messenger): :returns: a stochastic function decorated with a :class:`~pyro.poutine.replay_messenger.ReplayMessenger` """ - def __init__(self, trace=None, params=None): + def __init__( + self, + trace: Optional[Trace] = None, + params: Optional[Dict[str, torch.Tensor]] = None, + ) -> None: """ :param trace: a trace whose values should be reused @@ -45,7 +55,7 @@ def __init__(self, trace=None, params=None): self.trace = trace self.params = params - def _pyro_sample(self, msg): + def _pyro_sample(self, msg: Message) -> None: """ :param msg: current message at a trace site. @@ -56,6 +66,7 @@ def _pyro_sample(self, msg): At a sample site that does not appear in self.trace, reverts to default Messenger._pyro_sample behavior with no additional side effects. """ + assert msg["name"] is not None name = msg["name"] if self.trace is not None and name in self.trace: guide_msg = self.trace.nodes[name] @@ -66,9 +77,8 @@ def _pyro_sample(self, msg): msg["done"] = True msg["value"] = guide_msg["value"] msg["infer"] = guide_msg["infer"] - return None - def _pyro_param(self, msg): + def _pyro_param(self, msg: Message) -> None: name = msg["name"] if self.params is not None and name in self.params: assert hasattr( @@ -76,4 +86,3 @@ def _pyro_param(self, msg): ), "param {} must be constrained value".format(name) msg["done"] = True msg["value"] = self.params[name] - return None diff --git a/pyro/poutine/runtime.py b/pyro/poutine/runtime.py index cbb0b6aa73..99188b8c4e 100644 --- a/pyro/poutine/runtime.py +++ b/pyro/poutine/runtime.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import functools +from collections import Counter from typing import ( TYPE_CHECKING, Callable, @@ -57,7 +58,7 @@ class InferDict(TypedDict, total=False): _dim_to_symbol: Dict[int, str] _do_not_trace: bool _enumerate_symbol: str - _markov_scope: Optional[Dict[str, int]] + _markov_scope: Counter _enumerate_dim: int _dim_to_id: Dict[int, int] _markov_depth: int diff --git a/pyro/poutine/subsample_messenger.py b/pyro/poutine/subsample_messenger.py index a6717feebe..58fba6faa8 100644 --- a/pyro/poutine/subsample_messenger.py +++ b/pyro/poutine/subsample_messenger.py @@ -47,6 +47,7 @@ def __init__( with ignore_jit_warnings(["torch.Tensor results are registered as constants"]): self.device = device or torch.Tensor().device + @ignore_jit_warnings(["Converting a tensor to a Python boolean"]) def sample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor: """ :returns: a random subsample of `range(size)` @@ -164,6 +165,12 @@ def _process_message(self, msg: Message) -> None: full_size=self.size, # used for param initialization ) msg["cond_indep_stack"] = (frame,) + msg["cond_indep_stack"] + if isinstance(self.size, torch.Tensor) or isinstance( # type: ignore[unreachable] + self.subsample_size, torch.Tensor # type: ignore[unreachable] + ): + if not isinstance(msg["scale"], torch.Tensor): # type: ignore[unreachable] + with ignore_jit_warnings(): + msg["scale"] = torch.tensor(msg["scale"]) msg["scale"] = msg["scale"] * self.size / self.subsample_size def _postprocess_message(self, msg: Message) -> None: diff --git a/pyro/poutine/trace_messenger.py b/pyro/poutine/trace_messenger.py index 2b7609a3b5..d812034b3c 100644 --- a/pyro/poutine/trace_messenger.py +++ b/pyro/poutine/trace_messenger.py @@ -162,7 +162,7 @@ class TraceHandler: We can also use this for visualization. """ - def __init__(self, msngr: TraceMessenger, fn: Callable): + def __init__(self, msngr: TraceMessenger, fn: Callable) -> None: self.fn = fn self.msngr = msngr diff --git a/pyro/poutine/trace_struct.py b/pyro/poutine/trace_struct.py index 4c2b58bb23..0fe8f90066 100644 --- a/pyro/poutine/trace_struct.py +++ b/pyro/poutine/trace_struct.py @@ -465,17 +465,20 @@ def pack_tensors(self, plate_to_symbol: Optional[Dict[str, str]] = None) -> None ) ).with_traceback(traceback) from e - def format_shapes(self, title="Trace Shapes:", last_site=None): + def format_shapes( + self, title: str = "Trace Shapes:", last_site: Optional[str] = None + ) -> str: """ Returns a string showing a table of the shapes of all sites in the trace. """ if not self.nodes: return title - rows = [[title]] + rows: List[List[Optional[str]]] = [[title]] rows.append(["Param Sites:"]) for name, site in self.nodes.items(): + assert isinstance(site["value"], torch.Tensor) if site["type"] == "param": rows.append([name, None] + [str(size) for size in site["value"].shape]) if name == last_site: @@ -520,7 +523,7 @@ def format_shapes(self, title="Trace Shapes:", last_site=None): return _format_table(rows) -def _format_table(rows): +def _format_table(rows: List[List[Optional[str]]]) -> str: """ Formats a right justified table using None as column separator. """ @@ -538,8 +541,9 @@ def _format_table(rows): column_widths[j] = max(column_widths[j], widths[j]) # justify columns - for i, row in enumerate(rows): - cols = [[], [], []] + justified_rows: List[List[str]] = [] + for row in rows: + cols: List[List[str]] = [[], [], []] j = 0 for cell in row: if cell is None: @@ -552,16 +556,16 @@ def _format_table(rows): else col + [""] * (width - len(col)) for width, col, direction in zip(column_widths, cols, "rrl") ] - rows[i] = sum(cols, []) + justified_rows.append(sum(cols, [])) # compute cell widths - cell_widths = [0] * len(rows[0]) - for row in rows: - for j, cell in enumerate(row): + cell_widths = [0] * len(justified_rows[0]) + for justified_row in justified_rows: + for j, cell in enumerate(justified_row): cell_widths[j] = max(cell_widths[j], len(cell)) # justify cells return "\n".join( - " ".join(cell.rjust(width) for cell, width in zip(row, cell_widths)) - for row in rows + " ".join(cell.rjust(width) for cell, width in zip(justified_row, cell_widths)) + for justified_row in justified_rows ) diff --git a/pyro/poutine/uncondition_messenger.py b/pyro/poutine/uncondition_messenger.py index 1978ba9a85..6e1e778972 100644 --- a/pyro/poutine/uncondition_messenger.py +++ b/pyro/poutine/uncondition_messenger.py @@ -11,7 +11,7 @@ class UnconditionMessenger(Messenger): distribution, ignoring observations. """ - def __init__(self): + def __init__(self) -> None: super().__init__() def _pyro_sample(self, msg: Message) -> None: diff --git a/pyro/poutine/util.py b/pyro/poutine/util.py index 3a0ec0316b..e91a7972f3 100644 --- a/pyro/poutine/util.py +++ b/pyro/poutine/util.py @@ -1,36 +1,43 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 +from typing import TYPE_CHECKING, List, Optional + from .. import settings +if TYPE_CHECKING: + from pyro.distributions.distribution import Distribution + from pyro.poutine.runtime import Message + from pyro.poutine.trace_struct import Trace + _VALIDATION_ENABLED = __debug__ settings.register("validate_poutine", __name__, "_VALIDATION_ENABLED") -def enable_validation(is_validate): +def enable_validation(is_validate: bool) -> None: global _VALIDATION_ENABLED _VALIDATION_ENABLED = is_validate -def is_validation_enabled(): +def is_validation_enabled() -> bool: return _VALIDATION_ENABLED -def site_is_subsample(site): +def site_is_subsample(site: "Message") -> bool: """ Determines whether a trace site originated from a subsample statement inside an `plate`. """ return site["type"] == "sample" and type(site["fn"]).__name__ == "_Subsample" -def site_is_factor(site): +def site_is_factor(site: "Message") -> bool: """ Determines whether a trace site originated from a factor statement. """ return site["type"] == "sample" and type(site["fn"]).__name__ == "Unit" -def prune_subsample_sites(trace): +def prune_subsample_sites(trace: "Trace") -> "Trace": """ Copies and removes all subsample sites from a trace. """ @@ -41,7 +48,9 @@ def prune_subsample_sites(trace): return trace -def enum_extend(trace, msg, num_samples=None): +def enum_extend( + trace: "Trace", msg: "Message", num_samples: Optional[int] = None +) -> List["Trace"]: """ :param trace: a partial trace :param msg: the message at a Pyro primitive site @@ -57,18 +66,23 @@ def enum_extend(trace, msg, num_samples=None): num_samples = -1 extended_traces = [] + assert msg["name"] is not None + if TYPE_CHECKING: + assert isinstance(msg["fn"], Distribution) for i, s in enumerate(msg["fn"].enumerate_support(*msg["args"], **msg["kwargs"])): if i > num_samples and num_samples >= 0: break msg_copy = msg.copy() - msg_copy.update(value=s) + msg_copy.update(value=s) # type: ignore[call-arg] tr_cp = trace.copy() tr_cp.add_node(msg["name"], **msg_copy) extended_traces.append(tr_cp) return extended_traces -def mc_extend(trace, msg, num_samples=None): +def mc_extend( + trace: "Trace", msg: "Message", num_samples: Optional[int] = None +) -> List["Trace"]: """ :param trace: a partial trace :param msg: the message at a Pyro primitive site @@ -88,12 +102,13 @@ def mc_extend(trace, msg, num_samples=None): msg_copy = msg.copy() msg_copy["value"] = msg_copy["fn"](*msg_copy["args"], **msg_copy["kwargs"]) tr_cp = trace.copy() + assert msg_copy["name"] is not None tr_cp.add_node(msg_copy["name"], **msg_copy) extended_traces.append(tr_cp) return extended_traces -def discrete_escape(trace, msg): +def discrete_escape(trace: "Trace", msg: "Message") -> bool: """ :param trace: a partial trace :param msg: the message at a Pyro primitive site @@ -105,14 +120,15 @@ def discrete_escape(trace, msg): Subroutine for integrating out discrete variables for variance reduction. """ return ( - (msg["type"] == "sample") - and (not msg["is_observed"]) - and (msg["name"] not in trace) - and (getattr(msg["fn"], "has_enumerate_support", False)) + msg["type"] == "sample" + and not msg["is_observed"] + and msg["name"] is not None + and msg["name"] not in trace + and getattr(msg["fn"], "has_enumerate_support", False) ) -def all_escape(trace, msg): +def all_escape(trace: "Trace", msg: "Message") -> bool: """ :param trace: a partial trace :param msg: the message at a Pyro primitive site @@ -124,7 +140,8 @@ def all_escape(trace, msg): Subroutine for approximately integrating out variables for variance reduction. """ return ( - (msg["type"] == "sample") - and (not msg["is_observed"]) - and (msg["name"] not in trace) + msg["type"] == "sample" + and not msg["is_observed"] + and msg["name"] is not None + and msg["name"] not in trace ) diff --git a/tests/infer/test_jit.py b/tests/infer/test_jit.py index c1e49f78ee..fa013339c6 100644 --- a/tests/infer/test_jit.py +++ b/tests/infer/test_jit.py @@ -28,6 +28,7 @@ infer_discrete, ) from pyro.optim import Adam +from pyro.poutine.indep_messenger import CondIndepStackFrame from pyro.util import ignore_jit_warnings from tests.common import assert_close, assert_equal @@ -584,6 +585,25 @@ def hmm(transition, means, data): assert_equal(state, jit_state) +@pytest.mark.parametrize( + "x,y", + [ + ( + CondIndepStackFrame("a", -1, torch.tensor(2000), 2), + CondIndepStackFrame("a", -1, 2000, 2), + ), + ( + CondIndepStackFrame("a", -1, 1, 2), + CondIndepStackFrame("a", -1, torch.tensor(1), 2), + ), + ], +) +def test_cond_indep_equality(x, y): + assert x == y + assert not x != y + assert hash(x) == hash(y) + + def test_jit_arange_workaround(): def fn(x): y = torch.ones(x.shape[0], dtype=torch.long, device=x.device) From c66ad12d63f0fb6b982bb5fdea8687d428f18c01 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Tue, 16 Jan 2024 01:38:55 +0000 Subject: [PATCH 4/9] revert enum_messenger --- pyro/poutine/enum_messenger.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/pyro/poutine/enum_messenger.py b/pyro/poutine/enum_messenger.py index 962cc89ebc..fe87f76d7f 100644 --- a/pyro/poutine/enum_messenger.py +++ b/pyro/poutine/enum_messenger.py @@ -1,7 +1,6 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -from collections import Counter from typing import Any, Dict, List, Optional import torch @@ -178,16 +177,14 @@ def _pyro_sample(self, msg: Message) -> None: assert isinstance(msg["name"], str) assert msg["infer"] is not None # Compute upstream dims in scope; these are unsafe to use for this site's target_dim. - scope = msg["infer"].setdefault( - "_markov_scope", Counter() - ) # site name -> markov depth + scope = msg["infer"].get("_markov_scope") # site name -> markov depth param_dims = _ENUM_ALLOCATOR.dim_to_id.copy() # enum dim -> unique id - for name, depth in scope.items(): - if ( - self._markov_depths[name] == depth - ): # hide sites whose markov context has exited - param_dims.update(self._value_dims[name]) - if scope: + if scope is not None: + for name, depth in scope.items(): + if ( + self._markov_depths[name] == depth + ): # hide sites whose markov context has exited + param_dims.update(self._value_dims[name]) self._markov_depths[msg["name"]] = msg["infer"]["_markov_depth"] self._param_dims[msg["name"]] = param_dims if msg["is_observed"] or msg["infer"].get("enumerate") != "parallel": From af4ad19120c62b2d7b1cc33b1a97502795f8d302 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Tue, 16 Jan 2024 02:01:44 +0000 Subject: [PATCH 5/9] collapse_messenger --- pyro/poutine/collapse_messenger.py | 53 +++++++++++++----------------- pyro/poutine/trace_struct.py | 2 +- 2 files changed, 24 insertions(+), 31 deletions(-) diff --git a/pyro/poutine/collapse_messenger.py b/pyro/poutine/collapse_messenger.py index 6206894943..50264d758c 100644 --- a/pyro/poutine/collapse_messenger.py +++ b/pyro/poutine/collapse_messenger.py @@ -1,16 +1,19 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 + from functools import reduce, singledispatch +from typing import TYPE_CHECKING, FrozenSet, Tuple + +from typing_extensions import Self import pyro from pyro.distributions.distribution import COERCIONS from pyro.ops.linalg import ignore_torch_deprecation_warnings +from pyro.poutine.runtime import _PYRO_STACK, Message +from pyro.poutine.trace_messenger import TraceMessenger from pyro.poutine.util import site_is_subsample -from .runtime import _PYRO_STACK -from .trace_messenger import TraceMessenger - # TODO Remove import guard once funsor is a required dependency. try: import funsor @@ -24,20 +27,8 @@ Funsor = type("Funsor", (), {}) Variable = type("Variable", (), {}) - -@singledispatch -def _get_free_vars(x): - return x - - -@_get_free_vars.register(Variable) -def _(x): - return frozenset((x.name,)) - - -@_get_free_vars.register(tuple) -def _(x, subs): - return frozenset().union(*map(_get_free_vars, x)) +if TYPE_CHECKING: + from funsor.distribution import Distribution @singledispatch @@ -92,7 +83,7 @@ class CollapseMessenger(TraceMessenger): _coerce = None - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: if CollapseMessenger._coerce is None: import funsor from funsor.distribution import CoerceDistributionToFunsor @@ -102,18 +93,20 @@ def __init__(self, *args, **kwargs): self._block = False super().__init__(*args, **kwargs) - def _process_message(self, msg): + def _process_message(self, msg: Message) -> None: if self._block: return if site_is_subsample(msg): return super()._process_message(msg) - def _pyro_sample(self, msg): + def _pyro_sample(self, msg: Message) -> None: # Eagerly convert fn and value to Funsor. dim_to_name = {f.dim: f.name for f in msg["cond_indep_stack"]} dim_to_name.update(self.preserved_plates) msg["fn"] = funsor.to_funsor(msg["fn"], funsor.Real, dim_to_name) + if TYPE_CHECKING: + assert isinstance(msg["fn"], Distribution) domain = msg["fn"].inputs["value"] if msg["value"] is None: msg["value"] = funsor.Variable(msg["name"], domain) @@ -123,14 +116,14 @@ def _pyro_sample(self, msg): msg["done"] = True msg["stop"] = True - def _pyro_post_sample(self, msg): + def _pyro_post_sample(self, msg: Message) -> None: if self._block: return if site_is_subsample(msg): return super()._pyro_post_sample(msg) - def _pyro_barrier(self, msg): + def _pyro_barrier(self, msg: Message) -> None: # Get log_prob and record factor. name, log_prob, log_joint, sampled_vars = self._get_log_prob() self._block = True @@ -151,14 +144,14 @@ def _pyro_barrier(self, msg): value = _substitute(value, samples) msg["value"] = value - def __enter__(self): + def __enter__(self) -> Self: self.preserved_plates = { h.dim: h.name for h in _PYRO_STACK if isinstance(h, pyro.plate) } COERCIONS.append(self._coerce) return super().__enter__() - def __exit__(self, *args): + def __exit__(self, *args) -> None: _coerce = COERCIONS.pop() assert _coerce is self._coerce super().__exit__(*args) @@ -168,20 +161,20 @@ def __exit__(self, *args): pyro.factor(name, log_prob.data) @ignore_torch_deprecation_warnings() - def _get_log_prob(self): + def _get_log_prob(self) -> Tuple[str, Funsor, Funsor, FrozenSet[str]]: # Convert delayed statements to pyro.factor() - reduced_vars = [] + reduced_vars_list = [] log_prob_terms = [] - plates = frozenset() + plates: FrozenSet[str] = frozenset() for name, site in self.trace.nodes.items(): if not site["is_observed"]: - reduced_vars.append(name) + reduced_vars_list.append(name) log_prob_terms.append(site["fn"](value=site["value"])) plates |= frozenset( f.name for f in site["cond_indep_stack"] if f.vectorized ) - name = reduced_vars[0] - reduced_vars = frozenset(reduced_vars) + name = reduced_vars_list[0] + reduced_vars = frozenset(reduced_vars_list) assert log_prob_terms, "nothing to collapse" self.trace.nodes.clear() reduced_plates = plates - frozenset(self.preserved_plates.values()) diff --git a/pyro/poutine/trace_struct.py b/pyro/poutine/trace_struct.py index 0fe8f90066..5638e611cc 100644 --- a/pyro/poutine/trace_struct.py +++ b/pyro/poutine/trace_struct.py @@ -478,8 +478,8 @@ def format_shapes( rows.append(["Param Sites:"]) for name, site in self.nodes.items(): - assert isinstance(site["value"], torch.Tensor) if site["type"] == "param": + assert isinstance(site["value"], torch.Tensor) rows.append([name, None] + [str(size) for size in site["value"].shape]) if name == last_site: break From 7893ec845d0e04f2126c66fee140ca16b8e2933a Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Tue, 16 Jan 2024 03:05:54 +0000 Subject: [PATCH 6/9] fix gaussian test --- pyro/ops/gaussian.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pyro/ops/gaussian.py b/pyro/ops/gaussian.py index d1e2adec57..b3da2a0915 100644 --- a/pyro/ops/gaussian.py +++ b/pyro/ops/gaussian.py @@ -111,7 +111,7 @@ def event_permute(self, perm) -> "Gaussian": precision = self.precision[..., perm][..., perm, :] return Gaussian(self.log_normalizer, info_vec, precision) - def __add__(self, other: Union["Gaussian", float, torch.Tensor]) -> "Gaussian": + def __add__(self, other: Union["Gaussian", int, float, torch.Tensor]) -> "Gaussian": """ Adds two Gaussians in log-density space. """ @@ -122,12 +122,12 @@ def __add__(self, other: Union["Gaussian", float, torch.Tensor]) -> "Gaussian": self.info_vec + other.info_vec, self.precision + other.precision, ) - if isinstance(other, (float, torch.Tensor)): + if isinstance(other, (int, float, torch.Tensor)): return Gaussian(self.log_normalizer + other, self.info_vec, self.precision) raise ValueError("Unsupported type: {}".format(type(other))) - def __sub__(self, other: Union["Gaussian", float, torch.Tensor]) -> "Gaussian": - if isinstance(other, (float, torch.Tensor)): + def __sub__(self, other: Union["Gaussian", int, float, torch.Tensor]) -> "Gaussian": + if isinstance(other, (int, float, torch.Tensor)): return Gaussian(self.log_normalizer - other, self.info_vec, self.precision) raise ValueError("Unsupported type: {}".format(type(other))) From 0057c83aca8a6b74439645b6ce98abd5bfe28405 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Tue, 16 Jan 2024 03:47:53 +0000 Subject: [PATCH 7/9] TYPE_CHECKING --- pyro/poutine/block_messenger.py | 22 ++++++++++------- pyro/poutine/broadcast_messenger.py | 8 +++--- pyro/poutine/collapse_messenger.py | 12 +++++---- pyro/poutine/condition_messenger.py | 8 +++--- pyro/poutine/guide.py | 12 +++++---- pyro/poutine/infer_config_messenger.py | 12 +++++---- pyro/poutine/lift_messenger.py | 12 +++++---- pyro/poutine/markov_messenger.py | 8 +++--- pyro/poutine/mask_messenger.py | 8 +++--- pyro/poutine/replay_messenger.py | 14 ++++++----- pyro/poutine/runtime.py | 5 ++-- pyro/poutine/scale_messenger.py | 8 +++--- pyro/poutine/substitute_messenger.py | 17 +++++++------ pyro/poutine/trace_messenger.py | 10 +++++--- pyro/poutine/trace_struct.py | 34 ++++++++++++++++---------- pyro/poutine/uncondition_messenger.py | 8 ++++-- pyro/poutine/util.py | 2 +- 17 files changed, 121 insertions(+), 79 deletions(-) diff --git a/pyro/poutine/block_messenger.py b/pyro/poutine/block_messenger.py index 072f915e17..e7ebb5bbe0 100644 --- a/pyro/poutine/block_messenger.py +++ b/pyro/poutine/block_messenger.py @@ -2,10 +2,12 @@ # SPDX-License-Identifier: Apache-2.0 from functools import partial -from typing import Callable, List, Optional +from typing import TYPE_CHECKING, Callable, List, Optional from pyro.poutine.messenger import Messenger -from pyro.poutine.runtime import Message + +if TYPE_CHECKING: + from pyro.poutine.runtime import Message def _block_fn( @@ -14,7 +16,7 @@ def _block_fn( hide: List[str], hide_types: List[str], hide_all: bool, - msg: Message, + msg: "Message", ) -> bool: # handle observes if msg["type"] == "sample" and msg["is_observed"]: @@ -43,7 +45,7 @@ def _make_default_hide_fn( expose: Optional[List[str]], hide_types: Optional[List[str]], expose_types: Optional[List[str]], -) -> Callable[[Message], bool]: +) -> Callable[["Message"], bool]: # first, some sanity checks: # hide_all and expose_all intersect? assert (hide_all is False and expose_all is False) or ( @@ -81,9 +83,11 @@ def _make_default_hide_fn( return partial(_block_fn, expose, expose_types, hide, hide_types, hide_all) -def _negate_fn(fn: Callable[[Message], Optional[bool]]) -> Callable[[Message], bool]: +def _negate_fn( + fn: Callable[["Message"], Optional[bool]] +) -> Callable[["Message"], bool]: # typed version of lambda msg: not fn(msg) - def negated_fn(msg: Message) -> bool: + def negated_fn(msg: "Message") -> bool: return not fn(msg) return negated_fn @@ -140,8 +144,8 @@ class BlockMessenger(Messenger): def __init__( self, - hide_fn: Optional[Callable[[Message], Optional[bool]]] = None, - expose_fn: Optional[Callable[[Message], Optional[bool]]] = None, + hide_fn: Optional[Callable[["Message"], Optional[bool]]] = None, + expose_fn: Optional[Callable[["Message"], Optional[bool]]] = None, hide_all: bool = True, expose_all: bool = False, hide: Optional[List[str]] = None, @@ -161,5 +165,5 @@ def __init__( hide_all, expose_all, hide, expose, hide_types, expose_types ) - def _process_message(self, msg: Message) -> None: + def _process_message(self, msg: "Message") -> None: msg["stop"] = bool(self.hide_fn(msg)) diff --git a/pyro/poutine/broadcast_messenger.py b/pyro/poutine/broadcast_messenger.py index 445cd1caea..87e9dd2f7b 100644 --- a/pyro/poutine/broadcast_messenger.py +++ b/pyro/poutine/broadcast_messenger.py @@ -1,13 +1,15 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -from typing import List, Optional +from typing import TYPE_CHECKING, List, Optional from pyro.distributions.torch_distribution import TorchDistributionMixin from pyro.poutine.messenger import Messenger -from pyro.poutine.runtime import Message from pyro.util import ignore_jit_warnings +if TYPE_CHECKING: + from pyro.poutine.runtime import Message + class BroadcastMessenger(Messenger): """ @@ -41,7 +43,7 @@ class BroadcastMessenger(Messenger): @staticmethod @ignore_jit_warnings(["Converting a tensor to a Python boolean"]) - def _pyro_sample(msg: Message) -> None: + def _pyro_sample(msg: "Message") -> None: """ :param msg: current message at a trace site. """ diff --git a/pyro/poutine/collapse_messenger.py b/pyro/poutine/collapse_messenger.py index 50264d758c..4a7bd05479 100644 --- a/pyro/poutine/collapse_messenger.py +++ b/pyro/poutine/collapse_messenger.py @@ -10,7 +10,7 @@ import pyro from pyro.distributions.distribution import COERCIONS from pyro.ops.linalg import ignore_torch_deprecation_warnings -from pyro.poutine.runtime import _PYRO_STACK, Message +from pyro.poutine.runtime import _PYRO_STACK from pyro.poutine.trace_messenger import TraceMessenger from pyro.poutine.util import site_is_subsample @@ -30,6 +30,8 @@ if TYPE_CHECKING: from funsor.distribution import Distribution + from pyro.poutine.runtime import Message + @singledispatch def _substitute(x, subs): @@ -93,14 +95,14 @@ def __init__(self, *args, **kwargs) -> None: self._block = False super().__init__(*args, **kwargs) - def _process_message(self, msg: Message) -> None: + def _process_message(self, msg: "Message") -> None: if self._block: return if site_is_subsample(msg): return super()._process_message(msg) - def _pyro_sample(self, msg: Message) -> None: + def _pyro_sample(self, msg: "Message") -> None: # Eagerly convert fn and value to Funsor. dim_to_name = {f.dim: f.name for f in msg["cond_indep_stack"]} dim_to_name.update(self.preserved_plates) @@ -116,14 +118,14 @@ def _pyro_sample(self, msg: Message) -> None: msg["done"] = True msg["stop"] = True - def _pyro_post_sample(self, msg: Message) -> None: + def _pyro_post_sample(self, msg: "Message") -> None: if self._block: return if site_is_subsample(msg): return super()._pyro_post_sample(msg) - def _pyro_barrier(self, msg: Message) -> None: + def _pyro_barrier(self, msg: "Message") -> None: # Get log_prob and record factor. name, log_prob, log_joint, sampled_vars = self._get_log_prob() self._block = True diff --git a/pyro/poutine/condition_messenger.py b/pyro/poutine/condition_messenger.py index 16ef9dd250..9ce259cd1f 100644 --- a/pyro/poutine/condition_messenger.py +++ b/pyro/poutine/condition_messenger.py @@ -1,14 +1,16 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -from typing import Dict, Union +from typing import TYPE_CHECKING, Dict, Union import torch from pyro.poutine.messenger import Messenger -from pyro.poutine.runtime import Message from pyro.poutine.trace_struct import Trace +if TYPE_CHECKING: + from pyro.poutine.runtime import Message + class ConditionMessenger(Messenger): """ @@ -46,7 +48,7 @@ def __init__(self, data: Union[Dict[str, torch.Tensor], Trace]) -> None: super().__init__() self.data = data - def _pyro_sample(self, msg: Message) -> None: + def _pyro_sample(self, msg: "Message") -> None: """ :param msg: current message at a trace site. :returns: a sample from the stochastic function at the site. diff --git a/pyro/poutine/guide.py b/pyro/poutine/guide.py index 0de8640477..ca40214c5b 100644 --- a/pyro/poutine/guide.py +++ b/pyro/poutine/guide.py @@ -2,17 +2,19 @@ # SPDX-License-Identifier: Apache-2.0 from abc import ABC, abstractmethod -from typing import Callable, Dict, Optional, Tuple, Union +from typing import TYPE_CHECKING, Callable, Dict, Optional, Tuple, Union import torch import pyro.distributions as dist from pyro.distributions.torch_distribution import TorchDistributionMixin -from pyro.poutine.runtime import Message from pyro.poutine.trace_messenger import TraceMessenger from pyro.poutine.trace_struct import Trace from pyro.poutine.util import prune_subsample_sites, site_is_subsample +if TYPE_CHECKING: + from pyro.poutine.runtime import Message + class GuideMessenger(TraceMessenger, ABC): """ @@ -60,7 +62,7 @@ def __call__(self, *args, **kwargs) -> Dict[str, torch.Tensor]: # type: ignore[ samples[name] = site["value"] return samples - def _pyro_sample(self, msg: Message) -> None: + def _pyro_sample(self, msg: "Message") -> None: if msg["is_observed"] or site_is_subsample(msg): return assert isinstance(msg["name"], str) @@ -75,7 +77,7 @@ def _pyro_sample(self, msg: Message) -> None: posterior = posterior.expand(prior.batch_shape) msg["fn"] = posterior - def _pyro_post_sample(self, msg: Message) -> None: + def _pyro_post_sample(self, msg: "Message") -> None: # Manually apply outer plates. assert msg["infer"] is not None prior = msg["infer"].get("prior") @@ -143,7 +145,7 @@ def get_traces(self) -> Tuple[Trace, Trace]: """ guide_trace = prune_subsample_sites(self.trace) model_trace = model_trace = guide_trace.copy() - for name, guide_site in guide_trace.nodes.items(): + for name, guide_site in list(guide_trace.nodes.items()): if guide_site["type"] != "sample" or guide_site["is_observed"]: del guide_trace.nodes[name] continue diff --git a/pyro/poutine/infer_config_messenger.py b/pyro/poutine/infer_config_messenger.py index 1b2a57a727..362ae2cf86 100644 --- a/pyro/poutine/infer_config_messenger.py +++ b/pyro/poutine/infer_config_messenger.py @@ -1,10 +1,12 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -from typing import Callable +from typing import TYPE_CHECKING, Callable from pyro.poutine.messenger import Messenger -from pyro.poutine.runtime import InferDict, Message + +if TYPE_CHECKING: + from pyro.poutine.runtime import InferDict, Message class InferConfigMessenger(Messenger): @@ -18,7 +20,7 @@ class InferConfigMessenger(Messenger): :returns: stochastic function decorated with :class:`~pyro.poutine.infer_config_messenger.InferConfigMessenger` """ - def __init__(self, config_fn: Callable[[Message], InferDict]) -> None: + def __init__(self, config_fn: Callable[["Message"], "InferDict"]) -> None: """ :param config_fn: a callable taking a site and returning an infer dict @@ -28,7 +30,7 @@ def __init__(self, config_fn: Callable[[Message], InferDict]) -> None: super().__init__() self.config_fn = config_fn - def _pyro_sample(self, msg: Message) -> None: + def _pyro_sample(self, msg: "Message") -> None: """ :param msg: current message at a trace site. @@ -41,7 +43,7 @@ def _pyro_sample(self, msg: Message) -> None: assert msg["infer"] is not None msg["infer"].update(self.config_fn(msg)) - def _pyro_param(self, msg: Message) -> None: + def _pyro_param(self, msg: "Message") -> None: """ :param msg: current message at a trace site. diff --git a/pyro/poutine/lift_messenger.py b/pyro/poutine/lift_messenger.py index 2d071c9ac4..f40de4d381 100644 --- a/pyro/poutine/lift_messenger.py +++ b/pyro/poutine/lift_messenger.py @@ -2,16 +2,18 @@ # SPDX-License-Identifier: Apache-2.0 import warnings -from typing import Callable, Dict, Set, Union +from typing import TYPE_CHECKING, Callable, Dict, Set, Union from typing_extensions import Self from pyro import params from pyro.distributions.distribution import Distribution from pyro.poutine.messenger import Messenger -from pyro.poutine.runtime import Message from pyro.poutine.util import is_validation_enabled +if TYPE_CHECKING: + from pyro.poutine.runtime import Message + class LiftMessenger(Messenger): """ @@ -55,7 +57,7 @@ def __init__( """ super().__init__() self.prior = prior - self._samples_cache: Dict[str, Message] = {} + self._samples_cache: Dict[str, "Message"] = {} def __enter__(self) -> Self: self._samples_cache = {} @@ -77,10 +79,10 @@ def __exit__(self, *args, **kwargs) -> None: ) return super().__exit__(*args, **kwargs) - def _pyro_sample(self, msg: Message) -> None: + def _pyro_sample(self, msg: "Message") -> None: return None - def _pyro_param(self, msg: Message) -> None: + def _pyro_param(self, msg: "Message") -> None: """ Overrides the `pyro.param` call with samples sampled from the distribution specified in the prior. The prior can be a diff --git a/pyro/poutine/markov_messenger.py b/pyro/poutine/markov_messenger.py index 26a7f0d7a7..a9910eb728 100644 --- a/pyro/poutine/markov_messenger.py +++ b/pyro/poutine/markov_messenger.py @@ -3,12 +3,14 @@ from collections import Counter from contextlib import ExitStack -from typing import Iterable, Iterator, List, Optional, Set +from typing import TYPE_CHECKING, Iterable, Iterator, List, Optional, Set from typing_extensions import Self from pyro.poutine.reentrant_messenger import ReentrantMessenger -from pyro.poutine.runtime import Message + +if TYPE_CHECKING: + from pyro.poutine.runtime import Message class MarkovMessenger(ReentrantMessenger): @@ -79,7 +81,7 @@ def __exit__(self, *args, **kwargs) -> None: self._pos -= 1 return super().__exit__(*args, **kwargs) - def _pyro_sample(self, msg: Message) -> None: + def _pyro_sample(self, msg: "Message") -> None: if msg["done"] or type(msg["fn"]).__name__ == "_Subsample": return diff --git a/pyro/poutine/mask_messenger.py b/pyro/poutine/mask_messenger.py index c3c375d8a2..132acf3b33 100644 --- a/pyro/poutine/mask_messenger.py +++ b/pyro/poutine/mask_messenger.py @@ -1,12 +1,14 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -from typing import Union +from typing import TYPE_CHECKING, Union import torch from pyro.poutine.messenger import Messenger -from pyro.poutine.runtime import Message + +if TYPE_CHECKING: + from pyro.poutine.runtime import Message class MaskMessenger(Messenger): @@ -34,5 +36,5 @@ def __init__(self, mask: Union[bool, torch.BoolTensor]) -> None: super().__init__() self.mask = mask - def _process_message(self, msg: Message) -> None: + def _process_message(self, msg: "Message") -> None: msg["mask"] = self.mask if msg["mask"] is None else msg["mask"] & self.mask diff --git a/pyro/poutine/replay_messenger.py b/pyro/poutine/replay_messenger.py index 7e2ea27c3c..d6c88a95bf 100644 --- a/pyro/poutine/replay_messenger.py +++ b/pyro/poutine/replay_messenger.py @@ -1,13 +1,15 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -from typing import Dict, Optional +from typing import TYPE_CHECKING, Dict, Optional import torch from pyro.poutine.messenger import Messenger -from pyro.poutine.runtime import Message -from pyro.poutine.trace_struct import Trace + +if TYPE_CHECKING: + from pyro.poutine.runtime import Message + from pyro.poutine.trace_struct import Trace class ReplayMessenger(Messenger): @@ -40,7 +42,7 @@ class ReplayMessenger(Messenger): def __init__( self, - trace: Optional[Trace] = None, + trace: Optional["Trace"] = None, params: Optional[Dict[str, torch.Tensor]] = None, ) -> None: """ @@ -55,7 +57,7 @@ def __init__( self.trace = trace self.params = params - def _pyro_sample(self, msg: Message) -> None: + def _pyro_sample(self, msg: "Message") -> None: """ :param msg: current message at a trace site. @@ -78,7 +80,7 @@ def _pyro_sample(self, msg: Message) -> None: msg["value"] = guide_msg["value"] msg["infer"] = guide_msg["infer"] - def _pyro_param(self, msg: Message) -> None: + def _pyro_param(self, msg: "Message") -> None: name = msg["name"] if self.params is not None and name in self.params: assert hasattr( diff --git a/pyro/poutine/runtime.py b/pyro/poutine/runtime.py index 99188b8c4e..1a25a3405c 100644 --- a/pyro/poutine/runtime.py +++ b/pyro/poutine/runtime.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 import functools -from collections import Counter from typing import ( TYPE_CHECKING, Callable, @@ -29,6 +28,8 @@ _T = TypeVar("_T") if TYPE_CHECKING: + from collections import Counter + from pyro.distributions.score_parts import ScoreParts from pyro.distributions.torch_distribution import TorchDistributionMixin from pyro.poutine.indep_messenger import CondIndepStackFrame @@ -58,7 +59,7 @@ class InferDict(TypedDict, total=False): _dim_to_symbol: Dict[int, str] _do_not_trace: bool _enumerate_symbol: str - _markov_scope: Counter + _markov_scope: "Counter" _enumerate_dim: int _dim_to_id: Dict[int, int] _markov_depth: int diff --git a/pyro/poutine/scale_messenger.py b/pyro/poutine/scale_messenger.py index 121ecf6bc4..48e96c7255 100644 --- a/pyro/poutine/scale_messenger.py +++ b/pyro/poutine/scale_messenger.py @@ -1,14 +1,16 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -from typing import Union +from typing import TYPE_CHECKING, Union import torch from pyro.poutine.messenger import Messenger -from pyro.poutine.runtime import Message from pyro.poutine.util import is_validation_enabled +if TYPE_CHECKING: + from pyro.poutine.runtime import Message + class ScaleMessenger(Messenger): """ @@ -47,5 +49,5 @@ def __init__(self, scale: Union[float, torch.Tensor]) -> None: super().__init__() self.scale = scale - def _process_message(self, msg: Message) -> None: + def _process_message(self, msg: "Message") -> None: msg["scale"] = self.scale * msg["scale"] diff --git a/pyro/poutine/substitute_messenger.py b/pyro/poutine/substitute_messenger.py index 2b28616381..69caa6f298 100644 --- a/pyro/poutine/substitute_messenger.py +++ b/pyro/poutine/substitute_messenger.py @@ -2,16 +2,19 @@ # SPDX-License-Identifier: Apache-2.0 import warnings -from typing import Dict, Set +from typing import TYPE_CHECKING, Dict, Set -import torch from typing_extensions import Self from pyro import params from pyro.poutine.messenger import Messenger -from pyro.poutine.runtime import Message from pyro.poutine.util import is_validation_enabled +if TYPE_CHECKING: + import torch + + from pyro.poutine.runtime import Message + class SubstituteMessenger(Messenger): """ @@ -32,14 +35,14 @@ class SubstituteMessenger(Messenger): :returns: ``fn`` decorated with a :class:`~pyro.poutine.substitute_messenger.SubstituteMessenger` """ - def __init__(self, data: Dict[str, torch.Tensor]) -> None: + def __init__(self, data: Dict[str, "torch.Tensor"]) -> None: """ :param data: values for the parameters. Constructor """ super().__init__() self.data = data - self._data_cache: Dict[str, Message] = {} + self._data_cache: Dict[str, "Message"] = {} def __enter__(self) -> Self: self._data_cache = {} @@ -61,10 +64,10 @@ def __exit__(self, *args, **kwargs) -> None: ) return super().__exit__(*args, **kwargs) - def _pyro_sample(self, msg: Message) -> None: + def _pyro_sample(self, msg: "Message") -> None: return None - def _pyro_param(self, msg: Message) -> None: + def _pyro_param(self, msg: "Message") -> None: """ Overrides the `pyro.param` with substituted values. If the param name does not match the name the keys in `data`, diff --git a/pyro/poutine/trace_messenger.py b/pyro/poutine/trace_messenger.py index d812034b3c..735828c4fd 100644 --- a/pyro/poutine/trace_messenger.py +++ b/pyro/poutine/trace_messenger.py @@ -2,15 +2,17 @@ # SPDX-License-Identifier: Apache-2.0 import sys -from typing import Any, Callable, Literal, Optional +from typing import TYPE_CHECKING, Any, Callable, Literal, Optional from typing_extensions import Self from pyro.poutine.messenger import Messenger -from pyro.poutine.runtime import Message from pyro.poutine.trace_struct import Trace from pyro.poutine.util import site_is_subsample +if TYPE_CHECKING: + from pyro.poutine.runtime import Message + def identify_dense_edges(trace: Trace) -> None: """ @@ -134,7 +136,7 @@ def _reset(self) -> None: self.trace = tr super()._reset() - def _pyro_post_sample(self, msg: Message) -> None: + def _pyro_post_sample(self, msg: "Message") -> None: if self.param_only: return assert msg["name"] is not None @@ -145,7 +147,7 @@ def _pyro_post_sample(self, msg: Message) -> None: return self.trace.add_node(msg["name"], **msg.copy()) - def _pyro_post_param(self, msg: Message) -> None: + def _pyro_post_param(self, msg: "Message") -> None: assert msg["name"] is not None self.trace.add_node(msg["name"], **msg.copy()) diff --git a/pyro/poutine/trace_struct.py b/pyro/poutine/trace_struct.py index 5638e611cc..7b7e286747 100644 --- a/pyro/poutine/trace_struct.py +++ b/pyro/poutine/trace_struct.py @@ -4,6 +4,7 @@ import sys from collections import OrderedDict from typing import ( + TYPE_CHECKING, Any, Callable, Dict, @@ -18,18 +19,21 @@ ) import opt_einsum -import torch -from pyro.distributions.distribution import Distribution from pyro.distributions.score_parts import ScoreParts from pyro.distributions.util import scale_and_mask from pyro.ops.packed import pack -from pyro.poutine.runtime import Message from pyro.poutine.util import is_validation_enabled from pyro.util import warn_if_inf, warn_if_nan +if TYPE_CHECKING: + import torch -def allow_all_sites(name: str, site: Message) -> bool: + from pyro.distributions.distribution import Distribution + from pyro.poutine.runtime import Message + + +def allow_all_sites(name: str, site: "Message") -> bool: return True @@ -95,7 +99,7 @@ def __init__(self, graph_type: Literal["flat", "dense"] = "flat") -> None: graph_type ) self.graph_type = graph_type - self.nodes: OrderedDict[str, Message] = OrderedDict() + self.nodes: OrderedDict[str, "Message"] = OrderedDict() self._succ: OrderedDict[str, Set[str]] = OrderedDict() self._pred: OrderedDict[str, Set[str]] = OrderedDict() @@ -198,8 +202,8 @@ def topological_sort(self, reverse: bool = False) -> List[str]: def log_prob_sum( self, - site_filter: Callable[[str, Message], bool] = allow_all_sites, - ) -> Union[torch.Tensor, float]: + site_filter: Callable[[str, "Message"], bool] = allow_all_sites, + ) -> Union["torch.Tensor", float]: """ Compute the site-wise log probabilities of the trace. Each ``log_prob`` has shape equal to the corresponding ``batch_shape``. @@ -212,7 +216,8 @@ def log_prob_sum( result = 0.0 for name, site in self.nodes.items(): if site["type"] == "sample" and site_filter(name, site): - assert isinstance(site["fn"], Distribution) + if TYPE_CHECKING: + assert isinstance(site["fn"], Distribution) if "log_prob_sum" in site: log_p = site["log_prob_sum"] else: @@ -242,7 +247,7 @@ def log_prob_sum( def compute_log_prob( self, - site_filter: Callable[[str, Message], bool] = allow_all_sites, + site_filter: Callable[[str, "Message"], bool] = allow_all_sites, ) -> None: """ Compute the site-wise log probabilities of the trace. @@ -252,7 +257,8 @@ def compute_log_prob( """ for name, site in self.nodes.items(): if site["type"] == "sample" and site_filter(name, site): - assert isinstance(site["fn"], Distribution) + if TYPE_CHECKING: + assert isinstance(site["fn"], Distribution) if "log_prob" not in site: try: log_p = site["fn"].log_prob( @@ -290,7 +296,8 @@ def compute_score_parts(self) -> None: """ for name, site in self.nodes.items(): if site["type"] == "sample" and "score_parts" not in site: - assert isinstance(site["fn"], Distribution) + if TYPE_CHECKING: + assert isinstance(site["fn"], Distribution) # Note that ScoreParts overloads the multiplication operator # to correctly scale each of its three parts. try: @@ -380,7 +387,7 @@ def nonreparam_stochastic_nodes(self) -> List[str]: """ return list(set(self.stochastic_nodes) - set(self.reparameterized_nodes)) - def iter_stochastic_nodes(self) -> Iterator[Tuple[str, Message]]: + def iter_stochastic_nodes(self) -> Iterator[Tuple[str, "Message"]]: """ :return: an iterator over stochastic nodes in the trace. """ @@ -479,7 +486,8 @@ def format_shapes( rows.append(["Param Sites:"]) for name, site in self.nodes.items(): if site["type"] == "param": - assert isinstance(site["value"], torch.Tensor) + if TYPE_CHECKING: + assert isinstance(site["value"], torch.Tensor) rows.append([name, None] + [str(size) for size in site["value"].shape]) if name == last_site: break diff --git a/pyro/poutine/uncondition_messenger.py b/pyro/poutine/uncondition_messenger.py index 6e1e778972..34febd0543 100644 --- a/pyro/poutine/uncondition_messenger.py +++ b/pyro/poutine/uncondition_messenger.py @@ -1,8 +1,12 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 +from typing import TYPE_CHECKING + from pyro.poutine.messenger import Messenger -from pyro.poutine.runtime import Message + +if TYPE_CHECKING: + from pyro.poutine.runtime import Message class UnconditionMessenger(Messenger): @@ -14,7 +18,7 @@ class UnconditionMessenger(Messenger): def __init__(self) -> None: super().__init__() - def _pyro_sample(self, msg: Message) -> None: + def _pyro_sample(self, msg: "Message") -> None: """ :param msg: current message at a trace site. diff --git a/pyro/poutine/util.py b/pyro/poutine/util.py index e91a7972f3..8c682e2aef 100644 --- a/pyro/poutine/util.py +++ b/pyro/poutine/util.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, List, Optional -from .. import settings +from pyro import settings if TYPE_CHECKING: from pyro.distributions.distribution import Distribution From b4564c19e27e51df2f347553a099e8aade0008a1 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Tue, 16 Jan 2024 05:55:15 +0000 Subject: [PATCH 8/9] MarkovMessenger --- pyro/poutine/guide.py | 16 ++++++----- pyro/poutine/handlers.py | 47 ++++++++++++++++++++++++++++++- pyro/poutine/markov_messenger.py | 6 ++-- pyro/poutine/reparam_messenger.py | 12 ++++---- 4 files changed, 65 insertions(+), 16 deletions(-) diff --git a/pyro/poutine/guide.py b/pyro/poutine/guide.py index ca40214c5b..6685fc4df4 100644 --- a/pyro/poutine/guide.py +++ b/pyro/poutine/guide.py @@ -7,12 +7,12 @@ import torch import pyro.distributions as dist -from pyro.distributions.torch_distribution import TorchDistributionMixin from pyro.poutine.trace_messenger import TraceMessenger from pyro.poutine.trace_struct import Trace from pyro.poutine.util import prune_subsample_sites, site_is_subsample if TYPE_CHECKING: + from pyro.distributions.torch_distribution import TorchDistributionMixin from pyro.poutine.runtime import Message @@ -65,9 +65,10 @@ def __call__(self, *args, **kwargs) -> Dict[str, torch.Tensor]: # type: ignore[ def _pyro_sample(self, msg: "Message") -> None: if msg["is_observed"] or site_is_subsample(msg): return - assert isinstance(msg["name"], str) - assert isinstance(msg["fn"], TorchDistributionMixin) - assert msg["infer"] is not None + if TYPE_CHECKING: + assert isinstance(msg["name"], str) + assert isinstance(msg["fn"], TorchDistributionMixin) + assert msg["infer"] is not None prior = msg["fn"] msg["infer"]["prior"] = prior posterior = self.get_posterior(msg["name"], prior) @@ -82,15 +83,16 @@ def _pyro_post_sample(self, msg: "Message") -> None: assert msg["infer"] is not None prior = msg["infer"].get("prior") if prior is not None: - assert isinstance(msg["fn"], TorchDistributionMixin) + if TYPE_CHECKING: + assert isinstance(msg["fn"], TorchDistributionMixin) if prior.batch_shape != msg["fn"].batch_shape: msg["infer"]["prior"] = prior.expand(msg["fn"].batch_shape) return super()._pyro_post_sample(msg) @abstractmethod def get_posterior( - self, name: str, prior: TorchDistributionMixin - ) -> Union[TorchDistributionMixin, torch.Tensor]: + self, name: str, prior: "TorchDistributionMixin" + ) -> Union["TorchDistributionMixin", torch.Tensor]: """ Abstract method to compute a posterior distribution or sample a posterior value given a prior distribution conditioned on upstream diff --git a/pyro/poutine/handlers.py b/pyro/poutine/handlers.py index 969c5476d6..f0116b9be2 100644 --- a/pyro/poutine/handlers.py +++ b/pyro/poutine/handlers.py @@ -51,6 +51,9 @@ import collections import functools +from typing import Callable, Iterable, Optional, TypeVar, Union, overload + +from typing_extensions import ParamSpec from pyro.poutine import util @@ -74,6 +77,9 @@ from .trace_messenger import TraceMessenger from .uncondition_messenger import UnconditionMessenger +_P = ParamSpec("_P") +_T = TypeVar("_T") + ############################################ # Begin primitive operations ############################################ @@ -276,7 +282,46 @@ def _fn(*args, **kwargs): return wrapper(fn) if fn is not None else wrapper -def markov(fn=None, history=1, keep=False, dim=None, name=None): +@overload +def markov( + fn: None = ..., + history: int = 1, + keep: bool = False, + dim: Optional[int] = None, + name: Optional[str] = None, +) -> MarkovMessenger: + ... + + +@overload +def markov( + fn: Iterable[int] = ..., + history: int = 1, + keep: bool = False, + dim: Optional[int] = None, + name: Optional[str] = None, +) -> MarkovMessenger: + ... + + +@overload +def markov( + fn: Callable[_P, _T] = ..., + history: int = 1, + keep: bool = False, + dim: Optional[int] = None, + name: Optional[str] = None, +) -> Callable[_P, _T]: + ... + + +def markov( + fn: Optional[Union[Iterable[int], Callable]] = None, + history: int = 1, + keep: bool = False, + dim: Optional[int] = None, + name: Optional[str] = None, +) -> Union[MarkovMessenger, Callable]: """ Markov dependency declaration. diff --git a/pyro/poutine/markov_messenger.py b/pyro/poutine/markov_messenger.py index a9910eb728..6c41594fb9 100644 --- a/pyro/poutine/markov_messenger.py +++ b/pyro/poutine/markov_messenger.py @@ -53,16 +53,16 @@ def __init__( raise NotImplementedError( "vectorized markov not yet implemented, try setting name to None" ) - self._iterable: Optional[Iterable] = None + self._iterable: Optional[Iterable[int]] = None self._pos = -1 self._stack: List[Set[str]] = [] super().__init__() - def generator(self, iterable: Iterable) -> Self: + def generator(self, iterable: Iterable[int]) -> Self: self._iterable = iterable return self - def __iter__(self) -> Iterator: + def __iter__(self) -> Iterator[int]: with ExitStack() as stack: assert self._iterable is not None for value in self._iterable: diff --git a/pyro/poutine/reparam_messenger.py b/pyro/poutine/reparam_messenger.py index acf6ff5d40..10405e0330 100644 --- a/pyro/poutine/reparam_messenger.py +++ b/pyro/poutine/reparam_messenger.py @@ -16,12 +16,13 @@ import torch from typing_extensions import ParamSpec -from pyro.distributions.torch_distribution import TorchDistributionMixin from pyro.poutine.messenger import Messenger -from pyro.poutine.runtime import Message, effectful +from pyro.poutine.runtime import effectful if TYPE_CHECKING: + from pyro.distributions.torch_distribution import TorchDistributionMixin from pyro.infer.reparam.reparam import Reparam + from pyro.poutine.runtime import Message _P = ParamSpec("_P") _T = TypeVar("_T") @@ -59,7 +60,7 @@ class ReparamMessenger(Messenger): def __init__( self, - config: Union[Dict[str, "Reparam"], Callable[[Message], Optional["Reparam"]]], + config: Union[Dict[str, "Reparam"], Callable[["Message"], Optional["Reparam"]]], ) -> None: super().__init__() assert isinstance(config, dict) or callable(config) @@ -69,11 +70,12 @@ def __init__( def __call__(self, fn: Callable[_P, _T]) -> Callable[_P, _T]: return ReparamHandler(self, fn) - def _pyro_sample(self, msg: Message) -> None: + def _pyro_sample(self, msg: "Message") -> None: if type(msg["fn"]).__name__ == "_Subsample": return assert msg["name"] is not None - assert isinstance(msg["fn"], TorchDistributionMixin) + if TYPE_CHECKING: + assert isinstance(msg["fn"], TorchDistributionMixin) if isinstance(self.config, dict): reparam = self.config.get(msg["name"]) else: From 1885e22e1e2c989571aa700c806d80869a64a01e Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Tue, 16 Jan 2024 06:29:57 +0000 Subject: [PATCH 9/9] replay messenger --- pyro/poutine/replay_messenger.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyro/poutine/replay_messenger.py b/pyro/poutine/replay_messenger.py index d6c88a95bf..9c26490528 100644 --- a/pyro/poutine/replay_messenger.py +++ b/pyro/poutine/replay_messenger.py @@ -3,11 +3,11 @@ from typing import TYPE_CHECKING, Dict, Optional -import torch - from pyro.poutine.messenger import Messenger if TYPE_CHECKING: + import torch + from pyro.poutine.runtime import Message from pyro.poutine.trace_struct import Trace @@ -43,7 +43,7 @@ class ReplayMessenger(Messenger): def __init__( self, trace: Optional["Trace"] = None, - params: Optional[Dict[str, torch.Tensor]] = None, + params: Optional[Dict[str, "torch.Tensor"]] = None, ) -> None: """ :param trace: a trace whose values should be reused