-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathmodel.py
114 lines (100 loc) · 5.22 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import dgl
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class HGTLayer(nn.Module):
def __init__(self, in_dim, out_dim, num_types, num_relations, n_heads, dropout = 0.2, use_norm = False):
super(HGTLayer, self).__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.num_types = num_types
self.num_relations = num_relations
self.n_heads = n_heads
self.d_k = out_dim // n_heads
self.sqrt_dk = math.sqrt(self.d_k)
self.k_linears = nn.ModuleList()
self.q_linears = nn.ModuleList()
self.v_linears = nn.ModuleList()
self.a_linears = nn.ModuleList()
self.norms = nn.ModuleList()
self.use_norm = use_norm
for t in range(num_types):
self.k_linears.append(nn.Linear(in_dim, out_dim))
self.q_linears.append(nn.Linear(in_dim, out_dim))
self.v_linears.append(nn.Linear(in_dim, out_dim))
self.a_linears.append(nn.Linear(out_dim, out_dim))
if use_norm:
self.norms.append(nn.LayerNorm(out_dim))
self.relation_pri = nn.Parameter(torch.ones(num_relations, self.n_heads))
self.relation_att = nn.Parameter(torch.Tensor(num_relations, n_heads, self.d_k, self.d_k))
self.relation_msg = nn.Parameter(torch.Tensor(num_relations, n_heads, self.d_k, self.d_k))
self.skip = nn.Parameter(torch.ones(num_types))
self.drop = nn.Dropout(dropout)
nn.init.xavier_uniform_(self.relation_att)
nn.init.xavier_uniform_(self.relation_msg)
def edge_attention(self, edges):
etype = edges.data['id'][0]
relation_att = self.relation_att[etype]
relation_pri = self.relation_pri[etype]
relation_msg = self.relation_msg[etype]
key = torch.bmm(edges.src['k'].transpose(1,0), relation_att).transpose(1,0)
att = (edges.dst['q'] * key).sum(dim=-1) * relation_pri / self.sqrt_dk
val = torch.bmm(edges.src['v'].transpose(1,0), relation_msg).transpose(1,0)
return {'a': att, 'v': val}
def message_func(self, edges):
return {'v': edges.data['v'], 'a': edges.data['a']}
def reduce_func(self, nodes):
att = F.softmax(nodes.mailbox['a'], dim=1)
h = torch.sum(att.unsqueeze(dim = -1) * nodes.mailbox['v'], dim=1)
return {'t': h.view(-1, self.out_dim)}
def forward(self, G, inp_key, out_key):
node_dict, edge_dict = G.node_dict, G.edge_dict
for srctype, etype, dsttype in G.canonical_etypes:
k_linear = self.k_linears[node_dict[srctype]]
v_linear = self.v_linears[node_dict[srctype]]
q_linear = self.q_linears[node_dict[dsttype]]
G.nodes[srctype].data['k'] = k_linear(G.nodes[srctype].data[inp_key]).view(-1, self.n_heads, self.d_k)
G.nodes[srctype].data['v'] = v_linear(G.nodes[srctype].data[inp_key]).view(-1, self.n_heads, self.d_k)
G.nodes[dsttype].data['q'] = q_linear(G.nodes[dsttype].data[inp_key]).view(-1, self.n_heads, self.d_k)
G.apply_edges(func=self.edge_attention, etype=etype)
G.multi_update_all({etype : (self.message_func, self.reduce_func) \
for etype in edge_dict}, cross_reducer = 'mean')
for ntype in G.ntypes:
n_id = node_dict[ntype]
alpha = torch.sigmoid(self.skip[n_id])
trans_out = self.a_linears[n_id](G.nodes[ntype].data['t'])
trans_out = trans_out * alpha + G.nodes[ntype].data[inp_key] * (1-alpha)
if self.use_norm:
G.nodes[ntype].data[out_key] = self.drop(self.norms[n_id](trans_out))
else:
G.nodes[ntype].data[out_key] = self.drop(trans_out)
def __repr__(self):
return '{}(in_dim={}, out_dim={}, num_types={}, num_types={})'.format(
self.__class__.__name__, self.in_dim, self.out_dim,
self.num_types, self.num_relations)
class HGT(nn.Module):
def __init__(self, G, n_inp, n_hid, n_out, n_layers, n_heads, use_norm = True):
super(HGT, self).__init__()
self.gcs = nn.ModuleList()
self.n_inp = n_inp
self.n_hid = n_hid
self.n_out = n_out
self.n_layers = n_layers
self.adapt_ws = nn.ModuleList()
for t in range(len(G.node_dict)):
self.adapt_ws.append(nn.Linear(n_inp, n_hid))
for _ in range(n_layers):
self.gcs.append(HGTLayer(n_hid, n_hid, len(G.node_dict), len(G.edge_dict), n_heads, use_norm = use_norm))
self.out = nn.Linear(n_hid, n_out)
def forward(self, G, out_key):
for ntype in G.ntypes:
n_id = G.node_dict[ntype]
G.nodes[ntype].data['h'] = torch.tanh(self.adapt_ws[n_id](G.nodes[ntype].data['inp']))
for i in range(self.n_layers):
self.gcs[i](G, 'h', 'h')
return self.out(G.nodes[out_key].data['h'])
def __repr__(self):
return '{}(n_inp={}, n_hid={}, n_out={}, n_layers={})'.format(
self.__class__.__name__, self.n_inp, self.n_hid,
self.n_out, self.n_layers)