forked from OpenNMT/OpenNMT-py
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel_builder.py
332 lines (282 loc) · 11.9 KB
/
model_builder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
"""
This file is for models creation, which consults options
and creates each encoder and decoder accordingly.
"""
import re
import torch
import torch.nn as nn
from torch.nn.init import xavier_uniform_
import onmt.modules
from onmt.encoders import str2enc
from onmt.decoders import str2dec
from onmt.inputters.inputter import dict_to_vocabs
from onmt.modules import Embeddings, CopyGenerator
from onmt.utils.misc import use_gpu
from onmt.utils.logging import logger
from onmt.utils.parse import ArgumentParser
from onmt.constants import DefaultTokens, ModelTask
def build_embeddings(opt, vocabs, for_encoder=True):
"""
Args:
opt: the option in current environment.
vocab.
for_encoder(bool): build Embeddings for encoder or decoder?
"""
feat_pad_indices = []
num_feat_embeddings = []
if for_encoder:
emb_dim = opt.src_word_vec_size
word_padding_idx = vocabs['src'][DefaultTokens.PAD]
num_word_embeddings = len(vocabs['src'])
if 'src_feats' in vocabs:
feat_pad_indices = \
[fv[DefaultTokens.PAD] for fv in vocabs['src_feats']]
num_feat_embeddings = \
[len(fv) for fv in vocabs['src_feats']]
freeze_word_vecs = opt.freeze_word_vecs_enc
else:
emb_dim = opt.tgt_word_vec_size
word_padding_idx = vocabs['tgt'][DefaultTokens.PAD]
num_word_embeddings = len(vocabs['tgt'])
freeze_word_vecs = opt.freeze_word_vecs_dec
emb = Embeddings(
word_vec_size=emb_dim,
position_encoding=opt.position_encoding,
position_encoding_type=opt.position_encoding_type,
feat_merge=opt.feat_merge,
feat_vec_exponent=opt.feat_vec_exponent,
feat_vec_size=opt.feat_vec_size,
dropout=opt.dropout[0] if type(opt.dropout) is list else opt.dropout,
word_padding_idx=word_padding_idx,
feat_padding_idx=feat_pad_indices,
word_vocab_size=num_word_embeddings,
feat_vocab_sizes=num_feat_embeddings,
sparse=opt.optim == "sparseadam",
freeze_word_vecs=freeze_word_vecs
)
return emb
def build_encoder(opt, embeddings):
"""
Various encoder dispatcher function.
Args:
opt: the option in current environment.
embeddings (Embeddings): vocab embeddings for this encoder.
"""
enc_type = opt.encoder_type if opt.model_type == "text" else opt.model_type
return str2enc[enc_type].from_opt(opt, embeddings)
def build_decoder(opt, embeddings):
"""
Various decoder dispatcher function.
Args:
opt: the option in current environment.
embeddings (Embeddings): vocab embeddings for this decoder.
"""
dec_type = "ifrnn" if opt.decoder_type == "rnn" and opt.input_feed \
else opt.decoder_type
return str2dec[dec_type].from_opt(opt, embeddings)
def load_test_model(opt, model_path=None):
if model_path is None:
model_path = opt.models[0]
checkpoint = torch.load(model_path,
map_location=lambda storage, loc: storage)
model_opt = ArgumentParser.ckpt_model_opts(checkpoint['opt'])
# Patch for NLLB200 model loading
if ('encoder.embeddings.make_embedding.pe.pe' not in
checkpoint['model'].keys()):
model_opt.position_encoding_type = 'SinusoidalConcat'
ArgumentParser.update_model_opts(model_opt)
ArgumentParser.validate_model_opts(model_opt)
vocabs = dict_to_vocabs(checkpoint['vocab'])
# Avoid functionality on inference
model_opt.update_vocab = False
model = build_base_model(model_opt, vocabs, use_gpu(opt), checkpoint,
opt.gpu)
if opt.fp32:
model.float()
elif opt.int8:
if opt.gpu >= 0:
raise ValueError(
"Dynamic 8-bit quantization is not supported on GPU")
torch.quantization.quantize_dynamic(model, inplace=True)
model.eval()
model.generator.eval()
return vocabs, model, model_opt
def build_src_emb(model_opt, vocabs):
# Build embeddings.
if model_opt.model_type == "text":
src_emb = build_embeddings(model_opt, vocabs)
else:
src_emb = None
return src_emb
def build_encoder_with_embeddings(model_opt, vocabs):
# Build encoder.
src_emb = build_src_emb(model_opt, vocabs)
encoder = build_encoder(model_opt, src_emb)
return encoder, src_emb
def build_decoder_with_embeddings(
model_opt, vocabs, share_embeddings=False, src_emb=None
):
# Build embeddings.
tgt_emb = build_embeddings(model_opt, vocabs, for_encoder=False)
if share_embeddings:
tgt_emb.word_lut.weight = src_emb.word_lut.weight
# Build decoder.
decoder = build_decoder(model_opt, tgt_emb)
return decoder, tgt_emb
def build_task_specific_model(model_opt, vocabs):
# Share the embedding matrix - preprocess with share_vocab required.
if model_opt.share_embeddings:
# src/tgt vocab should be the same if `-share_vocab` is specified.
assert (
vocabs['src'] == vocabs['tgt']
), "preprocess with -share_vocab if you use share_embeddings"
if model_opt.model_task == ModelTask.SEQ2SEQ:
encoder, src_emb = build_encoder_with_embeddings(model_opt, vocabs)
decoder, _ = build_decoder_with_embeddings(
model_opt,
vocabs,
share_embeddings=model_opt.share_embeddings,
src_emb=src_emb,
)
return onmt.models.NMTModel(encoder=encoder, decoder=decoder)
elif model_opt.model_task == ModelTask.LANGUAGE_MODEL:
src_emb = build_src_emb(model_opt, vocabs)
decoder, _ = build_decoder_with_embeddings(
model_opt, vocabs, share_embeddings=True, src_emb=src_emb
)
return onmt.models.LanguageModel(decoder=decoder)
else:
raise ValueError(f"No model defined for {model_opt.model_task} task")
def use_embeddings_from_checkpoint(vocabs, model, generator, checkpoint):
# Update vocabulary embeddings with checkpoint embeddings
logger.info("Updating vocabulary embeddings with checkpoint embeddings")
# Embedding layers
enc_emb_name = 'encoder.embeddings.make_embedding.emb_luts.0.weight'
dec_emb_name = 'decoder.embeddings.make_embedding.emb_luts.0.weight'
for side, emb_name in [('src', enc_emb_name), ('tgt', dec_emb_name)]:
if emb_name not in checkpoint['model']:
continue
new_tokens = []
ckp_vocabs = dict_to_vocabs(checkpoint['vocab'])
for i, tok in enumerate(vocabs[side].ids_to_tokens):
if tok in ckp_vocabs[side]:
old_i = ckp_vocabs[side].lookup_token(tok)
model.state_dict()[emb_name][i] = checkpoint['model'][
emb_name
][old_i]
if side == 'tgt':
generator.state_dict()['weight'][i] = checkpoint[
'generator'
]['weight'][old_i]
generator.state_dict()['bias'][i] = checkpoint[
'generator'
]['bias'][old_i]
else:
# Just for debugging purposes
new_tokens.append(tok)
logger.info("%s: %d new tokens" % (side, len(new_tokens)))
# Remove old vocabulary associated embeddings
del checkpoint['model'][emb_name]
del checkpoint['generator']['weight'], checkpoint['generator']['bias']
def build_base_model(model_opt, vocabs, gpu, checkpoint=None, gpu_id=None):
"""Build a model from opts.
Args:
model_opt: the option loaded from checkpoint. It's important that
the opts have been updated and validated. See
:class:`onmt.utils.parse.ArgumentParser`.
vocabs (dict[str, Vocab]):
`Field` objects for the model.
gpu (bool): whether to use gpu.
checkpoint: the model generated by train phase, or a resumed snapshot
model from a stopped training.
gpu_id (int or NoneType): Which GPU to use.
Returns:
the NMTModel.
"""
# for back compat when attention_dropout was not defined
try:
model_opt.attention_dropout
except AttributeError:
model_opt.attention_dropout = model_opt.dropout
# Build Model
if gpu and gpu_id is not None:
device = torch.device("cuda", gpu_id)
elif gpu and not gpu_id:
device = torch.device("cuda")
elif not gpu:
device = torch.device("cpu")
model = build_task_specific_model(model_opt, vocabs)
# Build Generator.
if not model_opt.copy_attn:
generator = nn.Linear(model_opt.dec_hid_size,
len(vocabs['tgt']))
if model_opt.share_decoder_embeddings:
generator.weight = model.decoder.embeddings.word_lut.weight
else:
vocab_size = len(vocabs['tgt'])
pad_idx = vocabs['tgt'][DefaultTokens.PAD]
generator = CopyGenerator(model_opt.dec_hid_size, vocab_size, pad_idx)
if model_opt.share_decoder_embeddings:
generator.linear.weight = model.decoder.embeddings.word_lut.weight
# Load the model states from checkpoint or initialize them.
if checkpoint is None or model_opt.update_vocab:
if model_opt.param_init != 0.0:
for p in model.parameters():
p.data.uniform_(-model_opt.param_init, model_opt.param_init)
for p in generator.parameters():
p.data.uniform_(-model_opt.param_init, model_opt.param_init)
if model_opt.param_init_glorot:
for p in model.parameters():
if p.dim() > 1:
xavier_uniform_(p)
for p in generator.parameters():
if p.dim() > 1:
xavier_uniform_(p)
if hasattr(model, "encoder") and hasattr(model.encoder, "embeddings"):
model.encoder.embeddings.load_pretrained_vectors(
model_opt.pre_word_vecs_enc)
if hasattr(model.decoder, 'embeddings'):
model.decoder.embeddings.load_pretrained_vectors(
model_opt.pre_word_vecs_dec)
if checkpoint is not None:
# This preserves backward-compat for models using customed layernorm
def fix_key(s):
s = re.sub(r'(.*)\.layer_norm((_\d+)?)\.b_2',
r'\1.layer_norm\2.bias', s)
s = re.sub(r'(.*)\.layer_norm((_\d+)?)\.a_2',
r'\1.layer_norm\2.weight', s)
return s
checkpoint['model'] = {fix_key(k): v
for k, v in checkpoint['model'].items()}
if '0.weight' in checkpoint['generator']:
checkpoint['generator']['weight'] =\
checkpoint['generator'].pop('0.weight')
if '0.bias' in checkpoint['generator']:
checkpoint['generator']['bias'] =\
checkpoint['generator'].pop('0.bias')
# end of patch for backward compatibility
if model_opt.update_vocab:
# Update model embeddings with those from the checkpoint
# after initialization
use_embeddings_from_checkpoint(vocabs, model, generator,
checkpoint)
model.load_state_dict(checkpoint['model'], strict=False)
generator.load_state_dict(checkpoint['generator'], strict=False)
model.generator = generator
if model_opt.freeze_encoder:
model.encoder.requires_grad_(False)
model.encoder.embeddings.requires_grad_()
if model_opt.freeze_decoder:
model.decoder.requires_grad_(False)
model.decoder.embeddings.requires_grad_()
model.to(device)
if model_opt.model_dtype == 'fp16' and \
model_opt.apex_opt_level not in ['O0', 'O1', 'O2', 'O3'] and \
model_opt.optim == 'fusedadam':
model.half()
return model
def build_model(model_opt, opt, vocabs, checkpoint):
logger.info('Building model...')
model = build_base_model(model_opt, vocabs, use_gpu(opt), checkpoint)
logger.info(model)
return model