-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathpredict.py
117 lines (92 loc) · 3.02 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# torch imports
import torch
from torch import nn
from torch.autograd import Variable as var
from torch.nn.utils.rnn import pack_padded_sequence as pack
from torch.nn.utils.rnn import pad_packed_sequence as unpack
import torchvision.transforms as T
from torch.utils.data import DataLoader
# standard imports
import json
import pandas as pd
import numpy as np
from collections import namedtuple
from AmazonDataset import AmazonDataset
import os
import sys
# from models.cnn_rnn import EncoderCNN,DecoderRNN
from models.cnn_rnn_attn import EncoderCNN,DecoderRNN
from utils import draw_image
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
def parse():
config_file=sys.argv[1]
config=json.loads(open(config_file,'r').read(),object_hook=lambda d: namedtuple('X', d.keys())(*d.values()))
return config
def to_var(x, volatile=False):
if torch.cuda.is_available():
x = x.cuda(config.training.cuda_device)
return var(x, volatile=volatile)
def find_classes(label_list_file):
classes=[]
with open(label_list_file) as f:
for line in f:
classes.append(line.strip)
classes=np.array(classes)
return classes
def filename_clean(target):
return target.split('.')[0]
def main(config):
test_transform = T.Compose([
T.Resize(224),
T.CenterCrop(224),
T.ToTensor(),
T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])
test_dataset = AmazonDataset(config.paths.test_dir,
transform=test_transform,
target_transform=filename_clean,
dtype="test")
test_loader = DataLoader(test_dataset,
batch_size = config.training.batch_size,
num_workers = config.training.n_workers)
# define model
encoder = EncoderCNN()
decoder = DecoderRNN(config.model.embed_size, config.model.hidden_size, encoder.output_size, 19, config.model.total_size)
encoder.load_state_dict(torch.load(config.paths.cnn_save_path))
decoder.load_state_dict(torch.load(config.paths.rnn_save_path))
if torch.cuda.is_available():
encoder.cuda(config.training.cuda_device)
decoder.cuda(config.training.cuda_device)
encoder.eval()
decoder.eval()
filenames=[]
predictions=[]
classes=find_classes(config.paths.label_list_file)
print("Running Predictions :")
for i,(images,filename) in enumerate(test_loader):
print("Batch [%d/%d]"%((i+1),len(test_loader)))
images=to_var(images,volatile=True)
cnn_features=encoder(images)
if attention:
attn,preds=decoder.sample(cnn_features)
else:
preds=decoder.sample(cnn_features)
prediction=[]
for j in range(preds.size(0)):
pred=preds[j].data.cpu().numpy().tolist()
if 18 in pred:
pred=pred[:pred.index(18)]
prediction.append(' '.join([classes[k-1] for k in pred]))
if attention and config.training.draw_image:
draw_image(attn,filename,prediction)
filenames+=list(filename)
predictions+=prediction
submission=pd.DataFrame()
submission['image_name']=filenames
submission['tags']=predictions
submission.to_csv(config.paths.submission_file,index=False)
if __name__=='__main__':
config=parse()
attention = config.training.attention
main(config)