From 014119e874ecedfc7868ac0fd1586e52bce2ec19 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Mon, 19 Feb 2024 16:53:14 +0000 Subject: [PATCH 01/11] wip --- dim.py | 122 ++++++++++++++++++++++++ pyro/contrib/named/infer/elbo.py | 53 +++++++++++ pyro/distributions/named.py | 119 ++++++++++++++++++++++++ pyro/poutine/broadcast_messenger.py | 7 ++ pyro/poutine/enum_messenger.py | 115 ++++++++++++----------- pyro/poutine/indep_messenger.py | 8 +- torchdim.py | 139 ++++++++++++++++++++++++++++ 7 files changed, 506 insertions(+), 57 deletions(-) create mode 100644 dim.py create mode 100644 pyro/contrib/named/infer/elbo.py create mode 100644 pyro/distributions/named.py create mode 100644 torchdim.py diff --git a/dim.py b/dim.py new file mode 100644 index 0000000000..db2c23dc3b --- /dev/null +++ b/dim.py @@ -0,0 +1,122 @@ +import torch +from functorch.dim import dims + +import pyro +import pyro.distributions as dist + +i, j = dims(2) +mu = torch.ones(3, 2) +normal = dist.Normal(mu[i, j], 1, validate_args=True) +import pdb + +pdb.set_trace() + +with pyro.plate("z_plate", 5) as z: + import pdb + + pdb.set_trace() + print(f"z = {z}") + +f, c, d = dims(3) +p = torch.ones(5, 4, 3) +i = torch.ones(3).long()[d] # c +pi = p[f, i] +pc = p[f, c] +pass + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import torch +import torch.nn.functional as F +from torch.distributions import constraints +from tqdm import tqdm + +import pyro +from pyro.distributions import * +from pyro.infer import SVI, TraceEnum_ELBO, config_enumerate +from pyro.optim import Adam + +assert pyro.__version__.startswith("1.8.6") +pyro.set_rng_seed(0) + +device = torch.device("cuda") + +data = torch.cat( + ( + MultivariateNormal(-8 * torch.ones(2), torch.eye(2)).sample([50]), + MultivariateNormal(8 * torch.ones(2), torch.eye(2)).sample([50]), + MultivariateNormal(torch.tensor([1.5, 2]), torch.eye(2)).sample([50]), + MultivariateNormal(torch.tensor([-0.5, 1]), torch.eye(2)).sample([50]), + ) +) + +data = data.to(device) +# plt.scatter(data[:, 0], data[:, 1]) +# plt.title("Data Samples from Mixture of 4 Gaussians") +# plt.show() +N = data.shape[0] +num_particles = 10 + + +######################################## +def mix_weights(beta): + beta1m_cumprod = (1 - beta).cumprod(-1) + return F.pad(beta, (0, 1), value=1) * F.pad(beta1m_cumprod, (1, 0), value=1) + + +######################################## + + +def model(data): + with pyro.plate("beta_plate", T - 1, device=device): + beta = pyro.sample("beta", Beta(1, alpha)) + + with pyro.plate("mu_plate", T, device=device): + mu = pyro.sample( + "mu", + MultivariateNormal( + torch.zeros(2, device=device), 5 * torch.eye(2, device=device) + ), + ) + + with pyro.plate("data", N, device=device) as idx: + # dim=-4 is an enumeration dim for z (e.g. T = 6 clusters) + # dim=-3 is a particle vectorization (e.g. num_particles = 10 particles) + # dim=-2 is allocated for "data" plate (1 value broadcasted over a batch) + # dim=-1 is allocated as event dimension (2 values) + z = pyro.sample("z", Categorical(mix_weights(beta).unsqueeze(-2))[idx]) + pyro.sample( + "obs", MultivariateNormal(mu[z], torch.eye(2, device=device)), obs=data[idx] + ) + + +######################################## +def guide(data): + kappa = pyro.param( + "kappa", + lambda: Uniform( + torch.tensor(0.0, device=device), torch.tensor(2.0, device=device) + ).sample([T - 1]), + constraint=constraints.positive, + ) + tau = pyro.param( + "tau", + lambda: MultivariateNormal( + torch.zeros(2, device=device), 3 * torch.eye(2, device=device) + ).sample([T]), + ) + phi = pyro.param( + "phi", + lambda: Dirichlet(1 / T * torch.ones(T, device=device)).sample([N]), + constraint=constraints.simplex, + ) + + with pyro.plate("beta_plate", T - 1, device=device): + q_beta = pyro.sample("beta", Beta(torch.ones(T - 1, device=device), kappa)) + + with pyro.plate("mu_plate", T, device=device): + q_mu = pyro.sample("mu", MultivariateNormal(tau, torch.eye(2, device=device))) + + with pyro.plate("data", N, device=device): + z = pyro.sample("z", Categorical(phi)) diff --git a/pyro/contrib/named/infer/elbo.py b/pyro/contrib/named/infer/elbo.py new file mode 100644 index 0000000000..e1f9f208b0 --- /dev/null +++ b/pyro/contrib/named/infer/elbo.py @@ -0,0 +1,53 @@ +import torch +from functorch.dim import dims + +import pyro +from pyro import poutine +from pyro.distributions.util import is_identically_zero + + +def log_density(model, args, kwargs): + """ + (EXPERIMENTAL INTERFACE) Computes log of joint density for the model given + latent values ``params``. + + :param model: Python callable containing NumPyro primitives. + :param tuple model_args: args provided to the model. + :param dict model_kwargs: kwargs provided to the model. + :param dict params: dictionary of current parameter values keyed by site + name. + :return: log of joint density and a corresponding model trace + """ + model_trace = poutine.trace(model).get_trace(*args, **kwargs) + log_joint = 0.0 + for site in model_trace.nodes.values(): + if site["type"] == "sample" and site["fn"]: + value = site["value"] + scale = site["scale"] + log_prob = site["fn"].log_prob(value) + + if scale is not None: + log_prob = scale * log_prob + + sum_dims = getattr(log_prob, "dims", ()) + tuple(range(log_prob.ndim)) + log_prob = log_prob.sum(sum_dims) + log_joint = log_joint + log_prob + return log_joint, model_trace + + +class ELBO: + def __init__(self, num_particles=1, vectorize_particles=True): + self.num_particles = num_particles + self.vectorize_particles = vectorize_particles + + def loss(self, model, guide, *args, **kwargs): + if self.num_particles > 1: + vectorize = pyro.plate("num_particles", self.num_particles, dim=dims(1)) + model = vectorize(model) + guide = vectorize(guide) + + guide_log_density, guide_trace = log_density(guide, args, kwargs) + replay_model = poutine.replay(model, trace=guide_trace) + model_log_density, model_trace = log_density(replay_model, args, kwargs) + elbo = (model_log_density - guide_log_density) / self.num_particles + return -elbo diff --git a/pyro/distributions/named.py b/pyro/distributions/named.py new file mode 100644 index 0000000000..ecc39b94f1 --- /dev/null +++ b/pyro/distributions/named.py @@ -0,0 +1,119 @@ +import inspect + +import torch + +import pyro.distributions as dist +from pyro.distributions import constraints +from pyro.distributions.torch_distribution import TorchDistributionMixin + + +def order(x, batch_dims): + batch_shape = set(getattr(x, "dims", ())) + event_shape = x.shape + if batch_shape: + x = x.order(*(dim for dim in batch_dims if dim in batch_shape)) + x = x.reshape( + tuple(dim.size if dim in batch_shape else 1 for dim in batch_dims) + + event_shape + ) + return x + + +def index_select(input, dim, index): + return input.order(dim)[index] + + +class NamedDistribution(TorchDistributionMixin): + dist_class: dist.Distribution + + def __init__(self, *args, **kwargs) -> None: + ast_fields = inspect.getfullargspec(self.dist_class.__init__)[0][1:] + kwargs.update(zip(ast_fields, args)) + self.batch_dims = tuple( + set.union( + *[ + set(getattr(kwargs[k], "dims", ())) + for k in kwargs + if k in self.dist_class.arg_constraints + ] + ) + ) + for k in self.dist_class.arg_constraints: + if k in kwargs: + kwargs[k] = order(kwargs[k], self.batch_dims) + self.base_dist = self.dist_class(**kwargs) + self.sample_shape = torch.Size() + + @property + def has_rsample(self): + return self.base_dist.has_rsample + + @property + def has_enumerate_support(self): + return self.base_dist.has_enumerate_support + + @constraints.dependent_property + def support(self): + return self.base_dist.support + + @property + def batch_shape(self): + return self.batch_dims + + @property + def event_shape(self): + return self.base_dist.event_shape + + def sample(self, sample_shape=torch.Size()): + return self.base_dist.sample(self.sample_shape + sample_shape)[self.batch_dims] + + def rsample(self, sample_shape=torch.Size()): + return self.base_dist.rsample(self.sample_shape + sample_shape)[self.batch_dims] + + def log_prob(self, value): + value_dims = set(getattr(value, "dims", ())) + extra_dims = tuple(value_dims - set(self.batch_dims)) + value = order(value, extra_dims + self.batch_dims) + return self.base_dist.log_prob(value)[extra_dims + self.batch_dims] + + def expand(self, batch_shape, _instance=None): + """ + Returns a new :class:`ExpandedDistribution` instance with batch + dimensions expanded to `batch_shape`. + + :param tuple batch_shape: batch shape to expand to. + :param _instance: unused argument for compatibility with + :meth:`torch.distributions.Distribution.expand` + :return: an instance of `ExpandedDistribution`. + :rtype: :class:`ExpandedDistribution` + """ + for dim in batch_shape: + if dim not in set(self.batch_dims): + self.batch_dims = self.batch_dims + (dim,) + self.sample_shape = self.sample_shape + (dim.size,) + return self + + def enumerate_support(self, expand=False): + samples = self.base_dist.enumerate_support(expand=False) + return samples + + +# class NamedDistributionMeta(type): +# pass +# def __call__(cls, *args, **kwargs): + + +def make_dist(backend_dist_class): + + dist_class = type( + backend_dist_class.__name__, + (NamedDistribution,), + {"dist_class": backend_dist_class}, + ) + return dist_class + + +Normal = make_dist(dist.Normal) +Categorical = make_dist(dist.Categorical) +LogNormal = make_dist(dist.LogNormal) +Dirichlet = make_dist(dist.Dirichlet) diff --git a/pyro/poutine/broadcast_messenger.py b/pyro/poutine/broadcast_messenger.py index 87e9dd2f7b..1fd59c50be 100644 --- a/pyro/poutine/broadcast_messenger.py +++ b/pyro/poutine/broadcast_messenger.py @@ -3,6 +3,9 @@ from typing import TYPE_CHECKING, List, Optional +from functorch.dim import Dim + +from pyro.distributions.named import NamedDistribution from pyro.distributions.torch_distribution import TorchDistributionMixin from pyro.poutine.messenger import Messenger from pyro.util import ignore_jit_warnings @@ -56,6 +59,10 @@ def _pyro_sample(msg: "Message") -> None: dist = msg["fn"] actual_batch_shape = dist.batch_shape + if isinstance(msg["fn"], NamedDistribution): + prefix_batch_shape = tuple(f.dim for f in msg["cond_indep_stack"]) + msg["fn"].expand(prefix_batch_shape) + return target_batch_shape = [ None if size == 1 else size for size in actual_batch_shape ] diff --git a/pyro/poutine/enum_messenger.py b/pyro/poutine/enum_messenger.py index 408669e900..43037cfd49 100644 --- a/pyro/poutine/enum_messenger.py +++ b/pyro/poutine/enum_messenger.py @@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional import torch +from functorch.dim import Dim from typing_extensions import Self from pyro.distributions.torch import Categorical @@ -152,8 +153,8 @@ def __init__(self, first_available_dim: Optional[int] = None) -> None: super().__init__() def __enter__(self) -> Self: - if self.first_available_dim is not None: - _ENUM_ALLOCATOR.set_first_available_dim(self.first_available_dim) + # if self.first_available_dim is not None: + # _ENUM_ALLOCATOR.set_first_available_dim(self.first_available_dim) self._markov_depths: Dict[str, int] = ( {} ) # site name -> depth (nonnegative integer) @@ -176,61 +177,67 @@ def _pyro_sample(self, msg: Message) -> None: assert isinstance(msg["name"], str) assert msg["infer"] is not None - # Compute upstream dims in scope; these are unsafe to use for this site's target_dim. - scope = msg["infer"].get("_markov_scope") # site name -> markov depth - param_dims = _ENUM_ALLOCATOR.dim_to_id.copy() # enum dim -> unique id - if scope is not None: - for name, depth in scope.items(): - if ( - self._markov_depths[name] == depth - ): # hide sites whose markov context has exited - param_dims.update(self._value_dims[name]) - self._markov_depths[msg["name"]] = msg["infer"]["_markov_depth"] - self._param_dims[msg["name"]] = param_dims if msg["is_observed"] or msg["infer"].get("enumerate") != "parallel": return - - # Compute an enumerated value (at an arbitrary dim). - value = enumerate_site(msg) - actual_dim = -1 - len(msg["fn"].batch_shape) # the leftmost dim of log_prob - - # Move actual_dim to a safe target_dim. - target_dim, id_ = _ENUM_ALLOCATOR.allocate( - None if scope is None else set(param_dims) - ) - event_dim = msg["fn"].event_dim - categorical_support = getattr(value, "_pyro_categorical_support", None) - if categorical_support is not None: - # Preserve categorical supports to speed up Categorical.log_prob(). - # See pyro/distributions/torch.py for details. - assert target_dim < 0 - value = value.reshape(value.shape[:1] + (1,) * (-1 - target_dim)) - value._pyro_categorical_support = categorical_support # type: ignore[attr-defined] - elif actual_dim < target_dim: - assert ( - value.size(target_dim - event_dim) == 1 - ), "pyro.markov dim conflict at dim {}".format(actual_dim) - value = value.transpose(target_dim - event_dim, actual_dim - event_dim) - while value.dim() and value.size(0) == 1: - value = value.squeeze(0) - elif target_dim < actual_dim: - diff = actual_dim - target_dim - value = value.reshape(value.shape[:1] + (1,) * diff + value.shape[1:]) - - # Compute dims passed downstream through the value. - value_dims = { - dim: param_dims[dim] - for dim in range(event_dim - value.dim(), 0) - if value.size(dim - event_dim) > 1 and dim in param_dims - } - value_dims[target_dim] = id_ - - msg["infer"]["_enumerate_dim"] = target_dim - msg["infer"]["_dim_to_id"] = value_dims - msg["value"] = value + value = msg["fn"].enumerate_support(False) + dim = Dim(msg["name"]) + msg["value"] = value[dim] msg["done"] = True - - def _pyro_post_sample(self, msg: Message) -> None: + # Compute upstream dims in scope; these are unsafe to use for this site's target_dim. + # scope = msg["infer"].get("_markov_scope") # site name -> markov depth + # param_dims = _ENUM_ALLOCATOR.dim_to_id.copy() # enum dim -> unique id + # if scope is not None: + # for name, depth in scope.items(): + # if ( + # self._markov_depths[name] == depth + # ): # hide sites whose markov context has exited + # param_dims.update(self._value_dims[name]) + # self._markov_depths[msg["name"]] = msg["infer"]["_markov_depth"] + # self._param_dims[msg["name"]] = param_dims + # if msg["is_observed"] or msg["infer"].get("enumerate") != "parallel": + # return + + # # Compute an enumerated value (at an arbitrary dim). + # value = enumerate_site(msg) + # actual_dim = -1 - len(msg["fn"].batch_shape) # the leftmost dim of log_prob + + # # Move actual_dim to a safe target_dim. + # target_dim, id_ = _ENUM_ALLOCATOR.allocate( + # None if scope is None else set(param_dims) + # ) + # event_dim = msg["fn"].event_dim + # categorical_support = getattr(value, "_pyro_categorical_support", None) + # if categorical_support is not None: + # # Preserve categorical supports to speed up Categorical.log_prob(). + # # See pyro/distributions/torch.py for details. + # assert target_dim < 0 + # value = value.reshape(value.shape[:1] + (1,) * (-1 - target_dim)) + # value._pyro_categorical_support = categorical_support # type: ignore[attr-defined] + # elif actual_dim < target_dim: + # assert ( + # value.size(target_dim - event_dim) == 1 + # ), "pyro.markov dim conflict at dim {}".format(actual_dim) + # value = value.transpose(target_dim - event_dim, actual_dim - event_dim) + # while value.dim() and value.size(0) == 1: + # value = value.squeeze(0) + # elif target_dim < actual_dim: + # diff = actual_dim - target_dim + # value = value.reshape(value.shape[:1] + (1,) * diff + value.shape[1:]) + + # # Compute dims passed downstream through the value. + # value_dims = { + # dim: param_dims[dim] + # for dim in range(event_dim - value.dim(), 0) + # if value.size(dim - event_dim) > 1 and dim in param_dims + # } + # value_dims[target_dim] = id_ + + # msg["infer"]["_enumerate_dim"] = target_dim + # msg["infer"]["_dim_to_id"] = value_dims + # msg["value"] = value + # msg["done"] = True + + def _pyro_post_sample_(self, msg: Message) -> None: # Save all dims exposed in this sample value. # Whereas all of site["_dim_to_id"] are needed to interpret a # site's log_prob tensor, only a filtered subset self._value_dims[msg["name"]] diff --git a/pyro/poutine/indep_messenger.py b/pyro/poutine/indep_messenger.py index 69d41756f6..f128b361d8 100644 --- a/pyro/poutine/indep_messenger.py +++ b/pyro/poutine/indep_messenger.py @@ -5,6 +5,7 @@ from typing import Iterator, NamedTuple, Optional, Tuple import torch +from functorch.dim import Dim from typing_extensions import Self from pyro.poutine.messenger import Messenger @@ -98,14 +99,15 @@ def __enter__(self) -> Self: self._vectorized = True if self._vectorized is True: - self.dim = _DIM_ALLOCATOR.allocate(self.name, self.dim) + assert self.dim is not None + # self.dim = _DIM_ALLOCATOR.allocate(self.name, self.dim) return super().__enter__() def __exit__(self, *args) -> None: if self._vectorized is True: assert self.dim is not None - _DIM_ALLOCATOR.free(self.name, self.dim) + # _DIM_ALLOCATOR.free(self.name, self.dim) return super().__exit__(*args) def __iter__(self) -> Iterator[int]: @@ -134,7 +136,7 @@ def _reset(self) -> None: def indices(self) -> torch.Tensor: if self._indices is None: self._indices = torch.arange(self.size, dtype=torch.long).to(self.device) - return self._indices + return self._indices[self.dim] def _process_message(self, msg: Message) -> None: frame = CondIndepStackFrame(self.name, self.dim, self.size, self.counter) diff --git a/torchdim.py b/torchdim.py new file mode 100644 index 0000000000..f66dc2a727 --- /dev/null +++ b/torchdim.py @@ -0,0 +1,139 @@ +import torch +import torch.distributions.constraints as constraints +from functorch.dim import Dim, dims + +import pyro +from pyro import poutine +from pyro.contrib.named.infer.elbo import ELBO +from pyro.distributions.named import ( + Categorical, + Dirichlet, + LogNormal, + Normal, + index_select, +) +from pyro.infer import config_enumerate + +# from pyro.ops.indexing import Vindex + +# make_dist(dist.Normal) +i = Dim("i") +j = Dim("j") +k = Dim("k") +# i, j = dims(2) +loc = torch.zeros(2, 3)[i, j] +scale = torch.ones(2)[i] +normal = Normal(loc, scale=scale, validate_args=True) +x = normal.sample() +log_prob_x = normal.log_prob(x) +y = torch.randn(2)[i] +log_prob_y = normal.log_prob(y) +z = torch.randn(3, 4)[j, k] +log_prob_z = normal.log_prob(z) +dir = Dirichlet(torch.ones(3)) + + +@config_enumerate +def model(i, j, k): + data_plate = pyro.plate("data_plate", 6, dim=i) + feature_plate = pyro.plate("feature_plate", 5, dim=j) + component_plate = pyro.plate("component_plate", 4, dim=k) + with feature_plate: + with component_plate: + p = pyro.sample("p", Dirichlet(torch.ones(3))) + with data_plate as idx: + c = pyro.sample("c", Categorical(torch.ones(4))) + with feature_plate as vdx: # Capture plate index. + pc = index_select(p, dim=k, index=c) + x = pyro.sample( + "x", + Categorical(pc), + obs=torch.zeros(5, 6, dtype=torch.long)[vdx, idx], + ) + print(f" p.shape = {p.shape}") + print(f" c.shape = {c.shape}") + print(f" vdx.shape = {vdx.shape}") + print(f" pc.shape = {pc.shape}") + print(f" x.shape = {x.shape}") + + +def guide(i, j, k): + feature_plate = pyro.plate("feature_plate", 5, dim=j) + component_plate = pyro.plate("component_plate", 4, dim=k) + with feature_plate, component_plate: + pyro.sample("p", Dirichlet(torch.ones(3))) + + +pyro.clear_param_store() +print("Sampling:") +# model() +print("Enumerated Inference:") +data, feature, component = dims(3) +elbo = ELBO() +loss = elbo.loss(model, guide, data, feature, component) +elbo_10 = ELBO(num_particles=10) +loss_10 = elbo_10.loss(model, guide, data, feature, component) +elbo_100 = ELBO(num_particles=100) +loss_100 = elbo_100.loss(model, guide, data, feature, component) +elbo_1000 = ELBO(num_particles=1000) +loss_1000 = elbo_1000.loss(model, guide, data, feature, component) +import pdb + +pdb.set_trace() +pass + +# poutine.enum(model)() + +# Examples + +def model(data): + ... + mu = pyro.sample( + "mu", + MultivariateNormal( + torch.zeros(2, device=device), 5 * torch.eye(2, device=device) + ) + .expand([T]) + .to_event(1), + ) + ... + mu_vindex = Vindex(mu)[..., z, :]. # no if/else + +def model(data): + with pyro.plate("beta_plate", T-1): + beta = pyro.sample("beta", Beta(1, alpha)) + + with pyro.plate("mu_plate", T): + mu = pyro.sample("mu", MultivariateNormal(torch.zeros(2, device=device), 5 * torch.eye(2, device=device))) + + with pyro.plate("data", N): + #dim=-4 is an enumeration dim for z (e.g. T = 6 clusters) + #dim=-3 is a particle vectorization (e.g. num_particles = 10 particles) + #dim=-2 is allocated for "data" plate (1 value broadcasted over a batch) + #dim=-1 is allocated as event dimension (2 values) + z = pyro.sample("z", Categorical(mix_weights(beta).unsqueeze(-2))) + pyro.sample("obs", MultivariateNormal(index_select(mu, "mu_plate", z), + torch.eye(2, device=device)), obs=data) + +######################################## +def guide(data): + kappa = pyro.param('kappa', lambda: Uniform(torch.tensor(0., device=device), torch.tensor(2., device=device)).sample([T-1]), + constraint=constraints.positive) + tau = pyro.param('tau', lambda: MultivariateNormal(torch.zeros(2, device=device), 3 * torch.eye(2, device=device)).sample([T])) + phi = pyro.param('phi', lambda: Dirichlet(1/T * torch.ones(T, device=device)).sample([N]), constraint=constraints.simplex) + + with pyro.plate("beta_plate", T-1, device=device): + q_beta = pyro.sample("beta", Beta(torch.ones(T-1, device=device), kappa)) + + with pyro.plate("mu_plate", T, device=device): + q_mu = pyro.sample("mu", MultivariateNormal(tau, torch.eye(2, device=device))) + + with pyro.plate("data", N, device=device): + z = pyro.sample("z", Categorical(phi)) + +def guide(data): + # no plate here, use to_event instead + q_mu = pyro.sample( + "mu", + MultivariateNormal(tau, torch.eye(2, device=device)).to_event(1), + ) \ No newline at end of file From f1bc917336d2d7e00c19881db80c55b2562f4163 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Sat, 24 Feb 2024 03:54:44 +0000 Subject: [PATCH 02/11] prototype --- dim.py | 122 ---------------- pyro/contrib/named/infer/elbo.py | 5 +- pyro/distributions/named.py | 3 + pyro/distributions/torch_distribution.py | 33 ++++- pyro/poutine/broadcast_messenger.py | 12 +- pyro/poutine/enum_messenger.py | 2 +- pyro/poutine/indep_messenger.py | 4 +- torchdim.py | 176 +++++++++-------------- 8 files changed, 113 insertions(+), 244 deletions(-) delete mode 100644 dim.py diff --git a/dim.py b/dim.py deleted file mode 100644 index db2c23dc3b..0000000000 --- a/dim.py +++ /dev/null @@ -1,122 +0,0 @@ -import torch -from functorch.dim import dims - -import pyro -import pyro.distributions as dist - -i, j = dims(2) -mu = torch.ones(3, 2) -normal = dist.Normal(mu[i, j], 1, validate_args=True) -import pdb - -pdb.set_trace() - -with pyro.plate("z_plate", 5) as z: - import pdb - - pdb.set_trace() - print(f"z = {z}") - -f, c, d = dims(3) -p = torch.ones(5, 4, 3) -i = torch.ones(3).long()[d] # c -pi = p[f, i] -pc = p[f, c] -pass - -import matplotlib.pyplot as plt -import numpy as np -import pandas as pd -import torch -import torch.nn.functional as F -from torch.distributions import constraints -from tqdm import tqdm - -import pyro -from pyro.distributions import * -from pyro.infer import SVI, TraceEnum_ELBO, config_enumerate -from pyro.optim import Adam - -assert pyro.__version__.startswith("1.8.6") -pyro.set_rng_seed(0) - -device = torch.device("cuda") - -data = torch.cat( - ( - MultivariateNormal(-8 * torch.ones(2), torch.eye(2)).sample([50]), - MultivariateNormal(8 * torch.ones(2), torch.eye(2)).sample([50]), - MultivariateNormal(torch.tensor([1.5, 2]), torch.eye(2)).sample([50]), - MultivariateNormal(torch.tensor([-0.5, 1]), torch.eye(2)).sample([50]), - ) -) - -data = data.to(device) -# plt.scatter(data[:, 0], data[:, 1]) -# plt.title("Data Samples from Mixture of 4 Gaussians") -# plt.show() -N = data.shape[0] -num_particles = 10 - - -######################################## -def mix_weights(beta): - beta1m_cumprod = (1 - beta).cumprod(-1) - return F.pad(beta, (0, 1), value=1) * F.pad(beta1m_cumprod, (1, 0), value=1) - - -######################################## - - -def model(data): - with pyro.plate("beta_plate", T - 1, device=device): - beta = pyro.sample("beta", Beta(1, alpha)) - - with pyro.plate("mu_plate", T, device=device): - mu = pyro.sample( - "mu", - MultivariateNormal( - torch.zeros(2, device=device), 5 * torch.eye(2, device=device) - ), - ) - - with pyro.plate("data", N, device=device) as idx: - # dim=-4 is an enumeration dim for z (e.g. T = 6 clusters) - # dim=-3 is a particle vectorization (e.g. num_particles = 10 particles) - # dim=-2 is allocated for "data" plate (1 value broadcasted over a batch) - # dim=-1 is allocated as event dimension (2 values) - z = pyro.sample("z", Categorical(mix_weights(beta).unsqueeze(-2))[idx]) - pyro.sample( - "obs", MultivariateNormal(mu[z], torch.eye(2, device=device)), obs=data[idx] - ) - - -######################################## -def guide(data): - kappa = pyro.param( - "kappa", - lambda: Uniform( - torch.tensor(0.0, device=device), torch.tensor(2.0, device=device) - ).sample([T - 1]), - constraint=constraints.positive, - ) - tau = pyro.param( - "tau", - lambda: MultivariateNormal( - torch.zeros(2, device=device), 3 * torch.eye(2, device=device) - ).sample([T]), - ) - phi = pyro.param( - "phi", - lambda: Dirichlet(1 / T * torch.ones(T, device=device)).sample([N]), - constraint=constraints.simplex, - ) - - with pyro.plate("beta_plate", T - 1, device=device): - q_beta = pyro.sample("beta", Beta(torch.ones(T - 1, device=device), kappa)) - - with pyro.plate("mu_plate", T, device=device): - q_mu = pyro.sample("mu", MultivariateNormal(tau, torch.eye(2, device=device))) - - with pyro.plate("data", N, device=device): - z = pyro.sample("z", Categorical(phi)) diff --git a/pyro/contrib/named/infer/elbo.py b/pyro/contrib/named/infer/elbo.py index e1f9f208b0..e5dcd1c96a 100644 --- a/pyro/contrib/named/infer/elbo.py +++ b/pyro/contrib/named/infer/elbo.py @@ -1,9 +1,10 @@ -import torch +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + from functorch.dim import dims import pyro from pyro import poutine -from pyro.distributions.util import is_identically_zero def log_density(model, args, kwargs): diff --git a/pyro/distributions/named.py b/pyro/distributions/named.py index ecc39b94f1..a6b3087fd3 100644 --- a/pyro/distributions/named.py +++ b/pyro/distributions/named.py @@ -1,3 +1,6 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + import inspect import torch diff --git a/pyro/distributions/torch_distribution.py b/pyro/distributions/torch_distribution.py index ace02da72a..3e442dcb10 100644 --- a/pyro/distributions/torch_distribution.py +++ b/pyro/distributions/torch_distribution.py @@ -6,6 +6,7 @@ from typing import Callable import torch +from functorch.dim import Tensor from torch.distributions.kl import kl_divergence, register_kl import pyro.distributions.torch @@ -45,11 +46,37 @@ def __call__(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor: batched). The shape of the result should be `self.shape()`. :rtype: torch.Tensor """ + # return self.base_dist.sample(self.sample_shape + sample_shape)[self.batch_dims] return ( - self.rsample(sample_shape) + self.rsample(self.sample_shape + sample_shape) if self.has_rsample - else self.sample(sample_shape) - ) + else self.sample(self.sample_shape + sample_shape) + )[ + self.named_batch_shape[ + len(self.named_batch_shape) - len(self.sample_shape) : + ] + ] + + @property + def named_batch_shape(self): + if not hasattr(self, "_named_batch_shape"): + self._named_batch_shape = () + for param in self.arg_constraints: + value = getattr(self, param) + if isinstance(value, Tensor): + for dim in value.dims: + if dim not in set(self._named_batch_shape): + self._named_batch_shape += (dim,) + return self._named_batch_shape + + def expand_named_shape(self, named_batch_shape): + if not hasattr(self, "sample_shape"): + self.sample_shape = torch.Size() + for dim in named_batch_shape: + if dim not in set(self.named_batch_shape): + self._named_batch_shape += (dim,) + self.sample_shape = self.sample_shape + (dim.size,) + return self @property def batch_shape(self) -> torch.Size: diff --git a/pyro/poutine/broadcast_messenger.py b/pyro/poutine/broadcast_messenger.py index 1fd59c50be..ee2cf6ff80 100644 --- a/pyro/poutine/broadcast_messenger.py +++ b/pyro/poutine/broadcast_messenger.py @@ -5,7 +5,6 @@ from functorch.dim import Dim -from pyro.distributions.named import NamedDistribution from pyro.distributions.torch_distribution import TorchDistributionMixin from pyro.poutine.messenger import Messenger from pyro.util import ignore_jit_warnings @@ -59,14 +58,14 @@ def _pyro_sample(msg: "Message") -> None: dist = msg["fn"] actual_batch_shape = dist.batch_shape - if isinstance(msg["fn"], NamedDistribution): - prefix_batch_shape = tuple(f.dim for f in msg["cond_indep_stack"]) - msg["fn"].expand(prefix_batch_shape) - return target_batch_shape = [ None if size == 1 else size for size in actual_batch_shape ] + named_batch_shape = [] for f in msg["cond_indep_stack"]: + if isinstance(f.dim, Dim): + named_batch_shape.append(f.dim) + continue if f.dim is None or f.size == -1: continue assert f.dim < 0 @@ -95,6 +94,9 @@ def _pyro_sample(msg: "Message") -> None: target_batch_shape[i] = ( actual_batch_shape[i] if len(actual_batch_shape) >= -i else 1 ) + dist = dist.expand_named_shape(named_batch_shape) msg["fn"] = dist.expand(target_batch_shape) + msg["fn"]._named_batch_shape = dist.named_batch_shape + msg["fn"].sample_shape = dist.sample_shape if msg["fn"].has_rsample != dist.has_rsample: msg["fn"].has_rsample = dist.has_rsample # copy custom attribute diff --git a/pyro/poutine/enum_messenger.py b/pyro/poutine/enum_messenger.py index 43037cfd49..08de70a19b 100644 --- a/pyro/poutine/enum_messenger.py +++ b/pyro/poutine/enum_messenger.py @@ -11,7 +11,7 @@ from pyro.distributions.torch_distribution import TorchDistributionMixin from pyro.ops.indexing import Vindex from pyro.poutine.messenger import Messenger -from pyro.poutine.runtime import _ENUM_ALLOCATOR, Message +from pyro.poutine.runtime import Message from pyro.util import ignore_jit_warnings diff --git a/pyro/poutine/indep_messenger.py b/pyro/poutine/indep_messenger.py index f128b361d8..aae7efa11a 100644 --- a/pyro/poutine/indep_messenger.py +++ b/pyro/poutine/indep_messenger.py @@ -136,7 +136,9 @@ def _reset(self) -> None: def indices(self) -> torch.Tensor: if self._indices is None: self._indices = torch.arange(self.size, dtype=torch.long).to(self.device) - return self._indices[self.dim] + if isinstance(self.dim, Dim): + return self._indices[self.dim] + return self._indices def _process_message(self, msg: Message) -> None: frame = CondIndepStackFrame(self.name, self.dim, self.size, self.counter) diff --git a/torchdim.py b/torchdim.py index f66dc2a727..30008aaed3 100644 --- a/torchdim.py +++ b/torchdim.py @@ -1,53 +1,56 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + import torch -import torch.distributions.constraints as constraints -from functorch.dim import Dim, dims +from functorch.dim import dims import pyro -from pyro import poutine +import pyro.distributions as dist from pyro.contrib.named.infer.elbo import ELBO -from pyro.distributions.named import ( - Categorical, - Dirichlet, - LogNormal, - Normal, - index_select, -) -from pyro.infer import config_enumerate - -# from pyro.ops.indexing import Vindex +from pyro.distributions.named import index_select -# make_dist(dist.Normal) -i = Dim("i") -j = Dim("j") -k = Dim("k") # i, j = dims(2) -loc = torch.zeros(2, 3)[i, j] -scale = torch.ones(2)[i] -normal = Normal(loc, scale=scale, validate_args=True) -x = normal.sample() -log_prob_x = normal.log_prob(x) -y = torch.randn(2)[i] -log_prob_y = normal.log_prob(y) -z = torch.randn(3, 4)[j, k] -log_prob_z = normal.log_prob(z) -dir = Dirichlet(torch.ones(3)) - - -@config_enumerate -def model(i, j, k): - data_plate = pyro.plate("data_plate", 6, dim=i) - feature_plate = pyro.plate("feature_plate", 5, dim=j) - component_plate = pyro.plate("component_plate", 4, dim=k) - with feature_plate: - with component_plate: - p = pyro.sample("p", Dirichlet(torch.ones(3))) +# loc = torch.zeros(2, 3)[i, j] +# scale = torch.ones(2)[i] +# import pdb + +# pdb.set_trace() +# normal = dist.Normal(loc, scale=torch.tensor(1.0), validate_args=False) +# k.size = 4 +# normal.expand_named_shape([i, j, k]) +# import pdb + +# pdb.set_trace() +# normal.named_batch_shape +# x = normal.sample() +# log_prob_x = normal.log_prob(x) +# y = torch.randn(2)[i] +# log_prob_y = normal.log_prob(y) +# z = torch.randn(3, 4)[j, k] +# log_prob_z = normal.log_prob(z) +# dir = Dirichlet(torch.ones(3)) + +pyro.enable_validation(False) + + +# @config_enumerate +def model(data_dim, feature_dim, component_dim): + data_plate = pyro.plate("data_plate", 6, dim=data_dim) + feature_plate = pyro.plate("feature_plate", 5, dim=feature_dim) + component_plate = pyro.plate("component_plate", 4, dim=component_dim) + # component_plate = pyro.plate("component_plate", 4, dim=-1) + with feature_plate, component_plate: + p = pyro.sample("p", dist.Dirichlet(torch.ones(3))) with data_plate as idx: - c = pyro.sample("c", Categorical(torch.ones(4))) + c = pyro.sample( + "c", dist.Categorical(torch.ones(4).expand([data_dim.size, 4])[data_dim]) + ) with feature_plate as vdx: # Capture plate index. - pc = index_select(p, dim=k, index=c) + pc = index_select(p, dim=component_dim, index=c) + # pc = p[c] x = pyro.sample( "x", - Categorical(pc), + dist.Categorical(pc), obs=torch.zeros(5, 6, dtype=torch.long)[vdx, idx], ) print(f" p.shape = {p.shape}") @@ -57,83 +60,36 @@ def model(i, j, k): print(f" x.shape = {x.shape}") -def guide(i, j, k): - feature_plate = pyro.plate("feature_plate", 5, dim=j) - component_plate = pyro.plate("component_plate", 4, dim=k) +def guide(data_dim, feature_dim, component_dim): + data_plate = pyro.plate("data_plate", 6, dim=data_dim) + feature_plate = pyro.plate("feature_plate", 5, dim=feature_dim) + component_plate = pyro.plate("component_plate", 4, dim=component_dim) + # component_plate = pyro.plate("component_plate", 4, dim=-1) with feature_plate, component_plate: - pyro.sample("p", Dirichlet(torch.ones(3))) + pyro.sample( + "p", + dist.Dirichlet( + torch.ones(3).expand([feature_dim.size, component_dim.size, 3])[ + feature_dim, component_dim + ] + ), + ) + with data_plate: + pyro.sample( + "c", dist.Categorical(torch.ones(4).expand([data_dim.size, 4])[data_dim]) + ) +data_dim, feature_dim, component_dim = dims(3) pyro.clear_param_store() print("Sampling:") -# model() print("Enumerated Inference:") -data, feature, component = dims(3) elbo = ELBO() -loss = elbo.loss(model, guide, data, feature, component) +# model(data_dim, feature_dim, component_dim) +loss = elbo.loss(model, guide, data_dim, feature_dim, component_dim) elbo_10 = ELBO(num_particles=10) -loss_10 = elbo_10.loss(model, guide, data, feature, component) +loss_10 = elbo_10.loss(model, guide, data_dim, feature_dim, component_dim) elbo_100 = ELBO(num_particles=100) -loss_100 = elbo_100.loss(model, guide, data, feature, component) +loss_100 = elbo_100.loss(model, guide, data_dim, feature_dim, component_dim) elbo_1000 = ELBO(num_particles=1000) -loss_1000 = elbo_1000.loss(model, guide, data, feature, component) -import pdb - -pdb.set_trace() -pass - -# poutine.enum(model)() - -# Examples - -def model(data): - ... - mu = pyro.sample( - "mu", - MultivariateNormal( - torch.zeros(2, device=device), 5 * torch.eye(2, device=device) - ) - .expand([T]) - .to_event(1), - ) - ... - mu_vindex = Vindex(mu)[..., z, :]. # no if/else - -def model(data): - with pyro.plate("beta_plate", T-1): - beta = pyro.sample("beta", Beta(1, alpha)) - - with pyro.plate("mu_plate", T): - mu = pyro.sample("mu", MultivariateNormal(torch.zeros(2, device=device), 5 * torch.eye(2, device=device))) - - with pyro.plate("data", N): - #dim=-4 is an enumeration dim for z (e.g. T = 6 clusters) - #dim=-3 is a particle vectorization (e.g. num_particles = 10 particles) - #dim=-2 is allocated for "data" plate (1 value broadcasted over a batch) - #dim=-1 is allocated as event dimension (2 values) - z = pyro.sample("z", Categorical(mix_weights(beta).unsqueeze(-2))) - pyro.sample("obs", MultivariateNormal(index_select(mu, "mu_plate", z), - torch.eye(2, device=device)), obs=data) - -######################################## -def guide(data): - kappa = pyro.param('kappa', lambda: Uniform(torch.tensor(0., device=device), torch.tensor(2., device=device)).sample([T-1]), - constraint=constraints.positive) - tau = pyro.param('tau', lambda: MultivariateNormal(torch.zeros(2, device=device), 3 * torch.eye(2, device=device)).sample([T])) - phi = pyro.param('phi', lambda: Dirichlet(1/T * torch.ones(T, device=device)).sample([N]), constraint=constraints.simplex) - - with pyro.plate("beta_plate", T-1, device=device): - q_beta = pyro.sample("beta", Beta(torch.ones(T-1, device=device), kappa)) - - with pyro.plate("mu_plate", T, device=device): - q_mu = pyro.sample("mu", MultivariateNormal(tau, torch.eye(2, device=device))) - - with pyro.plate("data", N, device=device): - z = pyro.sample("z", Categorical(phi)) - -def guide(data): - # no plate here, use to_event instead - q_mu = pyro.sample( - "mu", - MultivariateNormal(tau, torch.eye(2, device=device)).to_event(1), - ) \ No newline at end of file +loss_1000 = elbo_1000.loss(model, guide, data_dim, feature_dim, component_dim) From 4dcfca2577ddb9c85ad657f3e81648d559c840ac Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Sat, 24 Feb 2024 04:02:57 +0000 Subject: [PATCH 03/11] rm named --- pyro/contrib/named/infer/elbo.py | 12 +-- pyro/distributions/named.py | 122 ------------------------------- pyro/ops/indexing.py | 4 + 3 files changed, 10 insertions(+), 128 deletions(-) delete mode 100644 pyro/distributions/named.py diff --git a/pyro/contrib/named/infer/elbo.py b/pyro/contrib/named/infer/elbo.py index e5dcd1c96a..d19875fdc9 100644 --- a/pyro/contrib/named/infer/elbo.py +++ b/pyro/contrib/named/infer/elbo.py @@ -7,21 +7,21 @@ from pyro import poutine -def log_density(model, args, kwargs): +def log_density(fn, args, kwargs): """ (EXPERIMENTAL INTERFACE) Computes log of joint density for the model given latent values ``params``. - :param model: Python callable containing NumPyro primitives. + :param fn: Python callable containing NumPyro primitives. :param tuple model_args: args provided to the model. :param dict model_kwargs: kwargs provided to the model. :param dict params: dictionary of current parameter values keyed by site name. :return: log of joint density and a corresponding model trace """ - model_trace = poutine.trace(model).get_trace(*args, **kwargs) + fn_trace = poutine.trace(fn).get_trace(*args, **kwargs) log_joint = 0.0 - for site in model_trace.nodes.values(): + for site in fn_trace.nodes.values(): if site["type"] == "sample" and site["fn"]: value = site["value"] scale = site["scale"] @@ -33,7 +33,7 @@ def log_density(model, args, kwargs): sum_dims = getattr(log_prob, "dims", ()) + tuple(range(log_prob.ndim)) log_prob = log_prob.sum(sum_dims) log_joint = log_joint + log_prob - return log_joint, model_trace + return log_joint, fn_trace class ELBO: @@ -43,7 +43,7 @@ def __init__(self, num_particles=1, vectorize_particles=True): def loss(self, model, guide, *args, **kwargs): if self.num_particles > 1: - vectorize = pyro.plate("num_particles", self.num_particles, dim=dims(1)) + vectorize = pyro.plate("num_particles", self.num_particles, dim=dims()) model = vectorize(model) guide = vectorize(guide) diff --git a/pyro/distributions/named.py b/pyro/distributions/named.py deleted file mode 100644 index a6b3087fd3..0000000000 --- a/pyro/distributions/named.py +++ /dev/null @@ -1,122 +0,0 @@ -# Copyright Contributors to the Pyro project. -# SPDX-License-Identifier: Apache-2.0 - -import inspect - -import torch - -import pyro.distributions as dist -from pyro.distributions import constraints -from pyro.distributions.torch_distribution import TorchDistributionMixin - - -def order(x, batch_dims): - batch_shape = set(getattr(x, "dims", ())) - event_shape = x.shape - if batch_shape: - x = x.order(*(dim for dim in batch_dims if dim in batch_shape)) - x = x.reshape( - tuple(dim.size if dim in batch_shape else 1 for dim in batch_dims) - + event_shape - ) - return x - - -def index_select(input, dim, index): - return input.order(dim)[index] - - -class NamedDistribution(TorchDistributionMixin): - dist_class: dist.Distribution - - def __init__(self, *args, **kwargs) -> None: - ast_fields = inspect.getfullargspec(self.dist_class.__init__)[0][1:] - kwargs.update(zip(ast_fields, args)) - self.batch_dims = tuple( - set.union( - *[ - set(getattr(kwargs[k], "dims", ())) - for k in kwargs - if k in self.dist_class.arg_constraints - ] - ) - ) - for k in self.dist_class.arg_constraints: - if k in kwargs: - kwargs[k] = order(kwargs[k], self.batch_dims) - self.base_dist = self.dist_class(**kwargs) - self.sample_shape = torch.Size() - - @property - def has_rsample(self): - return self.base_dist.has_rsample - - @property - def has_enumerate_support(self): - return self.base_dist.has_enumerate_support - - @constraints.dependent_property - def support(self): - return self.base_dist.support - - @property - def batch_shape(self): - return self.batch_dims - - @property - def event_shape(self): - return self.base_dist.event_shape - - def sample(self, sample_shape=torch.Size()): - return self.base_dist.sample(self.sample_shape + sample_shape)[self.batch_dims] - - def rsample(self, sample_shape=torch.Size()): - return self.base_dist.rsample(self.sample_shape + sample_shape)[self.batch_dims] - - def log_prob(self, value): - value_dims = set(getattr(value, "dims", ())) - extra_dims = tuple(value_dims - set(self.batch_dims)) - value = order(value, extra_dims + self.batch_dims) - return self.base_dist.log_prob(value)[extra_dims + self.batch_dims] - - def expand(self, batch_shape, _instance=None): - """ - Returns a new :class:`ExpandedDistribution` instance with batch - dimensions expanded to `batch_shape`. - - :param tuple batch_shape: batch shape to expand to. - :param _instance: unused argument for compatibility with - :meth:`torch.distributions.Distribution.expand` - :return: an instance of `ExpandedDistribution`. - :rtype: :class:`ExpandedDistribution` - """ - for dim in batch_shape: - if dim not in set(self.batch_dims): - self.batch_dims = self.batch_dims + (dim,) - self.sample_shape = self.sample_shape + (dim.size,) - return self - - def enumerate_support(self, expand=False): - samples = self.base_dist.enumerate_support(expand=False) - return samples - - -# class NamedDistributionMeta(type): -# pass -# def __call__(cls, *args, **kwargs): - - -def make_dist(backend_dist_class): - - dist_class = type( - backend_dist_class.__name__, - (NamedDistribution,), - {"dist_class": backend_dist_class}, - ) - return dist_class - - -Normal = make_dist(dist.Normal) -Categorical = make_dist(dist.Categorical) -LogNormal = make_dist(dist.LogNormal) -Dirichlet = make_dist(dist.Dirichlet) diff --git a/pyro/ops/indexing.py b/pyro/ops/indexing.py index 2fc57aa9f8..d8f8155214 100644 --- a/pyro/ops/indexing.py +++ b/pyro/ops/indexing.py @@ -215,3 +215,7 @@ def __init__(self, tensor): def __getitem__(self, args): return vindex(self._tensor, args) + + +def index_select(input, dim, index): + return input.order(dim)[index] From 1874ead097165210dfc3ba14b2feaa43fa92f3cb Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Sat, 24 Feb 2024 04:45:17 +0000 Subject: [PATCH 04/11] clean up --- pyro/contrib/named/infer/elbo.py | 23 ++--- pyro/distributions/torch_distribution.py | 54 ++++++----- pyro/poutine/broadcast_messenger.py | 8 +- pyro/poutine/enum_messenger.py | 117 +++++++++++------------ pyro/poutine/indep_messenger.py | 18 ++-- 5 files changed, 109 insertions(+), 111 deletions(-) diff --git a/pyro/contrib/named/infer/elbo.py b/pyro/contrib/named/infer/elbo.py index d19875fdc9..cb3306697f 100644 --- a/pyro/contrib/named/infer/elbo.py +++ b/pyro/contrib/named/infer/elbo.py @@ -5,21 +5,19 @@ import pyro from pyro import poutine +from pyro.poutine.util import prune_subsample_sites def log_density(fn, args, kwargs): """ - (EXPERIMENTAL INTERFACE) Computes log of joint density for the model given - latent values ``params``. - - :param fn: Python callable containing NumPyro primitives. - :param tuple model_args: args provided to the model. - :param dict model_kwargs: kwargs provided to the model. - :param dict params: dictionary of current parameter values keyed by site - name. + Compute log density of a stochastic function given its arguments. + + :param fn: Python callable containing Pyro primitives. + :param tuple args: args provided to the function. + :param dict kwargs: kwargs provided to the function. :return: log of joint density and a corresponding model trace """ - fn_trace = poutine.trace(fn).get_trace(*args, **kwargs) + fn_trace = prune_subsample_sites(poutine.trace(fn).get_trace(*args, **kwargs)) log_joint = 0.0 for site in fn_trace.nodes.values(): if site["type"] == "sample" and site["fn"]: @@ -30,9 +28,8 @@ def log_density(fn, args, kwargs): if scale is not None: log_prob = scale * log_prob - sum_dims = getattr(log_prob, "dims", ()) + tuple(range(log_prob.ndim)) - log_prob = log_prob.sum(sum_dims) - log_joint = log_joint + log_prob + sum_dims = tuple(f.dim for f in site["cond_indep_stack"]) + log_joint += log_prob.sum(sum_dims) return log_joint, fn_trace @@ -43,7 +40,7 @@ def __init__(self, num_particles=1, vectorize_particles=True): def loss(self, model, guide, *args, **kwargs): if self.num_particles > 1: - vectorize = pyro.plate("num_particles", self.num_particles, dim=dims()) + vectorize = pyro.plate("num_particles", self.num_particles, dim=dims(1)) model = vectorize(model) guide = vectorize(guide) diff --git a/pyro/distributions/torch_distribution.py b/pyro/distributions/torch_distribution.py index 3e442dcb10..b65f5f7f48 100644 --- a/pyro/distributions/torch_distribution.py +++ b/pyro/distributions/torch_distribution.py @@ -6,7 +6,6 @@ from typing import Callable import torch -from functorch.dim import Tensor from torch.distributions.kl import kl_divergence, register_kl import pyro.distributions.torch @@ -46,38 +45,45 @@ def __call__(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor: batched). The shape of the result should be `self.shape()`. :rtype: torch.Tensor """ - # return self.base_dist.sample(self.sample_shape + sample_shape)[self.batch_dims] + sample_shape = self.sample_shape + sample_shape + bind_dims = self.dims[len(self.dims) - len(self.sample_shape) :] return ( - self.rsample(self.sample_shape + sample_shape) + self.rsample(sample_shape) if self.has_rsample - else self.sample(self.sample_shape + sample_shape) - )[ - self.named_batch_shape[ - len(self.named_batch_shape) - len(self.sample_shape) : - ] - ] + else self.sample(sample_shape) + )[bind_dims] @property - def named_batch_shape(self): - if not hasattr(self, "_named_batch_shape"): - self._named_batch_shape = () + def dims(self): + if not hasattr(self, "_dims"): + seen = set() + result = [] for param in self.arg_constraints: value = getattr(self, param) - if isinstance(value, Tensor): - for dim in value.dims: - if dim not in set(self._named_batch_shape): - self._named_batch_shape += (dim,) - return self._named_batch_shape - - def expand_named_shape(self, named_batch_shape): - if not hasattr(self, "sample_shape"): - self.sample_shape = torch.Size() - for dim in named_batch_shape: - if dim not in set(self.named_batch_shape): - self._named_batch_shape += (dim,) + for dim in getattr(value, "dims", ()): + if dim not in seen: + seen.add(dim) + result.append(dim) + self._dims = tuple(result) + return self._dims + + def expand_dims(self, dims): + for dim in dims: + if dim not in set(self.dims): + self._dims += (dim,) self.sample_shape = self.sample_shape + (dim.size,) return self + @property + def sample_shape(self): + if not hasattr(self, "_sample_shape"): + self._sample_shape = torch.Size() + return self._sample_shape + + @sample_shape.setter + def sample_shape(self, value): + self._sample_shape = value + @property def batch_shape(self) -> torch.Size: """ diff --git a/pyro/poutine/broadcast_messenger.py b/pyro/poutine/broadcast_messenger.py index ee2cf6ff80..a1b5975d05 100644 --- a/pyro/poutine/broadcast_messenger.py +++ b/pyro/poutine/broadcast_messenger.py @@ -61,10 +61,10 @@ def _pyro_sample(msg: "Message") -> None: target_batch_shape = [ None if size == 1 else size for size in actual_batch_shape ] - named_batch_shape = [] + dims = [] for f in msg["cond_indep_stack"]: if isinstance(f.dim, Dim): - named_batch_shape.append(f.dim) + dims.append(f.dim) continue if f.dim is None or f.size == -1: continue @@ -94,9 +94,9 @@ def _pyro_sample(msg: "Message") -> None: target_batch_shape[i] = ( actual_batch_shape[i] if len(actual_batch_shape) >= -i else 1 ) - dist = dist.expand_named_shape(named_batch_shape) + dist = dist.expand_dims(tuple(dims)) msg["fn"] = dist.expand(target_batch_shape) - msg["fn"]._named_batch_shape = dist.named_batch_shape + msg["fn"]._dims = dist.dims msg["fn"].sample_shape = dist.sample_shape if msg["fn"].has_rsample != dist.has_rsample: msg["fn"].has_rsample = dist.has_rsample # copy custom attribute diff --git a/pyro/poutine/enum_messenger.py b/pyro/poutine/enum_messenger.py index 08de70a19b..408669e900 100644 --- a/pyro/poutine/enum_messenger.py +++ b/pyro/poutine/enum_messenger.py @@ -4,14 +4,13 @@ from typing import Any, Dict, List, Optional import torch -from functorch.dim import Dim from typing_extensions import Self from pyro.distributions.torch import Categorical from pyro.distributions.torch_distribution import TorchDistributionMixin from pyro.ops.indexing import Vindex from pyro.poutine.messenger import Messenger -from pyro.poutine.runtime import Message +from pyro.poutine.runtime import _ENUM_ALLOCATOR, Message from pyro.util import ignore_jit_warnings @@ -153,8 +152,8 @@ def __init__(self, first_available_dim: Optional[int] = None) -> None: super().__init__() def __enter__(self) -> Self: - # if self.first_available_dim is not None: - # _ENUM_ALLOCATOR.set_first_available_dim(self.first_available_dim) + if self.first_available_dim is not None: + _ENUM_ALLOCATOR.set_first_available_dim(self.first_available_dim) self._markov_depths: Dict[str, int] = ( {} ) # site name -> depth (nonnegative integer) @@ -177,67 +176,61 @@ def _pyro_sample(self, msg: Message) -> None: assert isinstance(msg["name"], str) assert msg["infer"] is not None + # Compute upstream dims in scope; these are unsafe to use for this site's target_dim. + scope = msg["infer"].get("_markov_scope") # site name -> markov depth + param_dims = _ENUM_ALLOCATOR.dim_to_id.copy() # enum dim -> unique id + if scope is not None: + for name, depth in scope.items(): + if ( + self._markov_depths[name] == depth + ): # hide sites whose markov context has exited + param_dims.update(self._value_dims[name]) + self._markov_depths[msg["name"]] = msg["infer"]["_markov_depth"] + self._param_dims[msg["name"]] = param_dims if msg["is_observed"] or msg["infer"].get("enumerate") != "parallel": return - value = msg["fn"].enumerate_support(False) - dim = Dim(msg["name"]) - msg["value"] = value[dim] + + # Compute an enumerated value (at an arbitrary dim). + value = enumerate_site(msg) + actual_dim = -1 - len(msg["fn"].batch_shape) # the leftmost dim of log_prob + + # Move actual_dim to a safe target_dim. + target_dim, id_ = _ENUM_ALLOCATOR.allocate( + None if scope is None else set(param_dims) + ) + event_dim = msg["fn"].event_dim + categorical_support = getattr(value, "_pyro_categorical_support", None) + if categorical_support is not None: + # Preserve categorical supports to speed up Categorical.log_prob(). + # See pyro/distributions/torch.py for details. + assert target_dim < 0 + value = value.reshape(value.shape[:1] + (1,) * (-1 - target_dim)) + value._pyro_categorical_support = categorical_support # type: ignore[attr-defined] + elif actual_dim < target_dim: + assert ( + value.size(target_dim - event_dim) == 1 + ), "pyro.markov dim conflict at dim {}".format(actual_dim) + value = value.transpose(target_dim - event_dim, actual_dim - event_dim) + while value.dim() and value.size(0) == 1: + value = value.squeeze(0) + elif target_dim < actual_dim: + diff = actual_dim - target_dim + value = value.reshape(value.shape[:1] + (1,) * diff + value.shape[1:]) + + # Compute dims passed downstream through the value. + value_dims = { + dim: param_dims[dim] + for dim in range(event_dim - value.dim(), 0) + if value.size(dim - event_dim) > 1 and dim in param_dims + } + value_dims[target_dim] = id_ + + msg["infer"]["_enumerate_dim"] = target_dim + msg["infer"]["_dim_to_id"] = value_dims + msg["value"] = value msg["done"] = True - # Compute upstream dims in scope; these are unsafe to use for this site's target_dim. - # scope = msg["infer"].get("_markov_scope") # site name -> markov depth - # param_dims = _ENUM_ALLOCATOR.dim_to_id.copy() # enum dim -> unique id - # if scope is not None: - # for name, depth in scope.items(): - # if ( - # self._markov_depths[name] == depth - # ): # hide sites whose markov context has exited - # param_dims.update(self._value_dims[name]) - # self._markov_depths[msg["name"]] = msg["infer"]["_markov_depth"] - # self._param_dims[msg["name"]] = param_dims - # if msg["is_observed"] or msg["infer"].get("enumerate") != "parallel": - # return - - # # Compute an enumerated value (at an arbitrary dim). - # value = enumerate_site(msg) - # actual_dim = -1 - len(msg["fn"].batch_shape) # the leftmost dim of log_prob - - # # Move actual_dim to a safe target_dim. - # target_dim, id_ = _ENUM_ALLOCATOR.allocate( - # None if scope is None else set(param_dims) - # ) - # event_dim = msg["fn"].event_dim - # categorical_support = getattr(value, "_pyro_categorical_support", None) - # if categorical_support is not None: - # # Preserve categorical supports to speed up Categorical.log_prob(). - # # See pyro/distributions/torch.py for details. - # assert target_dim < 0 - # value = value.reshape(value.shape[:1] + (1,) * (-1 - target_dim)) - # value._pyro_categorical_support = categorical_support # type: ignore[attr-defined] - # elif actual_dim < target_dim: - # assert ( - # value.size(target_dim - event_dim) == 1 - # ), "pyro.markov dim conflict at dim {}".format(actual_dim) - # value = value.transpose(target_dim - event_dim, actual_dim - event_dim) - # while value.dim() and value.size(0) == 1: - # value = value.squeeze(0) - # elif target_dim < actual_dim: - # diff = actual_dim - target_dim - # value = value.reshape(value.shape[:1] + (1,) * diff + value.shape[1:]) - - # # Compute dims passed downstream through the value. - # value_dims = { - # dim: param_dims[dim] - # for dim in range(event_dim - value.dim(), 0) - # if value.size(dim - event_dim) > 1 and dim in param_dims - # } - # value_dims[target_dim] = id_ - - # msg["infer"]["_enumerate_dim"] = target_dim - # msg["infer"]["_dim_to_id"] = value_dims - # msg["value"] = value - # msg["done"] = True - - def _pyro_post_sample_(self, msg: Message) -> None: + + def _pyro_post_sample(self, msg: Message) -> None: # Save all dims exposed in this sample value. # Whereas all of site["_dim_to_id"] are needed to interpret a # site's log_prob tensor, only a filtered subset self._value_dims[msg["name"]] diff --git a/pyro/poutine/indep_messenger.py b/pyro/poutine/indep_messenger.py index aae7efa11a..78809334d4 100644 --- a/pyro/poutine/indep_messenger.py +++ b/pyro/poutine/indep_messenger.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import numbers -from typing import Iterator, NamedTuple, Optional, Tuple +from typing import Iterator, NamedTuple, Optional, Tuple, Union import torch from functorch.dim import Dim @@ -15,7 +15,7 @@ class CondIndepStackFrame(NamedTuple): name: str - dim: Optional[int] + dim: Optional[Union[int, Dim]] size: int counter: int full_size: Optional[int] = None @@ -24,7 +24,7 @@ class CondIndepStackFrame(NamedTuple): def vectorized(self) -> bool: return self.dim is not None - def _key(self) -> Tuple[str, Optional[int], int, int]: + def _key(self) -> Tuple[str, Optional[Union[int, Dim]], int, int]: size = self.size with ignore_jit_warnings(["Converting a tensor to a Python number"]): if isinstance(size, torch.Tensor): # type: ignore[unreachable] @@ -70,7 +70,7 @@ def __init__( self, name: str, size: int, - dim: Optional[int] = None, + dim: Optional[Union[int, Dim]] = None, device: Optional[str] = None, ) -> None: if not torch._C._get_tracing_state() and size == 0: @@ -99,15 +99,16 @@ def __enter__(self) -> Self: self._vectorized = True if self._vectorized is True: - assert self.dim is not None - # self.dim = _DIM_ALLOCATOR.allocate(self.name, self.dim) + if not isinstance(self.dim, Dim): + self.dim = _DIM_ALLOCATOR.allocate(self.name, self.dim) return super().__enter__() def __exit__(self, *args) -> None: if self._vectorized is True: assert self.dim is not None - # _DIM_ALLOCATOR.free(self.name, self.dim) + if not isinstance(self.dim, Dim): + _DIM_ALLOCATOR.free(self.name, self.dim) return super().__exit__(*args) def __iter__(self) -> Iterator[int]: @@ -128,7 +129,8 @@ def __iter__(self) -> Iterator[int]: def _reset(self) -> None: if self._vectorized: assert self.dim is not None - _DIM_ALLOCATOR.free(self.name, self.dim) + if not isinstance(self.dim, Dim): + _DIM_ALLOCATOR.free(self.name, self.dim) self._vectorized = None self.counter = 0 From 8cbbf2fbe37715665f7f908032bc4fea0689dc0e Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Sat, 23 Mar 2024 15:32:31 +0000 Subject: [PATCH 05/11] named_shape --- pyro/distributions/torch_distribution.py | 49 ++++++++++++------------ pyro/poutine/broadcast_messenger.py | 10 ++--- 2 files changed, 30 insertions(+), 29 deletions(-) diff --git a/pyro/distributions/torch_distribution.py b/pyro/distributions/torch_distribution.py index b65f5f7f48..a497166b03 100644 --- a/pyro/distributions/torch_distribution.py +++ b/pyro/distributions/torch_distribution.py @@ -3,9 +3,10 @@ import warnings from collections import OrderedDict -from typing import Callable +from typing import Callable, Tuple import torch +from functorch.dim import Dim from torch.distributions.kl import kl_divergence, register_kl import pyro.distributions.torch @@ -45,8 +46,10 @@ def __call__(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor: batched). The shape of the result should be `self.shape()`. :rtype: torch.Tensor """ - sample_shape = self.sample_shape + sample_shape - bind_dims = self.dims[len(self.dims) - len(self.sample_shape) :] + sample_shape = self.named_sample_shape + sample_shape + bind_dims = self.named_shape[ + len(self.named_shape) - len(self.named_sample_shape) : + ] return ( self.rsample(sample_shape) if self.has_rsample @@ -54,35 +57,33 @@ def __call__(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor: )[bind_dims] @property - def dims(self): - if not hasattr(self, "_dims"): - seen = set() + def named_shape(self) -> Tuple[Dim]: + if getattr(self, "_named_shape", None) is None: result = [] for param in self.arg_constraints: value = getattr(self, param) for dim in getattr(value, "dims", ()): - if dim not in seen: - seen.add(dim) + if dim not in result: result.append(dim) - self._dims = tuple(result) - return self._dims - - def expand_dims(self, dims): - for dim in dims: - if dim not in set(self.dims): - self._dims += (dim,) - self.sample_shape = self.sample_shape + (dim.size,) + self._named_shape = tuple(result) + return self._named_shape + + def expand_named_shape(self, named_shape): + for dim in named_shape: + if dim not in set(self.named_shape): + self._named_shape += (dim,) + self.named_sample_shape = self.named_sample_shape + (dim.size,) return self @property - def sample_shape(self): - if not hasattr(self, "_sample_shape"): - self._sample_shape = torch.Size() - return self._sample_shape - - @sample_shape.setter - def sample_shape(self, value): - self._sample_shape = value + def named_sample_shape(self): + if getattr(self, "_named_sample_shape", None) is None: + self._named_sample_shape = torch.Size() + return self._named_sample_shape + + @named_sample_shape.setter + def named_sample_shape(self, value): + self._named_sample_shape = value @property def batch_shape(self) -> torch.Size: diff --git a/pyro/poutine/broadcast_messenger.py b/pyro/poutine/broadcast_messenger.py index a1b5975d05..be7b4dc24c 100644 --- a/pyro/poutine/broadcast_messenger.py +++ b/pyro/poutine/broadcast_messenger.py @@ -61,10 +61,10 @@ def _pyro_sample(msg: "Message") -> None: target_batch_shape = [ None if size == 1 else size for size in actual_batch_shape ] - dims = [] + named_shape = () for f in msg["cond_indep_stack"]: if isinstance(f.dim, Dim): - dims.append(f.dim) + named_shape += (f.dim,) continue if f.dim is None or f.size == -1: continue @@ -94,9 +94,9 @@ def _pyro_sample(msg: "Message") -> None: target_batch_shape[i] = ( actual_batch_shape[i] if len(actual_batch_shape) >= -i else 1 ) - dist = dist.expand_dims(tuple(dims)) + dist = dist.expand_named_shape(named_shape) msg["fn"] = dist.expand(target_batch_shape) - msg["fn"]._dims = dist.dims - msg["fn"].sample_shape = dist.sample_shape + msg["fn"]._named_shape = dist.named_shape + msg["fn"].named_sample_shape = dist.named_sample_shape if msg["fn"].has_rsample != dist.has_rsample: msg["fn"].has_rsample = dist.has_rsample # copy custom attribute From 252661b495e5d091cc36f3a9e295b1d4ec1aa370 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Tue, 26 Mar 2024 19:03:40 +0000 Subject: [PATCH 06/11] add test --- pyro/contrib/named/infer/__init__.py | 6 + pyro/contrib/named/infer/elbo.py | 146 ++++++++++++++++----- pyro/distributions/torch_distribution.py | 18 ++- pyro/poutine/broadcast_messenger.py | 13 +- pyro/poutine/indep_messenger.py | 15 +-- tests/contrib/named/infer/test_gradient.py | 63 +++++++++ torchdim.py | 95 -------------- 7 files changed, 208 insertions(+), 148 deletions(-) create mode 100644 pyro/contrib/named/infer/__init__.py create mode 100644 tests/contrib/named/infer/test_gradient.py delete mode 100644 torchdim.py diff --git a/pyro/contrib/named/infer/__init__.py b/pyro/contrib/named/infer/__init__.py new file mode 100644 index 0000000000..77832766f4 --- /dev/null +++ b/pyro/contrib/named/infer/__init__.py @@ -0,0 +1,6 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +from pyro.contrib.named.infer.elbo import Trace_ELBO + +__all__ = ["Trace_ELBO"] diff --git a/pyro/contrib/named/infer/elbo.py b/pyro/contrib/named/infer/elbo.py index cb3306697f..505946b896 100644 --- a/pyro/contrib/named/infer/elbo.py +++ b/pyro/contrib/named/infer/elbo.py @@ -1,51 +1,133 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -from functorch.dim import dims +from typing import Any, Callable, Tuple + +import torch +from functorch.dim import Dim +from typing_extensions import ParamSpec import pyro from pyro import poutine -from pyro.poutine.util import prune_subsample_sites +from pyro.distributions.torch_distribution import TorchDistributionMixin +from pyro.infer import ELBO as _OrigELBO +from pyro.poutine.messenger import Messenger +from pyro.poutine.runtime import Message +_P = ParamSpec("_P") -def log_density(fn, args, kwargs): - """ - Compute log density of a stochastic function given its arguments. - :param fn: Python callable containing Pyro primitives. - :param tuple args: args provided to the function. - :param dict kwargs: kwargs provided to the function. - :return: log of joint density and a corresponding model trace - """ - fn_trace = prune_subsample_sites(poutine.trace(fn).get_trace(*args, **kwargs)) - log_joint = 0.0 - for site in fn_trace.nodes.values(): - if site["type"] == "sample" and site["fn"]: - value = site["value"] - scale = site["scale"] - log_prob = site["fn"].log_prob(value) +class ELBO(_OrigELBO): + def _get_trace(self, *args, **kwargs): + raise RuntimeError("shouldn't be here!") - if scale is not None: - log_prob = scale * log_prob + def differentiable_loss(self, model, guide, *args, **kwargs): + raise NotImplementedError("Must implement differentiable_loss") + + def loss(self, model, guide, *args, **kwargs): + return self.differentiable_loss(model, guide, *args, **kwargs).detach().item() - sum_dims = tuple(f.dim for f in site["cond_indep_stack"]) - log_joint += log_prob.sum(sum_dims) - return log_joint, fn_trace + def loss_and_grads(self, model, guide, *args, **kwargs): + loss = self.differentiable_loss(model, guide, *args, **kwargs) + loss.backward() + return loss.item() -class ELBO: - def __init__(self, num_particles=1, vectorize_particles=True): - self.num_particles = num_particles - self.vectorize_particles = vectorize_particles +def track_provenance(x: torch.Tensor, provenance: Dim) -> torch.Tensor: + return x.unsqueeze(0)[provenance] - def loss(self, model, guide, *args, **kwargs): + +class track_nonreparam(Messenger): + def _pyro_post_sample(self, msg: Message) -> None: + if ( + msg["type"] == "sample" + and isinstance(msg["fn"], TorchDistributionMixin) + and not msg["is_observed"] + and not msg["fn"].has_rsample + ): + provenance = Dim(msg["name"]) + msg["value"] = track_provenance(msg["value"], provenance) + + +def get_importance_trace( + model: Callable[_P, Any], + guide: Callable[_P, Any], + *args: _P.args, + **kwargs: _P.kwargs +) -> Tuple[poutine.Trace, poutine.Trace]: + """ + Returns traces from the guide and the model that is run against it. + The returned traces also store the log probability at each site. + """ + with track_nonreparam(): + guide_trace = poutine.trace(guide).get_trace(*args, **kwargs) + replay_model = poutine.replay(model, trace=guide_trace) + model_trace = poutine.trace(replay_model).get_trace(*args, **kwargs) + + for is_guide, trace in zip((True, False), (guide_trace, model_trace)): + for site in list(trace.nodes.values()): + if site["type"] == "sample" and isinstance( + site["fn"], TorchDistributionMixin + ): + log_prob = site["fn"].log_prob(site["value"]) + site["log_prob"] = log_prob + + if is_guide and not site["fn"].has_rsample: + # importance sampling weights + site["log_measure"] = log_prob - log_prob.detach() + else: + trace.remove_node(site["name"]) + return model_trace, guide_trace + + +class Trace_ELBO(ELBO): + def differentiable_loss( + self, + model: Callable[_P, Any], + guide: Callable[_P, Any], + *args: _P.args, + **kwargs: _P.kwargs + ) -> torch.Tensor: if self.num_particles > 1: - vectorize = pyro.plate("num_particles", self.num_particles, dim=dims(1)) + vectorize = pyro.plate( + "num_particles", self.num_particles, dim=Dim("num_particles") + ) model = vectorize(model) guide = vectorize(guide) - guide_log_density, guide_trace = log_density(guide, args, kwargs) - replay_model = poutine.replay(model, trace=guide_trace) - model_log_density, model_trace = log_density(replay_model, args, kwargs) - elbo = (model_log_density - guide_log_density) / self.num_particles + model_trace, guide_trace = get_importance_trace(model, guide, *args, **kwargs) + + cost_terms = [] + # logp terms + for site in model_trace.nodes.values(): + cost = site["log_prob"] + scale = site["scale"] + batch_dims = tuple(f.dim for f in site["cond_indep_stack"]) + deps = tuple(set(getattr(cost, "dims", ())) - set(batch_dims)) + cost_terms.append((cost, scale, batch_dims, deps)) + # -logq terms + for site in guide_trace.nodes.values(): + cost = -site["log_prob"] + scale = site["scale"] + batch_dims = tuple(f.dim for f in site["cond_indep_stack"]) + deps = tuple(set(getattr(cost, "dims", ())) - set(batch_dims)) + cost_terms.append((cost, scale, batch_dims, deps)) + + elbo = 0.0 + for cost, scale, batch_dims, deps in cost_terms: + if deps: + dice_factor = 0.0 + for key in deps: + dice_factor += guide_trace.nodes[str(key)]["log_measure"] + dice_factor_dims = getattr(dice_factor, "dims", ()) + cost_dims = getattr(cost, "dims", ()) + sum_dims = tuple(set(dice_factor_dims) - set(cost_dims)) + if sum_dims: + dice_factor = dice_factor.sum(sum_dims) + cost = torch.exp(dice_factor) * cost + cost = cost.mean(deps) + if scale is not None: + cost = cost * scale + elbo += cost.sum(batch_dims) / self.num_particles + return -elbo diff --git a/pyro/distributions/torch_distribution.py b/pyro/distributions/torch_distribution.py index a497166b03..8f5ba3d9e0 100644 --- a/pyro/distributions/torch_distribution.py +++ b/pyro/distributions/torch_distribution.py @@ -8,6 +8,7 @@ import torch from functorch.dim import Dim from torch.distributions.kl import kl_divergence, register_kl +from typing_extensions import Self import pyro.distributions.torch @@ -47,14 +48,14 @@ def __call__(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor: :rtype: torch.Tensor """ sample_shape = self.named_sample_shape + sample_shape - bind_dims = self.named_shape[ + bind_named_dims = self.named_shape[ len(self.named_shape) - len(self.named_sample_shape) : ] return ( self.rsample(sample_shape) if self.has_rsample else self.sample(sample_shape) - )[bind_dims] + )[bind_named_dims] @property def named_shape(self) -> Tuple[Dim]: @@ -63,12 +64,17 @@ def named_shape(self) -> Tuple[Dim]: for param in self.arg_constraints: value = getattr(self, param) for dim in getattr(value, "dims", ()): - if dim not in result: + # Can't use `dim in result` when `result` is a list or a tuple + # RuntimeError: vmap: It looks like you're attempting to use + # a Tensor in some data-dependent control flow. We don't support + # that yet, please shout over at + # https://github.com/pytorch/functorch/issues/257 + if dim not in set(result): result.append(dim) self._named_shape = tuple(result) return self._named_shape - def expand_named_shape(self, named_shape): + def expand_named_shape(self, named_shape: Tuple[Dim]) -> Self: for dim in named_shape: if dim not in set(self.named_shape): self._named_shape += (dim,) @@ -76,13 +82,13 @@ def expand_named_shape(self, named_shape): return self @property - def named_sample_shape(self): + def named_sample_shape(self) -> torch.Size: if getattr(self, "_named_sample_shape", None) is None: self._named_sample_shape = torch.Size() return self._named_sample_shape @named_sample_shape.setter - def named_sample_shape(self, value): + def named_sample_shape(self, value: torch.Size) -> None: self._named_sample_shape = value @property diff --git a/pyro/poutine/broadcast_messenger.py b/pyro/poutine/broadcast_messenger.py index be7b4dc24c..a7bfd14873 100644 --- a/pyro/poutine/broadcast_messenger.py +++ b/pyro/poutine/broadcast_messenger.py @@ -61,10 +61,10 @@ def _pyro_sample(msg: "Message") -> None: target_batch_shape = [ None if size == 1 else size for size in actual_batch_shape ] - named_shape = () + named_shape: List[Dim] = [] for f in msg["cond_indep_stack"]: if isinstance(f.dim, Dim): - named_shape += (f.dim,) + named_shape.append(f.dim) continue if f.dim is None or f.size == -1: continue @@ -94,9 +94,10 @@ def _pyro_sample(msg: "Message") -> None: target_batch_shape[i] = ( actual_batch_shape[i] if len(actual_batch_shape) >= -i else 1 ) - dist = dist.expand_named_shape(named_shape) - msg["fn"] = dist.expand(target_batch_shape) - msg["fn"]._named_shape = dist.named_shape - msg["fn"].named_sample_shape = dist.named_sample_shape + if named_shape: + assert len(target_batch_shape) == 0 + msg["fn"] = dist.expand_named_shape(tuple(named_shape)) + else: + msg["fn"] = dist.expand(target_batch_shape) if msg["fn"].has_rsample != dist.has_rsample: msg["fn"].has_rsample = dist.has_rsample # copy custom attribute diff --git a/pyro/poutine/indep_messenger.py b/pyro/poutine/indep_messenger.py index 78809334d4..9784fbf80c 100644 --- a/pyro/poutine/indep_messenger.py +++ b/pyro/poutine/indep_messenger.py @@ -98,17 +98,15 @@ def __enter__(self) -> Self: if self._vectorized is not False: self._vectorized = True - if self._vectorized is True: - if not isinstance(self.dim, Dim): - self.dim = _DIM_ALLOCATOR.allocate(self.name, self.dim) + if self._vectorized is True and not isinstance(self.dim, Dim): + self.dim = _DIM_ALLOCATOR.allocate(self.name, self.dim) return super().__enter__() def __exit__(self, *args) -> None: - if self._vectorized is True: + if self._vectorized is True and not isinstance(self.dim, Dim): assert self.dim is not None - if not isinstance(self.dim, Dim): - _DIM_ALLOCATOR.free(self.name, self.dim) + _DIM_ALLOCATOR.free(self.name, self.dim) return super().__exit__(*args) def __iter__(self) -> Iterator[int]: @@ -127,10 +125,9 @@ def __iter__(self) -> Iterator[int]: yield i if isinstance(i, numbers.Number) else i.item() def _reset(self) -> None: - if self._vectorized: + if self._vectorized and not isinstance(self.dim, Dim): assert self.dim is not None - if not isinstance(self.dim, Dim): - _DIM_ALLOCATOR.free(self.name, self.dim) + _DIM_ALLOCATOR.free(self.name, self.dim) self._vectorized = None self.counter = 0 diff --git a/tests/contrib/named/infer/test_gradient.py b/tests/contrib/named/infer/test_gradient.py new file mode 100644 index 0000000000..69ea59353f --- /dev/null +++ b/tests/contrib/named/infer/test_gradient.py @@ -0,0 +1,63 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch +from functorch.dim import dims + +import pyro +import pyro.distributions as dist +from pyro.contrib.named.infer import Trace_ELBO +from pyro.distributions.testing import fakes +from pyro.infer import SVI +from pyro.optim import Adam +from tests.common import assert_equal + + +@pytest.mark.parametrize( + "reparameterized", [True, False], ids=["reparam", "nonreparam"] +) +def test_plate_elbo_vectorized_particles(reparameterized): + pyro.enable_validation(False) + pyro.clear_param_store() + data = torch.tensor([-0.5, 2.0]) + num_particles = 200000 + Normal = dist.Normal if reparameterized else fakes.NonreparameterizedNormal + i = dims() + + def model(): + data_plate = pyro.plate("data", len(data), dim=i) + + pyro.sample("nuisance_a", Normal(0, 1)) + with data_plate: + z = pyro.sample("z", Normal(0, 1)) + pyro.sample("nuisance_b", Normal(2, 3)) + with data_plate as idx: + pyro.sample("x", Normal(z, torch.ones(len(data))[idx]), obs=data[idx]) + pyro.sample("nuisance_c", Normal(4, 5)) + + def guide(): + loc = pyro.param("loc", torch.zeros(len(data))) + scale = pyro.param("scale", torch.ones(len(data))) + + pyro.sample("nuisance_c", Normal(4, 5)) + with pyro.plate("data", len(data), dim=i) as idx: + pyro.sample("z", Normal(loc[idx], scale[idx])) + pyro.sample("nuisance_b", Normal(2, 3)) + pyro.sample("nuisance_a", Normal(0, 1)) + + optim = Adam({"lr": 0.1}) + loss = Trace_ELBO( + num_particles=num_particles, + vectorize_particles=True, + ) + inference = SVI(model, guide, optim, loss=loss) + inference.loss_and_grads(model, guide) + params = dict(pyro.get_param_store().named_parameters()) + actual_grads = {name: param.grad.detach() for name, param in params.items()} + + expected_grads = { + "loc": torch.tensor([0.5, -2.0]), + "scale": torch.tensor([1.0, 1.0]), + } + assert_equal(actual_grads, expected_grads, prec=0.06) diff --git a/torchdim.py b/torchdim.py deleted file mode 100644 index 30008aaed3..0000000000 --- a/torchdim.py +++ /dev/null @@ -1,95 +0,0 @@ -# Copyright Contributors to the Pyro project. -# SPDX-License-Identifier: Apache-2.0 - -import torch -from functorch.dim import dims - -import pyro -import pyro.distributions as dist -from pyro.contrib.named.infer.elbo import ELBO -from pyro.distributions.named import index_select - -# i, j = dims(2) -# loc = torch.zeros(2, 3)[i, j] -# scale = torch.ones(2)[i] -# import pdb - -# pdb.set_trace() -# normal = dist.Normal(loc, scale=torch.tensor(1.0), validate_args=False) -# k.size = 4 -# normal.expand_named_shape([i, j, k]) -# import pdb - -# pdb.set_trace() -# normal.named_batch_shape -# x = normal.sample() -# log_prob_x = normal.log_prob(x) -# y = torch.randn(2)[i] -# log_prob_y = normal.log_prob(y) -# z = torch.randn(3, 4)[j, k] -# log_prob_z = normal.log_prob(z) -# dir = Dirichlet(torch.ones(3)) - -pyro.enable_validation(False) - - -# @config_enumerate -def model(data_dim, feature_dim, component_dim): - data_plate = pyro.plate("data_plate", 6, dim=data_dim) - feature_plate = pyro.plate("feature_plate", 5, dim=feature_dim) - component_plate = pyro.plate("component_plate", 4, dim=component_dim) - # component_plate = pyro.plate("component_plate", 4, dim=-1) - with feature_plate, component_plate: - p = pyro.sample("p", dist.Dirichlet(torch.ones(3))) - with data_plate as idx: - c = pyro.sample( - "c", dist.Categorical(torch.ones(4).expand([data_dim.size, 4])[data_dim]) - ) - with feature_plate as vdx: # Capture plate index. - pc = index_select(p, dim=component_dim, index=c) - # pc = p[c] - x = pyro.sample( - "x", - dist.Categorical(pc), - obs=torch.zeros(5, 6, dtype=torch.long)[vdx, idx], - ) - print(f" p.shape = {p.shape}") - print(f" c.shape = {c.shape}") - print(f" vdx.shape = {vdx.shape}") - print(f" pc.shape = {pc.shape}") - print(f" x.shape = {x.shape}") - - -def guide(data_dim, feature_dim, component_dim): - data_plate = pyro.plate("data_plate", 6, dim=data_dim) - feature_plate = pyro.plate("feature_plate", 5, dim=feature_dim) - component_plate = pyro.plate("component_plate", 4, dim=component_dim) - # component_plate = pyro.plate("component_plate", 4, dim=-1) - with feature_plate, component_plate: - pyro.sample( - "p", - dist.Dirichlet( - torch.ones(3).expand([feature_dim.size, component_dim.size, 3])[ - feature_dim, component_dim - ] - ), - ) - with data_plate: - pyro.sample( - "c", dist.Categorical(torch.ones(4).expand([data_dim.size, 4])[data_dim]) - ) - - -data_dim, feature_dim, component_dim = dims(3) -pyro.clear_param_store() -print("Sampling:") -print("Enumerated Inference:") -elbo = ELBO() -# model(data_dim, feature_dim, component_dim) -loss = elbo.loss(model, guide, data_dim, feature_dim, component_dim) -elbo_10 = ELBO(num_particles=10) -loss_10 = elbo_10.loss(model, guide, data_dim, feature_dim, component_dim) -elbo_100 = ELBO(num_particles=100) -loss_100 = elbo_100.loss(model, guide, data_dim, feature_dim, component_dim) -elbo_1000 = ELBO(num_particles=1000) -loss_1000 = elbo_1000.loss(model, guide, data_dim, feature_dim, component_dim) From 5538caacd032e6b99900da73e68c7440b7e574cc Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Wed, 27 Mar 2024 00:16:47 +0000 Subject: [PATCH 07/11] fixes --- pyro/contrib/named/__init__.py | 0 pyro/distributions/torch_distribution.py | 10 ++++++---- pyro/poutine/broadcast_messenger.py | 6 +++--- pyro/poutine/indep_messenger.py | 20 +++++++++++--------- 4 files changed, 20 insertions(+), 16 deletions(-) create mode 100644 pyro/contrib/named/__init__.py diff --git a/pyro/contrib/named/__init__.py b/pyro/contrib/named/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/pyro/distributions/torch_distribution.py b/pyro/distributions/torch_distribution.py index 8f5ba3d9e0..e625317e66 100644 --- a/pyro/distributions/torch_distribution.py +++ b/pyro/distributions/torch_distribution.py @@ -3,10 +3,9 @@ import warnings from collections import OrderedDict -from typing import Callable, Tuple +from typing import TYPE_CHECKING, Callable, Tuple import torch -from functorch.dim import Dim from torch.distributions.kl import kl_divergence, register_kl from typing_extensions import Self @@ -17,6 +16,9 @@ from .score_parts import ScoreParts from .util import broadcast_shape, scale_and_mask +if TYPE_CHECKING: + from functorch.dim import Dim + class TorchDistributionMixin(Distribution, Callable): """ @@ -58,7 +60,7 @@ def __call__(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor: )[bind_named_dims] @property - def named_shape(self) -> Tuple[Dim]: + def named_shape(self) -> Tuple["Dim"]: if getattr(self, "_named_shape", None) is None: result = [] for param in self.arg_constraints: @@ -74,7 +76,7 @@ def named_shape(self) -> Tuple[Dim]: self._named_shape = tuple(result) return self._named_shape - def expand_named_shape(self, named_shape: Tuple[Dim]) -> Self: + def expand_named_shape(self, named_shape: Tuple["Dim"]) -> Self: for dim in named_shape: if dim not in set(self.named_shape): self._named_shape += (dim,) diff --git a/pyro/poutine/broadcast_messenger.py b/pyro/poutine/broadcast_messenger.py index a7bfd14873..51c7ddc635 100644 --- a/pyro/poutine/broadcast_messenger.py +++ b/pyro/poutine/broadcast_messenger.py @@ -3,13 +3,13 @@ from typing import TYPE_CHECKING, List, Optional -from functorch.dim import Dim - from pyro.distributions.torch_distribution import TorchDistributionMixin from pyro.poutine.messenger import Messenger from pyro.util import ignore_jit_warnings if TYPE_CHECKING: + from functorch.dim import Dim + from pyro.poutine.runtime import Message @@ -63,7 +63,7 @@ def _pyro_sample(msg: "Message") -> None: ] named_shape: List[Dim] = [] for f in msg["cond_indep_stack"]: - if isinstance(f.dim, Dim): + if hasattr(f.dim, "is_bound"): named_shape.append(f.dim) continue if f.dim is None or f.size == -1: diff --git a/pyro/poutine/indep_messenger.py b/pyro/poutine/indep_messenger.py index 9784fbf80c..2ae2863501 100644 --- a/pyro/poutine/indep_messenger.py +++ b/pyro/poutine/indep_messenger.py @@ -2,20 +2,22 @@ # SPDX-License-Identifier: Apache-2.0 import numbers -from typing import Iterator, NamedTuple, Optional, Tuple, Union +from typing import TYPE_CHECKING, Iterator, NamedTuple, Optional, Tuple, Union import torch -from functorch.dim import Dim from typing_extensions import Self from pyro.poutine.messenger import Messenger from pyro.poutine.runtime import _DIM_ALLOCATOR, Message from pyro.util import ignore_jit_warnings +if TYPE_CHECKING: + from functorch.dim import Dim + class CondIndepStackFrame(NamedTuple): name: str - dim: Optional[Union[int, Dim]] + dim: Optional[Union[int, "Dim"]] size: int counter: int full_size: Optional[int] = None @@ -24,7 +26,7 @@ class CondIndepStackFrame(NamedTuple): def vectorized(self) -> bool: return self.dim is not None - def _key(self) -> Tuple[str, Optional[Union[int, Dim]], int, int]: + def _key(self) -> Tuple[str, Optional[Union[int, "Dim"]], int, int]: size = self.size with ignore_jit_warnings(["Converting a tensor to a Python number"]): if isinstance(size, torch.Tensor): # type: ignore[unreachable] @@ -70,7 +72,7 @@ def __init__( self, name: str, size: int, - dim: Optional[Union[int, Dim]] = None, + dim: Optional[Union[int, "Dim"]] = None, device: Optional[str] = None, ) -> None: if not torch._C._get_tracing_state() and size == 0: @@ -98,13 +100,13 @@ def __enter__(self) -> Self: if self._vectorized is not False: self._vectorized = True - if self._vectorized is True and not isinstance(self.dim, Dim): + if self._vectorized is True and not hasattr(self.dim, "is_bound"): self.dim = _DIM_ALLOCATOR.allocate(self.name, self.dim) return super().__enter__() def __exit__(self, *args) -> None: - if self._vectorized is True and not isinstance(self.dim, Dim): + if self._vectorized is True and not hasattr(self.dim, "is_bound"): assert self.dim is not None _DIM_ALLOCATOR.free(self.name, self.dim) return super().__exit__(*args) @@ -125,7 +127,7 @@ def __iter__(self) -> Iterator[int]: yield i if isinstance(i, numbers.Number) else i.item() def _reset(self) -> None: - if self._vectorized and not isinstance(self.dim, Dim): + if self._vectorized and not hasattr(self.dim, "is_bound"): assert self.dim is not None _DIM_ALLOCATOR.free(self.name, self.dim) self._vectorized = None @@ -135,7 +137,7 @@ def _reset(self) -> None: def indices(self) -> torch.Tensor: if self._indices is None: self._indices = torch.arange(self.size, dtype=torch.long).to(self.device) - if isinstance(self.dim, Dim): + if hasattr(self.dim, "is_bound"): return self._indices[self.dim] return self._indices From 76bd18479bf18d3dd1b2f4af59e90d2058172a89 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Wed, 27 Mar 2024 00:23:45 +0000 Subject: [PATCH 08/11] minor fix --- pyro/poutine/broadcast_messenger.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyro/poutine/broadcast_messenger.py b/pyro/poutine/broadcast_messenger.py index 51c7ddc635..eedd1b1364 100644 --- a/pyro/poutine/broadcast_messenger.py +++ b/pyro/poutine/broadcast_messenger.py @@ -61,7 +61,7 @@ def _pyro_sample(msg: "Message") -> None: target_batch_shape = [ None if size == 1 else size for size in actual_batch_shape ] - named_shape: List[Dim] = [] + named_shape: List["Dim"] = [] for f in msg["cond_indep_stack"]: if hasattr(f.dim, "is_bound"): named_shape.append(f.dim) From d458ec3c50a9e8d8789869cba99ad33e36cec5b3 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Wed, 27 Mar 2024 00:33:14 +0000 Subject: [PATCH 09/11] ignore --- Makefile | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 04c6112d47..f09c7eb694 100644 --- a/Makefile +++ b/Makefile @@ -40,7 +40,8 @@ scrub: FORCE doctest: FORCE # We skip testing pyro.distributions.torch wrapper classes because # they include torch docstrings which are tested upstream. - python -m pytest -p tests.doctest_fixtures --doctest-modules -o filterwarnings=ignore pyro --ignore=pyro/distributions/torch.py + python -m pytest -p tests.doctest_fixtures --doctest-modules -o filterwarnings=ignore pyro --ignore=pyro/distributions/torch.py \ + --ignore=pyro/contrib/named perf-test: FORCE bash scripts/perf_test.sh ${ref} From b6e8e057e9ec85d2260b51bb089b78c51c7191b7 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Wed, 27 Mar 2024 01:31:01 +0000 Subject: [PATCH 10/11] fix test --- pyro/distributions/torch_distribution.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/pyro/distributions/torch_distribution.py b/pyro/distributions/torch_distribution.py index e625317e66..9097c10668 100644 --- a/pyro/distributions/torch_distribution.py +++ b/pyro/distributions/torch_distribution.py @@ -50,14 +50,17 @@ def __call__(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor: :rtype: torch.Tensor """ sample_shape = self.named_sample_shape + sample_shape - bind_named_dims = self.named_shape[ - len(self.named_shape) - len(self.named_sample_shape) : - ] - return ( + result = ( self.rsample(sample_shape) if self.has_rsample else self.sample(sample_shape) - )[bind_named_dims] + ) + bind_named_dims = self.named_shape[ + len(self.named_shape) - len(self.named_sample_shape) : + ] + if bind_named_dims: + result = result[bind_named_dims] + return result @property def named_shape(self) -> Tuple["Dim"]: From faa2749d92c8d36d88f2a989091e8115aec81c1a Mon Sep 17 00:00:00 2001 From: Ben Zickel Date: Thu, 28 Mar 2024 14:44:47 +0200 Subject: [PATCH 11/11] Add effective sample size analytics to WeighedPredictive results. --- pyro/infer/importance.py | 96 +++++++++++++++++++--------------- pyro/infer/predictive.py | 13 +++-- tests/infer/test_predictive.py | 5 +- 3 files changed, 66 insertions(+), 48 deletions(-) diff --git a/pyro/infer/importance.py b/pyro/infer/importance.py index d25cf16680..bcb144b170 100644 --- a/pyro/infer/importance.py +++ b/pyro/infer/importance.py @@ -15,7 +15,59 @@ from .util import plate_log_prob_sum -class Importance(TracePosterior): +class WeightAnalytics: + def get_log_normalizer(self): + """ + Estimator of the normalizing constant of the target distribution. + (mean of the unnormalized weights) + """ + # ensure list is not empty + if len(self.log_weights) > 0: + log_w = ( + self.log_weights + if isinstance(self.log_weights, torch.Tensor) + else torch.tensor(self.log_weights) + ) + log_num_samples = torch.log(torch.tensor(log_w.numel() * 1.0)) + return torch.logsumexp(log_w - log_num_samples, 0) + else: + warnings.warn( + "The log_weights list is empty, can not compute normalizing constant estimate." + ) + + def get_normalized_weights(self, log_scale=False): + """ + Compute the normalized importance weights. + """ + if len(self.log_weights) > 0: + log_w = ( + self.log_weights + if isinstance(self.log_weights, torch.Tensor) + else torch.tensor(self.log_weights) + ) + log_w_norm = log_w - torch.logsumexp(log_w, 0) + return log_w_norm if log_scale else torch.exp(log_w_norm) + else: + warnings.warn( + "The log_weights list is empty. There is nothing to normalize." + ) + + def get_ESS(self): + """ + Compute (Importance Sampling) Effective Sample Size (ESS). + """ + if len(self.log_weights) > 0: + log_w_norm = self.get_normalized_weights(log_scale=True) + ess = torch.exp(-torch.logsumexp(2 * log_w_norm, 0)) + else: + warnings.warn( + "The log_weights list is empty, effective sample size is zero." + ) + ess = 0 + return ess + + +class Importance(TracePosterior, WeightAnalytics): """ :param model: probabilistic model defined as a function :param guide: guide used for sampling defined as a function @@ -55,48 +107,6 @@ def _traces(self, *args, **kwargs): log_weight = model_trace.log_prob_sum() - guide_trace.log_prob_sum() yield (model_trace, log_weight) - def get_log_normalizer(self): - """ - Estimator of the normalizing constant of the target distribution. - (mean of the unnormalized weights) - """ - # ensure list is not empty - if self.log_weights: - log_w = torch.tensor(self.log_weights) - log_num_samples = torch.log(torch.tensor(self.num_samples * 1.0)) - return torch.logsumexp(log_w - log_num_samples, 0) - else: - warnings.warn( - "The log_weights list is empty, can not compute normalizing constant estimate." - ) - - def get_normalized_weights(self, log_scale=False): - """ - Compute the normalized importance weights. - """ - if self.log_weights: - log_w = torch.tensor(self.log_weights) - log_w_norm = log_w - torch.logsumexp(log_w, 0) - return log_w_norm if log_scale else torch.exp(log_w_norm) - else: - warnings.warn( - "The log_weights list is empty. There is nothing to normalize." - ) - - def get_ESS(self): - """ - Compute (Importance Sampling) Effective Sample Size (ESS). - """ - if self.log_weights: - log_w_norm = self.get_normalized_weights(log_scale=True) - ess = torch.exp(-torch.logsumexp(2 * log_w_norm, 0)) - else: - warnings.warn( - "The log_weights list is empty, effective sample size is zero." - ) - ess = 0 - return ess - def vectorized_importance_weights(model, guide, *args, **kwargs): """ diff --git a/pyro/infer/predictive.py b/pyro/infer/predictive.py index 6be8b5cb5f..3d020f0fdb 100644 --- a/pyro/infer/predictive.py +++ b/pyro/infer/predictive.py @@ -9,6 +9,7 @@ import pyro import pyro.poutine as poutine +from pyro.infer.importance import WeightAnalytics from pyro.infer.util import plate_log_prob_sum from pyro.poutine.trace_struct import Trace from pyro.poutine.util import prune_subsample_sites @@ -317,16 +318,20 @@ def get_vectorized_trace(self, *args, **kwargs): class WeighedPredictiveResults(NamedTuple): - """ - Return value of call to instance of :class:`WeighedPredictive`. - """ - samples: Union[dict, tuple] log_weights: torch.Tensor guide_log_prob: torch.Tensor model_log_prob: torch.Tensor +class WeighedPredictiveResults(WeighedPredictiveResults, WeightAnalytics): + """ + Return value of call to instance of :class:`WeighedPredictive`. + """ + + pass + + class WeighedPredictive(Predictive): """ Class used to construct a weighed predictive distribution that is based diff --git a/tests/infer/test_predictive.py b/tests/infer/test_predictive.py index 1f28e1f05c..319a1196dd 100644 --- a/tests/infer/test_predictive.py +++ b/tests/infer/test_predictive.py @@ -46,6 +46,7 @@ def test_posterior_predictive_svi_manual_guide(parallel, predictive): num_trials = ( torch.ones(5) * 400 ) # Reduced to 400 from 1000 in order for guide optimization to converge + num_samples = 10000 num_success = dist.Binomial(num_trials, true_probs).sample() conditioned_model = poutine.condition(model, data={"obs": num_success}) elbo = Trace_ELBO(num_particles=100, vectorize_particles=True) @@ -57,7 +58,7 @@ def test_posterior_predictive_svi_manual_guide(parallel, predictive): posterior_predictive = predictive( model, guide=beta_guide, - num_samples=10000, + num_samples=num_samples, parallel=parallel, return_sites=["_RETURN"], ) @@ -71,6 +72,8 @@ def test_posterior_predictive_svi_manual_guide(parallel, predictive): assert marginal_return_vals.shape[:1] == weighed_samples.log_weights.shape # Weights should be uniform as the guide has the same distribution as the model assert weighed_samples.log_weights.std() < 0.6 + # Effective sample size should be close to actual number of samples taken from the guide + assert weighed_samples.get_ESS() > 0.8 * num_samples assert_close(marginal_return_vals.mean(dim=0), torch.ones(5) * 280, rtol=0.1)