From d8bbfcd26091ee2c45797963c6f3bf51badd38c5 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Sun, 12 Nov 2023 06:22:07 +0000 Subject: [PATCH 1/5] Annotate primitives & block_messenger --- pyro/distributions/distribution.py | 13 +++-- pyro/poutine/block_messenger.py | 48 +++++++++++++----- pyro/poutine/runtime.py | 62 +++++++++++++++++------ pyro/primitives.py | 79 +++++++++++++++++++++--------- 4 files changed, 148 insertions(+), 54 deletions(-) diff --git a/pyro/distributions/distribution.py b/pyro/distributions/distribution.py index eaa84ee273..fd7732b441 100644 --- a/pyro/distributions/distribution.py +++ b/pyro/distributions/distribution.py @@ -5,6 +5,8 @@ import inspect from abc import ABCMeta, abstractmethod +import torch + from pyro.distributions.score_parts import ScoreParts COERCIONS = [] @@ -50,7 +52,7 @@ class Distribution(metaclass=DistributionMeta): has_rsample = False has_enumerate_support = False - def __call__(self, *args, **kwargs): + def __call__(self, *args, **kwargs) -> torch.Tensor: """ Samples a random value (just an alias for ``.sample(*args, **kwargs)``). @@ -62,8 +64,13 @@ def __call__(self, *args, **kwargs): """ return self.sample(*args, **kwargs) + @property + @abstractmethod + def event_dim(self) -> int: + raise NotImplementedError + @abstractmethod - def sample(self, *args, **kwargs): + def sample(self, *args, **kwargs) -> torch.Tensor: """ Samples a random value. @@ -80,7 +87,7 @@ def sample(self, *args, **kwargs): raise NotImplementedError @abstractmethod - def log_prob(self, x, *args, **kwargs): + def log_prob(self, x, *args, **kwargs) -> torch.Tensor: """ Evaluates log probability densities for each of a batch of samples. diff --git a/pyro/poutine/block_messenger.py b/pyro/poutine/block_messenger.py index 7c9f1847b2..05ed415c0b 100644 --- a/pyro/poutine/block_messenger.py +++ b/pyro/poutine/block_messenger.py @@ -2,11 +2,20 @@ # SPDX-License-Identifier: Apache-2.0 from functools import partial +from typing import Callable, List, Optional from pyro.poutine.messenger import Messenger +from pyro.poutine.runtime import Message -def _block_fn(expose, expose_types, hide, hide_types, hide_all, msg): +def _block_fn( + expose: List[str], + expose_types: List[str], + hide: List[str], + hide_types: List[str], + hide_all: bool, + msg: Message, +) -> bool: # handle observes if msg["type"] == "sample" and msg["is_observed"]: msg_type = "observe" @@ -27,7 +36,14 @@ def _block_fn(expose, expose_types, hide, hide_types, hide_all, msg): return False -def _make_default_hide_fn(hide_all, expose_all, hide, expose, hide_types, expose_types): +def _make_default_hide_fn( + hide_all: bool, + expose_all: bool, + hide: Optional[List[str]], + expose: Optional[List[str]], + hide_types: Optional[List[str]], + expose_types: Optional[List[str]], +) -> Callable[[Message], bool]: # first, some sanity checks: # hide_all and expose_all intersect? assert (hide_all is False and expose_all is False) or ( @@ -65,6 +81,13 @@ def _make_default_hide_fn(hide_all, expose_all, hide, expose, hide_types, expose return partial(_block_fn, expose, expose_types, hide, hide_types, hide_all) +def _negate_fn(fn: Callable[[Message], Optional[bool]]) -> Callable[[Message], bool]: + def negated_fn(msg: Message) -> bool: + return not fn(msg) + + return negated_fn + + class BlockMessenger(Messenger): """ This handler selectively hides Pyro primitive sites from the outside world. @@ -116,14 +139,14 @@ class BlockMessenger(Messenger): def __init__( self, - hide_fn=None, - expose_fn=None, - hide_all=True, - expose_all=False, - hide=None, - expose=None, - hide_types=None, - expose_types=None, + hide_fn: Optional[Callable[[Message], Optional[bool]]] = None, + expose_fn: Optional[Callable[[Message], Optional[bool]]] = None, + hide_all: bool = True, + expose_all: bool = False, + hide: Optional[List[str]] = None, + expose: Optional[List[str]] = None, + hide_types: Optional[List[str]] = None, + expose_types: Optional[List[str]] = None, ): super().__init__() if not (hide_fn is None or expose_fn is None): @@ -131,12 +154,11 @@ def __init__( if hide_fn is not None: self.hide_fn = hide_fn elif expose_fn is not None: - self.hide_fn = lambda msg: not expose_fn(msg) + self.hide_fn = _negate_fn(expose_fn) else: self.hide_fn = _make_default_hide_fn( hide_all, expose_all, hide, expose, hide_types, expose_types ) - def _process_message(self, msg): + def _process_message(self, msg: Message) -> None: msg["stop"] = bool(self.hide_fn(msg)) - return None diff --git a/pyro/poutine/runtime.py b/pyro/poutine/runtime.py index 441b771080..8481ebca1d 100644 --- a/pyro/poutine/runtime.py +++ b/pyro/poutine/runtime.py @@ -4,16 +4,30 @@ from __future__ import annotations import functools -from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Set, Tuple, Union +from typing import ( + TYPE_CHECKING, + Callable, + Dict, + List, + Optional, + Set, + Tuple, + TypeVar, + Union, +) +# overload, import torch -from typing_extensions import TypedDict +from typing_extensions import ParamSpec, TypedDict from pyro.params.param_store import ( # noqa: F401 _MODULE_NAMESPACE_DIVIDER, ParamStoreDict, ) +P = ParamSpec("P") +T = TypeVar("T") + if TYPE_CHECKING: from pyro.poutine.indep_messenger import CondIndepStackFrame from pyro.poutine.messenger import Messenger @@ -26,8 +40,8 @@ class Message(TypedDict, total=False): - type: Optional[str] - name: str + type: str + name: Optional[str] fn: Callable is_observed: bool args: Tuple @@ -252,7 +266,7 @@ def apply_stack(initial_msg: Message) -> None: cont(msg) -def am_i_wrapped(): +def am_i_wrapped() -> bool: """ Checks whether the current computation is wrapped in a poutine. :returns: bool @@ -260,7 +274,23 @@ def am_i_wrapped(): return len(_PYRO_STACK) > 0 -def effectful(fn: Optional[Callable] = None, type: Optional[str] = None) -> Callable: +# @overload +# def effectful( +# fn: None = ..., type: Optional[str] = ... +# ) -> Callable[[Optional[Callable[P, T]]], Callable[P, Union[T, torch.Tensor, None]]]: +# ... +# +# +# @overload +# def effectful( +# fn: Callable[P, T] = ..., type: Optional[str] = ... +# ) -> Callable[P, Union[T, torch.Tensor, None]]: +# ... + + +def effectful( + fn: Optional[Callable[P, T]] = None, type: Optional[str] = None +) -> Callable: """ :param fn: function or callable that performs an effectful computation :param str type: the type label of the operation, e.g. `"sample"` @@ -277,31 +307,33 @@ def effectful(fn: Optional[Callable] = None, type: Optional[str] = None) -> Call assert type != "message", "cannot use 'message' as keyword" @functools.wraps(fn) - def _fn(*args, **kwargs): - name = kwargs.pop("name", None) - infer = kwargs.pop("infer", {}) - - value = kwargs.pop("obs", None) - is_observed = value is not None + def _fn( + *args: P.args, + name: Optional[str] = None, + infer: Optional[Dict] = None, + obs: Optional[torch.Tensor] = None, + **kwargs: P.kwargs, + ) -> Union[T, torch.Tensor, None]: + is_observed = obs is not None if not am_i_wrapped(): return fn(*args, **kwargs) else: - msg = { + msg: Message = { "type": type, "name": name, "fn": fn, "is_observed": is_observed, "args": args, "kwargs": kwargs, - "value": value, + "value": obs, "scale": 1.0, "mask": None, "cond_indep_stack": (), "done": False, "stop": False, "continuation": None, - "infer": infer, + "infer": infer if infer is not None else {}, } # apply the stack and return its return value apply_stack(msg) diff --git a/pyro/primitives.py b/pyro/primitives.py index ca67d7b061..ed44a4e746 100644 --- a/pyro/primitives.py +++ b/pyro/primitives.py @@ -6,6 +6,7 @@ from collections import OrderedDict from contextlib import ExitStack, contextmanager from inspect import isclass +from typing import Callable, Iterator, Optional, Sequence, Union import torch @@ -14,10 +15,12 @@ import pyro.poutine as poutine from pyro.distributions import constraints from pyro.params import param_with_module_name +from pyro.params.param_store import ParamStoreDict from pyro.poutine.plate_messenger import PlateMessenger from pyro.poutine.runtime import ( _MODULE_NAMESPACE_DIVIDER, _PYRO_PARAM_STORE, + Message, am_i_wrapped, apply_stack, effectful, @@ -26,14 +29,14 @@ from pyro.util import deep_getattr, set_rng_seed # noqa: F401 -def get_param_store(): +def get_param_store() -> ParamStoreDict: """ Returns the global :class:`~pyro.params.param_store.ParamStoreDict`. """ return _PYRO_PARAM_STORE -def clear_param_store(): +def clear_param_store() -> None: """ Clears the global :class:`~pyro.params.param_store.ParamStoreDict`. @@ -42,13 +45,20 @@ def clear_param_store(): models), and before each unit test (to avoid leaking parameters across tests). """ - return _PYRO_PARAM_STORE.clear() + _PYRO_PARAM_STORE.clear() -_param = effectful(_PYRO_PARAM_STORE.get_param, type="param") +_param: Callable[..., torch.Tensor] = effectful( + _PYRO_PARAM_STORE.get_param, type="param" +) -def param(name, init_tensor=None, constraint=constraints.real, event_dim=None): +def param( + name: str, + init_tensor: Union[torch.Tensor, Callable[[], torch.Tensor], None] = None, + constraint: constraints.Constraint = constraints.real, + event_dim: Optional[int] = None, +) -> torch.Tensor: """ Saves the variable as a parameter in the param store. To interact with the param store or write to disk, @@ -74,11 +84,19 @@ def param(name, init_tensor=None, constraint=constraints.real, event_dim=None): :rtype: torch.Tensor """ # Note effectful(-) requires the double passing of name below. - args = (name,) if init_tensor is None else (name, init_tensor) - return _param(*args, constraint=constraint, event_dim=event_dim, name=name) + return _param( + name, init_tensor, constraint=constraint, event_dim=event_dim, name=name + ) -def _masked_observe(name, fn, obs, obs_mask, *args, **kwargs): +def _masked_observe( + name: str, + fn: dist.Distribution, + obs: Optional[torch.Tensor], + obs_mask: torch.Tensor, + *args, + **kwargs, +) -> torch.Tensor: # Split into two auxiliary sample sites. with poutine.mask(mask=obs_mask): observed = sample(f"{name}_observed", fn, *args, **kwargs, obs=obs) @@ -102,7 +120,15 @@ def _masked_observe(name, fn, obs, obs_mask, *args, **kwargs): return deterministic(name, value) -def sample(name, fn, *args, **kwargs): +def sample( + name: str, + fn: dist.Distribution, + *args, + obs: Optional[torch.Tensor] = None, + obs_mask: Optional[torch.Tensor] = None, + infer: Optional[dict] = None, + **kwargs, +) -> torch.Tensor: """ Calls the stochastic function ``fn`` with additional side-effects depending on ``name`` and the enclosing context (e.g. an inference algorithm). See @@ -123,15 +149,13 @@ def sample(name, fn, *args, **kwargs): :returns: sample """ # Transform obs_mask into multiple sample statements. - obs = kwargs.pop("obs", None) - obs_mask = kwargs.pop("obs_mask", None) if obs_mask is not None: return _masked_observe(name, fn, obs, obs_mask, *args, **kwargs) # Check if stack is empty. # if stack empty, default behavior (defined here) - infer = kwargs.pop("infer", {}).copy() - is_observed = infer.pop("is_observed", obs is not None) + infer = {} if infer is None else infer.copy() + is_observed: bool = infer.pop("is_observed", obs is not None) if not am_i_wrapped(): if obs is not None and not infer.get("_deterministic"): warnings.warn( @@ -143,7 +167,7 @@ def sample(name, fn, *args, **kwargs): # if stack not empty, apply everything in the stack? else: # initialize data structure to pass up/down the stack - msg = { + msg: Message = { "type": "sample", "name": name, "fn": fn, @@ -161,10 +185,13 @@ def sample(name, fn, *args, **kwargs): } # apply the stack and return its return value apply_stack(msg) + assert isinstance(msg["value"], torch.Tensor) return msg["value"] -def factor(name, log_factor, *, has_rsample=None): +def factor( + name: str, log_factor: torch.Tensor, *, has_rsample: Optional[bool] = None +) -> None: """ Factor statement to add arbitrary log probability factor to a probabilisitic model. @@ -188,7 +215,9 @@ def factor(name, log_factor, *, has_rsample=None): sample(name, unit_dist, obs=unit_value, infer={"is_auxiliary": True}) -def deterministic(name, value, event_dim=None): +def deterministic( + name: str, value: torch.Tensor, event_dim: Optional[int] = None +) -> torch.Tensor: """ Deterministic statement to add a :class:`~pyro.distributions.Delta` site with name `name` and value `value` to the trace. This is useful when we @@ -215,7 +244,7 @@ def deterministic(name, value, event_dim=None): @effectful(type="subsample") -def subsample(data, event_dim): +def subsample(data: torch.Tensor, event_dim: int) -> torch.Tensor: """ Subsampling statement to subsample data tensors based on enclosing :class:`plate` s. @@ -374,7 +403,9 @@ def __init__(self, *args, **kwargs): @contextmanager -def plate_stack(prefix, sizes, rightmost_dim=-1): +def plate_stack( + prefix: str, sizes: Sequence[int], rightmost_dim: int = -1 +) -> Iterator[None]: """ Create a contiguous stack of :class:`plate` s with dimensions:: @@ -392,7 +423,9 @@ def plate_stack(prefix, sizes, rightmost_dim=-1): yield -def module(name, nn_module, update_module_params=False): +def module( + name: str, nn_module: torch.nn.Module, update_module_params: bool = False +) -> torch.nn.Module: """ Registers all parameters of a :class:`torch.nn.Module` with Pyro's :mod:`~pyro.params.param_store`. In conjunction with the @@ -462,7 +495,7 @@ def module(name, nn_module, update_module_params=False): param_name ] = target_state_dict[_name] else: - nn_module._parameters[mod_name] = target_state_dict[_name] + nn_module._parameters[mod_name] = target_state_dict[_name] # type: ignore[assignment] return nn_module @@ -508,7 +541,7 @@ def _fn(): @effectful(type="barrier") -def barrier(data): +def barrier(data: torch.Tensor) -> torch.Tensor: """ EXPERIMENTAL Ensures all values in ``data`` are ground, rather than lazy funsor values. This is useful in combination with @@ -517,7 +550,7 @@ def barrier(data): return data -def enable_validation(is_validate=True): +def enable_validation(is_validate: bool = True) -> None: """ Enable or disable validation checks in Pyro. Validation checks provide useful warnings and errors, e.g. NaN checks, validating distribution @@ -544,7 +577,7 @@ def enable_validation(is_validate=True): @contextmanager -def validation_enabled(is_validate=True): +def validation_enabled(is_validate: bool = True) -> Iterator[None]: """ Context manager that is useful when temporarily enabling/disabling validation checks. From 76415f0238eedfd96022a145783a865e94879b41 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Sun, 12 Nov 2023 06:45:28 +0000 Subject: [PATCH 2/5] overload --- pyro/poutine/block_messenger.py | 1 + pyro/poutine/runtime.py | 26 +++++++++++++------------- pyro/primitives.py | 10 +++++----- 3 files changed, 19 insertions(+), 18 deletions(-) diff --git a/pyro/poutine/block_messenger.py b/pyro/poutine/block_messenger.py index 05ed415c0b..79b594a29e 100644 --- a/pyro/poutine/block_messenger.py +++ b/pyro/poutine/block_messenger.py @@ -82,6 +82,7 @@ def _make_default_hide_fn( def _negate_fn(fn: Callable[[Message], Optional[bool]]) -> Callable[[Message], bool]: + # typed version of lambda msg: not fn(msg) def negated_fn(msg: Message) -> bool: return not fn(msg) diff --git a/pyro/poutine/runtime.py b/pyro/poutine/runtime.py index 8481ebca1d..fab46eaa13 100644 --- a/pyro/poutine/runtime.py +++ b/pyro/poutine/runtime.py @@ -14,9 +14,9 @@ Tuple, TypeVar, Union, + overload, ) -# overload, import torch from typing_extensions import ParamSpec, TypedDict @@ -274,18 +274,18 @@ def am_i_wrapped() -> bool: return len(_PYRO_STACK) > 0 -# @overload -# def effectful( -# fn: None = ..., type: Optional[str] = ... -# ) -> Callable[[Optional[Callable[P, T]]], Callable[P, Union[T, torch.Tensor, None]]]: -# ... -# -# -# @overload -# def effectful( -# fn: Callable[P, T] = ..., type: Optional[str] = ... -# ) -> Callable[P, Union[T, torch.Tensor, None]]: -# ... +@overload +def effectful( + fn: None = ..., type: Optional[str] = ... +) -> Callable[[Callable[P, T]], Callable[..., Union[T, torch.Tensor, None]]]: + ... + + +@overload +def effectful( + fn: Callable[P, T] = ..., type: Optional[str] = ... +) -> Callable[..., Union[T, torch.Tensor, None]]: + ... def effectful( diff --git a/pyro/primitives.py b/pyro/primitives.py index ed44a4e746..8c4376685b 100644 --- a/pyro/primitives.py +++ b/pyro/primitives.py @@ -9,11 +9,11 @@ from typing import Callable, Iterator, Optional, Sequence, Union import torch +from torch.distributions import constraints import pyro.distributions as dist import pyro.infer as infer import pyro.poutine as poutine -from pyro.distributions import constraints from pyro.params import param_with_module_name from pyro.params.param_store import ParamStoreDict from pyro.poutine.plate_messenger import PlateMessenger @@ -48,9 +48,7 @@ def clear_param_store() -> None: _PYRO_PARAM_STORE.clear() -_param: Callable[..., torch.Tensor] = effectful( - _PYRO_PARAM_STORE.get_param, type="param" -) +_param = effectful(_PYRO_PARAM_STORE.get_param, type="param") def param( @@ -84,9 +82,11 @@ def param( :rtype: torch.Tensor """ # Note effectful(-) requires the double passing of name below. - return _param( + value = _param( name, init_tensor, constraint=constraint, event_dim=event_dim, name=name ) + assert isinstance(value, torch.Tensor) + return value def _masked_observe( From ba3b4ef2020ef00956d22f9d9761cff0e73eaf2f Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Sun, 12 Nov 2023 07:33:35 +0000 Subject: [PATCH 3/5] TorchDistribution --- pyro/distributions/distribution.py | 5 --- pyro/distributions/torch_distribution.py | 4 +- pyro/primitives.py | 48 ++++++++++++------------ 3 files changed, 27 insertions(+), 30 deletions(-) diff --git a/pyro/distributions/distribution.py b/pyro/distributions/distribution.py index fd7732b441..f52231766a 100644 --- a/pyro/distributions/distribution.py +++ b/pyro/distributions/distribution.py @@ -64,11 +64,6 @@ def __call__(self, *args, **kwargs) -> torch.Tensor: """ return self.sample(*args, **kwargs) - @property - @abstractmethod - def event_dim(self) -> int: - raise NotImplementedError - @abstractmethod def sample(self, *args, **kwargs) -> torch.Tensor: """ diff --git a/pyro/distributions/torch_distribution.py b/pyro/distributions/torch_distribution.py index afa0dff352..89d6cbfd7f 100644 --- a/pyro/distributions/torch_distribution.py +++ b/pyro/distributions/torch_distribution.py @@ -27,7 +27,7 @@ class TorchDistributionMixin(Distribution): from :class:`TorchDistributionMixin`. """ - def __call__(self, sample_shape=torch.Size()): + def __call__(self, sample_shape=torch.Size()) -> torch.Tensor: """ Samples a random value. @@ -51,7 +51,7 @@ def __call__(self, sample_shape=torch.Size()): ) @property - def event_dim(self): + def event_dim(self) -> int: """ :return: Number of dimensions of individual events. :rtype: int diff --git a/pyro/primitives.py b/pyro/primitives.py index 8c4376685b..f7f3813dbc 100644 --- a/pyro/primitives.py +++ b/pyro/primitives.py @@ -6,7 +6,7 @@ from collections import OrderedDict from contextlib import ExitStack, contextmanager from inspect import isclass -from typing import Callable, Iterator, Optional, Sequence, Union +from typing import Callable, Dict, Iterator, Optional, Sequence, Union import torch from torch.distributions import constraints @@ -14,6 +14,7 @@ import pyro.distributions as dist import pyro.infer as infer import pyro.poutine as poutine +from pyro.distributions import TorchDistribution from pyro.params import param_with_module_name from pyro.params.param_store import ParamStoreDict from pyro.poutine.plate_messenger import PlateMessenger @@ -85,13 +86,13 @@ def param( value = _param( name, init_tensor, constraint=constraint, event_dim=event_dim, name=name ) - assert isinstance(value, torch.Tensor) + assert value is not None # type narrowing guaranteed by _param return value def _masked_observe( name: str, - fn: dist.Distribution, + fn: TorchDistribution, obs: Optional[torch.Tensor], obs_mask: torch.Tensor, *args, @@ -122,11 +123,11 @@ def _masked_observe( def sample( name: str, - fn: dist.Distribution, + fn: TorchDistribution, *args, obs: Optional[torch.Tensor] = None, obs_mask: Optional[torch.Tensor] = None, - infer: Optional[dict] = None, + infer: Optional[Dict[str, Union[str, bool]]] = None, **kwargs, ) -> torch.Tensor: """ @@ -155,7 +156,8 @@ def sample( # Check if stack is empty. # if stack empty, default behavior (defined here) infer = {} if infer is None else infer.copy() - is_observed: bool = infer.pop("is_observed", obs is not None) + is_observed = infer.pop("is_observed", obs is not None) + assert isinstance(is_observed, bool) if not am_i_wrapped(): if obs is not None and not infer.get("_deterministic"): warnings.warn( @@ -167,25 +169,25 @@ def sample( # if stack not empty, apply everything in the stack? else: # initialize data structure to pass up/down the stack - msg: Message = { - "type": "sample", - "name": name, - "fn": fn, - "is_observed": is_observed, - "args": args, - "kwargs": kwargs, - "value": obs, - "infer": infer, - "scale": 1.0, - "mask": None, - "cond_indep_stack": (), - "done": False, - "stop": False, - "continuation": None, - } + msg = Message( + type="sample", + name=name, + fn=fn, + is_observed=is_observed, + args=args, + kwargs=kwargs, + value=obs, + infer=infer, + scale=1.0, + mask=None, + cond_indep_stack=(), + done=False, + stop=False, + continuation=None, + ) # apply the stack and return its return value apply_stack(msg) - assert isinstance(msg["value"], torch.Tensor) + assert msg["value"] is not None # type narrowing guaranteed by apply_stack return msg["value"] From 6dd80e51d532a248cdb01ff4b0fc4ac9bd969aed Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Sun, 12 Nov 2023 23:03:14 +0000 Subject: [PATCH 4/5] TYPE_CHECKING --- pyro/distributions/distribution.py | 8 ++- pyro/distributions/torch_distribution.py | 2 +- pyro/poutine/block_messenger.py | 6 ++- pyro/poutine/runtime.py | 64 ++++++++++++------------ pyro/primitives.py | 13 ++--- 5 files changed, 47 insertions(+), 46 deletions(-) diff --git a/pyro/distributions/distribution.py b/pyro/distributions/distribution.py index f52231766a..eaa84ee273 100644 --- a/pyro/distributions/distribution.py +++ b/pyro/distributions/distribution.py @@ -5,8 +5,6 @@ import inspect from abc import ABCMeta, abstractmethod -import torch - from pyro.distributions.score_parts import ScoreParts COERCIONS = [] @@ -52,7 +50,7 @@ class Distribution(metaclass=DistributionMeta): has_rsample = False has_enumerate_support = False - def __call__(self, *args, **kwargs) -> torch.Tensor: + def __call__(self, *args, **kwargs): """ Samples a random value (just an alias for ``.sample(*args, **kwargs)``). @@ -65,7 +63,7 @@ def __call__(self, *args, **kwargs) -> torch.Tensor: return self.sample(*args, **kwargs) @abstractmethod - def sample(self, *args, **kwargs) -> torch.Tensor: + def sample(self, *args, **kwargs): """ Samples a random value. @@ -82,7 +80,7 @@ def sample(self, *args, **kwargs) -> torch.Tensor: raise NotImplementedError @abstractmethod - def log_prob(self, x, *args, **kwargs) -> torch.Tensor: + def log_prob(self, x, *args, **kwargs): """ Evaluates log probability densities for each of a batch of samples. diff --git a/pyro/distributions/torch_distribution.py b/pyro/distributions/torch_distribution.py index 89d6cbfd7f..124235a5c9 100644 --- a/pyro/distributions/torch_distribution.py +++ b/pyro/distributions/torch_distribution.py @@ -27,7 +27,7 @@ class TorchDistributionMixin(Distribution): from :class:`TorchDistributionMixin`. """ - def __call__(self, sample_shape=torch.Size()) -> torch.Tensor: + def __call__(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor: """ Samples a random value. diff --git a/pyro/poutine/block_messenger.py b/pyro/poutine/block_messenger.py index 79b594a29e..1ba113774e 100644 --- a/pyro/poutine/block_messenger.py +++ b/pyro/poutine/block_messenger.py @@ -2,10 +2,12 @@ # SPDX-License-Identifier: Apache-2.0 from functools import partial -from typing import Callable, List, Optional +from typing import TYPE_CHECKING, Callable, List, Optional from pyro.poutine.messenger import Messenger -from pyro.poutine.runtime import Message + +if TYPE_CHECKING: + from pyro.poutine.runtime import Message def _block_fn( diff --git a/pyro/poutine/runtime.py b/pyro/poutine/runtime.py index fab46eaa13..0256bc91c9 100644 --- a/pyro/poutine/runtime.py +++ b/pyro/poutine/runtime.py @@ -319,22 +319,22 @@ def _fn( if not am_i_wrapped(): return fn(*args, **kwargs) else: - msg: Message = { - "type": type, - "name": name, - "fn": fn, - "is_observed": is_observed, - "args": args, - "kwargs": kwargs, - "value": obs, - "scale": 1.0, - "mask": None, - "cond_indep_stack": (), - "done": False, - "stop": False, - "continuation": None, - "infer": infer if infer is not None else {}, - } + msg = Message( + type=type, + name=name, + fn=fn, + is_observed=is_observed, + args=args, + kwargs=kwargs, + value=obs, + scale=1.0, + mask=None, + cond_indep_stack=(), + done=False, + stop=False, + continuation=None, + infer=infer if infer is not None else {}, + ) # apply the stack and return its return value apply_stack(msg) return msg["value"] @@ -353,22 +353,22 @@ def _inspect() -> Message: :returns: A message with all effects applied. :rtype: dict """ - msg: Message = { - "type": "inspect", - "name": "_pyro_inspect", - "fn": lambda: True, - "is_observed": False, - "args": (), - "kwargs": {}, - "value": None, - "infer": {"_do_not_trace": True}, - "scale": 1.0, - "mask": None, - "cond_indep_stack": (), - "done": False, - "stop": False, - "continuation": None, - } + msg = Message( + type="inspect", + name="_pyro_inspect", + fn=lambda: True, + is_observed=False, + args=(), + kwargs={}, + value=None, + infer={"_do_not_trace": True}, + scale=1.0, + mask=None, + cond_indep_stack=(), + done=False, + stop=False, + continuation=None, + ) apply_stack(msg) return msg diff --git a/pyro/primitives.py b/pyro/primitives.py index f7f3813dbc..ba83973fab 100644 --- a/pyro/primitives.py +++ b/pyro/primitives.py @@ -6,7 +6,7 @@ from collections import OrderedDict from contextlib import ExitStack, contextmanager from inspect import isclass -from typing import Callable, Dict, Iterator, Optional, Sequence, Union +from typing import TYPE_CHECKING, Callable, Dict, Iterator, Optional, Sequence, Union import torch from torch.distributions import constraints @@ -14,9 +14,7 @@ import pyro.distributions as dist import pyro.infer as infer import pyro.poutine as poutine -from pyro.distributions import TorchDistribution from pyro.params import param_with_module_name -from pyro.params.param_store import ParamStoreDict from pyro.poutine.plate_messenger import PlateMessenger from pyro.poutine.runtime import ( _MODULE_NAMESPACE_DIVIDER, @@ -29,6 +27,10 @@ from pyro.poutine.subsample_messenger import SubsampleMessenger from pyro.util import deep_getattr, set_rng_seed # noqa: F401 +if TYPE_CHECKING: + from pyro.distributions import TorchDistribution + from pyro.params.param_store import ParamStoreDict + def get_param_store() -> ParamStoreDict: """ @@ -83,9 +85,8 @@ def param( :rtype: torch.Tensor """ # Note effectful(-) requires the double passing of name below. - value = _param( - name, init_tensor, constraint=constraint, event_dim=event_dim, name=name - ) + args = (name,) if init_tensor is None else (name, init_tensor) + value = _param(*args, constraint=constraint, event_dim=event_dim, name=name) assert value is not None # type narrowing guaranteed by _param return value From c5f6b401cf4dba72fca6149850314a374e97c3b9 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Sun, 12 Nov 2023 23:25:24 +0000 Subject: [PATCH 5/5] fix make docs --- pyro/poutine/block_messenger.py | 6 ++---- pyro/primitives.py | 8 +++----- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/pyro/poutine/block_messenger.py b/pyro/poutine/block_messenger.py index 1ba113774e..79b594a29e 100644 --- a/pyro/poutine/block_messenger.py +++ b/pyro/poutine/block_messenger.py @@ -2,12 +2,10 @@ # SPDX-License-Identifier: Apache-2.0 from functools import partial -from typing import TYPE_CHECKING, Callable, List, Optional +from typing import Callable, List, Optional from pyro.poutine.messenger import Messenger - -if TYPE_CHECKING: - from pyro.poutine.runtime import Message +from pyro.poutine.runtime import Message def _block_fn( diff --git a/pyro/primitives.py b/pyro/primitives.py index ba83973fab..5a7ef23d15 100644 --- a/pyro/primitives.py +++ b/pyro/primitives.py @@ -6,7 +6,7 @@ from collections import OrderedDict from contextlib import ExitStack, contextmanager from inspect import isclass -from typing import TYPE_CHECKING, Callable, Dict, Iterator, Optional, Sequence, Union +from typing import Callable, Dict, Iterator, Optional, Sequence, Union import torch from torch.distributions import constraints @@ -14,7 +14,9 @@ import pyro.distributions as dist import pyro.infer as infer import pyro.poutine as poutine +from pyro.distributions import TorchDistribution from pyro.params import param_with_module_name +from pyro.params.param_store import ParamStoreDict from pyro.poutine.plate_messenger import PlateMessenger from pyro.poutine.runtime import ( _MODULE_NAMESPACE_DIVIDER, @@ -27,10 +29,6 @@ from pyro.poutine.subsample_messenger import SubsampleMessenger from pyro.util import deep_getattr, set_rng_seed # noqa: F401 -if TYPE_CHECKING: - from pyro.distributions import TorchDistribution - from pyro.params.param_store import ParamStoreDict - def get_param_store() -> ParamStoreDict: """