-
Notifications
You must be signed in to change notification settings - Fork 52
/
train.py
256 lines (217 loc) · 10.5 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
import json
import logging
import math
import os
import time
import numpy as np
import torch
import torch.nn.functional as F
from torch.nn.parallel.distributed import DistributedDataParallel
from training.distributed import is_master
from training.precision import get_autocast
try:
import wandb
except ImportError:
wandb = None
from open_clip import get_input_dtype, CLIP, CustomTextCLIP
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def postprocess_clip_output(model_out):
return {
"image_features": model_out[0],
"text_features": model_out[1],
"logit_scale": model_out[2]
}
def unwrap_model(model):
if hasattr(model, 'module'):
return model.module
else:
return model
def backward(total_loss, scaler):
if scaler is not None:
scaler.scale(total_loss).backward()
else:
total_loss.backward()
def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, dist_model, args, tb_writer=None):
device = torch.device(args.device)
autocast = get_autocast(args.precision)
input_dtype = get_input_dtype(args.precision)
model.train()
if args.distill:
dist_model.eval()
data[f'{args.clip_type}_pt'].set_epoch(epoch) # set epoch in process safe manner via sampler or shared_epoch
dataloader = data[f'{args.clip_type}_pt'].dataloader
num_batches_per_epoch = dataloader.num_batches // args.accum_freq
sample_digits = math.ceil(math.log(dataloader.num_samples + 1, 10))
if args.accum_freq > 1:
accum_images, accum_input_ids, accum_attention_mask, accum_features = [], [], [], {}
losses_m = {}
batch_time_m = AverageMeter()
data_time_m = AverageMeter()
end = time.time()
for i, batch in enumerate(dataloader):
i_accum = i // args.accum_freq
step = num_batches_per_epoch * epoch + i_accum
if not args.skip_scheduler:
scheduler(step)
images, input_ids, attention_mask = batch
images = images.to(device=device, dtype=input_dtype, non_blocking=True)
input_ids = input_ids.to(device=device, non_blocking=True)
attention_mask = attention_mask.to(device=device, non_blocking=True)
data_time_m.update(time.time() - end)
optimizer.zero_grad()
if args.accum_freq == 1:
with autocast():
model_out = model(images, input_ids, attention_mask)
logit_scale = model_out["logit_scale"]
if args.distill:
with torch.no_grad():
dist_model_out = dist_model(images, input_ids, attention_mask)
model_out.update({f'dist_{k}' : v for k, v in dist_model_out.items()})
losses = loss(**model_out, output_dict=True)
total_loss = sum(losses.values())
losses["loss"] = total_loss
backward(total_loss, scaler)
else:
# First, cache the features without any gradient tracking.
with torch.no_grad():
with autocast():
model_out = model(images, input_ids, attention_mask)
model_out.pop("logit_scale")
for key, val in model_out.items():
if key in accum_features:
accum_features[key].append(val)
else:
accum_features[key] = [val]
accum_images.append(images)
accum_input_ids.append(input_ids)
accum_attention_mask.append(attention_mask)
# If (i + 1) % accum_freq is not zero, move on to the next batch.
if ((i + 1) % args.accum_freq) > 0:
# FIXME this makes data time logging unreliable when accumulating
continue
# Now, ready to take gradients for the last accum_freq batches.
# Re-do the forward pass for those batches, and use the cached features from the other batches as negatives.
# Call backwards each time, but only step optimizer at the end.
optimizer.zero_grad()
for j in range(args.accum_freq):
images = accum_images[j]
input_ids = accum_input_ids[j]
attention_mask = accum_attention_mask[j]
with autocast():
model_out = model(images, input_ids, attention_mask)
logit_scale = model_out.pop("logit_scale")
inputs = {}
for key, val in accum_features.items():
accumulated = accum_features[key]
inputs[key] = torch.cat(accumulated[:j] + [model_out[key]] + accumulated[j + 1:])
losses = loss(**inputs, logit_scale=logit_scale, output_dict=True)
del inputs
total_loss = sum(losses.values())
losses["loss"] = total_loss
backward(total_loss, scaler)
if scaler is not None:
if args.horovod:
optimizer.synchronize()
scaler.unscale_(optimizer)
if args.grad_clip_norm is not None:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0)
with optimizer.skip_synchronize():
scaler.step(optimizer)
else:
if args.grad_clip_norm is not None:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0)
scaler.step(optimizer)
scaler.update()
else:
if args.grad_clip_norm is not None:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0)
optimizer.step()
# reset gradient accum, if enabled
if args.accum_freq > 1:
accum_images, accum_input_ids, accum_attention_mask, accum_features = [], [], [], {}
# Note: we clamp to 4.6052 = ln(100), as in the original paper.
with torch.no_grad():
unwrap_model(model).logit_scale.clamp_(0, math.log(100))
batch_time_m.update(time.time() - end)
end = time.time()
batch_count = i_accum + 1
if is_master(args) and (i_accum % args.log_every_n_steps == 0 or batch_count == num_batches_per_epoch):
batch_size = len(images)
num_samples = batch_count * batch_size * args.accum_freq * args.world_size
samples_per_epoch = dataloader.num_samples
percent_complete = 100.0 * batch_count / num_batches_per_epoch
# NOTE loss is coarsely sampled, just master node and per log update
for key, val in losses.items():
if key not in losses_m:
losses_m[key] = AverageMeter()
losses_m[key].update(val.item(), batch_size)
logit_scale_scalar = logit_scale.item()
# if args.add_time_attn:
# if hasattr(model, 'module'):
# t_gate = [[F.sigmoid(m.t_attn_gate).detach().item(), F.sigmoid(m.t_ffn_gate).detach().item()] for m in model.module.vision_model.encoder.layers]
# else:
# t_gate = [[F.sigmoid(m.t_attn_gate).detach().item(), F.sigmoid(m.t_ffn_gate).detach().item()] for m in model.vision_model.encoder.layers]
# t_attn_gate, t_ffn_gate = list(zip(*t_gate))
loss_log = " ".join(
[
f"{loss_name.capitalize()}: {loss_m.val:#.5g} ({loss_m.avg:#.5g})"
for loss_name, loss_m in losses_m.items()
]
)
samples_per_second = args.accum_freq * args.batch_size * args.world_size / batch_time_m.val
samples_per_second_per_gpu = args.accum_freq * args.batch_size / batch_time_m.val
# if args.add_time_attn:
# logging.info(
# f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] "
# f"Data (t): {data_time_m.avg:.3f} "
# f"Batch (t): {batch_time_m.avg:.3f}, {samples_per_second:#g}/s, {samples_per_second_per_gpu:#g}/s/gpu "
# f"LR: {optimizer.param_groups[0]['lr']:5f} "
# f"Logit Scale: {logit_scale_scalar:.3f} " + loss_log +
# f"\nt_attn_gate: {[round(i, 2) for i in t_attn_gate]}\nt_ffn_gate: {[round(i, 2) for i in t_ffn_gate]}\n"
# )
# else:
logging.info(
f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] "
f"Data (t): {data_time_m.avg:.3f} "
f"Batch (t): {batch_time_m.avg:.3f}, {samples_per_second:#g}/s, {samples_per_second_per_gpu:#g}/s/gpu "
f"LR: {optimizer.param_groups[0]['lr']:5f} "
f"Logit Scale: {logit_scale_scalar:.3f} " + loss_log
)
# Save train loss / etc. Using non avg meter values as loggers have their own smoothing
log_data = {
"data_time": data_time_m.val,
"batch_time": batch_time_m.val,
"samples_per_second": samples_per_second,
"samples_per_second_per_gpu": samples_per_second_per_gpu,
"scale": logit_scale_scalar,
"lr": optimizer.param_groups[0]["lr"]
}
log_data.update({name:val.val for name,val in losses_m.items()})
# if args.add_time_attn:
# log_data.update({f'layer_{i}_t_attn_gate': attn for i, attn in enumerate(t_attn_gate)})
# log_data.update({f'layer_{i}_t_ffn_gate': ffn for i, ffn in enumerate(t_ffn_gate)})
for name, val in log_data.items():
name = "train/" + name
if tb_writer is not None:
tb_writer.add_scalar(name, val, step)
if args.wandb:
assert wandb is not None, 'Please install wandb.'
wandb.log({name: val, 'step': step})
# resetting batch / data time meters per log window
batch_time_m.reset()
data_time_m.reset()
# end for