-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathtest_gpus.py
70 lines (60 loc) · 2.25 KB
/
test_gpus.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
# -----------------------------------------------------------
# Stacked Cross Attention Network implementation based on
# https://arxiv.org/abs/1803.08024.
# "Stacked Cross Attention for Image-Text Matching"
# Kuang-Huei Lee, Xi Chen, Gang Hua, Houdong Hu, Xiaodong He
#
# Writen by Kuang-Huei Lee, 2018
# ---------------------------------------------------------------
"""Training script"""
import evaluation as evaluation
from vocab import Vocabulary, deserialize_vocab
from model import SCAN
import data
import argparse
import os
import torch
import torch.nn as nn
def main():
# Hyper Parameters
parser = argparse.ArgumentParser()
parser.add_argument('--data_path', default='./data/',
help='path to datasets')
parser.add_argument('--model_path', default='./data/',
help='path to model')
parser.add_argument('--split', default='test',
help='val/test')
parser.add_argument('--gpuid', default=0., type=str,
help='gpuid')
parser.add_argument('--fold5', action='store_true',
help='fold5')
opts = parser.parse_args()
device_id = opts.gpuid
print("use GPU:", device_id)
os.environ['CUDA_VISIBLE_DEVICES']=str(device_id)
device_id = 0
torch.cuda.set_device(0)
# load model and options
checkpoint = torch.load(opts.model_path)
opt = checkpoint['opt']
opt.loss_verbose = False
opt.split = opts.split
opt.data_path = opts.data_path
opt.fold5 = opts.fold5
# load vocabulary used by the model
vocab = deserialize_vocab(os.path.join(opt.vocab_path, '%s_vocab.json' % opt.data_name))
opt.vocab_size = len(vocab)
# construct model
model = SCAN(opt)
model.cuda()
model = nn.DataParallel(model)
# load model state
model.load_state_dict(checkpoint['model'])
print('Loading dataset')
data_loader = data.get_test_loader(opt.split, opt.data_name, vocab,
opt.batch_size, opt.workers, opt)
print(opt)
print('Computing results...')
evaluation.evalrank(model.module, data_loader, opt, split=opt.split, fold5=opt.fold5)
if __name__ == '__main__':
main()