Skip to content

Commit 2f71bc5

Browse files
KounianhuaDuzhjwy9343yzh119
authored
[Example] Neural Graph Collaborative Filtering (NGCF). (dmlc#2612)
* ngcf * ngcf * update * ngcf * ngcf * remove data * update * data Co-authored-by: zhjwy9343 <[email protected]> Co-authored-by: Zihao Ye <[email protected]>
1 parent 469088e commit 2f71bc5

11 files changed

+812
-1
lines changed

examples/README.md

+8-1
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,11 @@ The folder contains example implementations of selected research papers related
4545
| [GNN-FiLM: Graph Neural Networks with Feature-wise Linear Modulation](#gnnfilm) | :heavy_check_mark: | | | | |
4646
| [Hierarchical Graph Pooling with Structure Learning](#hgp-sl) | | | :heavy_check_mark: | | |
4747
| [Graph Representation Learning via Hard and Channel-Wise Attention Networks](#hardgat) |:heavy_check_mark: | | | | |
48+
| [Neural Graph Collaborative Filtering](#ngcf) | | :heavy_check_mark: | | | |
4849
| [Graph Cross Networks with Vertex Infomax Pooling](#gxn) | | | :heavy_check_mark: | | |
4950
| [Towards Deeper Graph Neural Networks](#dagnn) | :heavy_check_mark: | | | | |
5051

52+
5153
## 2020
5254

5355
- <a name="grand"></a> Feng et al. Graph Random Neural Network for Semi-Supervised Learning on Graphs. [Paper link](https://arxiv.org/abs/2005.11079).
@@ -70,7 +72,7 @@ The folder contains example implementations of selected research papers related
7072
- Example code: [Molecule embedding](https://github.com/awslabs/dgl-lifesci/tree/master/examples/molecule_embeddings), [PyTorch for custom data](https://github.com/awslabs/dgl-lifesci/tree/master/examples/property_prediction/csv_data_configuration)
7173
- Tags: molecules, graph classification, unsupervised learning, self-supervised learning, molecular property prediction
7274

73-
- <a name="GNN-FiLM"></a> Marc Brockschmidt. GNN-FiLM: Graph Neural Networks with Feature-wise Linear Modulation. [Paper link](https://arxiv.org/abs/1906.12192).
75+
- <a name="gnnfilm"></a> Marc Brockschmidt. GNN-FiLM: Graph Neural Networks with Feature-wise Linear Modulation. [Paper link](https://arxiv.org/abs/1906.12192).
7476
- Example code: [Pytorch](../examples/pytorch/GNN-FiLM)
7577
- Tags: multi-relational graphs, hypernetworks, GNN architectures
7678

@@ -168,6 +170,11 @@ The folder contains example implementations of selected research papers related
168170
- Example code: [Pytorch](../examples/pytorch/hardgat)
169171
- Tags: node classification, graph attention
170172

173+
- <a name='ngcf'></a> Wang, Xiang, et al. Neural Graph Collaborative Filtering. [Paper link](https://arxiv.org/abs/1905.08108).
174+
- Example code: [Pytorch](../examples/pytorch/NGCF)
175+
- Tags: Collaborative Filtering, Recommendation, Graph Neural Network
176+
177+
171178
## 2018
172179

173180
- <a name="dgmg"></a> Li et al. Learning Deep Generative Models of Graphs. [Paper link](https://arxiv.org/abs/1803.03324).
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
wget https://s3.us-west-2.amazonaws.com/dgl-data/dataset/amazon-book.zip
2+
unzip amazon-book.zip
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
wget https://s3.us-west-2.amazonaws.com/dgl-data/dataset/gowalla.zip
2+
unzip gowalla.zip

examples/pytorch/NGCF/NGCF/main.py

+110
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
import torch
2+
import torch.optim as optim
3+
from model import NGCF
4+
from utility.batch_test import *
5+
from utility.helper import early_stopping
6+
from time import time
7+
import os
8+
9+
def main(args):
10+
# Step 1: Prepare graph data and device ================================================================= #
11+
if args.gpu >= 0 and torch.cuda.is_available():
12+
device = 'cuda:{}'.format(args.gpu)
13+
else:
14+
device = 'cpu'
15+
16+
g=data_generator.g
17+
g=g.to(device)
18+
19+
# Step 2: Create model and training components=========================================================== #
20+
model = NGCF(g, args.embed_size, args.layer_size, args.mess_dropout, args.regs[0]).to(device)
21+
optimizer = optim.Adam(model.parameters(), lr=args.lr)
22+
23+
# Step 3: training epoches ============================================================================== #
24+
n_batch = data_generator.n_train // args.batch_size + 1
25+
t0 = time()
26+
cur_best_pre_0, stopping_step = 0, 0
27+
loss_loger, pre_loger, rec_loger, ndcg_loger, hit_loger = [], [], [], [], []
28+
for epoch in range(args.epoch):
29+
t1 = time()
30+
loss, mf_loss, emb_loss = 0., 0., 0.
31+
for idx in range(n_batch):
32+
users, pos_items, neg_items = data_generator.sample()
33+
u_g_embeddings, pos_i_g_embeddings, neg_i_g_embeddings = model(g, 'user', 'item', users,
34+
pos_items,
35+
neg_items)
36+
37+
batch_loss, batch_mf_loss, batch_emb_loss = model.create_bpr_loss(u_g_embeddings,
38+
pos_i_g_embeddings,
39+
neg_i_g_embeddings)
40+
optimizer.zero_grad()
41+
batch_loss.backward()
42+
optimizer.step()
43+
44+
loss += batch_loss
45+
mf_loss += batch_mf_loss
46+
emb_loss += batch_emb_loss
47+
48+
49+
if (epoch + 1) % 10 != 0:
50+
if args.verbose > 0 and epoch % args.verbose == 0:
51+
perf_str = 'Epoch %d [%.1fs]: train==[%.5f=%.5f + %.5f]' % (
52+
epoch, time() - t1, loss, mf_loss, emb_loss)
53+
print(perf_str)
54+
continue #end the current epoch and move to the next epoch, let the following evaluation run every 10 epoches
55+
56+
#evaluate the model every 10 epoches
57+
t2 = time()
58+
users_to_test = list(data_generator.test_set.keys())
59+
ret = test(model, g, users_to_test)
60+
t3 = time()
61+
62+
loss_loger.append(loss)
63+
rec_loger.append(ret['recall'])
64+
pre_loger.append(ret['precision'])
65+
ndcg_loger.append(ret['ndcg'])
66+
hit_loger.append(ret['hit_ratio'])
67+
68+
if args.verbose > 0:
69+
perf_str = 'Epoch %d [%.1fs + %.1fs]: train==[%.5f=%.5f + %.5f], recall=[%.5f, %.5f], ' \
70+
'precision=[%.5f, %.5f], hit=[%.5f, %.5f], ndcg=[%.5f, %.5f]' % \
71+
(epoch, t2 - t1, t3 - t2, loss, mf_loss, emb_loss, ret['recall'][0], ret['recall'][-1],
72+
ret['precision'][0], ret['precision'][-1], ret['hit_ratio'][0], ret['hit_ratio'][-1],
73+
ret['ndcg'][0], ret['ndcg'][-1])
74+
print(perf_str)
75+
76+
cur_best_pre_0, stopping_step, should_stop = early_stopping(ret['recall'][0], cur_best_pre_0,
77+
stopping_step, expected_order='acc', flag_step=5)
78+
79+
# early stop
80+
if should_stop == True:
81+
break
82+
83+
if ret['recall'][0] == cur_best_pre_0 and args.save_flag == 1:
84+
torch.save(model.state_dict(), args.weights_path + args.model_name)
85+
print('save the weights in path: ', args.weights_path + args.model_name)
86+
87+
recs = np.array(rec_loger)
88+
pres = np.array(pre_loger)
89+
ndcgs = np.array(ndcg_loger)
90+
hit = np.array(hit_loger)
91+
92+
best_rec_0 = max(recs[:, 0])
93+
idx = list(recs[:, 0]).index(best_rec_0)
94+
95+
final_perf = "Best Iter=[%d]@[%.1f]\trecall=[%s], precision=[%s], hit=[%s], ndcg=[%s]" % \
96+
(idx, time() - t0, '\t'.join(['%.5f' % r for r in recs[idx]]),
97+
'\t'.join(['%.5f' % r for r in pres[idx]]),
98+
'\t'.join(['%.5f' % r for r in hit[idx]]),
99+
'\t'.join(['%.5f' % r for r in ndcgs[idx]]))
100+
print(final_perf)
101+
102+
if __name__ == '__main__':
103+
if not os.path.exists(args.weights_path):
104+
os.mkdir(args.weights_path)
105+
args.mess_dropout = eval(args.mess_dropout)
106+
args.layer_size = eval(args.layer_size)
107+
args.regs = eval(args.regs)
108+
print(args)
109+
main(args)
110+

examples/pytorch/NGCF/NGCF/model.py

+117
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
import dgl.function as fn
5+
6+
class NGCFLayer(nn.Module):
7+
def __init__(self, in_size, out_size, norm_dict, dropout):
8+
super(NGCFLayer, self).__init__()
9+
self.in_size = in_size
10+
self.out_size = out_size
11+
12+
#weights for different types of messages
13+
self.W1 = nn.Linear(in_size, out_size, bias = True)
14+
self.W2 = nn.Linear(in_size, out_size, bias = True)
15+
16+
#leaky relu
17+
self.leaky_relu = nn.LeakyReLU(0.2)
18+
19+
#dropout layer
20+
self.dropout = nn.Dropout(dropout)
21+
22+
#initialization
23+
torch.nn.init.xavier_uniform_(self.W1.weight)
24+
torch.nn.init.constant_(self.W1.bias, 0)
25+
torch.nn.init.xavier_uniform_(self.W2.weight)
26+
torch.nn.init.constant_(self.W2.bias, 0)
27+
28+
#norm
29+
self.norm_dict = norm_dict
30+
31+
def forward(self, g, feat_dict):
32+
33+
funcs = {} #message and reduce functions dict
34+
#for each type of edges, compute messages and reduce them all
35+
for srctype, etype, dsttype in g.canonical_etypes:
36+
if srctype == dsttype: #for self loops
37+
messages = self.W1(feat_dict[srctype])
38+
g.nodes[srctype].data[etype] = messages #store in ndata
39+
funcs[(srctype, etype, dsttype)] = (fn.copy_u(etype, 'm'), fn.sum('m', 'h')) #define message and reduce functions
40+
else:
41+
src, dst = g.edges(etype=(srctype, etype, dsttype))
42+
norm = self.norm_dict[(srctype, etype, dsttype)]
43+
messages = norm * (self.W1(feat_dict[srctype][src]) + self.W2(feat_dict[srctype][src]*feat_dict[dsttype][dst])) #compute messages
44+
g.edges[(srctype, etype, dsttype)].data[etype] = messages #store in edata
45+
funcs[(srctype, etype, dsttype)] = (fn.copy_e(etype, 'm'), fn.sum('m', 'h')) #define message and reduce functions
46+
47+
g.multi_update_all(funcs, 'sum') #update all, reduce by first type-wisely then across different types
48+
feature_dict={}
49+
for ntype in g.ntypes:
50+
h = self.leaky_relu(g.nodes[ntype].data['h']) #leaky relu
51+
h = self.dropout(h) #dropout
52+
h = F.normalize(h,dim=1,p=2) #l2 normalize
53+
feature_dict[ntype] = h
54+
return feature_dict
55+
56+
class NGCF(nn.Module):
57+
def __init__(self, g, in_size, layer_size, dropout, lmbd=1e-5):
58+
super(NGCF, self).__init__()
59+
self.lmbd = lmbd
60+
self.norm_dict = dict()
61+
for srctype, etype, dsttype in g.canonical_etypes:
62+
src, dst = g.edges(etype=(srctype, etype, dsttype))
63+
dst_degree = g.in_degrees(dst, etype=(srctype, etype, dsttype)).float() #obtain degrees
64+
src_degree = g.out_degrees(src, etype=(srctype, etype, dsttype)).float()
65+
norm = torch.pow(src_degree * dst_degree, -0.5).unsqueeze(1) #compute norm
66+
self.norm_dict[(srctype, etype, dsttype)] = norm
67+
68+
self.layers = nn.ModuleList()
69+
self.layers.append(
70+
NGCFLayer(in_size, layer_size[0], self.norm_dict, dropout[0])
71+
)
72+
self.num_layers = len(layer_size)
73+
for i in range(self.num_layers-1):
74+
self.layers.append(
75+
NGCFLayer(layer_size[i], layer_size[i+1], self.norm_dict, dropout[i+1])
76+
)
77+
self.initializer = nn.init.xavier_uniform_
78+
79+
#embeddings for different types of nodes
80+
self.feature_dict = nn.ParameterDict({
81+
ntype: nn.Parameter(self.initializer(torch.empty(g.num_nodes(ntype), in_size))) for ntype in g.ntypes
82+
})
83+
84+
def create_bpr_loss(self, users, pos_items, neg_items):
85+
pos_scores = (users * pos_items).sum(1)
86+
neg_scores = (users * neg_items).sum(1)
87+
88+
mf_loss = nn.LogSigmoid()(pos_scores - neg_scores).mean()
89+
mf_loss = -1 * mf_loss
90+
91+
regularizer = (torch.norm(users) ** 2 + torch.norm(pos_items) ** 2 + torch.norm(neg_items) ** 2) / 2
92+
emb_loss = self.lmbd * regularizer / users.shape[0]
93+
94+
return mf_loss + emb_loss, mf_loss, emb_loss
95+
96+
def rating(self, u_g_embeddings, pos_i_g_embeddings):
97+
return torch.matmul(u_g_embeddings, pos_i_g_embeddings.t())
98+
99+
def forward(self, g,user_key, item_key, users, pos_items, neg_items):
100+
h_dict = {ntype : self.feature_dict[ntype] for ntype in g.ntypes}
101+
#obtain features of each layer and concatenate them all
102+
user_embeds = []
103+
item_embeds = []
104+
user_embeds.append(h_dict[user_key])
105+
item_embeds.append(h_dict[item_key])
106+
for layer in self.layers:
107+
h_dict = layer(g, h_dict)
108+
user_embeds.append(h_dict[user_key])
109+
item_embeds.append(h_dict[item_key])
110+
user_embd = torch.cat(user_embeds, 1)
111+
item_embd = torch.cat(item_embeds, 1)
112+
113+
u_g_embeddings = user_embd[users, :]
114+
pos_i_g_embeddings = item_embd[pos_items, :]
115+
neg_i_g_embeddings = item_embd[neg_items, :]
116+
117+
return u_g_embeddings, pos_i_g_embeddings, neg_i_g_embeddings

0 commit comments

Comments
 (0)