-
Notifications
You must be signed in to change notification settings - Fork 0
/
vocab.py
146 lines (114 loc) · 3.26 KB
/
vocab.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
#! /usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright © 2018 LeonTao
#
# Distributed under terms of the MIT license.
import pickle
PAD = 'PAD'
UNK = 'UNK'
SOS = 'SOS'
EOS = 'EOS'
PAD_ID = 0
UNK_ID = 1
SOS_ID = 2
EOS_ID = 3
class Vocab(object):
def __init__(self):
self.init_vocab()
def init_vocab(self):
self.word2idx = {}
self.idx2word = {}
self.word2idx[PAD] = 0
self.word2idx[UNK] = 1
self.word2idx[SOS] = 2
self.word2idx[EOS] = 3
@property
def size(self):
return len(self.word2idx)
'''word to id '''
def word_to_id(self, word):
return self.word2idx.get(word, self.unkid)
def words_to_id(self, words):
word_ids = [self.word_to_id(cur_word) for cur_word in words]
# word_ids = [id for id in word_ids if id != self.unkid]
return word_ids
'''id to word'''
def id_to_word(self, id):
return self.idx2word.get(id, self.unk)
'''ids to word'''
def ids_to_word(self, ids):
words = [self.id_to_word(id) for id in ids]
return words
def build_from_freq(self, freq_list):
cur_id = 4 # because of the unk, pad, sos, and eos tokens.
for word, _ in freq_list:
self.word2idx[word] = cur_id
cur_id += 1
# init idx2word
self.idx2word = {v: k for k, v in self.word2idx.items()}
'''save and restore'''
def save(self, path='vocab.idx2word.dict'):
if len(self.idx2word) == 0:
raise RuntimeError("Save vocab after call build_from_freq()")
pickle.dump(self.word2idx, open(path, 'wb'))
# pickle.dump(self.idx2word, open('./vocab_idx2word.dict', 'wb'))
def load(self, path='vocab_idx2word.dict'):
try:
self.word2idx = pickle.load(open(path, 'rb'))
self.idx2word = {v: k for k, v in self.word2idx.items()}
except FileNotFoundError:
raise RuntimeError("Make sure vocab_word2idx.dict exists.")
''' wordid '''
@property
def padid(self):
"""return the id of padding
"""
return self.word2idx.get(PAD, 0)
@property
def unkid(self):
"""return the id of unknown word
"""
return self.word2idx.get(UNK, 1)
@property
def sosid(self):
"""return the id of padding
"""
return self.word2idx.get(SOS, 2)
@property
def eosid(self):
"""return the id of padding
"""
return self.word2idx.get(EOS, 3)
'''words '''
@property
def unk(self):
"""return the str of unknown word
"""
return UNK
@property
def pad(self):
"""return the str of padding
"""
return PAD
@property
def sos(self):
"""return the str of padding
"""
return SOS
@property
def eos(self):
"""return the str of padding
"""
return EOS
def ids_to_text(self, ids):
final_ids = []
for id in ids:
if id in [self.padid, self.sosid]:
continue
elif id == self.eosid:
break
else:
final_ids.append(id)
words = self.ids_to_word(final_ids)
text = ' '.join(words)
return text