-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcam_images.py
72 lines (57 loc) · 2.27 KB
/
cam_images.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm
import os
import csv
import utils
import opts
from train import grad_cam
from model import resnet, resnext, resnext_wsl, vgg_bn, densenet, inception_v3
from dataloader import TestDataset, my_transform
def main(opt):
if torch.cuda.is_available():
device = torch.device('cuda')
torch.cuda.set_device(opt.gpu_id)
else:
device = torch.device('cpu')
if opt.network == 'resnet':
model = resnet(opt.classes, opt.layers)
elif opt.network == 'resnext':
model = resnext(opt.classes, opt.layers)
elif opt.network == 'resnext_wsl':
# resnext_wsl must specify the opt.battleneck_width parameter
opt.network = 'resnext_wsl_32x' + str(opt.battleneck_width) +'d'
model = resnext_wsl(opt.classes, opt.battleneck_width)
elif opt.network == 'vgg':
model = vgg_bn(opt.classes, opt.layers)
elif opt.network == 'densenet':
model = densenet(opt.classes, opt.layers)
elif opt.network == 'inception_v3':
model = inception_v3(opt.classes, opt.layers)
model = nn.DataParallel(model, device_ids=[7, 6])
model = model.to(device)
train_data, _ = utils.read_data(
os.path.join(opt.root_dir, opt.train_dir),
os.path.join(opt.root_dir, opt.train_label),
val_num=1)
val_transforms = my_transform(False, opt.crop_size)
dataset = WeatherDataset(train_data[0], train_data[1], val_transforms)
loader = torch.utils.data.DataLoader(dataset,
batch_size=opt.batch_size,
shuffle=False, num_workers=2)
model.load_state_dict(torch.load(
opt.model_dir+'/'+opt.network+'-'+str(opt.layers)+'-'+str(crop_size)+'_model.ckpt'))
im_labels = []
for name, label in zip(im_names, labels):
im_labels.append([name, label])
header = ['filename', 'type']
utils.mkdir(opt.results_dir)
result = opt.network + '-' +str(opt.layers) + '-'+str(crop_size)+ '_result.csv'
filename = os.path.join(opt.results_dir, result)
with open(filename, 'w', encoding='utf-8') as f:
f_csv = csv.writer(f)
f_csv.writerow(header)
f_csv.writerows(im_labels)
opt = opts.parse_args()
main(opt)