forked from morningmoni/HiLAP
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathreadData_rcv1.py
80 lines (73 loc) · 2.97 KB
/
readData_rcv1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from collections import defaultdict
from tqdm import tqdm
def read_rcv1_ids(filepath):
ids = set()
with open(filepath) as f:
new_doc = True
for line in f:
line_split = line.strip().split()
if new_doc and len(line_split) == 2:
tmp, did = line_split
if tmp == '.I':
ids.add(did)
new_doc = False
else:
print(line_split)
print('maybe error')
elif len(line_split) == 0:
new_doc = True
print('{} samples in {}'.format(len(ids), filepath))
return ids
def read_rcv1():
p2c = defaultdict(list)
id2doc = defaultdict(lambda: defaultdict(list))
nodes = defaultdict(lambda: defaultdict(list))
with open('rcv1/rcv1.topics.hier.orig.txt') as f:
for line in f:
start = line.find('parent: ') + len('parent: ')
end = line.find(' ', start)
parent = line[start:end]
start = line.find('child: ') + len('child: ')
end = line.find(' ', start)
child = line[start:end]
start = line.find('child-description: ') + len('child-description: ')
end = line.find('\n', start)
child_desc = line[start:end]
p2c[parent].append(child)
for label in p2c:
if label == 'None':
continue
for children in p2c[label]:
nodes[label]['children'].append(children)
nodes[children]['parent'].append(label)
with open('rcv1/rcv1-v2.topics.qrels') as f:
for line in f:
cat, doc_id, _ = line.strip().split()
id2doc[doc_id]['categories'].append(cat)
X_train = []
X_test = []
train_ids = []
test_ids = []
train_id_set = read_rcv1_ids('../datasets/rcv1_token/lyrl2004_tokens_train.dat')
test_id_set = read_rcv1_ids('../datasets/rcv1_token/lyrl2004_tokens_test_pt0.dat')
test_id_set |= read_rcv1_ids('../datasets/rcv1_token/lyrl2004_tokens_test_pt1.dat')
test_id_set |= read_rcv1_ids('../datasets/rcv1_token/lyrl2004_tokens_test_pt2.dat')
test_id_set |= read_rcv1_ids('../datasets/rcv1_token/lyrl2004_tokens_test_pt3.dat')
print('len(test) total={}'.format(len(test_id_set)))
n_not_found = 0
with open('rcv1/docs.txt') as f:
for line in tqdm(f):
doc_id, text = line.strip().split(maxsplit=1)
if doc_id in train_id_set:
train_ids.append(doc_id)
X_train.append(text)
elif doc_id in test_id_set:
test_ids.append(doc_id)
X_test.append(text)
else:
n_not_found += 1
print('there are {} that cannot be found in official tokenized rcv1'.format(n_not_found))
print('len(train_ids)={} len(test_ids)={}'.format(len(train_ids), len(test_ids)))
return X_train, X_test, train_ids, test_ids, dict(id2doc), dict(nodes)
if __name__ == '__main__':
read_rcv1()