-
Notifications
You must be signed in to change notification settings - Fork 4
/
train.py
305 lines (261 loc) · 10 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
import sys
import os
import yaml
from absl import flags, logging
from absl import logging
import ml_collections
from ml_collections import config_flags
from clu import metric_writers
import wandb
sys.path.append("./")
sys.path.append("../")
from tqdm import trange
import jax
import jax.numpy as np
import optax
import flax
from flax.core import FrozenDict
from flax.training import checkpoints, common_utils, train_state
import tensorflow as tf
from eval import eval_generation, eval_likelihood
from models.diffusion import VariationalDiffusionModel
from models.diffusion_utils import loss_vdm
from models.train_utils import (
create_input_iter,
param_count,
train_step,
to_wandb_config,
)
from datasets import load_data, augment_data
replicate = flax.jax_utils.replicate
unreplicate = flax.jax_utils.unreplicate
logging.set_verbosity(logging.INFO)
def train(
config: ml_collections.ConfigDict, workdir: str = "./logging/"
) -> train_state.TrainState:
# Set up wandb run
if config.wandb.log_train and jax.process_index() == 0:
wandb_config = to_wandb_config(config)
run = wandb.init(
entity=config.wandb.entity,
project=config.wandb.project,
job_type=config.wandb.job_type,
group=config.wandb.group,
config=wandb_config,
)
wandb.define_metric(
"*", step_metric="train/step"
) # Set default x-axis as 'train/step'
workdir = os.path.join(workdir, run.group, run.name)
# Recursively create workdir
os.makedirs(workdir, exist_ok=True)
# Dump a yaml config file in the output directory
with open(os.path.join(workdir, "config.yaml"), "w") as f:
yaml.dump(config.to_dict(), f)
writer = metric_writers.create_default_writer(
logdir=workdir, just_logging=jax.process_index() != 0
)
# Load the dataset
train_ds, norm_dict = load_data(
config.data.dataset,
config.data.n_features,
config.data.n_particles,
config.training.batch_size,
config.seed,
shuffle=True,
split="train",
simulation_set=config.data.simulation_set,
conditioning_parameters=config.data.conditioning_parameters,
# **config.data.kwargs,
)
add_augmentations = (
True if config.data.add_rotations or config.data.add_translations else False
)
batches = create_input_iter(train_ds)
logging.info("Loaded the %s dataset", config.data.dataset)
## Model configuration
# Score and (optional) encoder model configs
score_dict = FrozenDict(config.score)
encoder_dict = FrozenDict(config.encoder)
decoder_dict = FrozenDict(config.decoder)
# Diffusion model
x_mean = tuple(map(float, norm_dict["mean"]))
x_std = tuple(map(float, norm_dict["std"]))
norm_dict_input = FrozenDict(
{
"x_mean": x_mean,
"x_std": x_std,
"box_size": config.data.box_size,
}
)
vdm = VariationalDiffusionModel(
d_feature=config.data.n_features,
timesteps=config.vdm.timesteps,
noise_schedule=config.vdm.noise_schedule,
noise_scale=config.vdm.noise_scale,
d_t_embedding=config.vdm.d_t_embedding,
gamma_min=config.vdm.gamma_min,
gamma_max=config.vdm.gamma_max,
score=config.score.score,
score_dict=score_dict,
embed_context=config.vdm.embed_context,
d_context_embedding=config.vdm.d_context_embedding,
n_classes=config.vdm.n_classes,
use_encdec=config.vdm.use_encdec,
encoder_dict=encoder_dict,
decoder_dict=decoder_dict,
norm_dict=norm_dict_input,
)
rng = jax.random.PRNGKey(config.seed)
rng, rng_params = jax.random.split(rng)
# Pass a test batch through to initialize model
# TODO: Make so we don't have to pass an entire batch (slow)
x_batch, conditioning_batch, mask_batch = next(batches)
_, params = vdm.init_with_output(
{"sample": rng, "params": rng_params},
x_batch[0],
conditioning_batch[0] if conditioning_batch is not None else None,
mask_batch[0],
)
logging.info("Instantiated the model")
logging.info("Number of parameters: %d", param_count(params))
## Training config and loop
# Default schedule if not specified
if not hasattr(config.optim, "lr_schedule"):
config.optim.lr_schedule = "cosine"
if config.optim.lr_schedule == "cosine":
schedule = optax.warmup_cosine_decay_schedule(
init_value=0.0,
peak_value=config.optim.learning_rate,
warmup_steps=config.training.warmup_steps,
decay_steps=config.training.n_train_steps,
)
elif config.optim.lr_schedule == "constant":
schedule = optax.constant_schedule(config.optim.learning_rate)
else:
raise ValueError(f"Invalid learning rate schedule: {config.optim.lr_schedule}")
tx = optax.adamw(learning_rate=schedule, weight_decay=config.optim.weight_decay)
# Check if config.optim.grad_clip exists, if so add gradient clipping
if hasattr(config.optim, "grad_clip"):
if config.optim.grad_clip is not None:
tx = optax.chain(
optax.clip(config.optim.grad_clip),
tx,
)
state = train_state.TrainState.create(apply_fn=vdm.apply, params=params, tx=tx)
pstate = replicate(state)
logging.info("Starting training...")
train_metrics = []
with trange(config.training.n_train_steps) as steps:
for step in steps:
rng, *train_step_rng = jax.random.split(
rng, num=jax.local_device_count() + 1
)
train_step_rng = np.asarray(train_step_rng)
x, conditioning, mask = next(batches)
if add_augmentations:
x, conditioning, mask = augment_data(
x=x,
mask=mask,
conditioning=conditioning,
rng=rng,
norm_dict=norm_dict,
n_pos_dim=config.data.n_pos_features,
box_size=config.data.box_size,
rotations=config.data.add_rotations,
translations=config.data.add_translations,
)
pstate, metrics = train_step(
pstate,
(x, conditioning, mask),
train_step_rng,
vdm,
loss_vdm,
config.training.unconditional_dropout,
config.training.p_uncond,
)
steps.set_postfix(val=unreplicate(metrics["loss"]))
train_metrics.append(metrics)
# Log periodically
if (
(step % config.training.log_every_steps == 0)
and (step != 0)
and (jax.process_index() == 0)
):
train_metrics = common_utils.get_metrics(train_metrics)
summary = {
f"train/{k}": v
for k, v in jax.tree_map(lambda x: x.mean(), train_metrics).items()
}
writer.write_scalars(step, summary)
train_metrics = []
if config.wandb.log_train:
wandb.log({"train/step": step, **summary})
# Eval periodically
if (
(step % config.training.eval_every_steps == 0)
and (step != 0)
and (jax.process_index() == 0)
and (config.wandb.log_train)
):
if conditioning_batch is not None:
eval_likelihood(
vdm=vdm,
pstate=unreplicate(pstate),
rng=rng,
true_samples=x_batch.reshape((-1, *x_batch.shape[2:])),
conditioning=conditioning_batch.reshape(
(-1, *conditioning_batch.shape[2:])
),
mask=mask_batch.reshape((-1, *mask_batch.shape[2:])),
)
eval_generation(
vdm=vdm,
pstate=unreplicate(pstate),
rng=rng,
n_samples=config.training.batch_size,
n_particles=x_batch.shape[-2], # config.data.n_particles,
true_samples=x_batch.reshape((-1, *x_batch.shape[2:])),
conditioning=conditioning_batch.reshape(
(-1, *conditioning_batch.shape[2:])
)
if conditioning_batch is not None
else None,
mask=mask_batch.reshape((-1, *mask_batch.shape[2:])),
norm_dict=norm_dict,
steps=500,
boxsize=config.data.box_size,
)
# Save checkpoints periodically
if (
(step % config.training.save_every_steps == 0)
and (step != 0)
and (jax.process_index() == 0)
):
state_ckpt = unreplicate(pstate)
checkpoints.save_checkpoint(
ckpt_dir=workdir,
target=state_ckpt,
step=step,
overwrite=True,
keep=np.inf,
)
logging.info("All done! Have a great day.")
return unreplicate(pstate)
if __name__ == "__main__":
FLAGS = flags.FLAGS
config_flags.DEFINE_config_file(
"config",
None,
"File path to the training or sampling hyperparameter configuration.",
lock_config=True,
)
# Parse flags
FLAGS(sys.argv)
# Ensure TF does not see GPU and grab all GPU memory
tf.config.experimental.set_visible_devices([], "GPU")
logging.info("JAX process: %d / %d", jax.process_index(), jax.process_count())
logging.info("JAX local devices: %r", jax.local_devices())
logging.info("JAX total visible devices: %r", jax.device_count())
# Start training run
train(config=FLAGS.config, workdir=FLAGS.config.wandb.workdir)