-
-
Notifications
You must be signed in to change notification settings - Fork 984
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Subsampling in some autoguides produces parameters with wrong shapes #3286
Comments
Hi @fritzo! I had a closer look at the issue and it's a little more complicated than I thought... Are there already some tests for the creation of parameters in auto guides? |
I think that the same phenomenon happens for import pyro
import torch
def model():
with pyro.plate("dummy", 20, subsample_size=3):
pyro.sample("x", pyro.distributions.Normal(0, 1))
guide = pyro.infer.autoguide.AutoLowRankMultivariateNormal(model, rank=2)
elbo = pyro.infer.Trace_ELBO()
with pyro.poutine.trace(param_only=True) as param_capture:
elbo.differentiable_loss(model, guide)
print(pyro.param("AutoLowRankMultivariateNormal.loc").shape)
# torch.Size([3])
print(pyro.param("AutoLowRankMultivariateNormal.scale").shape)
# torch.Size([3])
print(pyro.param("AutoLowRankMultivariateNormal.cov_factor").shape)
# torch.Size([3,2]) The parameters should have 20 rows but they have 3. Following the suggestion of the doc, we can initialize the parameters with import pyro
import torch
def model():
with pyro.plate("dummy", 20, subsample_size=3):
pyro.sample("x", pyro.distributions.Normal(0, 1))
pyro.param("AutoLowRankMultivariateNormal.loc", torch.zeros(20))
pyro.param("AutoLowRankMultivariateNormal.scale", torch.ones(20))
pyro.param("AutoLowRankMultivariateNormal.cov_factor", torch.ones(20,2))
guide = pyro.infer.autoguide.AutoLowRankMultivariateNormal(model, rank=2)
elbo = pyro.infer.Trace_ELBO()
with pyro.poutine.trace(param_only=True) as param_capture:
elbo.differentiable_loss(model, guide)
# ...
# AssertionError |
AutoGuides + data subsampling requires using |
Thanks @martinjankowiak! I have tried creating the plates manually in different contexts, but I did not get any luck. Have a look at the example below: am I doing it wrong? import pyro
import torch
def model():
with pyro.plate("dummy", 20, subsample_size=3):
pyro.sample("x", pyro.distributions.Categorical(torch.ones(1)))
def create_plate_x():
return pyro.plate("dummy", 20, subsample_size=3, dim=-1)
guide = pyro.infer.autoguide.AutoDiscreteParallel(model, create_plates=create_plate_x)
elbo = pyro.infer.TraceEnum_ELBO()
with pyro.poutine.trace(param_only=True) as param_capture:
elbo.differentiable_loss(model, guide)
print(pyro.param("AutoDiscreteParallel.x_probs").shape)
# torch.Size([3, 1]) Thanks for the link to the test! It seems to run with |
Issue Description
Auto guides need to create parameters in the background. The shape of those parameters is determined by the plates in the model. When plates are subsampled, the parameters should have the dimension of the full plate, not the subsampled plate. This is the case for some auto guides, but for
AutoDiscreteParallel
the shape of the parameters is wrong.Environment
Code Snippet
The code below shows the difference in behavior between
AutoNormal
andAutoDiscreteParallel
. In both cases, the model creates a plate of size 20 and subsamples it to size 3. Upon gathering the parameters,AutoNormal
produces parameters with 20 rows, whereasAutoDiscreteParallel
produces parameters with 3 rows.I believe that the issue is in the functions
_setup_prototype
inpyro/infer/autoguide/guides.py
. Below is the code fromAutoNormal
(see here).There is no equivalent in the
_setup_prototype
function ofAutoDiscreteParallel
(see here).I will work on a pull request to fix this. I would like to also create some additional tests for this and other cases, but I am not too sure where to start. Any help would be appreciated.
The text was updated successfully, but these errors were encountered: