-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathencode_nounphrase.py
33 lines (24 loc) · 925 Bytes
/
encode_nounphrase.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import h5py
import json
import numpy as np
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--nounphrasemapping', dest='nounphrasemapping', default='', type=str)
parser.add_argument('--vocabmapping', dest='vocabmapping', default='', type=str)
parser.add_argument('--dest', dest='dest', default='', type=str)
if len(sys.argv) == 1:
parser.print_help()
sys.exit(1)
args = parser.parse_args()
nounphrase_mappings = json.load(open(args.nounphrasemapping))
vocab_mappings = json.load(open(args.vocabmapping))
input_wtoi = vocab_mappings["input_wtoi"]
n = len(nounphrase_mappings["noun_itow"])
encode = np.zeros((n, 18))
for k, v in nounphrase_mappings["noun_itow"].iteritems():
words = v.split(" ")
for i, word in enumerate(words):
encode[int(k) - 1][i] = input_wtoi.get(word, input_wtoi["UNK"])
f = h5py.File(args.dest, "w")
f.create_dataset("encode", data=encode, dtype="uint32")
f.close()