-
Notifications
You must be signed in to change notification settings - Fork 245
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TraceEnum_ELBO
: Subsample local variables that depend on a global model-enumerated variable
#1572
base: master
Are you sure you want to change the base?
Conversation
Hi @ordabayevy, I don't understand how you can move prod and sum around. In particular, I'm not sure if your first equation makes sense: |
I think you are right @fehiepsi . Let me think more about this. |
So the actual equation should be (same in the code): This seems intuitive to me - subsample within a plate and then scale the product before summing it up. I did some tests and it seems to be unbiased. However, I can't figure out how to prove unbiasedness mathematically. |
Code I used to check unbiasedness: import torch
import torch.distributions as dist
import matplotlib.pyplot as plt
a = torch.tensor([0, 1])
logits_a = torch.log(torch.tensor([0.3, 0.7]))
# values are sampled from N(0, 1) and N(1, 1)
b = torch.rand(1000)
b[500:] += 1
d = dist.Normal(a, 1)
log_b = d.log_prob(b.reshape(-1,1))
expected = torch.logsumexp(log_b.sum(0) + logits_a, 0)
results = []
for _ in range(50000):
idx = torch.randperm(1000)[:100] # subsample 100 samples
scale = 10 # 1000 / 100
results.append(torch.logsumexp(d.log_prob(b[idx].reshape(-1,1)).sum(0) * scale + logits_a, 0))
print(expected)
print(torch.mean(torch.tensor(results)))
plt.plot(results)
plt.hlines(expected, 0, 50000, "C1")
plt.show()
>>> tensor(-1087.9426)
>>> tensor(-1087.8069) |
I think it's easier to see the issue if we use a smaller number of data (e.g. just 2). Assume we are using subsample to estimate |
One of the features not supported by
TraceEnum_ELBO
is that you cannot subsample a local variable when it depends on a global variable that is enumerated in the model because it requires a common scale:This has been asked on the forum as well: https://forum.pyro.ai/t/enumeration-and-subsampling-expected-all-enumerated-sample-sites-to-share-common-poutine-scale/4938
Proposed solution here is to scale log factors as follows ($N$ - total size, $M$ - subsample size):
$\log \sum_a p(a) {\prod_i}^{N} p(b_i | a) \approx \frac{N}{M}\log \sum_a p(a) {\prod_i}^{M} p(b_i | a)$
Expectation of the left hand side:
$\mathbb{E} [ \log \sum_a p(a) {\prod_i}^{N} p(b_i | a) ] = \mathbb{E} [ \log {\prod_i}^{N} \sum_a p(a) p(b_i | a) ]= \mathbb{E} [ \log {\prod_i}^{N} p(b_i) ]$
$= \mathbb{E} [{\sum_i}^N \log p(b_i) ] = {\sum_i}^N \mathbb{E} [ \log p(b_i) ]$
$= N \mathbb{E} [ \log p(b_i) ]$
Expectation of the right hand side:
$\mathbb{E} [ \frac{N}{M} \log \sum_a p(a) {\prod_i}^{M} p(b_i | a) ] = \frac{N}{M} \mathbb{E} [ \log {\prod_i}^{M} \sum_a p(a) p(b_i | a) ] = \frac{N}{M} \mathbb{E} [ \log {\prod_i}^{M} p(b_i) ]$
$= \frac{N}{M} \mathbb{E} [{\sum_i}^M \log p(b_i) ] = \frac{N}{M} {\sum_i}^M \mathbb{E} [ \log p(b_i) ]$
$= N \mathbb{E} [ \log p(b_i) ]$