Skip to content

Commit

Permalink
Add effective sample size analytics to WeighedPredictive results (#3351)
Browse files Browse the repository at this point in the history
  • Loading branch information
BenZickel authored Mar 28, 2024
1 parent 0e08427 commit 58080f8
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 46 deletions.
101 changes: 59 additions & 42 deletions pyro/infer/importance.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import math
import warnings
from typing import List, Union

import torch

Expand All @@ -15,55 +16,26 @@
from .util import plate_log_prob_sum


class Importance(TracePosterior):
class LogWeightsMixin:
"""
:param model: probabilistic model defined as a function
:param guide: guide used for sampling defined as a function
:param num_samples: number of samples to draw from the guide (default 10)
This method performs posterior inference by importance sampling
using the guide as the proposal distribution.
If no guide is provided, it defaults to proposing from the model's prior.
Mixin class to compute analytics from a ``.log_weights`` attribute.
"""

def __init__(self, model, guide=None, num_samples=None):
"""
Constructor. default to num_samples = 10, guide = model
"""
super().__init__()
if num_samples is None:
num_samples = 10
warnings.warn(
"num_samples not provided, defaulting to {}".format(num_samples)
)
if guide is None:
# propose from the prior by making a guide from the model by hiding observes
guide = poutine.block(model, hide_types=["observe"])
self.num_samples = num_samples
self.model = model
self.guide = guide

def _traces(self, *args, **kwargs):
"""
Generator of weighted samples from the proposal distribution.
"""
for i in range(self.num_samples):
guide_trace = poutine.trace(self.guide).get_trace(*args, **kwargs)
model_trace = poutine.trace(
poutine.replay(self.model, trace=guide_trace)
).get_trace(*args, **kwargs)
log_weight = model_trace.log_prob_sum() - guide_trace.log_prob_sum()
yield (model_trace, log_weight)
log_weights: Union[List[Union[float, torch.Tensor]], torch.Tensor]

def get_log_normalizer(self):
"""
Estimator of the normalizing constant of the target distribution.
(mean of the unnormalized weights)
"""
# ensure list is not empty
if self.log_weights:
log_w = torch.tensor(self.log_weights)
log_num_samples = torch.log(torch.tensor(self.num_samples * 1.0))
if len(self.log_weights) > 0:
log_w = (
self.log_weights
if isinstance(self.log_weights, torch.Tensor)
else torch.tensor(self.log_weights)
)
log_num_samples = torch.log(torch.tensor(log_w.numel() * 1.0))
return torch.logsumexp(log_w - log_num_samples, 0)
else:
warnings.warn(
Expand All @@ -74,8 +46,12 @@ def get_normalized_weights(self, log_scale=False):
"""
Compute the normalized importance weights.
"""
if self.log_weights:
log_w = torch.tensor(self.log_weights)
if len(self.log_weights) > 0:
log_w = (
self.log_weights
if isinstance(self.log_weights, torch.Tensor)
else torch.tensor(self.log_weights)
)
log_w_norm = log_w - torch.logsumexp(log_w, 0)
return log_w_norm if log_scale else torch.exp(log_w_norm)
else:
Expand All @@ -87,7 +63,7 @@ def get_ESS(self):
"""
Compute (Importance Sampling) Effective Sample Size (ESS).
"""
if self.log_weights:
if len(self.log_weights) > 0:
log_w_norm = self.get_normalized_weights(log_scale=True)
ess = torch.exp(-torch.logsumexp(2 * log_w_norm, 0))
else:
Expand All @@ -98,6 +74,47 @@ def get_ESS(self):
return ess


class Importance(TracePosterior, LogWeightsMixin):
"""
:param model: probabilistic model defined as a function
:param guide: guide used for sampling defined as a function
:param num_samples: number of samples to draw from the guide (default 10)
This method performs posterior inference by importance sampling
using the guide as the proposal distribution.
If no guide is provided, it defaults to proposing from the model's prior.
"""

def __init__(self, model, guide=None, num_samples=None):
"""
Constructor. default to num_samples = 10, guide = model
"""
super().__init__()
if num_samples is None:
num_samples = 10
warnings.warn(
"num_samples not provided, defaulting to {}".format(num_samples)
)
if guide is None:
# propose from the prior by making a guide from the model by hiding observes
guide = poutine.block(model, hide_types=["observe"])
self.num_samples = num_samples
self.model = model
self.guide = guide

def _traces(self, *args, **kwargs):
"""
Generator of weighted samples from the proposal distribution.
"""
for i in range(self.num_samples):
guide_trace = poutine.trace(self.guide).get_trace(*args, **kwargs)
model_trace = poutine.trace(
poutine.replay(self.model, trace=guide_trace)
).get_trace(*args, **kwargs)
log_weight = model_trace.log_prob_sum() - guide_trace.log_prob_sum()
yield (model_trace, log_weight)


def vectorized_importance_weights(model, guide, *args, **kwargs):
"""
:param model: probabilistic model defined as a function
Expand Down
10 changes: 7 additions & 3 deletions pyro/infer/predictive.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
# SPDX-License-Identifier: Apache-2.0

import warnings
from dataclasses import dataclass
from functools import reduce
from typing import List, NamedTuple, Union
from typing import List, Union

import torch

import pyro
import pyro.poutine as poutine
from pyro.infer.importance import LogWeightsMixin
from pyro.infer.util import plate_log_prob_sum
from pyro.poutine.trace_struct import Trace
from pyro.poutine.util import prune_subsample_sites
Expand All @@ -34,7 +36,8 @@ def _guess_max_plate_nesting(model, args, kwargs):
return max_plate_nesting


class _predictiveResults(NamedTuple):
@dataclass(frozen=True, eq=False)
class _predictiveResults:
"""
Return value of call to ``_predictive`` and ``_predictive_sequential``.
"""
Expand Down Expand Up @@ -316,7 +319,8 @@ def get_vectorized_trace(self, *args, **kwargs):
).trace


class WeighedPredictiveResults(NamedTuple):
@dataclass(frozen=True, eq=False)
class WeighedPredictiveResults(LogWeightsMixin):
"""
Return value of call to instance of :class:`WeighedPredictive`.
"""
Expand Down
5 changes: 4 additions & 1 deletion tests/infer/test_predictive.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def test_posterior_predictive_svi_manual_guide(parallel, predictive):
num_trials = (
torch.ones(5) * 400
) # Reduced to 400 from 1000 in order for guide optimization to converge
num_samples = 10000
num_success = dist.Binomial(num_trials, true_probs).sample()
conditioned_model = poutine.condition(model, data={"obs": num_success})
elbo = Trace_ELBO(num_particles=100, vectorize_particles=True)
Expand All @@ -57,7 +58,7 @@ def test_posterior_predictive_svi_manual_guide(parallel, predictive):
posterior_predictive = predictive(
model,
guide=beta_guide,
num_samples=10000,
num_samples=num_samples,
parallel=parallel,
return_sites=["_RETURN"],
)
Expand All @@ -71,6 +72,8 @@ def test_posterior_predictive_svi_manual_guide(parallel, predictive):
assert marginal_return_vals.shape[:1] == weighed_samples.log_weights.shape
# Weights should be uniform as the guide has the same distribution as the model
assert weighed_samples.log_weights.std() < 0.6
# Effective sample size should be close to actual number of samples taken from the guide
assert weighed_samples.get_ESS() > 0.8 * num_samples
assert_close(marginal_return_vals.mean(dim=0), torch.ones(5) * 280, rtol=0.1)


Expand Down

0 comments on commit 58080f8

Please sign in to comment.