forked from ruotianluo/self-critical.pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
3b7f1fe
commit 4db8d20
Showing
2 changed files
with
176 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import torch | ||
import torch.nn as nn | ||
from torch.autograd import Variable | ||
import torch.optim as optim | ||
|
||
import numpy as np | ||
|
||
import time | ||
import os | ||
from six.moves import cPickle | ||
|
||
import opts | ||
import models | ||
from dataloader import * | ||
import eval_utils | ||
import misc.utils as utils | ||
import get_rewards | ||
|
||
import os | ||
|
||
def train(opt): | ||
loader = DataLoader(opt) | ||
opt.vocab_size = loader.vocab_size | ||
opt.seq_length = loader.seq_length | ||
|
||
infos = {} | ||
if opt.start_from is not None: | ||
# open old infos and check if models are compatible | ||
with open(os.path.join(opt.start_from, 'infos_'+opt.id+'.pkl')) as f: | ||
infos = cPickle.load(f) | ||
saved_model_opt = infos['opt'] | ||
need_be_same=["caption_model", "rnn_type", "rnn_size", "num_layers"] | ||
for checkme in need_be_same: | ||
assert vars(saved_model_opt)[checkme] == vars(opt)[checkme], "Command line argument and saved model disagree on '%s' " % checkme | ||
|
||
iteration = infos.get('iter', 0) | ||
epoch = infos.get('epoch', 0) | ||
val_result_history = infos.get('val_result_history', {}) | ||
loss_history = infos.get('loss_history', {}) | ||
lr_history = infos.get('lr_history', {}) | ||
ss_prob_history = infos.get('ss_prob_history', {}) | ||
|
||
loader.iterators = infos.get('iterators', loader.iterators) | ||
loader.split_ix = infos.get('split_ix', loader.split_ix) | ||
if opt.load_best_score == 1: | ||
best_val_score = infos.get('best_val_score', None) | ||
|
||
model = models.setup(opt) | ||
model.cuda() | ||
|
||
update_lr_flag = True | ||
# Assure in training mode | ||
model.train() | ||
|
||
crit = utils.LanguageModelCriterion() | ||
rl_crit = utils.RewardCriterion() | ||
|
||
optimizer = optim.Adam(model.parameters(), lr=opt.learning_rate, weight_decay=opt.weight_decay) | ||
|
||
# Load the optimizer | ||
if vars(opt).get('start_from', None) is not None and os.path.isfile(os.path.join(opt.start_from,"optimizer.pth")): | ||
optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, 'optimizer.pth'))) | ||
|
||
while True: | ||
if update_lr_flag: | ||
# Assign the learning rate | ||
if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0: | ||
frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every | ||
decay_factor = opt.learning_rate_decay_rate ** frac | ||
opt.current_lr = opt.learning_rate * decay_factor | ||
utils.set_lr(optimizer, opt.current_lr) # set the decayed rate | ||
else: | ||
opt.current_lr = opt.learning_rate | ||
update_lr_flag = False | ||
|
||
start = time.time() | ||
# Load data from train split (0) | ||
data = loader.get_batch('train') | ||
print('Read data:', time.time() - start) | ||
|
||
torch.cuda.synchronize() | ||
start = time.time() | ||
|
||
tmp = [data['fc_feats'], data['att_feats']] | ||
tmp = [Variable(torch.from_numpy(_), requires_grad=False).cuda() for _ in tmp] | ||
fc_feats, att_feats = tmp | ||
|
||
optimizer.zero_grad() | ||
|
||
gen_result, sample_logprobs = model.sample(fc_feats, att_feats, {'sample_max':0}) | ||
|
||
rewards = get_rewards.get_self_critical_reward(model, fc_feats, att_feats, data, gen_result) | ||
loss = rl_crit(sample_logprobs, gen_result, Variable(torch.from_numpy(rewards).float().cuda(), requires_grad=False)) | ||
|
||
loss.backward() | ||
utils.clip_gradient(optimizer, opt.grad_clip) | ||
optimizer.step() | ||
train_loss = loss.data[0] | ||
torch.cuda.synchronize() | ||
end = time.time() | ||
print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \ | ||
.format(iteration, epoch, np.mean(rewards[:,0]), end - start)) | ||
|
||
# Update the iteration and epoch | ||
iteration += 1 | ||
if data['bounds']['wrapped']: | ||
epoch += 1 | ||
update_lr_flag = True | ||
|
||
# Write the training loss summary | ||
if (iteration % opt.losses_log_every == 0): | ||
loss_history[iteration] = np.mean(rewards[:,0]) | ||
lr_history[iteration] = opt.current_lr | ||
|
||
# make evaluation on validation set, and save model | ||
if (iteration % opt.save_checkpoint_every == 0): | ||
# eval model | ||
eval_kwargs = {'split': 'val', | ||
'dataset': opt.input_json} | ||
eval_kwargs.update(vars(opt)) | ||
val_loss, predictions, lang_stats = eval_utils.eval_split(model, crit, loader, eval_kwargs) | ||
|
||
# Write validation result into summary | ||
val_result_history[iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions} | ||
|
||
# Save model if is improving on validation result | ||
if opt.language_eval == 1: | ||
current_score = lang_stats['CIDEr'] | ||
else: | ||
current_score = - val_loss | ||
|
||
best_flag = False | ||
if True: # if true | ||
if best_val_score is None or current_score > best_val_score: | ||
best_val_score = current_score | ||
best_flag = True | ||
checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth') | ||
torch.save(model.state_dict(), checkpoint_path) | ||
print("model saved to {}".format(checkpoint_path)) | ||
optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth') | ||
torch.save(optimizer.state_dict(), optimizer_path) | ||
|
||
# Dump miscalleous informations | ||
infos['iter'] = iteration | ||
infos['epoch'] = epoch | ||
infos['iterators'] = loader.iterators | ||
infos['split_ix'] = loader.split_ix | ||
infos['best_val_score'] = best_val_score | ||
infos['opt'] = opt | ||
infos['val_result_history'] = val_result_history | ||
infos['loss_history'] = loss_history | ||
infos['lr_history'] = lr_history | ||
infos['ss_prob_history'] = ss_prob_history | ||
infos['vocab'] = loader.get_vocab() | ||
with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'.pkl'), 'wb') as f: | ||
cPickle.dump(infos, f) | ||
|
||
if best_flag: | ||
checkpoint_path = os.path.join(opt.checkpoint_path, 'model-best.pth') | ||
torch.save(model.state_dict(), checkpoint_path) | ||
print("model saved to {}".format(checkpoint_path)) | ||
with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'-best.pkl'), 'wb') as f: | ||
cPickle.dump(infos, f) | ||
|
||
# Stop if reaching max epochs | ||
if epoch >= opt.max_epochs and opt.max_epochs != -1: | ||
break | ||
|
||
opt = opts.parse_opt() | ||
train(opt) |