-
Notifications
You must be signed in to change notification settings - Fork 9
/
InfoVAE.py
86 lines (73 loc) · 3.41 KB
/
InfoVAE.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torch.autograd import Variable
class InfoVAE(nn.Module):
def __init__(self,nfeat=1000, ncode=5, alpha=0, lambd=10000, nhidden=128, nhidden2=35, dropout=0.2):
super(InfoVAE, self).__init__()
self.ncode = int(ncode)
self.alpha = float(alpha)
self.lambd = float(lambd)
self.encd = nn.Linear(nfeat, nhidden)
self.d1 = nn.Dropout(p=dropout)
self.enc2 = nn.Linear(nhidden, nhidden2)
self.d2 = nn.Dropout(p=dropout)
self.mu = nn.Linear(nhidden2, ncode)
self.lv = nn.Linear(nhidden2, ncode)
self.decd = nn.Linear(ncode, nhidden2)
self.d3 = nn.Dropout(p=dropout)
self.dec2 = nn.Linear(nhidden2, nhidden)
self.d4 = nn.Dropout(p=dropout)
self.outp = nn.Linear(nhidden, nfeat)
def encode(self, x):
x = self.d1(F.leaky_relu(self.encd(x)))
x = self.d2(F.leaky_relu(self.enc2(x)))
mu = self.mu(x)
logvar = self.lv(x)
return mu, logvar
def reparameterize(self, mu, logvar):
std = torch.exp(0.5*logvar)
eps = torch.randn_like(std)
return eps.mul(std).add_(mu)
def decode(self, x):
x = self.d3(F.leaky_relu(self.decd(x)))
x = self.d4(F.leaky_relu(self.dec2(x)))
x = self.outp(x)
return x
def forward(self, x):
mu, logvar = self.encode(x)
z = self.reparameterize(mu, logvar)
return self.decode(z), mu, logvar
# https://ermongroup.github.io/blog/a-tutorial-on-mmd-variational-autoencoders/
def compute_kernel(self, x, y):
x_size = x.size(0)
y_size = y.size(0)
dim = x.size(1)
x = x.unsqueeze(1) # (x_size, 1, dim)
y = y.unsqueeze(0) # (1, y_size, dim)
tiled_x = x.expand(x_size, y_size, dim)
tiled_y = y.expand(x_size, y_size, dim)
# The example code divides by (dim) here, making <kernel_input> ~ 1/dim
# excluding (dim) makes <kernel_input> ~ 1
kernel_input = (tiled_x - tiled_y).pow(2).mean(2)#/float(dim)
return torch.exp(-kernel_input) # (x_size, y_size)
# https://ermongroup.github.io/blog/a-tutorial-on-mmd-variational-autoencoders/
def compute_mmd(self, x, y):
xx_kernel = self.compute_kernel(x,x)
yy_kernel = self.compute_kernel(y,y)
xy_kernel = self.compute_kernel(x,y)
return torch.mean(xx_kernel) + torch.mean(yy_kernel) - 2*torch.mean(xy_kernel)
def loss(self, x, weig, epoch):
recon_x, mu, logvar = self.forward(x)
MSE = torch.sum(0.5 * weig * (x - recon_x).pow(2))
# KL divergence (Kingma and Welling, https://arxiv.org/abs/1312.6114, Appendix B)
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
#return MSE + self.beta*KLD, MSE
# https://ermongroup.github.io/blog/a-tutorial-on-mmd-variational-autoencoders/
true_samples = Variable(torch.randn(200, self.ncode), requires_grad=False)
z = self.reparameterize(mu, logvar) #duplicate call
# compute MMD ~ 1, so upweight to match KLD which is ~ n_batch x n_code
MMD = self.compute_mmd(true_samples,z) * x.size(0) * self.ncode
return MSE + (1-self.alpha)*KLD + (self.lambd+self.alpha-1)*MMD, MSE, KLD, MMD