-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataset_tr.py
95 lines (78 loc) · 2.55 KB
/
dataset_tr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import os
import operator
import numpy as np
from six.moves import cPickle as pickle
from collections import defaultdict
from external.vqa.vqa import VQA
image_dir = "/projectnb/statnlp/shawnlin/dataset/mscoco_vqa_2014/train2014"
img_prefix = "COCO_train2014_"
qjson = "/projectnb/statnlp/shawnlin/dataset/mscoco_vqa_2014/Questions_Train_mscoco/OpenEnded_mscoco_train2014_questions.json"
ajson = "/projectnb/statnlp/shawnlin/dataset/mscoco_vqa_2014/Annotations_Train_mscoco/mscoco_train2014_annotations.json"
vqa = VQA(ajson, qjson)
img_names = [f for f in os.listdir(image_dir) if '.jpg' in f]
img_ids = []
for fname in img_names:
img_id = fname.split('.')[0].rpartition(img_prefix)[-1]
img_ids.append(int(img_id))
ques_ids = vqa.getQuesIds(img_ids)
q2i = defaultdict(lambda: len(q2i))
pad = q2i["<pad>"]
start = q2i["<sos>"]
end = q2i["<eos>"]
UNK = q2i["<unk>"]
a2i_count = {}
for ques_id in ques_ids:
qa = vqa.loadQA(ques_id)[0]
qqa = vqa.loadQQA(ques_id)[0]
ques = qqa['question'][:-1]
[q2i[x] for x in ques.lower().strip().split(" ")]
answers = qa['answers']
for ans in answers:
if not ans['answer_confidence'] == 'yes':
continue
ans = ans['answer'].lower()
if ans not in a2i_count:
a2i_count[ans] = 1
else:
a2i_count[ans] = a2i_count[ans] + 1
a_sort = sorted(a2i_count.items(), key=operator.itemgetter(1), reverse=True)
i2a = {}
count = 0
a2i = defaultdict(lambda: len(a2i))
for word, _ in a_sort:
a2i[word]
i2a[a2i[word]] = word
count = count + 1
if count == 1000:
break
ques_ids_modif = []
for ques_id in ques_ids:
qa = vqa.loadQA(ques_id)[0]
qqa = vqa.loadQQA(ques_id)[0]
ques = qqa['question'][:-1]
answers = qa['answers']
answer = ""
for ans in answers:
ans = ans['answer'].lower()
if ans in a2i:
answer = ans
break
if answer == "":
continue
ques_ids_modif.append(ques_id)
print(len(ques_ids_modif), len(ques_ids))
with open('./data/q2i.pkl', 'wb') as f:
pickle.dump(dict(q2i), f)
with open('./data/a2i.pkl', 'wb') as f:
pickle.dump(dict(a2i), f)
with open('./data/i2a.pkl', 'wb') as f:
pickle.dump(i2a, f)
with open('./data/a2i_count.pkl', 'wb') as f:
pickle.dump(a2i_count, f)
np.save('./data/q2i.npy', dict(q2i))
np.save('./data/a2i.npy', dict(a2i))
np.save('./data/i2a.npy', i2a)
np.save('./data/a2i_count.npy', a2i_count)
np.save('./data/tr_img_names.npy', img_names)
np.save('./data/tr_img_ids.npy', img_ids)
np.save('./data/tr_ques_ids.npy', ques_ids_modif)