-
Notifications
You must be signed in to change notification settings - Fork 22
/
Copy pathvqrec.py
262 lines (208 loc) · 10.9 KB
/
vqrec.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
import torch
import torch.nn as nn
import torch.nn.functional as F
from recbole.model.layers import TransformerEncoder
from recbole.model.abstract_recommender import SequentialRecommender
def log(t, eps = 1e-6):
return torch.log(t + eps)
def sample_gumbel(shape, device, dtype, eps=1e-6):
u = torch.empty(shape, device=device, dtype=dtype).uniform_(0, 1)
return -log(-log(u, eps), eps)
def sinkhorn_sorting_operator(r, n_iters=8):
n = r.shape[1]
for _ in range(n_iters):
r = r - torch.logsumexp(r, dim=2, keepdim=True)
r = r - torch.logsumexp(r, dim=1, keepdim=True)
return torch.exp(r)
def gumbel_sinkhorn(r, n_iters=8, temperature=0.7):
r = log(r)
gumbel = sample_gumbel(r.shape, r.device, r.dtype)
r = (r + gumbel) / temperature
return sinkhorn_sorting_operator(r, n_iters)
def differentiable_topk(x, k, temperature=1.):
*_, n, dim = x.shape
topk_tensors = []
for i in range(k):
is_last = i == (k - 1)
values, indices = (x / temperature).softmax(dim=-1).topk(1, dim=-1)
topks = torch.zeros_like(x).scatter_(-1, indices, values)
topk_tensors.append(topks)
if not is_last:
x.scatter_(-1, indices, float('-inf'))
topks = torch.cat(topk_tensors, dim=-1)
return topks.reshape(*_, k * n, dim)
class VQRec(SequentialRecommender):
def __init__(self, config, dataset):
super().__init__(config, dataset)
# VQRec args
self.code_dim = config['code_dim']
self.code_cap = config['code_cap']
self.pq_codes = dataset.pq_codes
self.temperature = config['temperature']
self.index_assignment_flag = False
self.sinkhorn_iter = config['sinkhorn_iter']
self.fake_idx_ratio = config['fake_idx_ratio']
self.train_stage = config['train_stage']
assert self.train_stage in [
'pretrain', 'inductive_ft'
], f'Unknown train stage: [{self.train_stage}]'
# load parameters info
self.n_layers = config['n_layers']
self.n_heads = config['n_heads']
self.hidden_size = config['hidden_size'] # same as embedding_size
self.inner_size = config['inner_size'] # the dimensionality in feed-forward layer
self.hidden_dropout_prob = config['hidden_dropout_prob']
self.attn_dropout_prob = config['attn_dropout_prob']
self.hidden_act = config['hidden_act']
self.layer_norm_eps = config['layer_norm_eps']
self.initializer_range = config['initializer_range']
self.loss_type = config['loss_type']
# define layers and loss
self.pq_code_embedding = nn.Embedding(
self.code_dim * (1 + self.code_cap), self.hidden_size, padding_idx=0)
self.reassigned_code_embedding = None
self.position_embedding = nn.Embedding(self.max_seq_length, self.hidden_size)
self.trm_encoder = TransformerEncoder(
n_layers=self.n_layers,
n_heads=self.n_heads,
hidden_size=self.hidden_size,
inner_size=self.inner_size,
hidden_dropout_prob=self.hidden_dropout_prob,
attn_dropout_prob=self.attn_dropout_prob,
hidden_act=self.hidden_act,
layer_norm_eps=self.layer_norm_eps
)
self.trans_matrix = nn.Parameter(torch.randn(self.code_dim, self.code_cap + 1, self.code_cap + 1))
self.LayerNorm = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
self.dropout = nn.Dropout(self.hidden_dropout_prob)
if self.loss_type == 'BPR':
raise NotImplementedError()
elif self.loss_type == 'CE':
self.loss_fct = nn.CrossEntropyLoss()
else:
raise NotImplementedError("Make sure 'loss_type' in ['CE']!")
# parameters initialization
self.apply(self._init_weights)
def _init_weights(self, module):
""" Initialize the weights """
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.initializer_range)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
def code_projection(self):
doubly_stochastic_matrix = gumbel_sinkhorn(torch.exp(self.trans_matrix), n_iters=self.sinkhorn_iter)
trans = differentiable_topk(doubly_stochastic_matrix.reshape(-1, self.code_cap + 1), 1)
trans = torch.ceil(trans.reshape(-1, self.code_cap + 1, self.code_cap + 1))
raw_embed = self.pq_code_embedding.weight.reshape(self.code_dim, self.code_cap + 1, -1)
trans_embed = torch.bmm(trans, raw_embed).reshape(-1, self.hidden_size)
return trans_embed
def forward(self, item_seq, item_seq_len):
position_ids = torch.arange(item_seq.size(1), dtype=torch.long, device=item_seq.device)
position_ids = position_ids.unsqueeze(0).expand_as(item_seq)
position_embedding = self.position_embedding(position_ids)
pq_code_seq = self.pq_codes[item_seq]
if self.index_assignment_flag:
pq_code_emb = F.embedding(pq_code_seq, self.reassigned_code_embedding, padding_idx=0).mean(dim=-2)
else:
pq_code_emb = self.pq_code_embedding(pq_code_seq).mean(dim=-2)
input_emb = pq_code_emb + position_embedding
input_emb = self.LayerNorm(input_emb)
input_emb = self.dropout(input_emb)
extended_attention_mask = self.get_attention_mask(item_seq)
trm_output = self.trm_encoder(input_emb, extended_attention_mask, output_all_encoded_layers=True)
output = trm_output[-1]
output = self.gather_indexes(output, item_seq_len - 1)
return output # [B H]
def calculate_item_emb(self):
if self.index_assignment_flag:
pq_code_emb = F.embedding(self.pq_codes, self.reassigned_code_embedding, padding_idx=0).mean(dim=-2)
else:
pq_code_emb = self.pq_code_embedding(self.pq_codes).mean(dim=-2)
return pq_code_emb # [B H]
def generate_fake_neg_item_emb(self, item_index):
rand_idx = torch.randint_like(input=item_index, high=self.code_cap)
# flatten pq codes
base_id = (torch.arange(self.code_dim).to(item_index.device) * (self.code_cap + 1)).unsqueeze(0)
rand_idx = rand_idx + base_id + 1
mask = torch.bernoulli(torch.full_like(item_index, self.fake_idx_ratio, dtype=torch.float))
fake_item_idx = torch.where(mask > 0, rand_idx, item_index)
return self.pq_code_embedding(fake_item_idx).mean(dim=-2)
def seq_item_contrastive_task(self, seq_output, same_pos_id, interaction):
pos_id = interaction['item_id']
pos_pq_code = self.pq_codes[pos_id]
if self.index_assignment_flag:
pos_items_emb = F.embedding(pos_pq_code, self.reassigned_code_embedding, padding_idx=0).mean(dim=-2)
else:
pos_items_emb = self.pq_code_embedding(pos_pq_code).mean(dim=-2)
pos_items_emb = F.normalize(pos_items_emb, dim=1)
pos_logits = (seq_output * pos_items_emb).sum(dim=1, keepdim=True) / self.temperature
pos_logits = torch.exp(pos_logits)
neg_logits = torch.matmul(seq_output, pos_items_emb.transpose(0, 1)) / self.temperature
neg_logits = torch.where(same_pos_id, torch.tensor([0], dtype=torch.float, device=same_pos_id.device), neg_logits)
neg_logits = torch.exp(neg_logits).sum(dim=1).reshape(-1, 1)
fake_item_emb = self.generate_fake_neg_item_emb(pos_pq_code)
fake_item_emb = F.normalize(fake_item_emb, dim=-1)
fake_logits = (seq_output * fake_item_emb).sum(dim=1, keepdim=True) / self.temperature
fake_logits = torch.exp(fake_logits)
loss = -torch.log(pos_logits / (neg_logits + fake_logits))
return loss.mean()
def pretrain(self, interaction):
item_seq = interaction[self.ITEM_SEQ]
item_seq_len = interaction[self.ITEM_SEQ_LEN]
seq_output = self.forward(item_seq, item_seq_len)
seq_output = F.normalize(seq_output, dim=1)
# Remove sequences with the same next item
pos_id = interaction['item_id']
same_pos_id = (pos_id.unsqueeze(1) == pos_id.unsqueeze(0))
same_pos_id = torch.logical_xor(same_pos_id, torch.eye(pos_id.shape[0], dtype=torch.bool, device=pos_id.device))
return self.seq_item_contrastive_task(seq_output, same_pos_id, interaction)
def calculate_loss(self, interaction):
if self.train_stage == 'pretrain':
return self.pretrain(interaction)
item_seq = interaction[self.ITEM_SEQ]
item_seq_len = interaction[self.ITEM_SEQ_LEN]
seq_output = self.forward(item_seq, item_seq_len)
pos_items = interaction[self.POS_ITEM_ID]
if self.loss_type == 'BPR':
raise NotImplementedError()
else: # self.loss_type = 'CE'
test_item_emb = self.calculate_item_emb()
if self.temperature > 0:
seq_output = F.normalize(seq_output, dim=-1)
test_item_emb = F.normalize(test_item_emb, dim=-1)
logits = torch.matmul(seq_output, test_item_emb.transpose(0, 1))
if self.temperature > 0:
logits /= self.temperature
loss = self.loss_fct(logits, pos_items)
return loss
def full_sort_predict(self, interaction):
item_seq = interaction[self.ITEM_SEQ]
item_seq_len = interaction[self.ITEM_SEQ_LEN]
seq_output = self.forward(item_seq, item_seq_len)
test_items_emb = self.calculate_item_emb()
if self.temperature > 0:
seq_output = F.normalize(seq_output, dim=-1)
test_items_emb = F.normalize(test_items_emb, dim=-1)
scores = torch.matmul(seq_output, test_items_emb.transpose(0, 1)) # [B n_items]
return scores
def predict_on_subsets(self, interaction, idxs):
item_seq = interaction[self.ITEM_SEQ]
item_seq_len = interaction[self.ITEM_SEQ_LEN]
self.pq_codes = self.pq_codes.to(self.device)
seq_output = self.forward(item_seq, item_seq_len)
seq_output = F.normalize(seq_output, dim=-1)
pos_pq_code = self.pq_codes[torch.LongTensor(idxs).to(self.device)]
pos_items_emb = self.pq_code_embedding(pos_pq_code).mean(dim=-2)
pos_items_emb = F.normalize(pos_items_emb, dim=1)
candidate_scores = (seq_output.unsqueeze(1) * pos_items_emb).sum(dim=-1) # (B, C)
candidate_scores = candidate_scores.cpu().numpy()
scores = torch.full((item_seq.shape[0], self.n_items), -10000.)
for i in range(idxs.shape[0]):
for j in range(idxs.shape[1]):
scores[i, idxs[i, j]] = float(candidate_scores[i, j])
return scores