forked from bigscience-workshop/Megatron-DeepSpeed
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathindex
152 lines (142 loc) · 6.26 KB
/
index
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import sys
from transformers import BertTokenizer, BertModel
sys.path.append("..")
import torch
from tqdm import tqdm
import faiss
import torch.nn.functional as F
import numpy as np
from data.dataprocess import load_data, save_data, save_json, load_json
from ranking.train_sup import SimcseModel
def get_embedding(text, model, tokenizer,device):
"""
单句版,获取句向量
text:文本(str)
"""
source = tokenizer(text, return_tensors='pt')
with torch.no_grad():
# for source in tqdm(test_dataloader):
# source [batch, 1, seq_len] -> [batch, seq_len]
source_input_ids = source['input_ids'].squeeze(1).to(device)
source_attention_mask = source['attention_mask'].squeeze(1).to(device)
source_token_type_ids = source['token_type_ids'].squeeze(1).to(device)
embedding = model(source_input_ids, source_attention_mask,source_token_type_ids)
try:
embedding = embedding.last_hidden_state[:, 0]
except:
embedding = embedding
return F.normalize(embedding, dim=-1).cpu().numpy(), embedding
def _get_embedding(data, model, tokenizer,device,bsz = 1, verbose = False):
"""
batch版,兼容单句str
"""
if isinstance(data,str):
assert bsz == 1, "batch size设置大于1"
return get_embedding(data, model, tokenizer,device)
else:
# 0831:试了一下,没什么问题,和单句算相比,前5个有效数字都是一样的,前7个就有一小半不一样了
# 但是记得之前有一次好像有问题,折腾一下午也没解决
# TODO mask的原因?
with torch.no_grad():
norm_embed, embed = [], []
for i in tqdm(range(0,len(data)//bsz + 1),disable = not verbose):
text = data[bsz*i : min(bsz*(i+1),len(data))]
#print(text)
if len(text) == 0:
continue
length = max([len(t) for t in text])
source = tokenizer.batch_encode_plus(text,padding = "max_length",max_length = length, truncation = True, return_tensors = "pt")
#print("shape",source["input_ids"].shape)
for k,v in source.items():
source[k] = v.cuda()
outputs = model(**source)
# !!!,除非确实要压缩多个维度,否则不要squeeze()
embeddings = outputs.last_hidden_state[:,0,:].squeeze(1)
#print("shape",embeddings.shape)
norm_embed.extend(F.normalize(embeddings, dim=-1).cpu().numpy())
embed.extend(embeddings)
torch.cuda.empty_cache()
return norm_embed, embed
def setup_faiss(dim, vec,measure = faiss.METRIC_INNER_PRODUCT,param = 'Flat'):
index = faiss.index_factory(dim,param,measure)
res = faiss.StandardGpuResources()
gpu_index = faiss.index_cpu_to_gpu(res,device = 0,index=index)
vec = np.load(vec)
gpu_index.add(vec.squeeze())
return gpu_index
class Index():
def __init__(self, dim : int, model, tokenizer = None,measure = faiss.METRIC_INNER_PRODUCT,param = 'Flat',sup = False):
"""
model可以传地址,也可以传模型
"""
if isinstance(model,str):
if sup or "sup" in model:
simcse = SimcseModel(pretrained_model='/hg/model/roberta-large-zh', pooling='cls') #, trained_model = model_path)
simcse.load_state_dict(torch.load(model))
simcse.to('cuda:0')
else:
simcse = BertModel.from_pretrained(model)
simcse.to('cuda:0')
self.model = simcse
tokenizer = BertTokenizer.from_pretrained(model)
else:
#TODO 直接传模型的情况,tok也要直接传进来
self.model = model
tokenizer = tokenizer
self.index = faiss.index_factory(dim,param,measure)
res = faiss.StandardGpuResources()
self.gpu_index = faiss.index_cpu_to_gpu(res,device = 0,index=self.index)
self.tok = tokenizer
def setup_faiss(self,vec_path = None, id2data_path = None,data = None,save_path = None):
"""
(vec,id2data)和data只能传一个
data: [str,str],作为检索库的数据
"""
# assert (vec and (not id2data)) or ((not vec) and id2data), "You must provide both of them or not at all"
if vec_path:
vec = np.load(vec_path)
self.id2data = load_json(id2data_path)
else:
# assert data,"You must provide either both vec and id2data or text"
vec = []
self.id2data = {}
for idx, text in enumerate(tqdm(data)):
# 主要为了和读取进来的字典统一,读取进来的key都变成字符串了
#TODO 是读取或者保存的问题吗?
self.id2data[str(idx)] = text
vec.append(get_embedding(text,model = self.model, tokenizer = self.tok, device = "cuda:0")[0])
vec = np.array(vec)
if save_path:
np.save(save_path + "/embedding.npy", vec)
save_json(save_path + "/id2data.json", self.id2data)
self.gpu_index.add(vec.squeeze())
# TODO 要考虑texts里有input和label的情况
# TODO Batch?
def search(self, texts: list, return_num = 3, threshold = 0):
"""
texts是list,每个元素是一句话(str)
返回原文+召回+距离
"""
res = []
for t in tqdm(texts):
embed = get_embedding(t,model = self.model, tokenizer = self.tok, device = "cuda:0")[0]
dis, idx = self.gpu_index.search(embed,return_num)
for i in range(len(dis[0])):
# if dis[0][i] == 1:
# continue
if threshold > 0:
if dis[0][i] >= threshold:
res.append("\t".join([t,self.id2data[str(idx[0][i])],str(dis[0][i])]))
else:
res.append("\t".join([t,self.id2data[str(idx[0][i])],str(dis[0][i])]))
return res
def postmatch(pattern,match):
hanzi = set()
for p in pattern:
if p.isalpha():
hanzi.add(p)
match = set(list(match))
if match & hanzi:
return True
else:
return False