forked from yangjianxin1/SimCSE
-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
84 lines (70 loc) · 3.72 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertModel, BertConfig, BertTokenizer
# from utils.modeling import BertModel as SimBertModel
# from utils.modeling import BertConfig as SimBertConfig
from transformers import BertTokenizer
class SimcseModel(nn.Module):
"""Simcse无监督模型定义"""
def __init__(self, pretrained_model, pooling, dropout=0.3):
super(SimcseModel, self).__init__()
# config = SimBertConfig.from_pretrained(pretrained_model)
config = BertConfig.from_pretrained(pretrained_model)
config.attention_probs_dropout_prob = dropout # 修改config的dropout系数
config.hidden_dropout_prob = dropout
self.bert = BertModel.from_pretrained(pretrained_model, config=config)
# self.bert = SimBertModel.from_pretrained(pretrained_model, config=config)
self.pooling = pooling
def forward(self, input_ids, attention_mask, token_type_ids):
out = self.bert(input_ids, attention_mask, token_type_ids, output_hidden_states=True, return_dict=True)
# return out[1]
if self.pooling == 'cls':
return out.last_hidden_state[:, 0] # [batch, 768]
if self.pooling == 'pooler':
return out.pooler_output # [batch, 768]
if self.pooling == 'last-avg':
last = out.last_hidden_state.transpose(1, 2) # [batch, 768, seqlen]
return torch.avg_pool1d(last, kernel_size=last.shape[-1]).squeeze(-1) # [batch, 768]
if self.pooling == 'first-last-avg':
first = out.hidden_states[1].transpose(1, 2) # [batch, 768, seqlen]
last = out.hidden_states[-1].transpose(1, 2) # [batch, 768, seqlen]
first_avg = torch.avg_pool1d(first, kernel_size=last.shape[-1]).squeeze(-1) # [batch, 768]
last_avg = torch.avg_pool1d(last, kernel_size=last.shape[-1]).squeeze(-1) # [batch, 768]
avg = torch.cat((first_avg.unsqueeze(1), last_avg.unsqueeze(1)), dim=1) # [batch, 2, 768]
return torch.avg_pool1d(avg.transpose(1, 2), kernel_size=2).squeeze(-1) # [batch, 768]
def simcse_unsup_loss(y_pred, device, temp=0.05):
"""无监督的损失函数
y_pred (tensor): bert的输出, [batch_size * 2, 768]
"""
# 得到y_pred对应的label, [1, 0, 3, 2, ..., batch_size-1, batch_size-2]
y_true = torch.arange(y_pred.shape[0], device=device)
y_true = (y_true - y_true % 2 * 2) + 1
# batch内两两计算相似度, 得到相似度矩阵(对角矩阵)
sim = F.cosine_similarity(y_pred.unsqueeze(1), y_pred.unsqueeze(0), dim=-1)
# 将相似度矩阵对角线置为很小的值, 消除自身的影响
sim = sim - torch.eye(y_pred.shape[0], device=device) * 1e12
# 相似度矩阵除以温度系数
sim = sim / temp
# 计算相似度矩阵与y_true的交叉熵损失
# 计算交叉熵,每个case都会计算与其他case的相似度得分,得到一个得分向量,目的是使得该得分向量中正样本的得分最高,负样本的得分最低
loss = F.cross_entropy(sim, y_true)
return torch.mean(loss)
def simcse_sup_loss(y_pred, device, lamda=0.05):
"""
有监督损失函数
"""
similarities = F.cosine_similarity(y_pred.unsqueeze(0), y_pred.unsqueeze(1), dim=2)
row = torch.arange(0, y_pred.shape[0], 3)
col = torch.arange(0, y_pred.shape[0])
col = col[col % 3 != 0]
similarities = similarities[row, :]
similarities = similarities[:, col]
similarities = similarities / lamda
y_true = torch.arange(0, len(col), 2, device=device)
loss = F.cross_entropy(similarities, y_true)
return loss
if __name__ == '__main__':
y_pred = torch.rand((30 ,16))
loss = simcse_sup_loss(y_pred, 'cpu', lamda=0.05)
print(loss)