diff --git a/README.md b/README.md index 4d89bd60..279c4e7e 100644 --- a/README.md +++ b/README.md @@ -68,8 +68,9 @@ $ bash scripts/copy_model.sh fc fc_rl Then ```bash -python train_rl_tb.py --caption_model fc --rnn_size 512 --batch_size 10 --seq_per_img 5 --input_encoding_size 512 --train_only 0 --id fc_rl --input_json data/cocotalk.json --input_fc_h5 data/cocotalk_fc.h5 --input_att_h5 data/cocotalk_att.h5 --input_label_h5 data/cocotalk_label.h5 --beam_size 1 --learning_rate 5e-5 --optim adam --optim_alpha 0.9 --optim_beta 0.999 --checkpoint_path log_fc_rl --start_from log_fc_rl --save_checkpoint_every 5000 --language_eval 1 --val_images_use 5000 +python train_rl.py --caption_model fc --rnn_size 512 --batch_size 10 --seq_per_img 5 --input_encoding_size 512 --train_only 0 --id fc_rl --input_json data/cocotalk.json --input_fc_h5 data/cocotalk_fc.h5 --input_att_h5 data/cocotalk_att.h5 --input_label_h5 data/cocotalk_label.h5 --beam_size 1 --learning_rate 5e-5 --optim adam --optim_alpha 0.9 --optim_beta 0.999 --checkpoint_path log_fc_rl --start_from log_fc_rl --save_checkpoint_every 5000 --language_eval 1 --val_images_use 5000 ``` +You can also use train_rl_tb which uses tensorboard. ### Caption images after training diff --git a/train_rl.py b/train_rl.py new file mode 100644 index 00000000..d211d04f --- /dev/null +++ b/train_rl.py @@ -0,0 +1,174 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import torch +import torch.nn as nn +from torch.autograd import Variable +import torch.optim as optim + +import numpy as np + +import time +import os +from six.moves import cPickle + +import opts +import models +from dataloader import * +import eval_utils +import misc.utils as utils +import get_rewards + +import os + +def train(opt): + loader = DataLoader(opt) + opt.vocab_size = loader.vocab_size + opt.seq_length = loader.seq_length + + infos = {} + if opt.start_from is not None: + # open old infos and check if models are compatible + with open(os.path.join(opt.start_from, 'infos_'+opt.id+'.pkl')) as f: + infos = cPickle.load(f) + saved_model_opt = infos['opt'] + need_be_same=["caption_model", "rnn_type", "rnn_size", "num_layers"] + for checkme in need_be_same: + assert vars(saved_model_opt)[checkme] == vars(opt)[checkme], "Command line argument and saved model disagree on '%s' " % checkme + + iteration = infos.get('iter', 0) + epoch = infos.get('epoch', 0) + val_result_history = infos.get('val_result_history', {}) + loss_history = infos.get('loss_history', {}) + lr_history = infos.get('lr_history', {}) + ss_prob_history = infos.get('ss_prob_history', {}) + + loader.iterators = infos.get('iterators', loader.iterators) + loader.split_ix = infos.get('split_ix', loader.split_ix) + if opt.load_best_score == 1: + best_val_score = infos.get('best_val_score', None) + + model = models.setup(opt) + model.cuda() + + update_lr_flag = True + # Assure in training mode + model.train() + + crit = utils.LanguageModelCriterion() + rl_crit = utils.RewardCriterion() + + optimizer = optim.Adam(model.parameters(), lr=opt.learning_rate, weight_decay=opt.weight_decay) + + # Load the optimizer + if vars(opt).get('start_from', None) is not None and os.path.isfile(os.path.join(opt.start_from,"optimizer.pth")): + optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, 'optimizer.pth'))) + + while True: + if update_lr_flag: + # Assign the learning rate + if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0: + frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every + decay_factor = opt.learning_rate_decay_rate ** frac + opt.current_lr = opt.learning_rate * decay_factor + utils.set_lr(optimizer, opt.current_lr) # set the decayed rate + else: + opt.current_lr = opt.learning_rate + update_lr_flag = False + + start = time.time() + # Load data from train split (0) + data = loader.get_batch('train') + print('Read data:', time.time() - start) + + torch.cuda.synchronize() + start = time.time() + + tmp = [data['fc_feats'], data['att_feats']] + tmp = [Variable(torch.from_numpy(_), requires_grad=False).cuda() for _ in tmp] + fc_feats, att_feats = tmp + + optimizer.zero_grad() + + gen_result, sample_logprobs = model.sample(fc_feats, att_feats, {'sample_max':0}) + + rewards = get_rewards.get_self_critical_reward(model, fc_feats, att_feats, data, gen_result) + loss = rl_crit(sample_logprobs, gen_result, Variable(torch.from_numpy(rewards).float().cuda(), requires_grad=False)) + + loss.backward() + utils.clip_gradient(optimizer, opt.grad_clip) + optimizer.step() + train_loss = loss.data[0] + torch.cuda.synchronize() + end = time.time() + print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \ + .format(iteration, epoch, np.mean(rewards[:,0]), end - start)) + + # Update the iteration and epoch + iteration += 1 + if data['bounds']['wrapped']: + epoch += 1 + update_lr_flag = True + + # Write the training loss summary + if (iteration % opt.losses_log_every == 0): + loss_history[iteration] = np.mean(rewards[:,0]) + lr_history[iteration] = opt.current_lr + + # make evaluation on validation set, and save model + if (iteration % opt.save_checkpoint_every == 0): + # eval model + eval_kwargs = {'split': 'val', + 'dataset': opt.input_json} + eval_kwargs.update(vars(opt)) + val_loss, predictions, lang_stats = eval_utils.eval_split(model, crit, loader, eval_kwargs) + + # Write validation result into summary + val_result_history[iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions} + + # Save model if is improving on validation result + if opt.language_eval == 1: + current_score = lang_stats['CIDEr'] + else: + current_score = - val_loss + + best_flag = False + if True: # if true + if best_val_score is None or current_score > best_val_score: + best_val_score = current_score + best_flag = True + checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth') + torch.save(model.state_dict(), checkpoint_path) + print("model saved to {}".format(checkpoint_path)) + optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth') + torch.save(optimizer.state_dict(), optimizer_path) + + # Dump miscalleous informations + infos['iter'] = iteration + infos['epoch'] = epoch + infos['iterators'] = loader.iterators + infos['split_ix'] = loader.split_ix + infos['best_val_score'] = best_val_score + infos['opt'] = opt + infos['val_result_history'] = val_result_history + infos['loss_history'] = loss_history + infos['lr_history'] = lr_history + infos['ss_prob_history'] = ss_prob_history + infos['vocab'] = loader.get_vocab() + with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'.pkl'), 'wb') as f: + cPickle.dump(infos, f) + + if best_flag: + checkpoint_path = os.path.join(opt.checkpoint_path, 'model-best.pth') + torch.save(model.state_dict(), checkpoint_path) + print("model saved to {}".format(checkpoint_path)) + with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'-best.pkl'), 'wb') as f: + cPickle.dump(infos, f) + + # Stop if reaching max epochs + if epoch >= opt.max_epochs and opt.max_epochs != -1: + break + +opt = opts.parse_opt() +train(opt)