-
Notifications
You must be signed in to change notification settings - Fork 7
/
mcm.py
27 lines (27 loc) · 906 Bytes
/
mcm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch
from net import MLP,CNN #
from torchvision import datasets, transforms
from sklearn.metrics import multilabel_confusion_matrix
#
test_loader = torch.utils.data.DataLoader(
datasets.FashionMNIST('./fashionmnist_data/', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=1, shuffle=True)
model=MLP()
device=torch.device('cpu')
model=model.to(device)
model.load_state_dict(torch.load('output/MLP.pt'))
model.eval()
pres=[]
labels=[]
i=0
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
pres.append(pred[0][0].item())
labels.append(target[0].item())
mcm = multilabel_confusion_matrix(labels, pres)#mcm
print(mcm)