Skip to content

Commit 397f1a8

Browse files
committed
Update dataloader raw.
1 parent 503d2f5 commit 397f1a8

File tree

4 files changed

+28
-11
lines changed

4 files changed

+28
-11
lines changed

README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ There's something difference compared to neuraltalk2.
66
- Use resnet101; the same way as in self-critical (the preprocessing code may have bug, haven't tested yet)
77

88
# TODO:
9-
- eval code for arbitrary images
109
- Other models
1110

1211
# Requirements

dataloaderraw.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,27 @@
77
import os
88
import numpy as np
99
import random
10+
import torch
11+
from torch.autograd import Variable
1012
import skimage
1113
import skimage.io
1214
import scipy.misc
1315

16+
from torchvision import transforms as trn
17+
preprocess = trn.Compose([
18+
#trn.ToTensor(),
19+
trn.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
20+
])
21+
22+
from misc.resnet_utils import myResnet
23+
import misc.resnet as resnet
24+
25+
resnet = resnet.resnet101()
26+
resnet.load_state_dict(torch.load('/home-nfs/rluo/rluo/model/pytorch-resnet/resnet101.pth'))
27+
my_resnet = myResnet(resnet)
28+
my_resnet.cuda()
29+
my_resnet.eval()
30+
1431
class DataLoaderRaw():
1532

1633
def __init__(self, opt):
@@ -65,7 +82,8 @@ def get_batch(self, split, batch_size=None):
6582
batch_size = batch_size or self.batch_size
6683

6784
# pick an index of the datapoint to load next
68-
img_batch = np.ndarray([batch_size, 224,224,3], dtype = 'float32')
85+
fc_batch = np.ndarray((batch_size, 2048), dtype = 'float32')
86+
att_batch = np.ndarray((batch_size, 14, 14, 2048), dtype = 'float32')
6987
max_index = self.N
7088
wrapped = False
7189
infos = []
@@ -85,15 +103,22 @@ def get_batch(self, split, batch_size=None):
85103
img = img[:,:,np.newaxis]
86104
img = img.concatenate((img, img, img), axis=2)
87105

88-
img_batch[i] = img[16:240, 16:240, :].astype('float32')/255.0
106+
img = img.astype('float32')/255.0
107+
img = torch.from_numpy(img.transpose([2,0,1])).cuda()
108+
img = Variable(preprocess(img), volatile=True)
109+
tmp_fc, tmp_att = my_resnet(img)
110+
111+
fc_batch[i] = tmp_fc.data.cpu().float().numpy()
112+
att_batch[i] = tmp_att.data.cpu().float().numpy()
89113

90114
info_struct = {}
91115
info_struct['id'] = self.ids[ri]
92116
info_struct['file_path'] = self.files[ri]
93117
infos.append(info_struct)
94118

95119
data = {}
96-
data['images'] = img_batch
120+
data['fc_feats'] = fc_batch
121+
data['att_feats'] = att_batch
97122
data['bounds'] = {'it_pos_now': self.iterator, 'it_max': self.N, 'wrapped': wrapped}
98123
data['infos'] = infos
99124

train.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@
2121

2222
import os
2323

24-
#from ipdb import set_trace
25-
2624
def train(opt):
2725
loader = DataLoader(opt)
2826
opt.vocab_size = loader.vocab_size

train_tb.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,6 @@
2222
import misc.utils as utils
2323
import tensorflow as tf
2424

25-
import os
26-
NUM_THREADS = 2 #int(os.environ['OMP_NUM_THREADS'])
27-
28-
#from ipdb import set_trace
29-
3025
def add_summary_value(writer, key, value, iteration):
3126
summary = tf.Summary(value=[tf.Summary.Value(tag=key, simple_value=value)])
3227
writer.add_summary(summary, iteration)

0 commit comments

Comments
 (0)