forked from BUPT-GAMMA/OpenHGNN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Micro_layer.py
166 lines (149 loc) · 8.61 KB
/
Micro_layer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import torch
from torch import nn
import torch.nn.functional as F
import dgl
from dgl.nn.pytorch.softmax import edge_softmax
import dgl.function as fn
class MicroConv(nn.Module):
"""
Parameters
----------
in_feats : pair of ints
Input feature size.
out_feats : int
Output feature size.
num_heads : int
Number of heads in Multi-Head Attention.
dropout : float, optional
Dropout rate, defaults: 0.
negative_slope : float, optional
Negative slope rate, defaults: 0.2.
"""
def __init__(self, in_feats: tuple, out_feats: int, num_heads: int, dropout: float = 0.0, negative_slope: float = 0.2):
super(MicroConv, self).__init__()
self._in_src_feats, self._in_dst_feats = in_feats[0], in_feats[1]
self._out_feats = out_feats
self._num_heads = num_heads
self.dropout = nn.Dropout(dropout)
self.leaky_relu = nn.LeakyReLU(negative_slope)
def forward(self, graph: dgl.DGLHeteroGraph, feat: tuple, dst_node_transformation_weight: nn.Parameter,
src_node_transformation_weight: nn.Parameter, src_nodes_attention_weight: nn.Parameter):
r"""Compute graph attention network layer.
Parameters
----------
graph : specific relational DGLHeteroGraph
feat : pair of torch.Tensor
The pair contains two tensors of shape (N_{in}, D_{in_{src}})` and (N_{out}, D_{in_{dst}}).
dst_node_transformation_weight: Parameter (input_dst_dim, n_heads * hidden_dim)
src_node_transformation_weight: Parameter (input_src_dim, n_heads * hidden_dim)
src_nodes_attention_weight: Parameter (n_heads, 2 * hidden_dim)
Returns
-------
torch.Tensor, shape (N, H, D_out)` where H is the number of heads, and D_out is size of output feature.
"""
graph = graph.local_var()
# Tensor, (N_src, input_src_dim)
feat_src = self.dropout(feat[0])
# Tensor, (N_dst, input_dst_dim)
feat_dst = self.dropout(feat[1])
# Tensor, (N_src, n_heads, hidden_dim) -> (N_src, input_src_dim) * (input_src_dim, n_heads * hidden_dim)
feat_src = torch.matmul(feat_src, src_node_transformation_weight).view(-1, self._num_heads, self._out_feats)
# Tensor, (N_dst, n_heads, hidden_dim) -> (N_dst, input_dst_dim) * (input_dst_dim, n_heads * hidden_dim)
feat_dst = torch.matmul(feat_dst, dst_node_transformation_weight).view(-1, self._num_heads, self._out_feats)
# first decompose the weight vector into [a_l || a_r], then
# a^T [Wh_i || Wh_j] = a_l Wh_i + a_r Wh_j, This implementation is much efficient
# Tensor, (N_dst, n_heads, 1), (N_dst, n_heads, hidden_dim) * (n_heads, hidden_dim)
e_dst = (feat_dst * src_nodes_attention_weight[:, :self._out_feats]).sum(dim=-1, keepdim=True)
# Tensor, (N_src, n_heads, 1), (N_src, n_heads, hidden_dim) * (n_heads, hidden_dim)
e_src = (feat_src * src_nodes_attention_weight[:, self._out_feats:]).sum(dim=-1, keepdim=True)
# (N_src, n_heads, hidden_dim), (N_src, n_heads, 1)
graph.srcdata.update({'ft': feat_src, 'e_src': e_src})
# (N_dst, n_heads, 1)
graph.dstdata.update({'e_dst': e_dst})
# compute edge attention, e_src and e_dst are a_src * Wh_src and a_dst * Wh_dst respectively.
graph.apply_edges(fn.u_add_v('e_src', 'e_dst', 'e'))
# shape (edges_num, heads, 1)
e = self.leaky_relu(graph.edata.pop('e'))
# compute softmax
graph.edata['a'] = edge_softmax(graph, e)
graph.update_all(fn.u_mul_e('ft', 'a', 'msg'), fn.sum('msg', 'ft'))
# (N_dst, n_heads * hidden_dim), (N_dst, n_heads, hidden_dim) reshape
dst_features = graph.dstdata.pop('ft').reshape(-1, self._num_heads * self._out_feats)
dst_features = F.relu(dst_features)
return dst_features
class MacroConv(nn.Module):
"""
Parameters
----------
in_feats : int
Input feature size.
out_feats : int
Output feature size.
num_heads : int
Number of heads in Multi-Head Attention.
dropout : float, optional
Dropout rate, defaults: ``0``.
"""
def __init__(self, in_feats: int, out_feats: int, num_heads: int, dropout: float = 0.0, negative_slope: float = 0.2):
super(MacroConv, self).__init__()
self._in_feats = in_feats
self._out_feats = out_feats
self._num_heads = num_heads
self.dropout = nn.Dropout(dropout)
self.leaky_relu = nn.LeakyReLU(negative_slope)
def forward(self, graph, input_dst: dict, relation_features: dict, edge_type_transformation_weight: nn.ParameterDict,
central_node_transformation_weight: nn.ParameterDict, edge_types_attention_weight: nn.Parameter):
"""
:param graph: dgl.DGLHeteroGraph
:param input_dst: dict: {ntype: features}
:param relation_features: dict: {(stype, etype, dtype): features}
:param edge_type_transformation_weight: ParameterDict {etype: (n_heads * hidden_dim, n_heads * hidden_dim)}
:param central_node_transformation_weight: ParameterDict {ntype: (input_central_node_dim, n_heads * hidden_dim)}
:param edge_types_attention_weight: Parameter (n_heads, 2 * hidden_dim)
:return: output_features: dict, {"type": features}
"""
output_features = {}
for ntype in input_dst:
if graph.number_of_dst_nodes(ntype) != 0:
# (N_ntype, self._in_feats)
central_node_feature = input_dst[ntype]
# (N_ntype, n_heads, hidden_dim)
central_node_feature = torch.matmul(central_node_feature, central_node_transformation_weight[ntype]). \
view(-1, self._num_heads, self._out_feats)
types_features = []
for relation_tuple in relation_features:
stype, etype, dtype = relation_tuple
if dtype == ntype:
# (N_ntype, n_heads * hidden_dim)
types_features.append(torch.matmul(relation_features[relation_tuple],
edge_type_transformation_weight[etype]))
# TODO: another aggregation equation
# relation_features[relation_tuple] -> (N_ntype, n_heads * hidden_dim), (N_ntype, n_heads, hidden_dim)
# edge_type_transformation_weight -> (n_heads, hidden_dim, hidden_dim)
# each element -> (N_ntype, n_heads * hidden_dim)
# types_features.append(torch.einsum('abc,bcd->abd', relation_features[relation_tuple].reshape(-1, self._num_heads, self._out_feats),
# edge_type_transformation_weight[etype]).flatten(start_dim=1))
# Tensor, (relations_num, N_ntype, n_heads * hidden_dim)
types_features = torch.stack(types_features, dim=0)
# if the central node only interacts with one relation, then the attention score is 1,
# directly assgin the transformed feature to the central node
if types_features.shape[0] == 1:
output_features[ntype] = types_features.squeeze(dim=0)
else:
# Tensor, (relations_num, N_ntype, n_heads, hidden_dim)
types_features = types_features.view(types_features.shape[0], -1, self._num_heads, self._out_feats)
# (relations_num, N_ntype, n_heads, hidden_dim)
stacked_central_features = torch.stack([central_node_feature for _ in range(types_features.shape[0])],
dim=0)
# (relations_num, N_ntype, n_heads, 2 * hidden_dim)
concat_features = torch.cat((stacked_central_features, types_features), dim=-1)
# (relations_num, N_ntype, n_heads, 1) -> (n_heads, 2 * hidden_dim) * (relations_num, N_ntype, n_heads, 2 * hidden_dim)
attention_scores = (edge_types_attention_weight * concat_features).sum(dim=-1, keepdim=True)
attention_scores = self.leaky_relu(attention_scores)
attention_scores = F.softmax(attention_scores, dim=0)
# (N_ntype, n_heads, hidden_dim)
output_feature = (attention_scores * types_features).sum(dim=0)
output_feature = self.dropout(output_feature)
output_feature = output_feature.reshape(-1, self._num_heads * self._out_feats)
output_features[ntype] = output_feature
return output_features