-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathconvert2emb.py
78 lines (65 loc) · 2.7 KB
/
convert2emb.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
'''
Convert article's text in the dataset to word embeddings using a pretrained word2vec dictionary.
'''
import h5py
import numpy as np
from nltk.tokenize import wordpunct_tokenize
import nltk
import utils
import cPickle as pkl
import os
import parameters as prm
import time
def compute_emb(pages_path_in, pages_path_out, vocab):
wemb = pkl.load(open(prm.wordemb_path, 'rb'))
dim_emb = wemb[wemb.keys()[0]].shape[0]
W = 0.01 * np.random.randn(len(vocab), dim_emb).astype(np.float32)
for word, pos in vocab.items():
if word in wemb:
W[pos,:] = wemb[word]
f = h5py.File(pages_path_in, 'r')
if prm.att_doc and prm.att_segment_type == 'sentence':
nltk.download('punkt')
tokenizer = nltk.data.load('tokenizers/punkt/english.pickle')
os.remove(pages_path_out) if os.path.exists(pages_path_out) else None
# Save to HDF5
fout = h5py.File(pages_path_out,'a')
if prm.att_doc:
shape = (f['text'].shape[0],prm.max_segs_doc,prm.dim_emb)
else:
shape=(f['text'].shape[0],prm.dim_emb)
embs = fout.create_dataset('emb', shape=shape, dtype=np.float32)
mask = fout.create_dataset('mask', shape=(f['text'].shape[0],), dtype=np.float32)
i = 0
for text in f['text']:
st = time.time()
if prm.att_doc:
if prm.att_segment_type.lower() == 'section' or prm.att_segment_type.lower() == 'subsection':
segs = ['']
for line in text.split('\n'):
if prm.att_segment_type == 'section':
line = line.replace('===', '')
if line.strip().startswith('==') and line.strip().endswith('=='):
segs.append('')
segs[-1] += line + '\n'
elif prm.att_segment_type.lower() == 'sentence':
segs = tokenizer.tokenize(text.decode('ascii', 'ignore'))
elif prm.att_segment_type.lower() == 'word':
segs = wordpunct_tokenize(text.decode('ascii', 'ignore'))
else:
raise ValueError('Not a valid value for the attention segment type (att_segment_type) parameter. Valid options are "section", "subsection", "sentence" or "word".')
segs = segs[:prm.max_segs_doc]
emb_ = utils.Word2Vec_encode(segs, wemb)
embs[i,:len(emb_),:] = emb_
mask[i] = len(emb_)
else:
bow0, bow1 = utils.BOW(wordpunct_tokenize(text.lower()), vocab)
emb = (W[bow0] * bow1[:,None]).sum(0)
embs[i,:] = emb
i += 1
#if i > 3000:
# break
if i % prm.dispFreq == 0:
print 'processing article', i, 'time', time.time()-st
f.close()
fout.close()