Skip to content

Commit

Permalink
Address PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
francois-rozet committed Jan 12, 2024
1 parent 910243c commit ae105d4
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 45 deletions.
12 changes: 8 additions & 4 deletions pyro/contrib/zuko.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,13 @@
import pyro


class Zuko2Pyro(pyro.distributions.TorchDistribution):
class ZukoToPyro(pyro.distributions.TorchDistribution):
r"""Wraps a Zuko distribution as a Pyro distribution.
If ``dist`` has an ``rsample_and_log_prob`` method, like Zuko's flows, it will be
used when sampling instead of ``rsample``. The returned log density will be cached
for later scoring.
:param dist: A distribution instance.
:type dist: torch.distributions.Distribution
Expand All @@ -31,9 +35,9 @@ class Zuko2Pyro(pyro.distributions.TorchDistribution):
x = dist.sample((2, 3))
log_p = dist.log_prob(x)
# Zuko2Pyro(flow()) is a pyro.distributions.Distribution
# ZukoToPyro(flow()) is a pyro.distributions.Distribution
dist = Zuko2Pyro(flow())
dist = ZukoToPyro(flow())
x = dist((2, 3))
log_p = dist.log_prob(x)
Expand Down Expand Up @@ -74,4 +78,4 @@ def log_prob(self, x: Tensor) -> Tensor:
return self.dist.log_prob(x)

def expand(self, *args, **kwargs):
return Zuko2Pyro(self.dist.expand(*args, **kwargs))
return ZukoToPyro(self.dist.expand(*args, **kwargs))
24 changes: 16 additions & 8 deletions tests/contrib/test_zuko.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@
import torch

import pyro
from pyro.contrib.zuko import Zuko2Pyro
from pyro.contrib.zuko import ZukoToPyro
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam


@pytest.mark.parametrize("multivariate", [True, False])
def test_Zuko2Pyro(multivariate: bool):
@pytest.mark.parametrize("rsample_and_log_prob", [True, False])
def test_ZukoToPyro(multivariate: bool, rsample_and_log_prob: bool):
# Distribution
if multivariate:
normal = torch.distributions.MultivariateNormal
Expand All @@ -25,32 +26,39 @@ def test_Zuko2Pyro(multivariate: bool):

dist = normal(mu, sigma)

if rsample_and_log_prob:
def dummy(self, shape):
x = self.rsample(x)
return x, self.log_prob(x)

dist.rsample_and_log_prob = dummy

# Sample
x1 = pyro.sample("x1", Zuko2Pyro(dist))
x1 = pyro.sample("x1", ZukoToPyro(dist))

assert x1.shape == dist.event_shape

# Sample within plate
with pyro.plate("data", 4):
x2 = pyro.sample("x2", Zuko2Pyro(dist))
x2 = pyro.sample("x2", ZukoToPyro(dist))

assert x2.shape == (4, *dist.event_shape)

# SVI
def model():
pyro.sample("a", Zuko2Pyro(dist))
pyro.sample("a", ZukoToPyro(dist))

with pyro.plate("data", 4):
pyro.sample("b", Zuko2Pyro(dist))
pyro.sample("b", ZukoToPyro(dist))

def guide():
mu_ = pyro.param("mu", mu)
sigma_ = pyro.param("sigma", sigma)

pyro.sample("a", Zuko2Pyro(normal(mu_, sigma_)))
pyro.sample("a", ZukoToPyro(normal(mu_, sigma_)))

with pyro.plate("data", 4):
pyro.sample("b", Zuko2Pyro(normal(mu_, sigma_)))
pyro.sample("b", ZukoToPyro(normal(mu_, sigma_)))

svi = SVI(model, guide, optim=Adam({"lr": 1e-3}), loss=Trace_ELBO())
svi.step()
60 changes: 30 additions & 30 deletions tutorial/source/svi_flow_guide.ipynb

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions tutorial/source/vae_flow_prior.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
"import torch.utils.data as data\n",
"import zuko\n",
"\n",
"from pyro.contrib.zuko import Zuko2Pyro\n",
"from pyro.contrib.zuko import ZukoToPyro\n",
"from pyro.optim import Adam\n",
"from pyro.infer import SVI, Trace_ELBO\n",
"from torch import Tensor\n",
Expand Down Expand Up @@ -132,7 +132,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"However, we choose a [masked autoregressive flow](https://arxiv.org/abs/1705.07057) (MAF) as prior $p_\\phi(z)$ instead of the typical standard Gaussian $\\mathcal{N}(0, I)$. Instead of implementing the MAF ourselves, we borrow it from the [Zuko](https://github.com/probabilists/zuko) library. Because Zuko distributions are very similar to Pyro distributions, a thin wrapper (`Zuko2Pyro`) is sufficient to make Zuko and Pyro 100% compatible."
"However, we choose a [masked autoregressive flow](https://arxiv.org/abs/1705.07057) (MAF) as prior $p_\\phi(z)$ instead of the typical standard Gaussian $\\mathcal{N}(0, I)$. Instead of implementing the MAF ourselves, we borrow it from the [Zuko](https://github.com/probabilists/zuko) library. Because Zuko distributions are very similar to Pyro distributions, a thin wrapper (`ZukoToPyro`) is sufficient to make Zuko and Pyro 100% compatible."
]
},
{
Expand Down Expand Up @@ -227,7 +227,7 @@
" pyro.module(\"decoder\", self.decoder)\n",
"\n",
" with pyro.plate(\"batch\", len(x)):\n",
" z = pyro.sample(\"z\", Zuko2Pyro(self.prior()))\n",
" z = pyro.sample(\"z\", ZukoToPyro(self.prior()))\n",
" x = pyro.sample(\"x\", self.decoder(z), obs=x)\n",
"\n",
" def guide(self, x: Tensor):\n",
Expand Down

0 comments on commit ae105d4

Please sign in to comment.