From 5cebc446b1689481ec92ab41aa38b379f4be2863 Mon Sep 17 00:00:00 2001 From: Ben Zickel <35469979+BenZickel@users.noreply.github.com> Date: Sun, 4 Aug 2024 19:39:43 +0300 Subject: [PATCH] Support for transformed distributions, based on stacking or concatenation transforms, in SplitReparam (#3390) --- pyro/infer/reparam/split.py | 68 ++++++++++++++++++++++++++++++- tests/infer/reparam/test_split.py | 50 +++++++++++++++++++++++ 2 files changed, 116 insertions(+), 2 deletions(-) diff --git a/pyro/infer/reparam/split.py b/pyro/infer/reparam/split.py index d5a389bc0e..83f2224263 100644 --- a/pyro/infer/reparam/split.py +++ b/pyro/infer/reparam/split.py @@ -6,10 +6,61 @@ import pyro import pyro.distributions as dist import pyro.poutine as poutine +from pyro.distributions.torch_distribution import TorchDistributionMixin from .reparam import Reparam +def same_support(fn: TorchDistributionMixin, *args): + """ + Returns support of the `fn` distribution. Used in :class:`SplitReparam` in + order to determine the support of the split value. + + :param fn: distribution class + :returns: distribution support + """ + return fn.support + + +def real_support(fn: TorchDistributionMixin, *args): + """ + Returns real support with same event dimension as that of the `fn` distribution. + Used in :class:`SplitReparam` in order to determine the support of the split value. + + :param fn: distribution class + :returns: distribution support + """ + return dist.constraints.independent(dist.constraints.real, fn.event_dim) + + +def default_support(fn: TorchDistributionMixin, slice, dim): + """ + Returns support of the `fn` distribution, corrected for split stacking and + concatenation transforms. Used in :class:`SplitReparam` in + order to determine the support of the split value. + + :param fn: distribution class + :param slice: slice for which to return support + :param dim: dimension for which to return support + :returns: distribution support + """ + support = fn.support + # Unwrap support + reinterpreted_batch_ndims_vec = [] + while isinstance(support, dist.constraints.independent): + reinterpreted_batch_ndims_vec.append(support.reinterpreted_batch_ndims) + support = support.base_constraint + # Slice concatenation and stacking transforms + if isinstance(support, dist.constraints.stack) and support.dim == dim: + support = dist.constraints.stack(support.cseq[slice], dim) + elif isinstance(support, dist.constraints.cat) and support.dim == dim: + support = dist.constraints.cat(support.cseq[slice], dim, support.lengths[slice]) + # Wrap support + for reinterpreted_batch_ndims in reinterpreted_batch_ndims_vec[::-1]: + support = dist.constraints.independent(support, reinterpreted_batch_ndims) + return support + + class SplitReparam(Reparam): """ Reparameterizer to split a random variable along a dimension, similar to @@ -28,14 +79,21 @@ class SplitReparam(Reparam): each chunk. :type: list(int) :param int dim: Dimension along which to split. Defaults to -1. + :param callable support_fn: Function which derives the split support + from the site's sampling function, split size, and split dimension. + Default is :func:`default_support` which correctly handles stacking + and concatenation transforms. Other options are :func:`same_support` + which returns the same support as that of the sampling function, and + :func:`real_support` which returns a real support. """ - def __init__(self, sections, dim): + def __init__(self, sections, dim, support_fn=default_support): assert isinstance(dim, int) and dim < 0 assert isinstance(sections, list) assert all(isinstance(size, int) for size in sections) self.event_dim = -dim self.sections = sections + self.support_fn = support_fn def apply(self, msg): name = msg["name"] @@ -53,14 +111,20 @@ def apply(self, msg): dim = fn.event_dim - self.event_dim left_shape = fn.event_shape[:dim] right_shape = fn.event_shape[1 + dim :] + start = 0 for i, size in enumerate(self.sections): event_shape = left_shape + (size,) + right_shape value_split[i] = pyro.sample( f"{name}_split_{i}", - dist.ImproperUniform(fn.support, fn.batch_shape, event_shape), + dist.ImproperUniform( + self.support_fn(fn, slice(start, start + size), -self.event_dim), + fn.batch_shape, + event_shape, + ), obs=value_split[i], infer={"is_observed": is_observed}, ) + start += size # Combine parts into value. if value is None: diff --git a/tests/infer/reparam/test_split.py b/tests/infer/reparam/test_split.py index fb450c4220..b3f43bc5f6 100644 --- a/tests/infer/reparam/test_split.py +++ b/tests/infer/reparam/test_split.py @@ -91,6 +91,56 @@ def model(): check_init_reparam(model, SplitReparam(splits, dim)) +@batch_shape +def test_transformed_distribution(batch_shape): + num_samples = 10 + + transform = dist.transforms.StackTransform( + [ + dist.transforms.OrderedTransform(), + dist.transforms.DiscreteCosineTransform(), + dist.transforms.HaarTransform(), + ], + dim=-1, + ) + + num_transforms = len(transform.transforms) + + def model(): + scale_tril = pyro.sample("scale_tril", dist.LKJCholesky(num_transforms, 1)) + with pyro.plate_stack("plates", batch_shape): + x_dist = dist.TransformedDistribution( + dist.MultivariateNormal( + torch.zeros(num_samples, num_transforms), scale_tril=scale_tril + ).to_event(1), + [transform], + ) + return pyro.sample("x", x_dist) + + assert model().shape == batch_shape + (num_samples, num_transforms) + + pyro.clear_param_store() + guide = pyro.infer.autoguide.AutoMultivariateNormal(model) + guide_sites = guide() + + assert guide_sites["x"].shape == batch_shape + (num_samples, num_transforms) + + for sections in [[1, 1, 1], [1, 2], [2, 1]]: + split_model = pyro.poutine.reparam( + model, config={"x": SplitReparam(sections, -1)} + ) + + pyro.clear_param_store() + guide = pyro.infer.autoguide.AutoMultivariateNormal(split_model) + guide_sites = guide() + + for n, section in enumerate(sections): + assert guide_sites[f"x_split_{n}"].shape == batch_shape + ( + num_samples, + section, + ) + + @event_shape_splits_dim @batch_shape def test_predictive(batch_shape, event_shape, splits, dim):