forked from taiyipan/TPSNet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
executable file
·119 lines (95 loc) · 3.79 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
from torchsummary import summary
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data.sampler import SubsetRandomSampler
from time import time
import traceback
import os
import argparse
from models import ResNet, BasicBlock
# calculate block count per residual layer
def block_count(depth: int) -> int:
assert (depth - 4) % 6 == 0
return (depth - 4) // 6
def get_num_blocks(depth: int) -> list:
return [block_count(depth), block_count(depth), block_count(depth)]
def make_model(k = 2, d = 82):
# instantiate model
model = ResNet(BasicBlock, get_num_blocks(d), k = k)
device = "cpu"
if torch.cuda.is_available():
device = "cuda"
print('cuda')
if torch.cuda.device_count() > 1:
print('cuda: {}'.format(torch.cuda.device_count()))
model = nn.DataParallel(model)
model.to(device)
# load best model (lowest validation loss)
try:
model.load_state_dict(torch.load('./top_models/tpsnet.pt'))
print('Model weights loaded')
except:
traceback.print_exc()
return model
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Deep Learning Project-1")
parser.add_argument('--lr', default=0.01, type=float, help='learning rate')
parser.add_argument('--resume', '-r', action='store_true', help='resume from latest checkpoint')
parser.add_argument('--epochs', '-e', type=int, default=200, help='no. of epochs')
parser.add_argument('-w','--num_workers',type=int,default=12,help='number of workers')
parser.add_argument('-b','--batch_size',type=int,default=128,help='batch_size')
args = parser.parse_args()
# hyperparams
num_workers = args.num_workers
batch_size = args.batch_size
n_epochs = args.epochs
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
test_data = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform_test)
test_loader = torch.utils.data.DataLoader(
test_data,
batch_size = batch_size,
num_workers = num_workers
)
model = make_model()
summary(model, (3, 32, 32))
# define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())
# load best model (lowest validation loss)
try:
model.load_state_dict(torch.load('./top_models/tpsnet.pt'))
print('Model weights loaded')
except:
traceback.print_exc()
# test model
test_loss = 0
total_correct = 0
total = 0
model.eval()
for data, target in test_loader:
if torch.cuda.is_available():
data, target = data.cuda(), target.cuda()
with torch.no_grad():
output = model(data)
loss = criterion(output, target)
test_loss += loss.item() * data.size(0)
# calculate accuracies
_, pred = torch.max(output, 1)
correct_tensor = pred.eq(target.data.view_as(pred))
correct = np.squeeze(correct_tensor.numpy()) if not torch.cuda.is_available() else np.squeeze(correct_tensor.cpu().numpy())
total_correct += np.sum(correct)
total += correct.shape[0]
print('total:', total)
print('total correct:', total_correct)
# calculate overall accuracy
print('Model accuracy on test dataset: {:.2f}%'.format(total_correct / total * 100))