-
Notifications
You must be signed in to change notification settings - Fork 108
/
Copy patheval.py
30 lines (24 loc) · 1.04 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# -*- coding: utf-8 -*-
from transformers import BertTokenizer, BertForMaskedLM
from load_data import valdataloader
from sklearn import metrics
import numpy as np
from tqdm import tqdm
tokenizer = BertTokenizer.from_pretrained('./bert-base-chinese')
model = BertForMaskedLM.from_pretrained('./bert-base-chinese')
model.eval()
pos_id = tokenizer.convert_tokens_to_ids('很')
neg_id = tokenizer.convert_tokens_to_ids('不')
mask_idx = 1
predict_all = np.array([], dtype=int)
labels_all = np.array([], dtype=int)
pbar = tqdm(valdataloader)
for batch_idx, batch_data in enumerate(pbar):
outputs = model(batch_data["input_ids"], batch_data["attention_mask"])
prediction_scores = outputs[0]
y_pred = prediction_scores[:, mask_idx, [neg_id, pos_id]].argmax(axis=1)
predict_all = np.append(predict_all, y_pred.cpu().numpy())
y_true = (batch_data["labels"][:, mask_idx] == pos_id).long()
labels_all = np.append(labels_all, y_true.cpu().numpy())
acc = metrics.accuracy_score(labels_all, predict_all)
print(acc)