1
1
from os import stat_result
2
- from typing import Optional
2
+ from typing import List , Optional
3
3
import time
4
4
5
5
import torch
12
12
13
13
from transformers import AutoConfig , AutoModel , AdamW
14
14
15
- from paraphrasegen .loss import ContrastiveLoss
15
+ from paraphrasegen .loss import ContrastiveLoss , Similarity
16
16
from paraphrasegen .constants import (
17
17
AVAIL_GPUS ,
18
18
BATCH_SIZE ,
@@ -72,25 +72,48 @@ def forward(self, attention_mask, outputs):
72
72
73
73
74
74
class MLPLayer (nn .Module ):
75
- def __init__ (self , in_dims : int = 768 , hidden_dims : int = 768 ):
75
+ def __init__ (
76
+ self , in_dims : int = 768 , hidden_dims : List [int ] = 768 , activation : str = "GELU"
77
+ ):
76
78
super (MLPLayer , self ).__init__ ()
77
- self .fc1 = nn .Linear (in_dims , hidden_dims )
78
- self .layer_norm = nn .LayerNorm (hidden_dims )
79
- self .activation = nn .Tanh ()
79
+
80
+ if activation == "GELU" :
81
+ activation_fn = nn .GELU ()
82
+ elif activation == "ReLU" :
83
+ activation_fn = nn .ReLU ()
84
+ elif activation == "mish" :
85
+ activation_fn = nn .Mish ()
86
+ elif activation == "leaky_relu" :
87
+ activation_fn = nn .LeakyReLU ()
88
+
89
+ layers = [
90
+ nn .Linear (in_dims , hidden_dims [0 ]),
91
+ nn .LayerNorm (hidden_dims [0 ]),
92
+ activation_fn ,
93
+ ]
94
+
95
+ for i in range (1 , len (hidden_dims )):
96
+ layers += [
97
+ nn .Linear (hidden_dims [i - 1 ], hidden_dims [i ]),
98
+ nn .LayerNorm (hidden_dims [i ]),
99
+ activation_fn ,
100
+ ]
101
+
102
+ self .net = nn .Sequential (* layers )
80
103
81
104
def forward (self , x : torch .Tensor ):
82
- out = self .fc1 (x )
83
- out = self .layer_norm (out )
84
- return self .activation (out )
105
+ return self .net (x )
85
106
86
107
87
108
class Encoder (pl .LightningModule ):
88
109
def __init__ (
89
110
self ,
90
111
model_name_or_path : str ,
91
112
input_mask_rate : float = 0.1 ,
92
- embedding_from : str = "single" ,
93
113
pooler_type : str = "cls" ,
114
+ mlp_layers : List [int ] = [768 ],
115
+ temp : float = 0.05 ,
116
+ hard_negative_weight : float = 0 ,
94
117
learning_rate : float = 3e-5 ,
95
118
weight_decay : float = 0 ,
96
119
) -> None :
@@ -99,17 +122,19 @@ def __init__(
99
122
self .save_hyperparameters ()
100
123
self .config = AutoConfig .from_pretrained (model_name_or_path )
101
124
self .input_mask_rate = input_mask_rate
102
- self .embedding_from = embedding_from
103
125
self .bert_model = AutoModel .from_pretrained (
104
126
model_name_or_path , config = self .config , cache_dir = PATH_BASE_MODELS
105
127
)
106
128
107
129
self .pooler_type = pooler_type
108
130
self .pooler = Pooler (pooler_type )
109
131
110
- self .net = MLPLayer ()
132
+ self .net = MLPLayer (in_dims = 768 , hidden_dims = mlp_layers )
111
133
112
- self .loss_fn = ContrastiveLoss ()
134
+ self .loss_fn = ContrastiveLoss (
135
+ temp = self .hparams .temp ,
136
+ hard_negative_weight = self .hparams .hard_negative_weight ,
137
+ )
113
138
114
139
def forward (
115
140
self , input_ids : torch .Tensor , attention_mask : torch .Tensor , do_mlm : bool = True
@@ -167,42 +192,59 @@ def training_step(self, batch, batch_idx):
167
192
attention_mask = batch ["target_attention_mask" ],
168
193
)
169
194
170
- loss = self .loss_fn (anchor_outputs , target_outputs )
195
+ negative_index = torch .randperm (batch ["anchor_input_ids" ].size (0 ))
196
+
197
+ negative_outputs = self (
198
+ input_ids = batch ["anchor_input_ids" ][negative_index ],
199
+ attention_mask = batch ["anchor_attention_mask" ][negative_index ],
200
+ )
201
+
202
+ loss = self .loss_fn (anchor_outputs , target_outputs , negative_outputs )
171
203
self .log ("loss/train" , loss )
172
204
173
205
return loss
174
206
175
- def validation_step (self , batch , batch_idx , dataloader_idx = 0 ):
207
+ def _evaluate (self , batch ):
176
208
anchor_outputs = self (
177
209
input_ids = batch ["anchor_input_ids" ],
178
210
attention_mask = batch ["anchor_attention_mask" ],
211
+ do_mlm = False ,
179
212
)
180
213
181
214
target_outputs = self (
182
215
input_ids = batch ["target_input_ids" ],
183
216
attention_mask = batch ["target_attention_mask" ],
217
+ do_mlm = False ,
184
218
)
185
219
186
- loss = self .loss_fn (anchor_outputs , target_outputs )
220
+ pos_anchor_emb = anchor_outputs [batch ["labels" ] == 1 ]
221
+ pos_target_emb = target_outputs [batch ["labels" ] == 1 ]
187
222
188
- self . log ( "loss/val" , loss , prog_bar = True )
189
- self . log ( "hp_metric" , loss )
223
+ neg_anchor_emb = anchor_outputs [ batch [ "labels" ] == 0 ]
224
+ neg_target_emb = target_outputs [ batch [ "labels" ] == 0 ]
190
225
191
- def test_step (self , batch , batch_idx ):
192
- anchor_outputs = self (
193
- input_ids = batch ["anchor_input_ids" ],
194
- attention_mask = batch ["anchor_attention_mask" ],
195
- )
226
+ pos_diff = torch .norm (pos_anchor_emb - pos_target_emb ).mean ()
227
+ neg_diff = torch .norm (neg_anchor_emb - neg_target_emb ).mean ()
196
228
197
- target_outputs = self (
198
- input_ids = batch ["target_input_ids" ],
199
- attention_mask = batch ["target_attention_mask" ],
229
+ sim = Similarity (temp = self .hparams .temp )
230
+ pos_sim = sim (pos_anchor_emb , pos_target_emb ).mean ()
231
+ neg_sim = sim (neg_anchor_emb , neg_target_emb ).mean ()
232
+
233
+ self .log_dict (
234
+ {
235
+ "diff/pos" : pos_diff ,
236
+ "diff/neg" : neg_diff ,
237
+ "sim/pos" : pos_sim ,
238
+ "sim/neg" : neg_sim ,
239
+ }
200
240
)
241
+ self .log ("hp_metric" , pos_sim - neg_sim )
201
242
202
- loss = self .loss_fn (anchor_outputs , target_outputs )
243
+ def validation_step (self , batch , batch_idx , dataloader_idx = 0 ):
244
+ self ._evaluate (batch )
203
245
204
- self . log ( "loss/test" , loss , prog_bar = True )
205
- self .log ( "hp_metric" , loss )
246
+ def test_step ( self , batch , batch_idx ):
247
+ self ._evaluate ( batch )
206
248
207
249
def configure_optimizers (self ):
208
250
"""Prepare optimizer and schedule (linear warmup and decay)"""
@@ -256,7 +298,7 @@ def configure_optimizers(self):
256
298
trainer = Trainer (
257
299
max_epochs = 1 ,
258
300
gpus = AVAIL_GPUS ,
259
- log_every_n_steps = 10 ,
301
+ log_every_n_steps = 2 ,
260
302
precision = 16 ,
261
303
stochastic_weight_avg = True ,
262
304
logger = TensorBoardLogger ("runs/" ),
0 commit comments