Skip to content

Commit

Permalink
Add train_rl without tb
Browse files Browse the repository at this point in the history
  • Loading branch information
ruotianluo committed May 5, 2017
1 parent 3b7f1fe commit 4db8d20
Show file tree
Hide file tree
Showing 2 changed files with 176 additions and 1 deletion.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,9 @@ $ bash scripts/copy_model.sh fc fc_rl

Then
```bash
python train_rl_tb.py --caption_model fc --rnn_size 512 --batch_size 10 --seq_per_img 5 --input_encoding_size 512 --train_only 0 --id fc_rl --input_json data/cocotalk.json --input_fc_h5 data/cocotalk_fc.h5 --input_att_h5 data/cocotalk_att.h5 --input_label_h5 data/cocotalk_label.h5 --beam_size 1 --learning_rate 5e-5 --optim adam --optim_alpha 0.9 --optim_beta 0.999 --checkpoint_path log_fc_rl --start_from log_fc_rl --save_checkpoint_every 5000 --language_eval 1 --val_images_use 5000
python train_rl.py --caption_model fc --rnn_size 512 --batch_size 10 --seq_per_img 5 --input_encoding_size 512 --train_only 0 --id fc_rl --input_json data/cocotalk.json --input_fc_h5 data/cocotalk_fc.h5 --input_att_h5 data/cocotalk_att.h5 --input_label_h5 data/cocotalk_label.h5 --beam_size 1 --learning_rate 5e-5 --optim adam --optim_alpha 0.9 --optim_beta 0.999 --checkpoint_path log_fc_rl --start_from log_fc_rl --save_checkpoint_every 5000 --language_eval 1 --val_images_use 5000
```
You can also use train_rl_tb which uses tensorboard.

### Caption images after training

Expand Down
174 changes: 174 additions & 0 deletions train_rl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim

import numpy as np

import time
import os
from six.moves import cPickle

import opts
import models
from dataloader import *
import eval_utils
import misc.utils as utils
import get_rewards

import os

def train(opt):
loader = DataLoader(opt)
opt.vocab_size = loader.vocab_size
opt.seq_length = loader.seq_length

infos = {}
if opt.start_from is not None:
# open old infos and check if models are compatible
with open(os.path.join(opt.start_from, 'infos_'+opt.id+'.pkl')) as f:
infos = cPickle.load(f)
saved_model_opt = infos['opt']
need_be_same=["caption_model", "rnn_type", "rnn_size", "num_layers"]
for checkme in need_be_same:
assert vars(saved_model_opt)[checkme] == vars(opt)[checkme], "Command line argument and saved model disagree on '%s' " % checkme

iteration = infos.get('iter', 0)
epoch = infos.get('epoch', 0)
val_result_history = infos.get('val_result_history', {})
loss_history = infos.get('loss_history', {})
lr_history = infos.get('lr_history', {})
ss_prob_history = infos.get('ss_prob_history', {})

loader.iterators = infos.get('iterators', loader.iterators)
loader.split_ix = infos.get('split_ix', loader.split_ix)
if opt.load_best_score == 1:
best_val_score = infos.get('best_val_score', None)

model = models.setup(opt)
model.cuda()

update_lr_flag = True
# Assure in training mode
model.train()

crit = utils.LanguageModelCriterion()
rl_crit = utils.RewardCriterion()

optimizer = optim.Adam(model.parameters(), lr=opt.learning_rate, weight_decay=opt.weight_decay)

# Load the optimizer
if vars(opt).get('start_from', None) is not None and os.path.isfile(os.path.join(opt.start_from,"optimizer.pth")):
optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

while True:
if update_lr_flag:
# Assign the learning rate
if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every
decay_factor = opt.learning_rate_decay_rate ** frac
opt.current_lr = opt.learning_rate * decay_factor
utils.set_lr(optimizer, opt.current_lr) # set the decayed rate
else:
opt.current_lr = opt.learning_rate
update_lr_flag = False

start = time.time()
# Load data from train split (0)
data = loader.get_batch('train')
print('Read data:', time.time() - start)

torch.cuda.synchronize()
start = time.time()

tmp = [data['fc_feats'], data['att_feats']]
tmp = [Variable(torch.from_numpy(_), requires_grad=False).cuda() for _ in tmp]
fc_feats, att_feats = tmp

optimizer.zero_grad()

gen_result, sample_logprobs = model.sample(fc_feats, att_feats, {'sample_max':0})

rewards = get_rewards.get_self_critical_reward(model, fc_feats, att_feats, data, gen_result)
loss = rl_crit(sample_logprobs, gen_result, Variable(torch.from_numpy(rewards).float().cuda(), requires_grad=False))

loss.backward()
utils.clip_gradient(optimizer, opt.grad_clip)
optimizer.step()
train_loss = loss.data[0]
torch.cuda.synchronize()
end = time.time()
print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \
.format(iteration, epoch, np.mean(rewards[:,0]), end - start))

# Update the iteration and epoch
iteration += 1
if data['bounds']['wrapped']:
epoch += 1
update_lr_flag = True

# Write the training loss summary
if (iteration % opt.losses_log_every == 0):
loss_history[iteration] = np.mean(rewards[:,0])
lr_history[iteration] = opt.current_lr

# make evaluation on validation set, and save model
if (iteration % opt.save_checkpoint_every == 0):
# eval model
eval_kwargs = {'split': 'val',
'dataset': opt.input_json}
eval_kwargs.update(vars(opt))
val_loss, predictions, lang_stats = eval_utils.eval_split(model, crit, loader, eval_kwargs)

# Write validation result into summary
val_result_history[iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions}

# Save model if is improving on validation result
if opt.language_eval == 1:
current_score = lang_stats['CIDEr']
else:
current_score = - val_loss

best_flag = False
if True: # if true
if best_val_score is None or current_score > best_val_score:
best_val_score = current_score
best_flag = True
checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth')
torch.save(model.state_dict(), checkpoint_path)
print("model saved to {}".format(checkpoint_path))
optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth')
torch.save(optimizer.state_dict(), optimizer_path)

# Dump miscalleous informations
infos['iter'] = iteration
infos['epoch'] = epoch
infos['iterators'] = loader.iterators
infos['split_ix'] = loader.split_ix
infos['best_val_score'] = best_val_score
infos['opt'] = opt
infos['val_result_history'] = val_result_history
infos['loss_history'] = loss_history
infos['lr_history'] = lr_history
infos['ss_prob_history'] = ss_prob_history
infos['vocab'] = loader.get_vocab()
with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'.pkl'), 'wb') as f:
cPickle.dump(infos, f)

if best_flag:
checkpoint_path = os.path.join(opt.checkpoint_path, 'model-best.pth')
torch.save(model.state_dict(), checkpoint_path)
print("model saved to {}".format(checkpoint_path))
with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'-best.pkl'), 'wb') as f:
cPickle.dump(infos, f)

# Stop if reaching max epochs
if epoch >= opt.max_epochs and opt.max_epochs != -1:
break

opt = opts.parse_opt()
train(opt)

0 comments on commit 4db8d20

Please sign in to comment.