-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathmain.py
414 lines (331 loc) · 18 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
import os
import torch
import json
import time
import logging
import random
import argparse
import numpy as np
import itertools
from typing import List
from datetime import datetime
from tqdm import tqdm
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam, Optimizer
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.tensorboard import SummaryWriter
from transformers import get_linear_schedule_with_warmup
from arguments import get_args
from policy import Policy
from data_pool import DataPool
from reward import Reward, reward_to_toxicity
from utils.utils import ensure_dir, ceil_div, reduce_mean, reduce_sum, distinctness
logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO"))
log = logging.getLogger(__name__)
class PromptDataset(Dataset):
def __init__(self, path):
self.prompts = [json.loads(s.strip())["prompt"]["text"].strip() for s in open(path, 'r').readlines()]
def __len__(self):
return len(self.prompts)
def __getitem__(self, idx):
return {'prompt': self.prompts[idx]}
class PromptCollator(object):
def __init__(self, tokenizer):
self.tokenizer = tokenizer
def __call__(self, sequences):
prompts = [sequence['prompt'] for sequence in sequences]
encodings_dict = self.tokenizer(prompts, return_tensors="pt", padding=True)
input_ids = encodings_dict['input_ids']
attention_mask = encodings_dict['attention_mask']
return input_ids, attention_mask
class SequenceDataset(Dataset):
def __init__(self, data_pool: DataPool):
self.queries, self.responses, self.cat_tokens = data_pool.get_data()
def __len__(self):
return len(self.queries)
def __getitem__(self, idx):
return {'query': self.queries[idx],
'response': self.responses[idx],
'cat_tokens': self.cat_tokens[idx]
}
class SequenceCollator(object):
def __init__(self, tokenizer):
self.tokenizer = tokenizer
def __call__(self, sequences):
queries = [sequence['query'] for sequence in sequences]
responses = [sequence['response'] for sequence in sequences]
cat_ids = [self.tokenizer.convert_tokens_to_ids(sequence['cat_tokens']) for sequence in sequences]
query_encodings_dict = self.tokenizer(queries, return_tensors="pt", padding=True)
query_input_ids = query_encodings_dict['input_ids']
query_mask = query_encodings_dict['attention_mask']
query_input_ids = torch.cat([query_input_ids.new(cat_ids)[:, None], query_input_ids], dim=1)
query_mask = torch.cat([query_mask.new([1] * len(query_mask))[:, None], query_mask], dim=1)
response_encodings_dict = self.tokenizer(responses, return_tensors="pt", padding=True)
response_input_ids = response_encodings_dict['input_ids']
response_mask = response_encodings_dict['attention_mask']
return query_input_ids, query_mask, response_input_ids, response_mask
class FixedController:
def __init__(self, coef):
self.value = coef
def update(self, current, n_steps, lower_bound):
pass
class AdaptiveController:
def __init__(self, init_coef, target, horizon):
self.value = init_coef
self.target = target
self.horizon = horizon
def update(self, current, n_steps, lower_bound):
proportional_error = np.clip(current / self.target - 1, -0.2, 0.2)
if lower_bound:
mult = 1 + proportional_error * n_steps / self.horizon
else:
mult = 1 - proportional_error * n_steps / self.horizon
self.value *= mult
class ConditionTrainer:
def __init__(self,
params: argparse.Namespace,
policy: Policy,
ref_policy: Policy,
data_pool: DataPool,
score_model: Reward,
tree_tokens: List[str],
train_dataloader: DataLoader,
val_dataloader: DataLoader,
optimizer: Optimizer,
scheduler: LambdaLR):
self.params = params
self.policy = policy
self.ref_policy = ref_policy
self.data_pool = data_pool
self.score_model = score_model
self.optimizer = optimizer
self.scheduler = scheduler
self.train_dataloader = train_dataloader
self.val_dataloader = val_dataloader
self.writer = SummaryWriter()
if self.params.adaptive_kl:
self.kl_ctl = AdaptiveController(self.params.kl_coef, self.params.target_kl, self.params.horizon)
else:
self.kl_ctl = FixedController(self.params.kl_coef)
self.kl_loss = torch.nn.KLDivLoss(reduction="none")
if self.params.adaptive_entropy:
self.entropy_ctl = AdaptiveController(self.params.entropy_coef, self.params.target_entropy,
self.params.horizon)
else:
self.entropy_ctl = FixedController(self.params.entropy_coef)
self.tree_tokens = tree_tokens
self.best_cat = self.tree_tokens[0]
self.best_cat_id = self.policy.tokenizer.convert_tokens_to_ids(self.best_cat)
self.sample_dataloader, self.sampler = None, None
self.seq_collator = SequenceCollator(tokenizer=policy.tokenizer)
def add_control_code(self, input_ids, attention_mask):
input_ids = torch.cat([input_ids.new([self.best_cat_id] * len(input_ids))[:, None], input_ids], dim=1)
attention_mask = torch.cat([attention_mask.new([1] * len(attention_mask))[:, None], attention_mask], dim=1)
return input_ids, attention_mask
def decode(self, query_input_ids, response_input_ids=None):
query = [self.policy.tokenizer.decode(p, skip_special_tokens=True, clean_up_tokenization_spaces=True)
for p in query_input_ids]
if response_input_ids is None:
return query
response = [self.policy.tokenizer.decode(r, skip_special_tokens=True, clean_up_tokenization_spaces=True)
for r in response_input_ids]
return query, response
def sample(self, step):
if step % self.params.sample_interval != 0:
return
log.info(f"[step {step}] Sampling ...")
prompts, responses = [], []
for i, batch in enumerate(tqdm(self.train_dataloader, total=len(self.train_dataloader),
desc='Sampling from current policy')):
input_ids, attention_mask = batch
if step == 0:
rollouts = self.ref_policy.sample(input_ids=input_ids, attention_mask=attention_mask, top_p=self.params.top_p)
prompt, response = rollouts['query/text'], rollouts['response/text']
else:
input_ids, attention_mask = self.add_control_code(input_ids, attention_mask)
rollouts = self.policy.sample(input_ids=input_ids, attention_mask=attention_mask, top_p=self.params.top_p)
response = rollouts['response/text']
prompt = self.decode(rollouts['query/input_ids'][:, 1:])
prompts.extend(prompt)
responses.extend(response)
scores = self.score_model.get_reward(prompts, responses, f'step{step}')
self.data_pool.add(prompts=prompts, responses=responses, scores=scores)
sample_dataset = SequenceDataset(data_pool=self.data_pool)
self.sample_dataloader = DataLoader(sample_dataset, batch_size=self.params.batch_size,
shuffle=True, drop_last=True, collate_fn=self.seq_collator)
self.sampler = iter(self.sample_dataloader)
def step(self, step_num):
step_started_at = time.time()
self.sample(step=step_num)
try:
batch = next(self.sampler)
assert len(batch[0]) == self.params.batch_size, 'insufficient batch'
except (StopIteration, AssertionError):
self.sampler = iter(self.sample_dataloader)
batch = next(self.sampler)
self.optimizer.zero_grad()
ppo_loss, stats = self.loss(step_num, *batch)
ppo_loss.backward()
if self.params.clip_grad:
torch.nn.utils.clip_grad_norm_(self.policy.model.parameters(), self.params.max_grad_norm)
self.optimizer.step()
self.scheduler.step()
for metric in ['kl', 'entropy']:
self.writer.add_scalar(f'Objective/{metric}', stats[f'objective/{metric}'], step_num)
for metric in ['lm', 'kl', 'entropy', 'total']:
self.writer.add_scalar(f'Loss/{metric}', stats[f'loss/{metric}'], step_num)
self.writer.add_scalar(f'Params/lr', self.optimizer.param_groups[0]['lr'], step_num)
self.writer.add_scalar(f'Params/kl_coef', self.kl_ctl.value, step_num)
self.writer.add_scalar(f'Params/entropy_coef', self.entropy_ctl.value, step_num)
self.kl_ctl.update(stats['objective/kl'], self.params.batch_size, True)
self.entropy_ctl.update(stats['objective/entropy'], self.params.batch_size, False)
step_time = time.time() - step_started_at
eps_per_second = float(self.params.batch_size) / step_time
log.info(f"[step {step_num}] step_time={step_time:.2f}s, eps/s={eps_per_second:.2f}")
self.save(step=step_num)
self.eval(step=step_num)
def loss(self, step, query_input_ids, query_mask, response_input_ids, response_mask):
outputs = self.policy.forward_pass(query_input_ids, query_mask, response_input_ids, response_mask)
lm_loss, logprobs, entropy, logits = outputs['response/lm_loss'], outputs['response/log_prob'], \
outputs['response/entropy'], outputs['response/logits']
logits = outputs['response/logits'][:, :, :-len(self.tree_tokens)]
masks = response_mask.to(self.policy.device)
with torch.no_grad():
ref_outputs = self.ref_policy.forward_pass(query_input_ids[:, 1:], query_mask[:, 1:],
response_input_ids, response_mask)
ref_logprobs, ref_logits = ref_outputs['response/log_prob'], ref_outputs['response/logits']
kl = torch.sum(self.kl_loss(F.log_softmax(ref_logits, dim=-1), F.softmax(logits, dim=-1)), dim=-1)
loss = reduce_mean(lm_loss + self.kl_ctl.value * kl - self.entropy_ctl.value * entropy, masks)
data = {'logprobs': logprobs, 'ref_logprobs': ref_logprobs, 'masks': masks,
'logits': logits, 'ref_logits': ref_logits,
'lm_loss': reduce_mean(lm_loss, masks), 'kl_loss': reduce_mean(kl, masks),
'entropy': reduce_mean(entropy, masks), 'total_loss': loss}
stats = self.record_step_stats(data)
queries, responses = self.decode(query_input_ids, response_input_ids)
self.print_samples(queries=queries, responses=responses, lm_loss=reduce_mean(lm_loss, masks, axis=1),
logprobs=logprobs, ref_logprobs=ref_logprobs, masks=masks, step=step)
return loss, stats
def record_step_stats(self, data):
masks = data['masks']
kl = torch.sum(self.kl_loss(F.log_softmax(data['ref_logits'], dim=-1), F.softmax(data['logits'], dim=-1)), dim=-1)
mean_kl = torch.mean(reduce_sum(kl, masks, axis=1))
mean_entropy = torch.mean(reduce_sum(-data['logprobs'], masks, axis=1))
stats = {
'objective/kl': mean_kl.item(),
'objective/entropy': mean_entropy.item(),
}
stats.update({
'loss/total': data['total_loss'].item(),
'loss/kl': data['kl_loss'].item(),
'loss/lm': data['lm_loss'].item(),
'loss/entropy': data['entropy'].item(),
})
return stats
def print_samples(self, queries, responses, lm_loss, logprobs, ref_logprobs, masks, step):
if step % self.params.log_interval != 0:
return
# Log samples
for i in range(min(3, len(queries))):
sample_kl = torch.sum((logprobs[i] - ref_logprobs[i]) * masks[i]).item()
print(queries[i] + responses[i])
print(f" lm_loss = {lm_loss[i].item():+.2f}")
print(f" kl = {sample_kl:+.2f}")
print(f" total = {lm_loss[i].item() + self.params.kl_coef * sample_kl:+.2f}")
def save(self, step):
if step % self.params.save_interval != 0:
return
torch.save({
'policy_model': self.policy.model.state_dict(),
'optimizer': self.optimizer.state_dict(),
'scheduler': self.scheduler.state_dict()
}, f'{self.params.model_dir}/ckp_{step}.pth')
log.info(f"[step {step}] model checkpoint saved")
def eval(self, step):
if step % self.params.eval_interval != 0:
return
log.info(f"[step {step}] evaluating ...")
generations, perplexities, toxicities = [], [], []
for i, (input_ids, attention_mask) in enumerate(tqdm(self.val_dataloader)):
with torch.no_grad():
input_ids, attention_mask = self.add_control_code(input_ids, attention_mask)
rollouts = self.policy.sample(input_ids=input_ids, attention_mask=attention_mask, top_p=self.params.top_p)
forward_inputs = {'query_input_ids': rollouts['query/input_ids'][:, 1:],
'query_mask': rollouts['query/mask'][:, 1:],
'response_input_ids': rollouts['response/input_ids'],
'response_mask': rollouts['response/mask']}
ref_logprobs = self.ref_policy.forward_pass(**forward_inputs)['response/log_prob']
perplexity = -1. * reduce_sum(ref_logprobs, rollouts['response/mask'].float(), axis=1)
perplexities.extend(perplexity.cpu().detach().numpy().tolist())
prompt = self.decode(rollouts['query/input_ids'][:, 1:])
response = rollouts['response/text']
score = self.score_model.get_reward(prompt, response, f'step{step}_eval{i}')
toxicity = [reward_to_toxicity(x) for x in score if x is not None]
toxicities.extend(toxicity)
generations.extend(rollouts['response/text'])
ppl_score, toxicity_score = np.mean(perplexities), np.mean(toxicities)
dist_1, dist_2, dist_3 = distinctness(generations)
print(f" perplexity = {ppl_score:+.2f}")
print(f" toxicity = {toxicity_score:+.2f}")
print(f'dist-1={dist_1:.3f}, dist-2={dist_2:.3f}, dist-3={dist_3:.3f}')
self.writer.add_scalar('Evaluation/perplexity', ppl_score, step)
self.writer.add_scalar('Evaluation/toxicity', toxicity_score, step)
self.writer.add_scalar('Evaluation/Dist-1', dist_1, step)
self.writer.add_scalar('Evaluation/Dist-2', dist_2, step)
self.writer.add_scalar('Evaluation/Dist-3', dist_3, step)
def main():
args = get_args()
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
if args.cuda and torch.cuda.is_available() and args.cuda_deterministic:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
num_gpus = torch.cuda.device_count()
log.info(f'Detect {num_gpus} GPUS')
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
time = datetime.now()
date_time = time.strftime("%m-%d-%Y_%H:%M:%S")
args.save_dir = os.path.join(args.output_dir, date_time)
args.reward_dir = os.path.join(args.save_dir, 'reward')
args.model_dir = os.path.join(args.save_dir, 'model')
args.tensorboard_dir = os.path.join(args.save_dir, 'tensorboard')
for d in [args.output_dir, args.save_dir, args.reward_dir, args.model_dir, args.tensorboard_dir]:
ensure_dir(d)
log.info(f'Write to output directory: {args.save_dir}')
with open(os.path.join(args.save_dir, 'args.json'), 'w') as f:
json.dump(args.__dict__, f, indent=2)
tree_tokens = [' _TREE_TOKEN_{}'.format(str(idx).zfill(5)) for idx in range(args.n_extra_tokens)] + \
[' _TREE_TOKEN_ZERO_COMMENTS']
log.info(f'Initializing models ...')
ref_policy = Policy(model_name=args.init_model, temperature=args.temperature, device=device)
policy = Policy(model_name=args.ref_model, temperature=args.temperature, device=device,
reward_cond=True, tree_tokens=tree_tokens)
reward = Reward(save_path=args.reward_dir, rate_limit=args.perspective_rate_limit, batch_size=args.batch_size)
data_pool = DataPool(tree_tokens=tree_tokens, n_extra_tokens=args.n_extra_tokens)
log.info(f'Initialization done!')
prompt_collator = PromptCollator(tokenizer=policy.tokenizer)
train_dataset = PromptDataset(path=args.dataset_train)
train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, collate_fn=prompt_collator)
log.info(f'Load train set with {len(train_dataset)} examples')
val_dataset = PromptDataset(path=args.dataset_val)
val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, collate_fn=prompt_collator)
log.info(f'Load val set with {len(val_dataset)} examples')
# set up optimizer and scheduler
optimizer = Adam(policy.model.parameters(), lr=args.lr, eps=1e-5)
args.total_steps = ceil_div(args.total_episodes, args.batch_size)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.num_warmup_steps, num_training_steps=args.total_steps)
trainer = ConditionTrainer(params=args, policy=policy, ref_policy=ref_policy, data_pool=data_pool,
score_model=reward, tree_tokens=tree_tokens,
train_dataloader=train_dataloader, val_dataloader=val_dataloader,
optimizer=optimizer, scheduler=scheduler)
for step_num in range(args.total_steps):
try:
trainer.step(step_num)
except RuntimeError:
torch.cuda.empty_cache()
continue
if __name__ == "__main__":
main()