forked from CSHaitao/JTR
-
Notifications
You must be signed in to change notification settings - Fork 0
/
model_search_cos.py
97 lines (82 loc) · 2.42 KB
/
model_search_cos.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from cProfile import label
import email
import sys
sys.path += ["./"]
import os
import time
import torch
import random
import faiss
import joblib
import logging
import argparse
import subprocess
import numpy as np
import pandas as pd
import pickle as pkl
from construct_tree import TreeInitialize,TreeNode
from torch import nn
from tqdm import tqdm, trange
from numba import jit
@jit(nopython=False)
def candidates_generator(embeddings,node_dict,topk): #qid
"""layer-wise retrieval algorithm in prediction."""
root = node_dict['0']
Q, A = root.children, []
layer = 0
embedding = embeddings.reshape(1,768).cpu().numpy()
while Q:
layer = layer+1
B = []
for node in Q:
if node.isleaf is True: #如果是叶节点
A.append(node)
B.append(node)
for node in B:
Q.remove(node)
if(len(Q) == 0):
break
probs = []
embeddings = []
for node in Q:
embeddings.append(node.embedding)
embeddings =np.array(embeddings)
probs = np.dot(embedding, embeddings.T).reshape(-1,).tolist()
prob_list = list(zip(Q, probs))
prob_list = sorted(prob_list, key=lambda x: x[1], reverse=True)
I = []
if len(prob_list) > topk:
for i in range(topk):
I.append(prob_list[i][0])
else:
for p in prob_list:
I.append(p[0])
Q = []
while I:
node = I.pop()
for child in node.children:
Q.append(child)
# A = []
# for i in range(topk):
# A.append(prob_list[i][0].val)
# return A
probs = []
leaf_embeddings = []
for leaf in A:
leaf_embeddings.append(leaf.embedding)
leaf_embeddings =np.array(leaf_embeddings)
probs = np.dot(embedding, leaf_embeddings.T).reshape(-1,).tolist()
prob_list = list(zip(A, probs))
prob_list = sorted(prob_list, key=lambda x: x[1], reverse=True)
A = []
for i in range(topk):
A.append(prob_list[i][0].val) #pid
return A
@numba.jit(nopython=True)
def metrics_count(embeddings,node_dict,topk): #(vtest, tree.root, 10, model
rank_list = []
size = embeddings.shape[0]
for i in range(size):
cands = candidates_generator(embeddings,node_dict,topk) #返回的节点
rank_list.append(cands)
return rank_list