-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathModel.py
154 lines (134 loc) · 6.79 KB
/
Model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import torch
import numpy as np
import torch.nn as nn
class Propagator(nn.Module):
def __init__(self, state_dim, dropout_rate):
super(Propagator, self).__init__()
self.reset_gate = nn.Sequential(
nn.Linear(state_dim*3, state_dim),
nn.Sigmoid(),
nn.Dropout(dropout_rate)
)
self.update_gate = nn.Sequential(
nn.Linear(state_dim*3, state_dim),
nn.Sigmoid(),
nn.Dropout(dropout_rate)
)
self.tansform = nn.Sequential(
nn.Linear(state_dim*3, state_dim),
nn.Tanh()
)
def forward(self, state_cur, a_in, a_out):
# state_cur: [batch_size, n_node, state_dim]
# a_in, a_out: [batch_size, n_node, state_dim]
a = torch.cat((a_in, a_out, state_cur), 2) # [batch_size, n_node, 3*state_dim]
r = self.reset_gate(a)
z = self.update_gate(a)
joined_input = torch.cat((a_in, a_out, r * state_cur), 2)
h_hat = self.tansform(joined_input)
output = (1 - z) * state_cur + z * h_hat
# [batch_size, n_node, state_dim]
return output
class GGNN(nn.Module):
def __init__(self, n_node, num_edge_types, opt):
super(GGNN, self).__init__()
self.n_node = n_node
self.num_edge_types = num_edge_types
self.state_dim = opt.state_dim
self.time_steps = opt.n_steps
self.use_bias = opt.use_bias
self.annotation_dim = opt.annotation_dim
self.use_cuda = opt.cuda
self.dropout_rate = opt.dropout_rate
# embedding for different type of edges. To use it as matrix, view each vector as [state_dim, state_dim]
self.edgeEmbed = nn.Embedding(num_edge_types, opt.state_dim * opt.state_dim, sparse=False)
if self.use_bias:
self.edgeBias = nn.Embedding(num_edge_types, opt.state_dim, sparse=False)
self.propagator = Propagator(self.state_dim, self.dropout_rate)
# output
self.attention = nn.Sequential(
nn.Linear(self.state_dim + self.annotation_dim, 1),
nn.Sigmoid()
)
self.out = nn.Sequential(
nn.Linear(self.state_dim + self.annotation_dim, 1),
nn.Tanh()
)
self.result = nn.Sigmoid()
def forward(self, prop_state, annotation, A):
# prop_state: [batch_size, n_node, state_dim]
# annotation: [batch_size, n_node, annotation_dim]
# A: [[[(edge_type, node_id)]]]
# len(A): batch_size, len(A[i]): n_node, len(A[i][j]): out degree of node_id j in graph i
for t in range(self.time_steps):
a_in = []
a_out = []
for i in range(len(A)): # have to process the graph one by one
# A[i]: List(List((edge_type, neighbour)))
a_in_i = [torch.zeros(self.state_dim).double() for k in range(self.n_node)]
a_out_i = [torch.zeros(self.state_dim).double() for k in range(self.n_node)]
if self.use_cuda:
a_in_i = [in_i.cuda() for in_i in a_in_i]
a_out_i = [out_i.cuda() for out_i in a_out_i]
for j in range(len(A[i])): # len(A[i]) should be n_node
# print(i, ': ', len(A[i][j]))
if len(A[i][j]) > 0:
# both edge_type and node_id should start from zero
vector_j = prop_state[i][j]
vector_j = vector_j.view(self.state_dim, 1)
for edge_type, neighbour_id in A[i][j]: # A[i][j]: (edge_type(out), neighbour)
edge_idx = torch.LongTensor([edge_type - 1])
if self.use_cuda:
edge_idx = edge_idx.cuda()
# [state_dim*state_dim]
edge_embed = self.edgeEmbed(edge_idx)
# [state_dim, state_dim]
edge_embed = edge_embed.view(self.state_dim, self.state_dim)
neighbour = prop_state[i][neighbour_id]
# [state_dim, 1]
neighbour = neighbour.view(self.state_dim, 1)
# print('neighbour: ', neighbour)
# [state_dim, 1]
product = torch.mm(edge_embed, neighbour)
# [state_dim]
product = product.view(self.state_dim)
if self.use_bias:
edge_idx = torch.LongTensor([edge_type - 1])
if self.use_cuda:
edge_idx = edge_idx.cuda()
product += self.edgeBias(edge_idx).view(self.state_dim)
a_out_i[j] += product
# compute incoming information for neighbour_id
edge_idx0 = torch.LongTensor([edge_type + self.num_edge_types // 2 - 1])
if self.use_cuda:
edge_idx0 = edge_idx0.cuda()
edge_embed0 = self.edgeEmbed(edge_idx0)
edge_embed0 = edge_embed0.view(self.state_dim, self.state_dim)
product0 = torch.mm(edge_embed0, vector_j)
product0 = product0.view(self.state_dim)
if self.use_bias:
edge_idx0 = torch.LongTensor([edge_type + self.num_edge_types // 2 - 1])
if self.use_cuda:
edge_idx0 = edge_idx0.cuda()
product0 += self.edgeBias(edge_idx0)\
.view(self.state_dim)
a_in_i[neighbour_id] += product0
# [n_node, state_dim]
a_in_i = torch.stack(a_in_i)
# [n_node, state_dim]
a_out_i = torch.stack(a_out_i)
a_in.append(a_in_i)
a_out.append(a_out_i)
# [batch_size, n_node, state_dim]
a_in = torch.stack(a_in)
a_out = torch.stack(a_out)
# print(a_in)
prop_state = self.propagator(prop_state, a_in, a_out)
join_state = torch.cat((prop_state, annotation), 2) # [batch_size, n_node, state_dim+annotation_dim]
atten = self.attention(join_state) # [batch_size, n_node, 1]
ou = self.out(join_state) # [batch_size, n_node, 1]
mul = atten * ou # [batch_size, n_node, 1]
mul = mul.view(-1, self.n_node) # [batch_size, n_node]
w_sum = torch.sum(mul, dim=1)
res = self.result(w_sum)
return res