-
Notifications
You must be signed in to change notification settings - Fork 67
/
Copy pathmain.py
executable file
·324 lines (280 loc) · 11.5 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
import config
from math import log2
import os
import numpy as np
import ujson as json
import re
from collections import Counter
import string
from tqdm import tqdm
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.cuda
from torch.utils.data import Dataset
import argparse
'''
Some functions are from the official evaluation script.
'''
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class SQuADDataset(Dataset):
def __init__(self, npz_file, num_steps, batch_size):
super().__init__()
data = np.load(npz_file)
self.context_idxs = torch.from_numpy(data["context_idxs"]).long()
self.context_char_idxs = torch.from_numpy(data["context_char_idxs"]).long()
self.ques_idxs = torch.from_numpy(data["ques_idxs"]).long()
self.ques_char_idxs = torch.from_numpy(data["ques_char_idxs"]).long()
self.y1s = torch.from_numpy(data["y1s"]).long()
self.y2s = torch.from_numpy(data["y2s"]).long()
self.ids = torch.from_numpy(data["ids"]).long()
num = len(self.ids)
self.batch_size = batch_size
self.num_steps = num_steps if num_steps >= 0 else num // batch_size
num_items = num_steps * batch_size
idxs = list(range(num))
self.idx_map = []
i, j = 0, num
while j <= num_items:
random.shuffle(idxs)
self.idx_map += idxs.copy()
i = j
j += num
random.shuffle(idxs)
self.idx_map += idxs[:num_items - i]
def __len__(self):
return self.num_steps
def __getitem__(self, item):
idxs = torch.LongTensor(self.idx_map[item:item + self.batch_size])
res = (self.context_idxs[idxs],
self.context_char_idxs[idxs],
self.ques_idxs[idxs],
self.ques_char_idxs[idxs],
self.y1s[idxs],
self.y2s[idxs], self.ids[idxs])
return res
class EMA(object):
def __init__(self, decay):
self.decay = decay
self.shadows = {}
self.devices = {}
def __len__(self):
return len(self.shadows)
def get(self, name: str):
return self.shadows[name].to(self.devices[name])
def set(self, name: str, param: nn.Parameter):
self.shadows[name] = param.data.to('cpu').clone()
self.devices[name] = param.data.device
def update_parameter(self, name: str, param: nn.Parameter):
if name in self.shadows:
data = param.data
new_shadow = self.decay * data + (1.0 - self.decay) * self.get(name)
param.data.copy_(new_shadow)
self.shadows[name] = new_shadow.to('cpu').clone()
def convert_tokens(eval_file, qa_id, pp1, pp2):
answer_dict = {}
remapped_dict = {}
for qid, p1, p2 in zip(qa_id, pp1, pp2):
context = eval_file[str(qid)]["context"]
spans = eval_file[str(qid)]["spans"]
uuid = eval_file[str(qid)]["uuid"]
l = len(spans)
if p1 >= l or p2 >= l:
ans = ""
else:
start_idx = spans[p1][0]
end_idx = spans[p2][1]
ans = context[start_idx: end_idx]
answer_dict[str(qid)] = ans
remapped_dict[uuid] = ans
return answer_dict, remapped_dict
def evaluate(eval_file, answer_dict):
f1 = exact_match = total = 0
for key, value in answer_dict.items():
total += 1
ground_truths = eval_file[key]["answers"]
prediction = value
exact_match += metric_max_over_ground_truths(exact_match_score, prediction, ground_truths)
f1 += metric_max_over_ground_truths(f1_score, prediction, ground_truths)
exact_match = 100.0 * exact_match / total
f1 = 100.0 * f1 / total
return {'exact_match': exact_match, 'f1': f1}
def normalize_answer(s):
def remove_articles(text):
return re.sub(r'\b(a|an|the)\b', ' ', text)
def white_space_fix(text):
return ' '.join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return ''.join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def f1_score(prediction, ground_truth):
prediction_tokens = normalize_answer(prediction).split()
ground_truth_tokens = normalize_answer(ground_truth).split()
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
num_same = sum(common.values())
if num_same == 0:
return 0
precision = 1.0 * num_same / len(prediction_tokens)
recall = 1.0 * num_same / len(ground_truth_tokens)
f1 = (2 * precision * recall) / (precision + recall)
return f1
def exact_match_score(prediction, ground_truth):
return (normalize_answer(prediction) == normalize_answer(ground_truth))
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
scores_for_ground_truths = []
for ground_truth in ground_truths:
score = metric_fn(prediction, ground_truth)
scores_for_ground_truths.append(score)
return max(scores_for_ground_truths)
def train(model, optimizer, scheduler, ema, dataset, start, length):
model.train()
losses = []
for i in tqdm(range(start, length + start), total=length):
optimizer.zero_grad()
Cwid, Ccid, Qwid, Qcid, y1, y2, ids = dataset[i]
Cwid, Ccid, Qwid, Qcid = Cwid.to(device), Ccid.to(device), Qwid.to(device), Qcid.to(device)
p1, p2 = model(Cwid, Ccid, Qwid, Qcid)
y1, y2 = y1.to(device), y2.to(device)
loss1 = F.nll_loss(p1, y1, reduction='mean')
loss2 = F.nll_loss(p2, y2, reduction='mean')
loss = (loss1 + loss2) / 2
losses.append(loss.item())
loss.backward()
optimizer.step()
scheduler.step()
for name, p in model.named_parameters():
if p.requires_grad: ema.update_parameter(name, p)
torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)
loss_avg = np.mean(losses)
print("STEP {:8d} loss {:8f}\n".format(i + 1, loss_avg))
def valid(model, dataset, eval_file):
model.eval()
answer_dict = {}
losses = []
num_batches = config.val_num_batches
with torch.no_grad():
for i in tqdm(random.sample(range(0, len(dataset)), num_batches), total=num_batches):
Cwid, Ccid, Qwid, Qcid, y1, y2, ids = dataset[i]
Cwid, Ccid, Qwid, Qcid = Cwid.to(device), Ccid.to(device), Qwid.to(device), Qcid.to(device)
p1, p2 = model(Cwid, Ccid, Qwid, Qcid)
y1, y2 = y1.to(device), y2.to(device)
loss1 = F.nll_loss(p1, y1, reduction='mean')
loss2 = F.nll_loss(p2, y2, reduction='mean')
loss = (loss1 + loss2) / 2
losses.append(loss.item())
yp1 = torch.argmax(p1, 1)
yp2 = torch.argmax(p2, 1)
yps = torch.stack([yp1, yp2], dim=1)
ymin, _ = torch.min(yps, 1)
ymax, _ = torch.max(yps, 1)
answer_dict_, _ = convert_tokens(eval_file, ids.tolist(), ymin.tolist(), ymax.tolist())
answer_dict.update(answer_dict_)
loss = np.mean(losses)
metrics = evaluate(eval_file, answer_dict)
metrics["loss"] = loss
print("VALID loss {:8f} F1 {:8f} EM {:8f}\n".format(loss, metrics["f1"], metrics["exact_match"]))
def test(model, dataset, eval_file):
model.eval()
answer_dict = {}
losses = []
num_batches = config.test_num_batches
with torch.no_grad():
for i in tqdm(range(num_batches), total=num_batches):
Cwid, Ccid, Qwid, Qcid, y1, y2, ids = dataset[i]
Cwid, Ccid, Qwid, Qcid = Cwid.to(device), Ccid.to(device), Qwid.to(device), Qcid.to(device)
p1, p2 = model(Cwid, Ccid, Qwid, Qcid)
y1, y2 = y1.to(device), y2.to(device)
loss1 = F.nll_loss(p1, y1, reduction='mean')
loss2 = F.nll_loss(p2, y2, reduction='mean')
loss = (loss1 + loss2) / 2
losses.append(loss.item())
yp1 = torch.argmax(p1, 1)
yp2 = torch.argmax(p2, 1)
yps = torch.stack([yp1, yp2], dim=1)
ymin, _ = torch.min(yps, 1)
ymax, _ = torch.max(yps, 1)
answer_dict_, _ = convert_tokens(eval_file, ids.tolist(), ymin.tolist(), ymax.tolist())
answer_dict.update(answer_dict_)
loss = np.mean(losses)
metrics = evaluate(eval_file, answer_dict)
f = open("log/answers.json", "w")
json.dump(answer_dict, f)
f.close()
metrics["loss"] = loss
print("TEST loss {:8f} F1 {:8f} EM {:8f}\n".format(loss, metrics["f1"], metrics["exact_match"]))
return metrics
def train_entry():
from models import QANet
with open(config.word_emb_file, "r") as fh:
word_mat = np.array(json.load(fh), dtype=np.float32)
with open(config.char_emb_file, "r") as fh:
char_mat = np.array(json.load(fh), dtype=np.float32)
with open(config.train_eval_file, "r") as fh:
train_eval_file = json.load(fh)
with open(config.dev_eval_file, "r") as fh:
dev_eval_file = json.load(fh)
print("Building model...")
train_dataset = SQuADDataset(config.train_record_file, config.num_steps, config.batch_size)
dev_dataset = SQuADDataset(config.dev_record_file, config.test_num_batches, config.batch_size)
lr = config.learning_rate
base_lr = 1.0
warm_up = config.lr_warm_up_num
model = QANet(word_mat, char_mat).to(device)
ema = EMA(config.ema_decay)
for name, p in model.named_parameters():
if p.requires_grad: ema.set(name, p)
params = filter(lambda param: param.requires_grad, model.parameters())
optimizer = optim.Adam(lr=base_lr, betas=(config.beta1, config.beta2), eps=1e-7, weight_decay=3e-7, params=params)
cr = lr / log2(warm_up)
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda ee: cr * log2(ee + 1) if ee < warm_up else lr)
L = config.checkpoint
N = config.num_steps
best_f1 = best_em = patience = 0
for iter in range(0, N, L):
train(model, optimizer, scheduler, ema, train_dataset, iter, L)
valid(model, train_dataset, train_eval_file)
metrics = test(model, dev_dataset, dev_eval_file)
print("Learning rate: {}".format(scheduler.get_lr()))
dev_f1 = metrics["f1"]
dev_em = metrics["exact_match"]
if dev_f1 < best_f1 and dev_em < best_em:
patience += 1
if patience > config.early_stop: break
else:
patience = 0
best_f1 = max(best_f1, dev_f1)
best_em = max(best_em, dev_em)
fn = os.path.join(config.save_dir, "model.pt")
torch.save(model, fn)
def test_entry():
with open(config.dev_eval_file, "r") as fh:
dev_eval_file = json.load(fh)
dev_dataset = SQuADDataset(config.dev_record_file, -1, config.batch_size)
fn = os.path.join(config.save_dir, "model.pt")
model = torch.load(fn)
test(model, dev_dataset, dev_eval_file)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--mode", action="store", dest="mode", default="train", help="train/test/debug")
pargs = parser.parse_args()
print("Current device is {}".format(device))
if pargs.mode == "train":
train_entry()
elif pargs.mode == "debug":
config.batch_size = 2
config.num_steps = 32
config.test_num_batches = 2
config.val_num_batches = 2
config.checkpoint = 2
config.period = 1
train_entry()
elif pargs.mode == "test":
test_entry()
else:
print("Unknown mode")
exit(0)