forked from ruotianluo/self-critical.pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
9c93471
commit 6a5db51
Showing
9 changed files
with
744 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import numpy as np | ||
import time | ||
import misc.utils as utils | ||
from collections import OrderedDict | ||
import torch | ||
from torch.autograd import Variable | ||
|
||
import sys | ||
sys.path.append("cider") | ||
from pyciderevalcap.ciderD.ciderD import CiderD | ||
#from pyciderevalcap.cider.cider import Cider | ||
|
||
CiderD_scorer = CiderD(df='coco-train-idxs') | ||
#CiderD_scorer = CiderD(df='corpus') | ||
|
||
def array_to_str(arr): | ||
out = '' | ||
for i in range(len(arr)): | ||
out += str(arr[i]) + ' ' | ||
if arr[i] == 0: | ||
break | ||
return out.strip() | ||
|
||
def get_self_critical_reward(model, fc_feats, att_feats, data, gen_result): | ||
batch_size = gen_result.size(0)# batch_size = sample_size * seq_per_img | ||
seq_per_img = batch_size // len(data['gts']) | ||
|
||
# get greedy decoding baseline | ||
greedy_res, _ = model.sample(Variable(fc_feats.data, volatile=True), Variable(att_feats.data, volatile=True)) | ||
|
||
res = OrderedDict() | ||
|
||
gen_result = gen_result.cpu().numpy() | ||
greedy_res = greedy_res.cpu().numpy() | ||
for i in range(batch_size): | ||
res[i] = [array_to_str(gen_result[i])] | ||
for i in range(batch_size): | ||
res[batch_size + i] = [array_to_str(greedy_res[i])] | ||
|
||
gts = OrderedDict() | ||
for i in range(len(data['gts'])): | ||
gts[i] = [array_to_str(data['gts'][i][j]) for j in range(len(data['gts'][i]))] | ||
|
||
#_, scores = Bleu(4).compute_score(gts, res) | ||
#scores = np.array(scores[3]) | ||
res = [{'image_id':i, 'caption': res[i]} for i in range(2 * batch_size)] | ||
gts = {i: gts[i % batch_size // seq_per_img] for i in range(2 * batch_size)} | ||
_, scores = CiderD_scorer.compute_score(gts, res) | ||
print('Cider scores:', _) | ||
|
||
scores = scores[:batch_size] - scores[batch_size:] | ||
|
||
#rewards = np.ones((batch_size, gen_result.shape[1])) * np.inf | ||
|
||
rewards = np.repeat(scores[:, np.newaxis], gen_result.shape[1], 1) | ||
|
||
# for i in range(batch_size): | ||
# for j in range(gen_result.shape[1]): | ||
# rewards[i, j] = scores[i] | ||
# if gen_result[i, j] == 0: | ||
# break | ||
|
||
return rewards |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,252 @@ | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from torch.autograd import * | ||
import misc.utils as utils | ||
|
||
class LSTMCore(nn.Module): | ||
def __init__(self, opt): | ||
super(LSTMCore, self).__init__() | ||
self.input_encoding_size = opt.input_encoding_size | ||
self.rnn_size = opt.rnn_size | ||
self.drop_prob_lm = opt.drop_prob_lm | ||
|
||
# Build a LSTM | ||
self.i2h = nn.Linear(self.input_encoding_size, 5 * self.rnn_size) | ||
self.h2h = nn.Linear(self.rnn_size, 5 * self.rnn_size) | ||
self.dropout = nn.Dropout(self.drop_prob_lm) | ||
|
||
def forward(self, xt, state): | ||
|
||
all_input_sums = self.i2h(xt) + self.h2h(state[0][-1]) | ||
sigmoid_chunk = all_input_sums.narrow(1, 0, 3 * self.rnn_size) | ||
sigmoid_chunk = F.sigmoid(sigmoid_chunk) | ||
in_gate = sigmoid_chunk.narrow(1, 0, self.rnn_size) | ||
forget_gate = sigmoid_chunk.narrow(1, self.rnn_size, self.rnn_size) | ||
out_gate = sigmoid_chunk.narrow(1, self.rnn_size * 2, self.rnn_size) | ||
|
||
in_transform = torch.max(\ | ||
all_input_sums.narrow(1, 3 * self.rnn_size, self.rnn_size), | ||
all_input_sums.narrow(1, 4 * self.rnn_size, self.rnn_size)) | ||
next_c = forget_gate * state[1][-1] + in_gate * in_transform | ||
next_h = out_gate * F.tanh(next_c) | ||
|
||
next_h = self.dropout(next_h) | ||
|
||
output = next_h | ||
state = (next_h.unsqueeze(0), next_c.unsqueeze(0)) | ||
return output, state | ||
|
||
class FCModel(nn.Module): | ||
def __init__(self, opt): | ||
super(FCModel, self).__init__() | ||
self.vocab_size = opt.vocab_size | ||
self.input_encoding_size = opt.input_encoding_size | ||
self.rnn_type = opt.rnn_type | ||
self.rnn_size = opt.rnn_size | ||
self.num_layers = opt.num_layers | ||
self.drop_prob_lm = opt.drop_prob_lm | ||
self.seq_length = opt.seq_length | ||
self.fc_feat_size = opt.fc_feat_size | ||
|
||
self.ss_prob = 0.0 # Schedule sampling probability | ||
|
||
self.img_embed = nn.Linear(self.fc_feat_size, self.input_encoding_size) | ||
self.core = LSTMCore(opt) | ||
self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size) | ||
self.logit = nn.Linear(self.rnn_size, self.vocab_size + 1) | ||
|
||
self.init_weights() | ||
|
||
def init_weights(self): | ||
initrange = 0.1 | ||
self.embed.weight.data.uniform_(-initrange, initrange) | ||
self.logit.bias.data.fill_(0) | ||
self.logit.weight.data.uniform_(-initrange, initrange) | ||
|
||
def init_hidden(self, bsz): | ||
weight = next(self.parameters()).data | ||
if self.rnn_type == 'lstm': | ||
return (Variable(weight.new(self.num_layers, bsz, self.rnn_size).zero_()), | ||
Variable(weight.new(self.num_layers, bsz, self.rnn_size).zero_())) | ||
else: | ||
return Variable(weight.new(self.num_layers, bsz, self.rnn_size).zero_()) | ||
|
||
def forward(self, fc_feats, att_feats, seq): | ||
batch_size = fc_feats.size(0) | ||
state = self.init_hidden(batch_size) | ||
outputs = [] | ||
|
||
for i in range(seq.size(1)): | ||
if i == 0: | ||
xt = self.img_embed(fc_feats) | ||
else: | ||
if i >= 2 and self.ss_prob > 0.0: # otherwiste no need to sample | ||
sample_prob = fc_feats.data.new(batch_size).uniform_(0, 1) | ||
sample_mask = sample_prob < self.ss_prob | ||
if sample_mask.sum() == 0: | ||
it = seq[:, i-1].clone() | ||
else: | ||
sample_ind = sample_mask.nonzero().view(-1) | ||
it = seq[:, i-1].data.clone() | ||
#prob_prev = torch.exp(outputs[-1].data.index_select(0, sample_ind)) # fetch prev distribution: shape Nx(M+1) | ||
#it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1)) | ||
prob_prev = torch.exp(outputs[-1].data) # fetch prev distribution: shape Nx(M+1) | ||
it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind)) | ||
it = Variable(it, requires_grad=False) | ||
else: | ||
it = seq[:, i-1].clone() | ||
# break if all the sequences end | ||
if i >= 2 and seq[:, i-1].data.sum() == 0: | ||
break | ||
xt = self.embed(it) | ||
|
||
output, state = self.core(xt, state) | ||
output = F.log_softmax(self.logit(output)) | ||
outputs.append(output) | ||
|
||
return torch.cat([_.unsqueeze(1) for _ in outputs[1:]], 1).contiguous() | ||
|
||
def sample_beam(self, fc_feats, att_feats, opt={}): | ||
beam_size = opt.get('beam_size', 10) | ||
batch_size = fc_feats.size(0) | ||
|
||
assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed' | ||
seq = torch.LongTensor(self.seq_length, batch_size).zero_() | ||
seqLogprobs = torch.FloatTensor(self.seq_length, batch_size) | ||
# lets process every image independently for now, for simplicity | ||
|
||
self.done_beams = [[] for _ in range(batch_size)] | ||
for k in range(batch_size): | ||
state = self.init_hidden(beam_size) | ||
|
||
beam_seq = torch.LongTensor(self.seq_length, beam_size).zero_() | ||
beam_seq_logprobs = torch.FloatTensor(self.seq_length, beam_size).zero_() | ||
beam_logprobs_sum = torch.zeros(beam_size) # running sum of logprobs for each beam | ||
for t in range(self.seq_length + 2): | ||
if t == 0: | ||
xt = self.img_embed(fc_feats[k:k+1]).expand(beam_size, self.input_encoding_size) | ||
elif t == 1: # input <bos> | ||
it = fc_feats.data.new(beam_size).long().zero_() | ||
xt = self.embed(Variable(it, requires_grad=False)) | ||
else: | ||
"""perform a beam merge. that is, | ||
for every previous beam we now many new possibilities to branch out | ||
we need to resort our beams to maintain the loop invariant of keeping | ||
the top beam_size most likely sequences.""" | ||
logprobsf = logprobs.float() # lets go to CPU for more efficiency in indexing operations | ||
ys,ix = torch.sort(logprobsf,1,True) # sorted array of logprobs along each previous beam (last true = descending) | ||
candidates = [] | ||
cols = min(beam_size, ys.size(1)) | ||
rows = beam_size | ||
if t == 2: # at first time step only the first beam is active | ||
rows = 1 | ||
for c in range(cols): | ||
for q in range(rows): | ||
# compute logprob of expanding beam q with word in (sorted) position c | ||
local_logprob = ys[q,c] | ||
candidate_logprob = beam_logprobs_sum[q] + local_logprob | ||
candidates.append({'c':ix.data[q,c], 'q':q, 'p':candidate_logprob.data[0], 'r':local_logprob.data[0]}) | ||
candidates = sorted(candidates, key=lambda x: -x['p']) | ||
|
||
# construct new beams | ||
new_state = [_.clone() for _ in state] | ||
if t > 2: | ||
# well need these as reference when we fork beams around | ||
beam_seq_prev = beam_seq[:t-2].clone() | ||
beam_seq_logprobs_prev = beam_seq_logprobs[:t-2].clone() | ||
for vix in range(beam_size): | ||
v = candidates[vix] | ||
# fork beam index q into index vix | ||
if t > 2: | ||
beam_seq[:t-2, vix] = beam_seq_prev[:, v['q']] | ||
beam_seq_logprobs[:t-2, vix] = beam_seq_logprobs_prev[:, v['q']] | ||
|
||
# rearrange recurrent states | ||
for state_ix in range(len(new_state)): | ||
# copy over state in previous beam q to new beam at vix | ||
new_state[state_ix][0, vix] = state[state_ix][0, v['q']] # dimension one is time step | ||
|
||
# append new end terminal at the end of this beam | ||
beam_seq[t-2, vix] = v['c'] # c'th word is the continuation | ||
beam_seq_logprobs[t-2, vix] = v['r'] # the raw logprob here | ||
beam_logprobs_sum[vix] = v['p'] # the new (sum) logprob along this beam | ||
|
||
if v['c'] == 0 or t == self.seq_length + 1: | ||
# END token special case here, or we reached the end. | ||
# add the beam to a set of done beams | ||
self.done_beams[k].append({'seq': beam_seq[:, vix].clone(), | ||
'logps': beam_seq_logprobs[:, vix].clone(), | ||
'p': beam_logprobs_sum[vix] | ||
}) | ||
|
||
# encode as vectors | ||
it = beam_seq[t-2] | ||
xt = self.embed(Variable(it.cuda())) | ||
|
||
if t >= 2: | ||
state = new_state | ||
|
||
output, state = self.core(xt, state) | ||
logprobs = F.log_softmax(self.logit(output)) | ||
|
||
self.done_beams[k] = sorted(self.done_beams[k], key=lambda x: -x['p']) | ||
seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score | ||
seqLogprobs[:, k] = self.done_beams[k][0]['logps'] | ||
# return the samples and their log likelihoods | ||
return seq.transpose(0, 1), seqLogprobs.transpose(0, 1) | ||
|
||
def sample(self, fc_feats, att_feats, opt={}): | ||
sample_max = opt.get('sample_max', 1) | ||
beam_size = opt.get('beam_size', 1) | ||
temperature = opt.get('temperature', 1.0) | ||
if beam_size > 1: | ||
return self.sample_beam(fc_feats, att_feats, opt) | ||
|
||
batch_size = fc_feats.size(0) | ||
state = self.init_hidden(batch_size) | ||
seq = [] | ||
seqLogprobs = [] | ||
for t in range(self.seq_length + 2): | ||
if t == 0: | ||
xt = self.img_embed(fc_feats) | ||
else: | ||
if t == 1: # input <bos> | ||
it = fc_feats.data.new(batch_size).long().zero_() | ||
elif sample_max: | ||
sampleLogprobs, it = torch.max(logprobs.data, 1) | ||
it = it.view(-1).long() | ||
else: | ||
if temperature == 1.0: | ||
prob_prev = torch.exp(logprobs.data).cpu() # fetch prev distribution: shape Nx(M+1) | ||
else: | ||
# scale logprobs by temperature | ||
prob_prev = torch.exp(torch.div(logprobs.data, temperature)).cpu() | ||
it = torch.multinomial(prob_prev, 1).cuda() | ||
sampleLogprobs = logprobs.gather(1, Variable(it, requires_grad=False)) # gather the logprobs at sampled positions | ||
it = it.view(-1).long() # and flatten indices for downstream processing | ||
|
||
xt = self.embed(Variable(it, requires_grad=False)) | ||
|
||
if t >= 2: | ||
# stop when all finished | ||
if t == 2: | ||
unfinished = it > 0 | ||
else: | ||
unfinished = unfinished * (it > 0) | ||
if unfinished.sum() == 0: | ||
break | ||
it = it * unfinished.type_as(it) | ||
seq.append(it) #seq[t] the input of t+2 time step | ||
seqLogprobs.append(sampleLogprobs.view(-1)) | ||
|
||
output, state = self.core(xt, state) | ||
logprobs = F.log_softmax(self.logit(output)) | ||
|
||
return torch.cat([_.unsqueeze(1) for _ in seq], 1), torch.cat([_.unsqueeze(1) for _ in seqLogprobs], 1) | ||
|
||
|
Oops, something went wrong.