-
Notifications
You must be signed in to change notification settings - Fork 30
/
Copy pathopts.py
116 lines (106 loc) · 7.33 KB
/
opts.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import argparse
def parse_opt():
parser = argparse.ArgumentParser()
# Data input settings
parser.add_argument('--input_json', type=str, default='data/coco.json',
help='path to the json file containing additional info and vocab')
parser.add_argument('--input_h5', type=str, default='data/coco.json',
help='path to the h5file containing the preprocessed dataset')
parser.add_argument('--cnn_model', type=str, default='vgg16',
help='vgg16 or vgg19')
parser.add_argument('--cnn_weight', type=str, default='models/vgg16.npy',
help='path to CNN tf model. Note this MUST be a vgg16 right now.')
parser.add_argument('--start_from', type=str, default=None,
help="""continue training from saved model at this path. Path must contain files saved by previous training process:
'infos.pkl' : configuration;
'checkpoint' : paths to model file(s) (created by tf).
Note: this file contains absolute paths, be careful when moving files around;
'model.ckpt-*' : file(s) with model definition (created by tf)
""")
# Model settings
parser.add_argument('--caption_model', type=str, default="show_tell",
help='show_tell, show_attend_tell, attention')
parser.add_argument('--rnn_size', type=int, default=512,
help='size of the rnn in number of hidden nodes in each layer')
parser.add_argument('--num_layers', type=int, default=1,
help='number of layers in the RNN')
parser.add_argument('--rnn_type', type=str, default='lstm',
help='rnn, gru, or lstm')
parser.add_argument('--input_encoding_size', type=int, default=512,
help='the encoding size of each token in the vocabulary, and the image.')
parser.add_argument('--att_hid_size', type=int, default=512,
help='the hidden size of the attention MLP; only useful in show_attend_tell; 0 if not using hidden layer')
# Optimization: General
parser.add_argument('--max_epochs', type=int, default=-1,
help='number of epochs')
parser.add_argument('--batch_size', type=int, default=16,
help='minibatch size')
parser.add_argument('--grad_clip', type=float, default=0.1, #5.,
help='clip gradients at this value')
parser.add_argument('--drop_prob_lm', type=float, default=0.5,
help='strength of dropout in the Language Model RNN')
parser.add_argument('--finetune_cnn_after', type=int, default=-1,
help='After what iteration do we start finetuning the CNN? (-1 = disable; never finetune, 0 = finetune from start)')
parser.add_argument('--seq_per_img', type=int, default=5,
help='number of captions to sample for each image during training. Done for efficiency since CNN forward pass is expensive. E.g. coco has 5 sents/image')
parser.add_argument('--beam_size', type=int, default=1,
help='used when sample_max = 1, indicates number of beams in beam search. Usually 2 or 3 works well. More is not better. Set this to 1 for faster runtime but a bit worse performance.')
#Optimization: for the Language Model
parser.add_argument('--optim', type=str, default='adam',
help='what update to use? rmsprop|sgd|sgdmom|adagrad|adam')
parser.add_argument('--learning_rate', type=float, default=4e-4,
help='learning rate')
parser.add_argument('--learning_rate_decay_start', type=int, default=-1,
help='at what iteration to start decaying learning rate? (-1 = dont) (in epoch)')
parser.add_argument('--learning_rate_decay_every', type=int, default=10,
help='every how many iterations thereafter to drop LR by half?(in epoch)')
parser.add_argument('--optim_alpha', type=float, default=0.8,
help='alpha for adam')
parser.add_argument('--optim_beta', type=float, default=0.999,
help='beta used for adam')
parser.add_argument('--optim_epsilon', type=float, default=1e-8,
help='epsilon that goes into denominator for smoothing')
#Optimization: for the CNN
parser.add_argument('--cnn_optim', type=str, default='adam',
help='optimization to use for CNN')
parser.add_argument('--cnn_optim_alpha', type=float, default=0.8,
help='alpha for momentum of CNN')
parser.add_argument('--cnn_optim_beta', type=float, default=0.999,
help='beta for momentum of CNN')
parser.add_argument('--cnn_learning_rate', type=float, default=1e-5,
help='learning rate for the CNN')
parser.add_argument('--cnn_weight_decay', type=float, default=0,
help='L2 weight decay just for the CNN')
# Evaluation/Checkpointing
parser.add_argument('--val_images_use', type=int, default=3200,
help='how many images to use when periodically evaluating the validation loss? (-1 = all)')
parser.add_argument('--save_checkpoint_every', type=int, default=2500,
help='how often to save a model checkpoint (in iterations)?')
parser.add_argument('--checkpoint_path', type=str, default='save',
help='directory to store checkpointed models')
parser.add_argument('--language_eval', type=int, default=0,
help='Evaluate language as well (1 = yes, 0 = no)? BLEU/CIDEr/METEOR/ROUGE_L? requires coco-caption code from Github.')
parser.add_argument('--losses_log_every', type=int, default=25,
help='How often do we snapshot losses, for inclusion in the progress dump? (0 = disable)')
parser.add_argument('--load_best_score', type=int, default=1,
help='Do we load previous best score when resuming training.')
# misc
parser.add_argument('--id', type=str, default='',
help='an id identifying this run/job. used in cross-val and appended when writing progress files')
parser.add_argument('--train_only', type=int, default=0,
help='if true then use 80k, else use 110k')
args = parser.parse_args()
# Check if args are valid
assert args.rnn_size > 0, "rnn_size should be greater than 0"
assert args.num_layers > 0, "num_layers should be greater than 0"
assert args.input_encoding_size > 0, "input_encoding_size should be greater than 0"
assert args.batch_size > 0, "batch_size should be greater than 0"
assert args.drop_prob_lm >= 0 and args.drop_prob_lm < 1, "drop_prob_lm should be between 0 and 1"
assert args.seq_per_img > 0, "seq_per_img should be greater than 0"
assert args.beam_size > 0, "beam_size should be greater than 0"
assert args.save_checkpoint_every > 0, "save_checkpoint_every should be greater than 0"
assert args.losses_log_every > 0, "losses_log_every should be greater than 0"
assert args.language_eval == 0 or args.language_eval == 1, "language_eval should be 0 or 1"
assert args.load_best_score == 0 or args.load_best_score == 1, "language_eval should be 0 or 1"
assert args.train_only == 0 or args.train_only == 1, "language_eval should be 0 or 1"
return args