-
Notifications
You must be signed in to change notification settings - Fork 21
/
Copy pathvae.py
113 lines (91 loc) · 3.67 KB
/
vae.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
import argparse
import os
from collections import OrderedDict
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
import funsor
import funsor.torch.distributions as dist
import funsor.ops as ops
from funsor.domains import Bint, Reals
REPO_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
DATA_PATH = os.path.join(REPO_PATH, 'data')
class Encoder(nn.Module):
def __init__(self):
super(Encoder, self).__init__()
self.fc1 = nn.Linear(784, 400)
self.fc21 = nn.Linear(400, 20)
self.fc22 = nn.Linear(400, 20)
def forward(self, image):
image = image.reshape(image.shape[:-2] + (-1,))
h1 = F.relu(self.fc1(image))
loc = self.fc21(h1)
scale = self.fc22(h1).exp()
return loc, scale
class Decoder(nn.Module):
def __init__(self):
super(Decoder, self).__init__()
self.fc3 = nn.Linear(20, 400)
self.fc4 = nn.Linear(400, 784)
def forward(self, z):
h3 = F.relu(self.fc3(z))
out = torch.sigmoid(self.fc4(h3))
return out.reshape(out.shape[:-1] + (28, 28))
def main(args):
funsor.set_backend("torch")
encoder = Encoder()
decoder = Decoder()
encode = funsor.function(Reals[28, 28], (Reals[20], Reals[20]))(encoder)
decode = funsor.function(Reals[20], Reals[28, 28])(decoder)
@funsor.interpretation(funsor.montecarlo.MonteCarlo())
def loss_function(data, subsample_scale):
# Lazily sample from the guide.
loc, scale = encode(data)
q = funsor.Independent(
dist.Normal(loc['i'], scale['i'], value='z_i'),
'z', 'i', 'z_i')
# Evaluate the model likelihood at the lazy value z.
probs = decode('z')
p = dist.Bernoulli(probs['x', 'y'], value=data['x', 'y'])
p = p.reduce(ops.add, {'x', 'y'})
# Construct an elbo. This is where sampling happens.
elbo = funsor.Integrate(q, p - q, 'z')
elbo = elbo.reduce(ops.add, 'batch') * subsample_scale
loss = -elbo
return loss
train_loader = torch.utils.data.DataLoader(
datasets.MNIST(DATA_PATH, train=True, download=True,
transform=transforms.ToTensor()),
batch_size=args.batch_size, shuffle=True)
encoder.train()
decoder.train()
optimizer = optim.Adam(list(encoder.parameters()) +
list(decoder.parameters()), lr=1e-3)
for epoch in range(args.num_epochs):
train_loss = 0
for batch_idx, (data, _) in enumerate(train_loader):
subsample_scale = float(len(train_loader.dataset) / len(data))
data = data[:, 0, :, :]
data = funsor.Tensor(data, OrderedDict(batch=Bint[len(data)]))
optimizer.zero_grad()
loss = loss_function(data, subsample_scale)
assert isinstance(loss, funsor.Tensor), loss.pretty()
loss.data.backward()
train_loss += loss.item()
optimizer.step()
if batch_idx % 50 == 0:
print(' loss = {}'.format(loss.item()))
if batch_idx and args.smoke_test:
return
print('epoch {} train_loss = {}'.format(epoch, train_loss))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='VAE MNIST Example')
parser.add_argument('-n', '--num-epochs', type=int, default=10)
parser.add_argument('--batch-size', type=int, default=8)
parser.add_argument('--smoke-test', action='store_true')
args = parser.parse_args()
main(args)