Skip to content

Commit

Permalink
add ray
Browse files Browse the repository at this point in the history
  • Loading branch information
showkeyjar committed Dec 10, 2018
1 parent 3c62186 commit cf0f1d5
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 53 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
misc/__pycache__/
scripts/__pycache__/
misc/resnet101.pth
count.txt
153 changes: 100 additions & 53 deletions scripts/prepro_ai_challenger.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import jieba
from misc.resnet_utils import myResnet
"""
Preprocess a raw json dataset into hdf5/json files for use in data_loader.lua
Expand Down Expand Up @@ -32,9 +34,12 @@
import h5py
import numpy as np
import torch
import torch.multiprocessing as mp
from torch.multiprocessing import Pool
import torchvision.models as models
from torch.autograd import Variable
import skimage.io
import ray

from torchvision import transforms as trn

Expand All @@ -43,8 +48,6 @@
trn.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

from misc.resnet_utils import myResnet
import jieba

def prepro_captions(imgs):
# preprocess all the captions
Expand All @@ -57,7 +60,8 @@ def prepro_captions(imgs):
txt.append("".join(jieba.cut(sentence)))
#txt = str(s).lower().translate(None, string.punctuation).strip().split()
img['processed_tokens'].append(txt)
if i < 10 and j == 0: print(txt)
if i < 10 and j == 0:
print(txt)


def build_vocab(imgs, params):
Expand All @@ -80,9 +84,11 @@ def build_vocab(imgs, params):
bad_words = [w for w, n in counts.items() if n <= count_thr]
vocab = [w for w, n in counts.items() if n > count_thr]
bad_count = sum(counts[w] for w in bad_words)
print('number of bad words: %d/%d = %.2f%%' % (len(bad_words), len(counts), len(bad_words) * 100.0 / len(counts)))
print('number of bad words: %d/%d = %.2f%%' %
(len(bad_words), len(counts), len(bad_words) * 100.0 / len(counts)))
print('number of words in vocab would be %d' % (len(vocab),))
print('number of UNKs: %d/%d = %.2f%%' % (bad_count, total_words, bad_count * 100.0 / total_words))
print('number of UNKs: %d/%d = %.2f%%' %
(bad_count, total_words, bad_count * 100.0 / total_words))

# lets look at the distribution of lengths as well
sent_lengths = {}
Expand All @@ -96,7 +102,8 @@ def build_vocab(imgs, params):
print('sentence length distribution (count, number of words):')
sum_len = sum(sent_lengths.values())
for i in range(max_len + 1):
print('%2d: %10d %f%%' % (i, sent_lengths.get(i, 0), sent_lengths.get(i, 0) * 100.0 / sum_len))
print('%2d: %10d %f%%' % (i, sent_lengths.get(
i, 0), sent_lengths.get(i, 0) * 100.0 / sum_len))

# lets now produce the final annotations
if bad_count > 0:
Expand All @@ -123,16 +130,16 @@ def build_vocab(imgs, params):


def assign_splits(imgs, params):
num_val = params['num_val']
num_test = params['num_test']
num_val = params['num_val']
num_test = params['num_test']

for i,img in enumerate(imgs):
if 'val' in img['file_path']:
img['split'] = 'val'
else:
img['split'] = 'train'
for i, img in enumerate(imgs):
if 'val' in img['file_path']:
img['split'] = 'val'
else:
img['split'] = 'train'

print('assigned %d to val, %d to test.' % (num_val, num_test))
print('assigned %d to val, %d to test.' % (num_val, num_test))


def encode_captions(imgs, params, wtoi):
Expand All @@ -146,12 +153,13 @@ def encode_captions(imgs, params, wtoi):
max_length = params['max_length']
N = len(imgs)
M = sum(len(img['final_captions']) for img in imgs)
#M = sum(sum(len(s) for s in img['final_captions']) for img in imgs) # total number of captions
# M = sum(sum(len(s) for s in img['final_captions']) for img in imgs) # total number of captions
print('Total number of captions sentence:' + str(M))
label_arrays = []
label_start_ix = np.zeros(N, dtype='uint32') # note: these will be one-indexed
# note: these will be one-indexed
label_start_ix = np.zeros(N, dtype='uint32')
label_end_ix = np.zeros(N, dtype='uint32')
label_length = np.zeros(M, dtype='uint32') # wrong
label_length = np.zeros(M, dtype='uint32') # wrong
#label_length = np.zeros(dtype='uint32')
caption_counter = 0
counter = 1
Expand All @@ -163,12 +171,14 @@ def encode_captions(imgs, params, wtoi):
for j, s in enumerate(img['final_captions']):
sentence_len = []
for sentence in s:
sentence_len.append(min(max_length, len(sentence))) # record the length of this sequence
# record the length of this sequence
sentence_len.append(min(max_length, len(sentence)))
#caption_counter += 1
for k, w in enumerate(sentence):
if k < max_length:
Li[j, k] = wtoi[w]
label_length[caption_counter] = min(max_length, sum(sentence_len)) # record the length of this sequence
label_length[caption_counter] = min(max_length, sum(
sentence_len)) # record the length of this sequence
caption_counter += 1

# note: word indices are 1-indexed, and captions are padded with zeros
Expand All @@ -185,31 +195,69 @@ def encode_captions(imgs, params, wtoi):
print('encoded captions to array of size ', L.shape)
return L, label_start_ix, label_end_ix, label_length


dset_fc, dset_att = None, None

params=None
import misc.resnet as resnet
# resnet101 is out of memory in Tesla M4, try resnet50
#resnet = resnet.resnet101()
resnet = resnet.resnet50(True)
#resnet.load_state_dict(torch.load('misc/resnet101.pth'))
my_resnet = myResnet(resnet)
#my_resnet.cuda()
my_resnet.eval()
#my_resnet.share_memory()

@ray.remote
def perpare_img(i, img):
global dset_fc, dset_att, my_resnet
# load the image E:/image_caption/
#I = skimage.io.imread(os.path.join(params['images_root'], img['file_path']))
I = skimage.io.imread(os.path.join("E:/image_caption/", img['file_path']))
# handle grayscale input images
if len(I.shape) == 2:
I = I[:, :, np.newaxis]
I = np.concatenate((I, I, I), axis=2)

I = I.astype('float32') / 255.0
#I = torch.from_numpy(I.transpose([2, 0, 1])).cuda()
I = torch.from_numpy(I.transpose([2, 0, 1]))
I = Variable(preprocess(I), volatile=True)
tmp_fc, tmp_att = my_resnet(I)
# write to h5
#dset_fc[i] = tmp_fc.data.cpu().float().numpy()
#dset_att[i] = tmp_att.data.cpu().float().numpy()
dset_fc[i] = tmp_fc.data.float().numpy()
dset_att[i] = tmp_att.data.float().numpy()
print('processing %d done' % i)
# if i % 1000 == 0:
# print('processing %d/%d (%.2f%% done)' % (i, N, i * 100.0 / N))


def main(params):
#global params
imgs = json.load(open(params['input_json'], 'r'))
#imgs = imgs['images']

# assign the splits
assign_splits(imgs, params)

seed(123) # make reproducible
#shuffle(imgs) # shuffle the order
# shuffle(imgs) # shuffle the order
prepro_captions(imgs)

# create the vocab
vocab = build_vocab(imgs, params)
itow = {i + 1: w for i, w in enumerate(vocab)} # a 1-indexed vocab translation table
# a 1-indexed vocab translation table
itow = {i + 1: w for i, w in enumerate(vocab)}
wtoi = {w: i + 1 for i, w in enumerate(vocab)} # inverse table

# encode captions in large arrays, ready to ship to hdf5 file
L, label_start_ix, label_end_ix, label_length = encode_captions(imgs, params, wtoi)
L, label_start_ix, label_end_ix, label_length = encode_captions(
imgs, params, wtoi)

import misc.resnet as resnet
resnet = resnet.resnet101(True)
resnet.load_state_dict(torch.load('misc/resnet101.pth'))
my_resnet = myResnet(resnet)
my_resnet.cuda()
my_resnet.eval()
global dset_fc, dset_att

# create output h5 file
N = len(imgs)
Expand All @@ -224,23 +272,17 @@ def main(params):

dset_fc = f_fc.create_dataset("fc", (N, 2048), dtype='float32')
dset_att = f_att.create_dataset("att", (N, 14, 14, 2048), dtype='float32')
for i, img in enumerate(imgs):
# load the image
I = skimage.io.imread(os.path.join(params['images_root'], img['file_path']))
# handle grayscale input images
if len(I.shape) == 2:
I = I[:, :, np.newaxis]
I = np.concatenate((I, I, I), axis=2)

I = I.astype('float32') / 255.0
I = torch.from_numpy(I.transpose([2, 0, 1])).cuda()
I = Variable(preprocess(I), volatile=True)
tmp_fc, tmp_att = my_resnet(I)
# write to h5
dset_fc[i] = tmp_fc.data.cpu().float().numpy()
dset_att[i] = tmp_att.data.cpu().float().numpy()
if i % 1000 == 0:
print('processing %d/%d (%.2f%% done)' % (i, N, i * 100.0 / N))

# pool = Pool()
# pool.starmap(perpare_img, [(i, img) for i, img in enumerate(imgs)])
# pool.close()
#pool.join()
#pool.close()
ray.init()
ray.get([perpare_img.remote(i, img) for i, img in enumerate(imgs[35933:])])
#for i, img in enumerate(imgs):
# if i > 35932:
# perpare_img(i, img)
f_fc.close()
f_att.close()
print('wrote ', params['output_h5'])
Expand All @@ -249,12 +291,14 @@ def main(params):
out = {}
out['ix_to_word'] = itow # encode the (1-indexed) vocab
out['images'] = []
for i, img in enumerate(imgs):

for img in imgs:
jimg = {}
jimg['split'] = img['split']
if 'file_path' in img: jimg['file_path'] = img['file_path'] # copy it over, might need
if 'id' in img: jimg['id'] = img['id'] # copy over & mantain an id, if present (e.g. coco ids, useful)
if 'file_path' in img:
jimg['file_path'] = img['file_path'] # copy it over, might need
if 'id' in img:
# copy over & mantain an id, if present (e.g. coco ids, useful)
jimg['id'] = img['id']

out['images'].append(jimg)

Expand All @@ -270,21 +314,24 @@ def main(params):
help='input json file to process into hdf5')
parser.add_argument('--num_val', default=30000, type=int,
help='number of images to assign to validation data (for CV etc)')
parser.add_argument('--output_json', default='E:/image_caption/coco_ai_challenger_talk.json', help='output json file')
parser.add_argument('--output_h5', default='E:/image_caption/coco_ai_challenger_talk', help='output h5 file')
parser.add_argument(
'--output_json', default='E:/image_caption/coco_ai_challenger_talk.json', help='output json file')
parser.add_argument(
'--output_h5', default='E:/image_caption/coco_ai_challenger_talk/', help='output h5 file')

# options
parser.add_argument('--max_length', default=32, type=int,
help='max length of a caption, in number of words. captions longer than this get clipped.')
parser.add_argument('--images_root', default='ai_challenger', # Note that all images are save in `images` folder under this folder
parser.add_argument('--images_root', default='E:/image_caption/', # Note that all images are save in `images` folder under this folder
help='root location in which images are stored, to be prepended to file_path in input json')
parser.add_argument('--word_count_threshold', default=5, type=int,
help='only words that occur more than this number of times will be put in vocab')
parser.add_argument('--num_test', default=0, type=int,
help='number of test images (to withold until very very end)')

args = parser.parse_args()
#global params
params = vars(args) # convert to ordinary dict
print('parsed input parameters:')
print(json.dumps(params, indent=2))
main(params)
main(params)

0 comments on commit cf0f1d5

Please sign in to comment.