Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add effective sample size analytics to WeighedPredictive results #3350

Closed
wants to merge 15 commits into from
3 changes: 2 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ scrub: FORCE
doctest: FORCE
# We skip testing pyro.distributions.torch wrapper classes because
# they include torch docstrings which are tested upstream.
python -m pytest -p tests.doctest_fixtures --doctest-modules -o filterwarnings=ignore pyro --ignore=pyro/distributions/torch.py
python -m pytest -p tests.doctest_fixtures --doctest-modules -o filterwarnings=ignore pyro --ignore=pyro/distributions/torch.py \
--ignore=pyro/contrib/named

perf-test: FORCE
bash scripts/perf_test.sh ${ref}
Expand Down
Empty file added pyro/contrib/named/__init__.py
Empty file.
6 changes: 6 additions & 0 deletions pyro/contrib/named/infer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from pyro.contrib.named.infer.elbo import Trace_ELBO

__all__ = ["Trace_ELBO"]
133 changes: 133 additions & 0 deletions pyro/contrib/named/infer/elbo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Callable, Tuple

import torch
from functorch.dim import Dim
from typing_extensions import ParamSpec

import pyro
from pyro import poutine
from pyro.distributions.torch_distribution import TorchDistributionMixin
from pyro.infer import ELBO as _OrigELBO
from pyro.poutine.messenger import Messenger
from pyro.poutine.runtime import Message

_P = ParamSpec("_P")


class ELBO(_OrigELBO):
def _get_trace(self, *args, **kwargs):
raise RuntimeError("shouldn't be here!")

def differentiable_loss(self, model, guide, *args, **kwargs):
raise NotImplementedError("Must implement differentiable_loss")

def loss(self, model, guide, *args, **kwargs):
return self.differentiable_loss(model, guide, *args, **kwargs).detach().item()

def loss_and_grads(self, model, guide, *args, **kwargs):
loss = self.differentiable_loss(model, guide, *args, **kwargs)
loss.backward()
return loss.item()


def track_provenance(x: torch.Tensor, provenance: Dim) -> torch.Tensor:
return x.unsqueeze(0)[provenance]


class track_nonreparam(Messenger):
def _pyro_post_sample(self, msg: Message) -> None:
if (
msg["type"] == "sample"
and isinstance(msg["fn"], TorchDistributionMixin)
and not msg["is_observed"]
and not msg["fn"].has_rsample
):
provenance = Dim(msg["name"])
msg["value"] = track_provenance(msg["value"], provenance)


def get_importance_trace(
model: Callable[_P, Any],
guide: Callable[_P, Any],
*args: _P.args,
**kwargs: _P.kwargs
) -> Tuple[poutine.Trace, poutine.Trace]:
"""
Returns traces from the guide and the model that is run against it.
The returned traces also store the log probability at each site.
"""
with track_nonreparam():
guide_trace = poutine.trace(guide).get_trace(*args, **kwargs)
replay_model = poutine.replay(model, trace=guide_trace)
model_trace = poutine.trace(replay_model).get_trace(*args, **kwargs)

for is_guide, trace in zip((True, False), (guide_trace, model_trace)):
for site in list(trace.nodes.values()):
if site["type"] == "sample" and isinstance(
site["fn"], TorchDistributionMixin
):
log_prob = site["fn"].log_prob(site["value"])
site["log_prob"] = log_prob

if is_guide and not site["fn"].has_rsample:
# importance sampling weights
site["log_measure"] = log_prob - log_prob.detach()
else:
trace.remove_node(site["name"])
return model_trace, guide_trace


class Trace_ELBO(ELBO):
def differentiable_loss(
self,
model: Callable[_P, Any],
guide: Callable[_P, Any],
*args: _P.args,
**kwargs: _P.kwargs
) -> torch.Tensor:
if self.num_particles > 1:
vectorize = pyro.plate(
"num_particles", self.num_particles, dim=Dim("num_particles")
)
model = vectorize(model)
guide = vectorize(guide)

model_trace, guide_trace = get_importance_trace(model, guide, *args, **kwargs)

cost_terms = []
# logp terms
for site in model_trace.nodes.values():
cost = site["log_prob"]
scale = site["scale"]
batch_dims = tuple(f.dim for f in site["cond_indep_stack"])
deps = tuple(set(getattr(cost, "dims", ())) - set(batch_dims))
cost_terms.append((cost, scale, batch_dims, deps))
# -logq terms
for site in guide_trace.nodes.values():
cost = -site["log_prob"]
scale = site["scale"]
batch_dims = tuple(f.dim for f in site["cond_indep_stack"])
deps = tuple(set(getattr(cost, "dims", ())) - set(batch_dims))
cost_terms.append((cost, scale, batch_dims, deps))

elbo = 0.0
for cost, scale, batch_dims, deps in cost_terms:
if deps:
dice_factor = 0.0
for key in deps:
dice_factor += guide_trace.nodes[str(key)]["log_measure"]
dice_factor_dims = getattr(dice_factor, "dims", ())
cost_dims = getattr(cost, "dims", ())
sum_dims = tuple(set(dice_factor_dims) - set(cost_dims))
if sum_dims:
dice_factor = dice_factor.sum(sum_dims)
cost = torch.exp(dice_factor) * cost
cost = cost.mean(deps)
if scale is not None:
cost = cost * scale
elbo += cost.sum(batch_dims) / self.num_particles

return -elbo
49 changes: 47 additions & 2 deletions pyro/distributions/torch_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@

import warnings
from collections import OrderedDict
from typing import Callable
from typing import TYPE_CHECKING, Callable, Tuple

import torch
from torch.distributions.kl import kl_divergence, register_kl
from typing_extensions import Self

import pyro.distributions.torch

Expand All @@ -15,6 +16,9 @@
from .score_parts import ScoreParts
from .util import broadcast_shape, scale_and_mask

if TYPE_CHECKING:
from functorch.dim import Dim


class TorchDistributionMixin(Distribution, Callable):
"""
Expand Down Expand Up @@ -45,11 +49,52 @@ def __call__(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor:
batched). The shape of the result should be `self.shape()`.
:rtype: torch.Tensor
"""
return (
sample_shape = self.named_sample_shape + sample_shape
result = (
self.rsample(sample_shape)
if self.has_rsample
else self.sample(sample_shape)
)
bind_named_dims = self.named_shape[
len(self.named_shape) - len(self.named_sample_shape) :
]
if bind_named_dims:
result = result[bind_named_dims]
return result

@property
def named_shape(self) -> Tuple["Dim"]:
if getattr(self, "_named_shape", None) is None:
result = []
for param in self.arg_constraints:
value = getattr(self, param)
for dim in getattr(value, "dims", ()):
# Can't use `dim in result` when `result` is a list or a tuple
# RuntimeError: vmap: It looks like you're attempting to use
# a Tensor in some data-dependent control flow. We don't support
# that yet, please shout over at
# https://github.com/pytorch/functorch/issues/257
if dim not in set(result):
result.append(dim)
self._named_shape = tuple(result)
return self._named_shape

def expand_named_shape(self, named_shape: Tuple["Dim"]) -> Self:
for dim in named_shape:
if dim not in set(self.named_shape):
self._named_shape += (dim,)
self.named_sample_shape = self.named_sample_shape + (dim.size,)
return self

@property
def named_sample_shape(self) -> torch.Size:
if getattr(self, "_named_sample_shape", None) is None:
self._named_sample_shape = torch.Size()
return self._named_sample_shape

@named_sample_shape.setter
def named_sample_shape(self, value: torch.Size) -> None:
self._named_sample_shape = value

@property
def batch_shape(self) -> torch.Size:
Expand Down
96 changes: 53 additions & 43 deletions pyro/infer/importance.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,59 @@
from .util import plate_log_prob_sum


class Importance(TracePosterior):
class WeightAnalytics:
def get_log_normalizer(self):
"""
Estimator of the normalizing constant of the target distribution.
(mean of the unnormalized weights)
"""
# ensure list is not empty
if len(self.log_weights) > 0:
log_w = (
self.log_weights
if isinstance(self.log_weights, torch.Tensor)
else torch.tensor(self.log_weights)
)
log_num_samples = torch.log(torch.tensor(log_w.numel() * 1.0))
return torch.logsumexp(log_w - log_num_samples, 0)
else:
warnings.warn(
"The log_weights list is empty, can not compute normalizing constant estimate."
)

def get_normalized_weights(self, log_scale=False):
"""
Compute the normalized importance weights.
"""
if len(self.log_weights) > 0:
log_w = (
self.log_weights
if isinstance(self.log_weights, torch.Tensor)
else torch.tensor(self.log_weights)
)
log_w_norm = log_w - torch.logsumexp(log_w, 0)
return log_w_norm if log_scale else torch.exp(log_w_norm)
else:
warnings.warn(
"The log_weights list is empty. There is nothing to normalize."
)

def get_ESS(self):
"""
Compute (Importance Sampling) Effective Sample Size (ESS).
"""
if len(self.log_weights) > 0:
log_w_norm = self.get_normalized_weights(log_scale=True)
ess = torch.exp(-torch.logsumexp(2 * log_w_norm, 0))
else:
warnings.warn(
"The log_weights list is empty, effective sample size is zero."
)
ess = 0
return ess


class Importance(TracePosterior, WeightAnalytics):
"""
:param model: probabilistic model defined as a function
:param guide: guide used for sampling defined as a function
Expand Down Expand Up @@ -55,48 +107,6 @@ def _traces(self, *args, **kwargs):
log_weight = model_trace.log_prob_sum() - guide_trace.log_prob_sum()
yield (model_trace, log_weight)

def get_log_normalizer(self):
"""
Estimator of the normalizing constant of the target distribution.
(mean of the unnormalized weights)
"""
# ensure list is not empty
if self.log_weights:
log_w = torch.tensor(self.log_weights)
log_num_samples = torch.log(torch.tensor(self.num_samples * 1.0))
return torch.logsumexp(log_w - log_num_samples, 0)
else:
warnings.warn(
"The log_weights list is empty, can not compute normalizing constant estimate."
)

def get_normalized_weights(self, log_scale=False):
"""
Compute the normalized importance weights.
"""
if self.log_weights:
log_w = torch.tensor(self.log_weights)
log_w_norm = log_w - torch.logsumexp(log_w, 0)
return log_w_norm if log_scale else torch.exp(log_w_norm)
else:
warnings.warn(
"The log_weights list is empty. There is nothing to normalize."
)

def get_ESS(self):
"""
Compute (Importance Sampling) Effective Sample Size (ESS).
"""
if self.log_weights:
log_w_norm = self.get_normalized_weights(log_scale=True)
ess = torch.exp(-torch.logsumexp(2 * log_w_norm, 0))
else:
warnings.warn(
"The log_weights list is empty, effective sample size is zero."
)
ess = 0
return ess


def vectorized_importance_weights(model, guide, *args, **kwargs):
"""
Expand Down
13 changes: 9 additions & 4 deletions pyro/infer/predictive.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import pyro
import pyro.poutine as poutine
from pyro.infer.importance import WeightAnalytics
from pyro.infer.util import plate_log_prob_sum
from pyro.poutine.trace_struct import Trace
from pyro.poutine.util import prune_subsample_sites
Expand Down Expand Up @@ -317,16 +318,20 @@ def get_vectorized_trace(self, *args, **kwargs):


class WeighedPredictiveResults(NamedTuple):
"""
Return value of call to instance of :class:`WeighedPredictive`.
"""

samples: Union[dict, tuple]
log_weights: torch.Tensor
guide_log_prob: torch.Tensor
model_log_prob: torch.Tensor


class WeighedPredictiveResults(WeighedPredictiveResults, WeightAnalytics):
"""
Return value of call to instance of :class:`WeighedPredictive`.
"""

pass


class WeighedPredictive(Predictive):
"""
Class used to construct a weighed predictive distribution that is based
Expand Down
4 changes: 4 additions & 0 deletions pyro/ops/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,3 +215,7 @@ def __init__(self, tensor):

def __getitem__(self, args):
return vindex(self._tensor, args)


def index_select(input, dim, index):
return input.order(dim)[index]
Loading
Loading