Skip to content

Commit

Permalink
Add self critical.
Browse files Browse the repository at this point in the history
  • Loading branch information
ruotianluo committed May 5, 2017
1 parent 9c93471 commit 6a5db51
Show file tree
Hide file tree
Showing 9 changed files with 744 additions and 14 deletions.
35 changes: 27 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# Neuraltalk2-pytorch
# Self-critical Sequence Training for Image Captioning

There's something difference compared to neuraltalk2.
- Instead of using random split, we use [karpathy's split](http://cs.stanford.edu/people/karpathy/deepimagesent/caption_datasets.zip).
- Instead of including the convnet in the model, we use preprocessed features. (finetuneable cnn version is in the branch **with_finetune**)
- Use resnet101; the same way as in self-critical (the preprocessing code may have bug, haven't tested yet)
This is an unofficial implementation for [Self-critical Sequence Training for Image Captioning](https://arxiv.org/abs/1612.00563). The result of FC model can be replicated. (Not able to replicate Att2in result.)

# TODO:
- Other models
The author helped me a lot when I tried to replicate the result. Great thanks.

This is based on my [neuraltalk2.pytorch](https://github.com/ruotianluo/neuraltalk2.pytorch) repository. The modifications are:
- Add FC model(as in the paper)
- Add self critical training.

# Requirements
Python 2.7 (may work for python 3), pytorch
Expand Down Expand Up @@ -39,7 +39,7 @@ This is telling the script to read in all the data (the images and the captions)
**(Copy end.)**

```bash
$ python train.py --input_json coco/cocotalk.json --input_json --input_fc_dir data/cocotalk_fc --input_att_dir data/cocotalk_att --input_label_h5 data/cocotalk_label.h5 --beam_size 1 --learning_rate 5e-4 --learning_rate_decay_start 0 --scheduled_sampling_start 0 --save_checkpoint_every 6000 --val_images_use 5000
$ python train.py --input_json coco/cocotalk.json --input_fc_dir data/cocotalk_fc --input_att_dir data/cocotalk_att --input_label_h5 data/cocotalk_label.h5 --id fc --caption_model fc --beam_size 1 --learning_rate 5e-4 --learning_rate_decay_start 0 --scheduled_sampling_start 0 --save_checkpoint_every 6000 --val_images_use 5000 --checkpoint_path log_fc
```

The train script will take over, and start dumping checkpoints into the folder specified by `checkpoint_path` (default = current folder). For more options, see `opts.py`.
Expand All @@ -52,6 +52,25 @@ If you'd like to evaluate BLEU/METEOR/CIDEr scores during training in addition t

**A few notes on training.** To give you an idea, with the default settings one epoch of MS COCO images is about 7500 iterations. 1 epoch of training (with no finetuning - notice this is the default) takes about 15 minutes and results in validation loss ~2.7 and CIDEr score of ~0.5. ~~By iteration 50,000 CIDEr climbs up to about 0.65 (validation loss at about 2.4).~~

# Train using self critical

First you should preprocess the dataset and get the cache for calculating cider score:
```
$ python scripts/prepro_ngrams.py --input_json .../dataset_coco.json --dict_json data/cocotalk.json --output_pkl data/coco-train --split train
```

And also you need to clone my forked [cider](https://github.com/ruotianluo/cider) repository.

Then, copy the model from the pretrained model (trained by cross entropy).
```
$ bash scripts/copy_model.sh fc fc_rl
```

Then
```bash
python train_rl_tb.py --caption_model fc --rnn_size 512 --batch_size 10 --seq_per_img 5 --input_encoding_size 512 --train_only 0 --id fc_rl --input_json data/cocotalk.json --input_fc_h5 data/cocotalk_fc.h5 --input_att_h5 data/cocotalk_att.h5 --input_label_h5 data/cocotalk_label.h5 --beam_size 1 --learning_rate 5e-5 --optim adam --optim_alpha 0.9 --optim_beta 0.999 --checkpoint_path log_fc_rl --start_from log_fc_rl --save_checkpoint_every 5000 --language_eval 1 --val_images_use 5000
```

### Caption images after training

## Evaluate on raw images(not ready yet)
Expand Down
8 changes: 2 additions & 6 deletions eval_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,8 @@

def language_eval(dataset, preds, model_id, split):
import sys
if 'coco' in dataset:
sys.path.append("coco-caption")
annFile = 'coco-caption/annotations/captions_val2014.json'
else:
sys.path.append("f30k-caption")
annFile = 'f30k-caption/annotations/dataset_flickr30k.json'
sys.path.append("coco-caption")
annFile = 'coco-caption/annotations/captions_val2014.json'
from pycocotools.coco import COCO
from pycocoevalcap.eval import COCOEvalCap

Expand Down
67 changes: 67 additions & 0 deletions get_rewards.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import time
import misc.utils as utils
from collections import OrderedDict
import torch
from torch.autograd import Variable

import sys
sys.path.append("cider")
from pyciderevalcap.ciderD.ciderD import CiderD
#from pyciderevalcap.cider.cider import Cider

CiderD_scorer = CiderD(df='coco-train-idxs')
#CiderD_scorer = CiderD(df='corpus')

def array_to_str(arr):
out = ''
for i in range(len(arr)):
out += str(arr[i]) + ' '
if arr[i] == 0:
break
return out.strip()

def get_self_critical_reward(model, fc_feats, att_feats, data, gen_result):
batch_size = gen_result.size(0)# batch_size = sample_size * seq_per_img
seq_per_img = batch_size // len(data['gts'])

# get greedy decoding baseline
greedy_res, _ = model.sample(Variable(fc_feats.data, volatile=True), Variable(att_feats.data, volatile=True))

res = OrderedDict()

gen_result = gen_result.cpu().numpy()
greedy_res = greedy_res.cpu().numpy()
for i in range(batch_size):
res[i] = [array_to_str(gen_result[i])]
for i in range(batch_size):
res[batch_size + i] = [array_to_str(greedy_res[i])]

gts = OrderedDict()
for i in range(len(data['gts'])):
gts[i] = [array_to_str(data['gts'][i][j]) for j in range(len(data['gts'][i]))]

#_, scores = Bleu(4).compute_score(gts, res)
#scores = np.array(scores[3])
res = [{'image_id':i, 'caption': res[i]} for i in range(2 * batch_size)]
gts = {i: gts[i % batch_size // seq_per_img] for i in range(2 * batch_size)}
_, scores = CiderD_scorer.compute_score(gts, res)
print('Cider scores:', _)

scores = scores[:batch_size] - scores[batch_size:]

#rewards = np.ones((batch_size, gen_result.shape[1])) * np.inf

rewards = np.repeat(scores[:, np.newaxis], gen_result.shape[1], 1)

# for i in range(batch_size):
# for j in range(gen_result.shape[1]):
# rewards[i, j] = scores[i]
# if gen_result[i, j] == 0:
# break

return rewards
252 changes: 252 additions & 0 deletions misc/FCModel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,252 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import *
import misc.utils as utils

class LSTMCore(nn.Module):
def __init__(self, opt):
super(LSTMCore, self).__init__()
self.input_encoding_size = opt.input_encoding_size
self.rnn_size = opt.rnn_size
self.drop_prob_lm = opt.drop_prob_lm

# Build a LSTM
self.i2h = nn.Linear(self.input_encoding_size, 5 * self.rnn_size)
self.h2h = nn.Linear(self.rnn_size, 5 * self.rnn_size)
self.dropout = nn.Dropout(self.drop_prob_lm)

def forward(self, xt, state):

all_input_sums = self.i2h(xt) + self.h2h(state[0][-1])
sigmoid_chunk = all_input_sums.narrow(1, 0, 3 * self.rnn_size)
sigmoid_chunk = F.sigmoid(sigmoid_chunk)
in_gate = sigmoid_chunk.narrow(1, 0, self.rnn_size)
forget_gate = sigmoid_chunk.narrow(1, self.rnn_size, self.rnn_size)
out_gate = sigmoid_chunk.narrow(1, self.rnn_size * 2, self.rnn_size)

in_transform = torch.max(\
all_input_sums.narrow(1, 3 * self.rnn_size, self.rnn_size),
all_input_sums.narrow(1, 4 * self.rnn_size, self.rnn_size))
next_c = forget_gate * state[1][-1] + in_gate * in_transform
next_h = out_gate * F.tanh(next_c)

next_h = self.dropout(next_h)

output = next_h
state = (next_h.unsqueeze(0), next_c.unsqueeze(0))
return output, state

class FCModel(nn.Module):
def __init__(self, opt):
super(FCModel, self).__init__()
self.vocab_size = opt.vocab_size
self.input_encoding_size = opt.input_encoding_size
self.rnn_type = opt.rnn_type
self.rnn_size = opt.rnn_size
self.num_layers = opt.num_layers
self.drop_prob_lm = opt.drop_prob_lm
self.seq_length = opt.seq_length
self.fc_feat_size = opt.fc_feat_size

self.ss_prob = 0.0 # Schedule sampling probability

self.img_embed = nn.Linear(self.fc_feat_size, self.input_encoding_size)
self.core = LSTMCore(opt)
self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size)
self.logit = nn.Linear(self.rnn_size, self.vocab_size + 1)

self.init_weights()

def init_weights(self):
initrange = 0.1
self.embed.weight.data.uniform_(-initrange, initrange)
self.logit.bias.data.fill_(0)
self.logit.weight.data.uniform_(-initrange, initrange)

def init_hidden(self, bsz):
weight = next(self.parameters()).data
if self.rnn_type == 'lstm':
return (Variable(weight.new(self.num_layers, bsz, self.rnn_size).zero_()),
Variable(weight.new(self.num_layers, bsz, self.rnn_size).zero_()))
else:
return Variable(weight.new(self.num_layers, bsz, self.rnn_size).zero_())

def forward(self, fc_feats, att_feats, seq):
batch_size = fc_feats.size(0)
state = self.init_hidden(batch_size)
outputs = []

for i in range(seq.size(1)):
if i == 0:
xt = self.img_embed(fc_feats)
else:
if i >= 2 and self.ss_prob > 0.0: # otherwiste no need to sample
sample_prob = fc_feats.data.new(batch_size).uniform_(0, 1)
sample_mask = sample_prob < self.ss_prob
if sample_mask.sum() == 0:
it = seq[:, i-1].clone()
else:
sample_ind = sample_mask.nonzero().view(-1)
it = seq[:, i-1].data.clone()
#prob_prev = torch.exp(outputs[-1].data.index_select(0, sample_ind)) # fetch prev distribution: shape Nx(M+1)
#it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1))
prob_prev = torch.exp(outputs[-1].data) # fetch prev distribution: shape Nx(M+1)
it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind))
it = Variable(it, requires_grad=False)
else:
it = seq[:, i-1].clone()
# break if all the sequences end
if i >= 2 and seq[:, i-1].data.sum() == 0:
break
xt = self.embed(it)

output, state = self.core(xt, state)
output = F.log_softmax(self.logit(output))
outputs.append(output)

return torch.cat([_.unsqueeze(1) for _ in outputs[1:]], 1).contiguous()

def sample_beam(self, fc_feats, att_feats, opt={}):
beam_size = opt.get('beam_size', 10)
batch_size = fc_feats.size(0)

assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed'
seq = torch.LongTensor(self.seq_length, batch_size).zero_()
seqLogprobs = torch.FloatTensor(self.seq_length, batch_size)
# lets process every image independently for now, for simplicity

self.done_beams = [[] for _ in range(batch_size)]
for k in range(batch_size):
state = self.init_hidden(beam_size)

beam_seq = torch.LongTensor(self.seq_length, beam_size).zero_()
beam_seq_logprobs = torch.FloatTensor(self.seq_length, beam_size).zero_()
beam_logprobs_sum = torch.zeros(beam_size) # running sum of logprobs for each beam
for t in range(self.seq_length + 2):
if t == 0:
xt = self.img_embed(fc_feats[k:k+1]).expand(beam_size, self.input_encoding_size)
elif t == 1: # input <bos>
it = fc_feats.data.new(beam_size).long().zero_()
xt = self.embed(Variable(it, requires_grad=False))
else:
"""perform a beam merge. that is,
for every previous beam we now many new possibilities to branch out
we need to resort our beams to maintain the loop invariant of keeping
the top beam_size most likely sequences."""
logprobsf = logprobs.float() # lets go to CPU for more efficiency in indexing operations
ys,ix = torch.sort(logprobsf,1,True) # sorted array of logprobs along each previous beam (last true = descending)
candidates = []
cols = min(beam_size, ys.size(1))
rows = beam_size
if t == 2: # at first time step only the first beam is active
rows = 1
for c in range(cols):
for q in range(rows):
# compute logprob of expanding beam q with word in (sorted) position c
local_logprob = ys[q,c]
candidate_logprob = beam_logprobs_sum[q] + local_logprob
candidates.append({'c':ix.data[q,c], 'q':q, 'p':candidate_logprob.data[0], 'r':local_logprob.data[0]})
candidates = sorted(candidates, key=lambda x: -x['p'])

# construct new beams
new_state = [_.clone() for _ in state]
if t > 2:
# well need these as reference when we fork beams around
beam_seq_prev = beam_seq[:t-2].clone()
beam_seq_logprobs_prev = beam_seq_logprobs[:t-2].clone()
for vix in range(beam_size):
v = candidates[vix]
# fork beam index q into index vix
if t > 2:
beam_seq[:t-2, vix] = beam_seq_prev[:, v['q']]
beam_seq_logprobs[:t-2, vix] = beam_seq_logprobs_prev[:, v['q']]

# rearrange recurrent states
for state_ix in range(len(new_state)):
# copy over state in previous beam q to new beam at vix
new_state[state_ix][0, vix] = state[state_ix][0, v['q']] # dimension one is time step

# append new end terminal at the end of this beam
beam_seq[t-2, vix] = v['c'] # c'th word is the continuation
beam_seq_logprobs[t-2, vix] = v['r'] # the raw logprob here
beam_logprobs_sum[vix] = v['p'] # the new (sum) logprob along this beam

if v['c'] == 0 or t == self.seq_length + 1:
# END token special case here, or we reached the end.
# add the beam to a set of done beams
self.done_beams[k].append({'seq': beam_seq[:, vix].clone(),
'logps': beam_seq_logprobs[:, vix].clone(),
'p': beam_logprobs_sum[vix]
})

# encode as vectors
it = beam_seq[t-2]
xt = self.embed(Variable(it.cuda()))

if t >= 2:
state = new_state

output, state = self.core(xt, state)
logprobs = F.log_softmax(self.logit(output))

self.done_beams[k] = sorted(self.done_beams[k], key=lambda x: -x['p'])
seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score
seqLogprobs[:, k] = self.done_beams[k][0]['logps']
# return the samples and their log likelihoods
return seq.transpose(0, 1), seqLogprobs.transpose(0, 1)

def sample(self, fc_feats, att_feats, opt={}):
sample_max = opt.get('sample_max', 1)
beam_size = opt.get('beam_size', 1)
temperature = opt.get('temperature', 1.0)
if beam_size > 1:
return self.sample_beam(fc_feats, att_feats, opt)

batch_size = fc_feats.size(0)
state = self.init_hidden(batch_size)
seq = []
seqLogprobs = []
for t in range(self.seq_length + 2):
if t == 0:
xt = self.img_embed(fc_feats)
else:
if t == 1: # input <bos>
it = fc_feats.data.new(batch_size).long().zero_()
elif sample_max:
sampleLogprobs, it = torch.max(logprobs.data, 1)
it = it.view(-1).long()
else:
if temperature == 1.0:
prob_prev = torch.exp(logprobs.data).cpu() # fetch prev distribution: shape Nx(M+1)
else:
# scale logprobs by temperature
prob_prev = torch.exp(torch.div(logprobs.data, temperature)).cpu()
it = torch.multinomial(prob_prev, 1).cuda()
sampleLogprobs = logprobs.gather(1, Variable(it, requires_grad=False)) # gather the logprobs at sampled positions
it = it.view(-1).long() # and flatten indices for downstream processing

xt = self.embed(Variable(it, requires_grad=False))

if t >= 2:
# stop when all finished
if t == 2:
unfinished = it > 0
else:
unfinished = unfinished * (it > 0)
if unfinished.sum() == 0:
break
it = it * unfinished.type_as(it)
seq.append(it) #seq[t] the input of t+2 time step
seqLogprobs.append(sampleLogprobs.view(-1))

output, state = self.core(xt, state)
logprobs = F.log_softmax(self.logit(output))

return torch.cat([_.unsqueeze(1) for _ in seq], 1), torch.cat([_.unsqueeze(1) for _ in seqLogprobs], 1)


Loading

0 comments on commit 6a5db51

Please sign in to comment.