-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
48 lines (40 loc) · 1.17 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import numpy as np
import torch
from torch.autograd import Variable
import torch.nn as nn
from dataloader import test_data
import matplotlib.pyplot as plt
from model import net
criterion=torch.nn.BCELoss()
eval_losses = []
eval_acces = []
eval_loss = 0
eval_acc = 0
#加载模块参数
model = net
model.load_state_dict(torch.load("ReLU.pth"))
#把dropout关掉,不进行参数更新
model.eval()
for im, label in test_data:
im = Variable(im)
label = Variable(label)
out = model(im)
# print(out.squeeze(-1))
# print(label)
loss = criterion(out, label)
eval_loss += loss.item()
_, pred = out.max(1)
num_correct = (pred == label).sum().item()
acc = num_correct / im.shape[0]
eval_acc += acc
eval_losses.append(loss.item()/im.shape[0])
eval_acces.append(acc)
print('整体损失值:Eval Loss: {:.6f}, Eval Acc:{:.6f}'.format(eval_loss/len(test_data), eval_acc/len(test_data)))
np.save('b_64_test_losses.npy',eval_losses)
np.save('b_64_test_acces.npy',eval_acces)
plt.plot(np.arange(len(eval_losses)),eval_losses)
plt.title('b_64_test loss')
plt.show()
plt.plot(np.arange(len(eval_acces)), eval_acces)
plt.title('b_64_test acc')
plt.show()