-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
231 lines (183 loc) · 7.9 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
import gzip
import os
import struct
import urllib.request
import configargparse
import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np
import optax
class Encoder(eqx.Module):
hidden: eqx.nn.Linear
logvar: eqx.nn.Linear
mean: eqx.nn.Linear
def __init__(self, input_dim, hidden_dim, latent_dim, *, rng):
hidden_rng, mean_rng, logvar_rng = jax.random.split(rng, 3)
self.hidden = eqx.nn.Linear(
in_features=input_dim, out_features=hidden_dim, key=hidden_rng
)
self.logvar = eqx.nn.Linear(
in_features=hidden_dim, out_features=latent_dim, key=logvar_rng
)
self.mean = eqx.nn.Linear(
in_features=hidden_dim, out_features=latent_dim, key=mean_rng
)
def __call__(self, x):
x = self.hidden(x)
x = jax.nn.sigmoid(x)
return self.mean(x), self.logvar(x)
class Decoder(eqx.Module):
hidden: eqx.nn.Linear
output: eqx.nn.Linear
def __init__(self, latent_dim, hidden_dim, output_dim, *, rng):
hidden_rng, output_rng = jax.random.split(rng)
self.hidden = eqx.nn.Linear(
in_features=latent_dim, out_features=hidden_dim, key=hidden_rng
)
self.output = eqx.nn.Linear(
in_features=hidden_dim, out_features=output_dim, key=output_rng
)
def __call__(self, z):
z = self.hidden(z)
z = jax.nn.sigmoid(z)
z = self.output(z)
z = jax.nn.sigmoid(z)
return z
class VAE(eqx.Module):
encoder: Encoder
decoder: Decoder
is_conditional: bool
def __init__(self, hidden_dim, latent_dim, is_conditional, *, rng):
encoder_rng, decoder_rng = jax.random.split(rng)
self.encoder = Encoder(28*28 + 10*is_conditional, hidden_dim, latent_dim, rng=encoder_rng)
self.decoder = Decoder(latent_dim + 10*is_conditional, hidden_dim, 28*28, rng=decoder_rng)
self.is_conditional = is_conditional
def __call__(self, x, label, *, rng):
x = jnp.ravel(x)
if self.is_conditional:
x = jnp.concatenate([x, label])
mean, logvar = self.encoder(x)
z = self.reparameterize(mean, logvar, rng=rng)
if self.is_conditional:
z = jnp.concatenate([z, label])
x_recon = self.decoder(z)
return jnp.reshape(x_recon, (1, 28, 28)), mean, logvar
def reparameterize(self, mean, logvar, *, rng):
std = jnp.exp(0.5 * logvar)
eps = jax.random.normal(rng, mean.shape)
return mean + eps * std
def mnist():
url_dir = "https://storage.googleapis.com/cvdf-datasets/mnist"
target_dir = os.getcwd() + "/data/mnist"
filenames = {
"images": "train-images-idx3-ubyte.gz",
"labels": "train-labels-idx1-ubyte.gz",
}
# download images and labels into data folder
for _, filename in filenames.items():
url = f"{url_dir}/{filename}"
target = f"{target_dir}/{filename}"
if not os.path.exists(target):
os.makedirs(target_dir, exist_ok=True)
urllib.request.urlretrieve(url, target)
print(f"Downloaded {url} to {target}")
# load images into memory
target = f"{target_dir}/{filenames['images']}"
with gzip.open(target, "rb") as fh:
_, batch, rows, cols = struct.unpack(">IIII", fh.read(16))
shape = (batch, 1, rows, cols)
images = jnp.frombuffer(fh.read(), dtype=jnp.uint8).reshape(shape)
# load labels into memory
target = f"{target_dir}/{filenames['labels']}"
with gzip.open(target, "rb") as fh:
_, batch = struct.unpack(">II", fh.read(8))
shape = (batch, 1,)
labels = jnp.frombuffer(fh.read(), dtype=jnp.uint8).reshape(shape)
return images, labels
def infinite_dataloader(images, labels, batch_size, *, rng):
dataset_size = images.shape[0]
indices = jnp.arange(dataset_size)
while True:
rng, perm_rng = jax.random.split(rng)
perm_indices = jax.random.permutation(perm_rng, indices)
start, end = 0, batch_size
while start < dataset_size:
batch_indices = perm_indices[start:min(end, dataset_size)]
start, end = end, end + batch_size
yield images[batch_indices], labels[batch_indices]
def gaussian_log_likelihood(mean, sample):
# TODO. Add variance term
return 0.5 * jnp.sum(jnp.square(mean - sample))
def bernoulli_log_likelihood(logits, labels):
# with soft labels
return - jnp.sum(labels * jnp.log(logits) + (1.0 - labels) * jnp.log(1 - logits))
def reconstruction_error(mean, sample):
return bernoulli_log_likelihood(logits=mean, labels=sample)
def kullback_leibler_divergence(mean, logvar):
return - 0.5 * jnp.sum(1 + logvar - jnp.square(mean) - jnp.exp(logvar))
def vae_loss(model, x, labels, *, rng):
rng_batch = jax.random.split(rng, x.shape[0])
x_recon, mean, logvar = jax.vmap(model)(x, labels, rng=rng_batch)
kl = jax.vmap(kullback_leibler_divergence)(mean, logvar)
neg_recon_error = - jax.vmap(reconstruction_error)(x_recon, x)
elbo = - kl + neg_recon_error
return - jnp.mean(elbo)
if __name__ == "__main__":
# initialize argparser
p = configargparse.ArgParser()
# seed
p.add_argument('--seed', type=int, help="global seed", default=42)
# architecture
p.add_argument('--hidden_dim', type=int, help="encoder and decoder hidden dimension", default=128)
p.add_argument('--latent_dim', type=int, help="latent dimension", default=6)
# vae vs conditional vae
p.add_argument('--conditional', action="store_true")
# training / optimisation
p.add_argument('--learning_rate', type=float, help="ADAM learning rate", default=3e-4)
p.add_argument('--batch_size', type=int, help="mini-batch size", default=128)
p.add_argument('--n_epochs', type=int, help="number of epochs", default=10)
# obtain and log arguments to console
args = p.parse_args()
print(p.format_values())
# obtain keys for pseudo-random number generators
rng = jax.random.PRNGKey(args.seed)
rng, model_rng, dataloader_rng = jax.random.split(rng, 3)
# initialize model and optimizer state
vae = VAE(hidden_dim=args.hidden_dim, latent_dim=args.latent_dim, is_conditional=args.conditional, rng=model_rng)
optim = optax.adam(learning_rate=args.learning_rate)
opt_state = optim.init(eqx.filter(vae, eqx.is_array))
# prepare training parameters
images, labels = mnist()
dataset_size = images.shape[0]
n_batches_per_epoch = (dataset_size + args.batch_size - 1) // args.batch_size
n_batches = n_batches_per_epoch * args.n_epochs
# scale images, one-hot encode labels, and initialize dataloader
images = images / 255
labels = np.eye(10).take(labels, axis=0).squeeze(1)
train_dataloader = infinite_dataloader(images, labels, args.batch_size, rng=dataloader_rng)
# define jitted training step function
@eqx.filter_jit
def make_step(vae, opt_state, x, labels, *, rng):
loss, grads = eqx.filter_value_and_grad(vae_loss)(vae, x, labels, rng=rng)
updates, opt_state = optim.update(grads, opt_state, vae)
vae = eqx.apply_updates(vae, updates)
return vae, opt_state, loss
# training loop
for batch_idx, (image_batch, label_batch) in zip(range(n_batches), train_dataloader):
# split PRNG key
rng, step_rng = jax.random.split(rng)
# perform training step
vae, opt_state, loss = make_step(vae, opt_state, image_batch, label_batch, rng=step_rng)
# log metrics
epoch, batch = divmod(batch_idx, n_batches_per_epoch)
if batch == 0:
losses = np.zeros((n_batches_per_epoch,))
losses[batch] = loss
if batch == n_batches_per_epoch - 1:
logging_components = [
"EPOCH: {0:2d}".format(epoch),
"BATCH: {0:3d}".format(batch),
"LOSS: {0:.4f}".format(losses.mean()),
]
print(" | ".join(logging_components))