-
Notifications
You must be signed in to change notification settings - Fork 0
/
Social_Encoders_v2.py
35 lines (28 loc) · 1.25 KB
/
Social_Encoders_v2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch
import torch.nn as nn
from torch.nn import init
import torch.nn.functional as F
class Social_Encoder(nn.Module):
def __init__(self, features, embed_dim, social_adj_lists, aggregator, base_model=None, cuda="cpu"):
super(Social_Encoder, self).__init__()
self.features = features
self.social_adj_lists = social_adj_lists
self.aggregator = aggregator
if base_model != None:
self.base_model = base_model
self.embed_dim = embed_dim
self.device = cuda
self.linear1 = nn.Linear(2*self.embed_dim, self.embed_dim) # nn.Linear(2 * self.embed_dim, self.embed_dim) #
self.dropout = nn.Dropout(p=0.5)
def forward(self, nodes):
to_neighs = []
for node in nodes:
to_neighs.append(self.social_adj_lists[int(node)])
neigh_feats = self.aggregator.forward(nodes, to_neighs) # user-user network
self_feats = self.features(torch.LongTensor(nodes.cpu().numpy())).to(self.device)
self_feats = self_feats.t()
# self-connection could be considered.
combined = torch.cat([self_feats, neigh_feats], dim=1)#
combined = F.relu(self.linear1(combined))
combined = self.dropout(combined)
return combined