-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathiclr17_data.py
47 lines (33 loc) · 1.32 KB
/
iclr17_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import random
base_dir = '/Users/saarkuzi/iclr17_dataset/'
grades = {}
with open(base_dir + 'annotation_aggregated.tsv', 'r') as input_file:
for line in input_file:
args = line.rstrip('\n').split('\t')
grades[args[0]] = args[1:]
paper_ids = []
for paper in grades:
if grades[paper][1] != '-' and grades[paper][2] != '-' and grades[paper][3] != '-':
paper_ids.append(paper)
else:
print(paper)
texts = []
ids = []
for data_type in ['train', 'test', 'val']:
ids += [line.rstrip('\n') for line in open(base_dir + data_type + '.ids', 'r').readlines()]
texts += [line.rstrip('\n') for line in open(base_dir + data_type + '.ids', 'r').readlines()]
texts = list(zip(ids, texts))
distilled = [element for element in texts if element[0] in paper_ids]
texts = distilled
random.shuffle(texts)
test_size = 30
test_file = open('dim.all.mod.neu.para.1.test.val.text', 'w+')
test_id_file = open('dim.all.mod.neu.para.1.test.val.ids', 'w+')
for i in range(test_size):
test_file.write(texts[i][1] + '\n')
test_id_file.write(texts[i][0] + '\n')
train_file = open('dim.all.mod.neu.para.1.train.text', 'w+')
train_id_file = open('dim.all.mod.neu.para.1.train.ids', 'w+')
for i in range(test_size, len(texts)):
train_file.write(texts[i][1] + '\n')
train_id_file.write(texts[i][0] + '\n')