forked from ruotianluo/self-critical.pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patheval.py
123 lines (108 loc) · 5.22 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import numpy as np
import time
import os
from six.moves import cPickle
import opts
import models
from dataloader import *
from dataloaderraw import *
import eval_utils
import argparse
import misc.utils as utils
import torch
# Input arguments and options
parser = argparse.ArgumentParser()
# Input paths
parser.add_argument('--model', type=str, default='',
help='path to model to evaluate')
parser.add_argument('--infos_path', type=str, default='',
help='path to infos to evaluate')
# Basic options
parser.add_argument('--batch_size', type=int, default=0,
help='if > 0 then overrule, otherwise load from checkpoint.')
parser.add_argument('--num_images', type=int, default=-1,
help='how many images to use when periodically evaluating the loss? (-1 = all)')
parser.add_argument('--language_eval', type=int, default=0,
help='Evaluate language as well (1 = yes, 0 = no)? BLEU/CIDEr/METEOR/ROUGE_L? requires coco-caption code from Github.')
parser.add_argument('--dump_images', type=int, default=1,
help='Dump images into vis/imgs folder for vis? (1=yes,0=no)')
parser.add_argument('--dump_json', type=int, default=1,
help='Dump json with predictions into vis folder? (1=yes,0=no)')
parser.add_argument('--dump_path', type=int, default=0,
help='Write image paths along with predictions into vis json? (1=yes,0=no)')
# Sampling options
parser.add_argument('--sample_max', type=int, default=1,
help='1 = sample argmax words. 0 = sample from distributions.')
parser.add_argument('--beam_size', type=int, default=2,
help='used when sample_max = 1, indicates number of beams in beam search. Usually 2 or 3 works well. More is not better. Set this to 1 for faster runtime but a bit worse performance.')
parser.add_argument('--temperature', type=float, default=1.0,
help='temperature when sampling from distributions (i.e. when sample_max = 0). Lower = "safer" predictions.')
# For evaluation on a folder of images:
parser.add_argument('--image_folder', type=str, default='',
help='If this is nonempty then will predict on the images in this folder path')
parser.add_argument('--image_root', type=str, default='',
help='In case the image paths have to be preprended with a root path to an image folder')
# For evaluation on MSCOCO images from some split:
parser.add_argument('--input_fc_dir', type=str, default='',
help='path to the h5file containing the preprocessed dataset')
parser.add_argument('--input_att_dir', type=str, default='',
help='path to the h5file containing the preprocessed dataset')
parser.add_argument('--input_label_h5', type=str, default='',
help='path to the h5file containing the preprocessed dataset')
parser.add_argument('--input_json', type=str, default='',
help='path to the json file containing additional info and vocab. empty = fetch from model checkpoint.')
parser.add_argument('--split', type=str, default='test',
help='if running on MSCOCO images, which split to use: val|test|train')
parser.add_argument('--coco_json', type=str, default='',
help='if nonempty then use this file in DataLoaderRaw (see docs there). Used only in MSCOCO test evaluation, where we have a specific json file of only test set images.')
# misc
parser.add_argument('--id', type=str, default='evalscript',
help='an id identifying this run/job. used only if language_eval = 1 for appending to intermediate files')
opt = parser.parse_args()
# Load infos
with open(opt.infos_path) as f:
infos = cPickle.load(f)
# override and collect parameters
if len(opt.input_fc_h5) == 0:
opt.input_fc_h5 = infos['opt'].input_fc_h5
opt.input_att_h5 = infos['opt'].input_att_h5
opt.input_label_h5 = infos['opt'].input_label_h5
if len(opt.input_json) == 0:
opt.input_json = infos['opt'].input_json
if opt.batch_size == 0:
opt.batch_size = infos['opt'].batch_size
ignore = ["id", "batch_size", "beam_size", "start_from"]
for k in vars(infos['opt']).keys():
if k not in ignore:
if k in vars(opt):
assert vars(opt)[k] == vars(infos['opt'])[k], k + ' option not consistent'
else:
vars(opt).update({k: vars(infos['opt'])[k]}) # copy over options from model
vocab = infos['vocab'] # ix -> word mapping
# Setup the model
model = models.setup(opt)
model.load_state_dict(torch.load(opt.model))
model.cuda()
model.eval()
crit = utils.LanguageModelCriterion()
# Create the Data Loader instance
if len(opt.image_folder) == 0:
loader = DataLoader(opt)
else:
loader = DataLoaderRaw({'folder_path': opt.image_folder,
'coco_json': opt.coco_json,
'batch_size': opt.batch_size})
loader.ix_to_word = infos['vocab']
# Set sample options
loss, split_predictions, lang_stats = eval_utils.eval_eval(model, crit, loader,
vars(opt))
print('loss: ', loss)
if lang_stats:
print(lang_stats)
if opt.dump_json == 1:
# dump the json
json.dump(split_predictions, open('vis/vis.json', 'w'))