-
Notifications
You must be signed in to change notification settings - Fork 1
/
vbpr.py
118 lines (99 loc) · 4.14 KB
/
vbpr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
# coding: utf-8
# @email: [email protected]
r"""
VBPR -- Recommended version
################################################
Reference:
VBPR: Visual Bayesian Personalized Ranking from Implicit Feedback -Ruining He, Julian McAuley. AAAI'16
"""
import numpy as np
import os
import torch
import torch.nn as nn
from common.abstract_recommender import GeneralRecommender
from common.loss import BPRLoss, EmbLoss
from common.init import xavier_normal_initialization
import torch.nn.functional as F
class VBPR(GeneralRecommender):
r"""BPR is a basic matrix factorization model that be trained in the pairwise way."""
def __init__(self, config, dataloader):
super(VBPR, self).__init__(config, dataloader)
# load parameters info
self.u_embedding_size = self.i_embedding_size = config["embedding_size"]
self.reg_weight = config[
"reg_weight"
] # float32 type: the weight decay for l2 normalizaton
# define layers and loss
self.u_embedding = nn.Parameter(
nn.init.xavier_uniform_(
torch.empty(self.n_users, self.u_embedding_size * 2)
)
)
self.i_embedding = nn.Parameter(
nn.init.xavier_uniform_(torch.empty(self.n_items, self.i_embedding_size))
)
if self.v_feat is not None and self.t_feat is not None:
self.item_raw_features = torch.cat((self.t_feat, self.v_feat), -1)
elif self.v_feat is not None:
self.item_raw_features = self.v_feat
else:
self.item_raw_features = self.t_feat
self.item_linear = nn.Linear(
self.item_raw_features.shape[1], self.i_embedding_size
)
self.loss = BPRLoss()
self.reg_loss = EmbLoss()
# parameters initialization
self.apply(xavier_normal_initialization)
def get_user_embedding(self, user):
r"""Get a batch of user embedding tensor according to input user's id.
Args:
user (torch.LongTensor): The input tensor that contains user's id, shape: [batch_size, ]
Returns:
torch.FloatTensor: The embedding tensor of a batch of user, shape: [batch_size, embedding_size]
"""
return self.u_embedding[user, :]
def get_item_embedding(self, item):
r"""Get a batch of item embedding tensor according to input item's id.
Args:
item (torch.LongTensor): The input tensor that contains item's id, shape: [batch_size, ]
Returns:
torch.FloatTensor: The embedding tensor of a batch of item, shape: [batch_size, embedding_size]
"""
return self.item_embedding[item, :]
def forward(self, dropout=0.0):
item_embeddings = self.item_linear(self.item_raw_features)
item_embeddings = torch.cat((self.i_embedding, item_embeddings), -1)
user_e = F.dropout(self.u_embedding, dropout)
item_e = F.dropout(item_embeddings, dropout)
return user_e, item_e
def calculate_loss(self, interaction):
"""
loss on one batch
:param interaction:
batch data format: tensor(3, batch_size)
[0]: user list; [1]: positive items; [2]: negative items
:return:
"""
user = interaction[0]
pos_item = interaction[1]
neg_item = interaction[2]
user_embeddings, item_embeddings = self.forward()
user_e = user_embeddings[user, :]
pos_e = item_embeddings[pos_item, :]
# neg_e = self.get_item_embedding(neg_item)
neg_e = item_embeddings[neg_item, :]
pos_item_score, neg_item_score = torch.mul(user_e, pos_e).sum(dim=1), torch.mul(
user_e, neg_e
).sum(dim=1)
mf_loss = self.loss(pos_item_score, neg_item_score)
reg_loss = self.reg_loss(user_e, pos_e, neg_e)
loss = mf_loss + self.reg_weight * reg_loss
return loss
def full_sort_predict(self, interaction):
user = interaction[0]
user_embeddings, item_embeddings = self.forward()
user_e = user_embeddings[user, :]
all_item_e = item_embeddings
score = torch.matmul(user_e, all_item_e.transpose(0, 1))
return score