-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathvae.py
104 lines (89 loc) · 4.39 KB
/
vae.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import torch
import torch.nn as nn
class VAE(nn.Module):
"""Variational Autoencoder as described in https://arxiv.org/abs/1312.6114
The variational autoencoder models the latent space with an underlying
standard Gaussian distribution. The data space is also modelled using a
normal distribution with mean and variance generated by the decoder network.
"""
def __init__(self, latent_dim, encoder, decoder):
"""Init a Variational Autoencoder model.
Args:
latent_dim: int
encoder: nn.Module
decoder: nn.Module
"""
super().__init__()
self.latent_dim = latent_dim
self.encoder = encoder
self.decoder = decoder
# We will use a standard Gaussian distribution for the prior by default.
# Note that if we decide to use a different prior distribution, then we
# need to revisit the formula for the reparametrization.
#
# Instead of storing the prior distribution as a parameter of the model,
# we will construct it on every `self.prior()` call. The reason for this
# is that once the Distribution object is initialized it cannot be moved
# to a different device. Using this workaround the prior is defined on
# the device on which the model parameters are currently placed.
self.register_buffer("mu_prior", torch.tensor(0.))
self.register_buffer("std_prior", torch.tensor(1.))
def prior(self):
return torch.distributions.Normal(self.mu_prior, self.std_prior)
def loss(self, x):
"""Compute the variational lower bound loss.
log p(x) >= E[ log p(x|z) ] + KL(q(x) | prior)
recon_loss kl_loss
"""
mu_z, log_std_z = self.encoder(x)
eps = self.prior().sample(mu_z.shape)
z = mu_z + log_std_z.exp() * eps
mu_x, log_std_x = self.decoder(z)
# Since we are parametrizing the distributions returned by the encoder
# and the decoder using normal distributions, we can calculate the
# reconstruction loss and the KL loss using closed form formulas.
# NOTE: Observe that the reconstruction loss is actually proportional to
# MSE loss scaled by the std.
#
# recon_loss = 0.5 * (np.log(2 * np.pi) + log_std_x + (x-mu_x) ** 2 / log_std_x.exp())
# recon_loss = recon_loss.mean(dim=0).sum()
# kl_loss = 0.5 * (-log_std_z - 1 + (mu_z ** 2 + log_std_z.exp()))
# kl_loss = kl_loss..mean(dim=0).sum()
#
# We can also compute the reconstruction and the KL losses for any type
# of distributions that we choose using the pytorch built in functions.
# NOTE: It would probably be a lot faster to use the formulas above for
# the reconstruction and KL losses, instead of creating the distributions
# calling the `log_prob` and `kl_divergence` built-ins.
q_dist = torch.distributions.Normal(mu_z, log_std_z.exp())
p_dist = torch.distributions.Normal(mu_x, log_std_x.exp())
recon_loss = -p_dist.log_prob(x).mean(dim=0).sum()
kl_loss = torch.distributions.kl.kl_divergence(q_dist, self.prior()).mean(dim=0).sum()
# NOTE: Note that we are averaging over the batch size and summing over
# all other dimensions. The reason for this is that the reconstruction
# loss has dim equal to the original space, and the kl loss has dim
# equal to the latent space. If we instead average over all dimensions,
# then the two losses will be in very different scales, and one might
# dominate over the other.
_, C, H, W = x.shape
recon_loss /= (C * H * W)
kl_loss /= (C * H * W)
total_loss = recon_loss + kl_loss
return (total_loss, recon_loss, kl_loss)
@torch.no_grad()
def reconstruct(self, x):
"""Encode the input into the latent space and then decode it back from
the latent representation. Return the reconstructed object.
"""
mu_z, log_std_z = self.encoder(x)
eps = self.prior().sample(mu_z.shape)
z = mu_z + log_std_z.exp() * eps
mu_x, _ = self.decoder(z)
return mu_x.cpu()
@torch.no_grad()
def sample(self, n):
"""Generate n samples using the decoder."""
z = self.prior().sample((n, self.latent_dim))
mu_x, log_std_x = self.decoder(z)
return mu_x.cpu()
#