-
Notifications
You must be signed in to change notification settings - Fork 6
/
model.py
176 lines (151 loc) · 6.36 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
'''
Defines generator and discriminator of the the Generative Adversarial Network (GAN)
for semi-supervised learning using PyTorch library.
It is inspired by Udacity (www.udacity.com) courses and an attempt to rewrite this
implementation in PyTorch:
https://github.com/udacity/deep-learning/blob/master/semi-supervised/semi-supervised_learning_2_solution.ipynb
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
class _netG(nn.Module):
'''
The generator network
'''
def __init__(self, nz, ngf, alpha, nc, use_gpu):
'''
:param nz: noise dimension
:param ngf: generator multiplier for convolution transpose output layers
:param alpha: negative slope for leaky relu
:param nc: number of image channels
:param use_gpu: indication to use the GPU
'''
super(_netG, self).__init__()
self.use_gpu = use_gpu
self.main = nn.Sequential(
# noise is going into a convolution
nn.ConvTranspose2d(nz, ngf * 4, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.LeakyReLU(alpha),
# (ngf * 4) x 4 x 4
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.LeakyReLU(alpha),
# (ngf * 2) x 8 x 8
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.LeakyReLU(alpha),
# (ngf) x 16 x 16
nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
nn.Tanh()
# (nc) x 32 x 32
)
def forward(self, inputs):
'''
:param inputs: we expect noise as input for generator network
'''
if isinstance(inputs.data, torch.cuda.FloatTensor) and self.use_gpu:
out = nn.parallel.data_parallel(self.main, inputs, range(1))
else:
out = self.main(inputs)
return out
class _ganLogits(nn.Module):
'''
Layer of the GAN logits of the discriminator
The layer gets class logits as inputs and calculates GAN logits to
differentiate real and fake images in a numerical stable way
'''
def __init__(self, num_classes):
'''
:param num_classes: Number of real data classes (10 for SVHN)
'''
super(_ganLogits, self).__init__()
self.num_classes = num_classes
def forward(self, class_logits):
'''
:param class_logits: Unscaled log probabilities of house numbers
'''
# Set gan_logits such that P(input is real | input) = sigmoid(gan_logits).
# Keep in mind that class_logits gives you the probability distribution over all the real
# classes and the fake class. You need to work out how to transform this multiclass softmax
# distribution into a binary real-vs-fake decision that can be described with a sigmoid.
# Numerical stability is very important.
# You'll probably need to use this numerical stability trick:
# log sum_i exp a_i = m + log sum_i exp(a_i - m).
# This is numerically stable when m = max_i a_i.
# (It helps to think about what goes wrong when...
# 1. One value of a_i is very large
# 2. All the values of a_i are very negative
# This trick and this value of m fix both those cases, but the naive implementation and
# other values of m encounter various problems)
real_class_logits, fake_class_logits = torch.split(class_logits, self.num_classes, dim=1)
fake_class_logits = torch.squeeze(fake_class_logits)
max_val, _ = torch.max(real_class_logits, 1, keepdim=True)
stable_class_logits = real_class_logits - max_val
max_val = torch.squeeze(max_val)
gan_logits = torch.log(torch.sum(torch.exp(stable_class_logits), 1)) + max_val - fake_class_logits
return gan_logits
class _netD(nn.Module):
'''
The discriminator network
'''
def __init__(self, ndf, alpha, nc, drop_rate, num_classes, use_gpu):
'''
:param ndf: multiplier for convolution output layers
:param alpha: negative slope for leaky relu
:param nc: number of image channels
:param drop_rate: rate for dropout layers
:param num_classes: number of output classes (10 for SVHN)
:param use_gpu: indication to use the GPU
'''
super(_netD, self).__init__()
self.use_gpu = use_gpu
self.main = nn.Sequential(
nn.Dropout2d(drop_rate/2.5),
# input is (number_channels) x 32 x 32
nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(alpha),
nn.Dropout2d(drop_rate),
# (ndf) x 16 x 16
nn.Conv2d(ndf, ndf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf),
nn.LeakyReLU(alpha),
# (ndf) x 8 x 8
nn.Conv2d(ndf, ndf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf),
nn.LeakyReLU(alpha),
nn.Dropout2d(drop_rate),
# (ndf) x 4 x 4
nn.Conv2d(ndf, ndf * 2, 3, 1, 1, bias=False),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(alpha),
# (ndf * 2) x 4 x 4
nn.Conv2d(ndf * 2, ndf * 2, 3, 1, 1, bias=False),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(alpha),
# (ndf * 2) x 4 x 4
nn.Conv2d(ndf * 2, ndf * 2, 3, 1, 0, bias=False),
nn.LeakyReLU(alpha),
# (ndf * 2) x 2 x 2
)
self.features = nn.AvgPool2d(kernel_size=2)
self.class_logits = nn.Linear(
in_features=(ndf * 2) * 1 * 1,
out_features=num_classes + 1)
self.gan_logits = _ganLogits(num_classes)
self.softmax = nn.Softmax(dim=0)
def forward(self, inputs):
'''
:param inputs: we expect real or fake images as an input for discriminator network
'''
if isinstance(inputs.data, torch.cuda.FloatTensor) and self.use_gpu:
out = nn.parallel.data_parallel(self.main, inputs, range(1))
else:
out = self.main(inputs)
features = self.features(out)
features = features.squeeze()
class_logits = self.class_logits(features)
gan_logits = self.gan_logits(class_logits)
out = self.softmax(class_logits)
return out, class_logits, gan_logits, features