-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathBertTrainModel.py
75 lines (61 loc) · 3.01 KB
/
BertTrainModel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import torch
from utils.compute_jaccard_score import compute_jaccard_score
import tqdm
def bert_train_model(model, dataloaders_dict, loss, optimizer, num_epochs, filename):
model.cuda()
for epoch in range(num_epochs):
# Mỗi epoch sẽ thực hiện 2 phase
for phase in ['train', 'val']:
# Nếu phase train thì huấn luyện, phase val thì tính loss và jaccard
if phase == 'train':
model.train()
else:
model.eval()
# Khởi tạo loss và jaccard
epoch_loss = 0.0
epoch_jaccard = 0.0
for data in tqdm.tqdm((dataloaders_dict[phase])):
# Lấy thông tin dữ liệu
ids = data['ids'].cuda()
masks = data['masks'].cuda()
tweet = data['tweet']
offsets = data['offsets'].numpy()
token_type_id = data['token_type_ids'].cuda()
start_idx = data['start_idx'].cuda()
end_idx = data['end_idx'].cuda()
# Reset tích lũy đạo hàm
optimizer.zero_grad()
with torch.set_grad_enabled(phase == 'train'):
start_logits, end_logits = model(ids, masks, token_type_id)
loss_value = loss(
start_logits, end_logits, start_idx, end_idx)
# nếu là phase train thì thực hiện lan truyền ngược
# và cập nhật tham số
if phase == 'train':
loss_value.backward()
optimizer.step()
epoch_loss += loss_value.item() * len(ids)
start_idx = start_idx.cpu().detach().numpy()
end_idx = end_idx.cpu().detach().numpy()
start_logits = torch.softmax(
start_logits, dim=1).cpu().detach().numpy()
end_logits = torch.softmax(
end_logits, dim=1).cpu().detach().numpy()
# Tính toán jaccard cho tất cả các câu
for i in range(len(ids)):
jaccard_score = compute_jaccard_score(
tweet[i],
start_idx[i],
end_idx[i],
start_logits[i],
end_logits[i],
offsets[i]
)
epoch_jaccard += jaccard_score
# Trung bình loss và jaccard
epoch_loss = epoch_loss / len(dataloaders_dict[phase].dataset)
epoch_jaccard = epoch_jaccard / \
len(dataloaders_dict[phase].dataset)
print("Epoch {}/{} | {:^5} | Loss: {:.4f} | Jaccard: {:.4f}".format(epoch +
1, num_epochs, phase, epoch_loss, epoch_jaccard))
torch.save(model.state_dict(), filename)