-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathmessage_passing.py
142 lines (113 loc) · 5.24 KB
/
message_passing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import inspect
import sys
import torch
# from torch_geometric.utils import scatter_
import torch_scatter
special_args = [
'edge_index', 'edge_index_i', 'edge_index_j', 'size', 'size_i', 'size_j'
]
__size_error_msg__ = ('All tensors which should get mapped to the same source '
'or target nodes must be of same size in dimension 0.')
is_python2 = sys.version_info[0] < 3
getargspec = inspect.getargspec if is_python2 else inspect.getfullargspec
def scatter_(name, src, index, dim=0, dim_size=None):
assert name in ['add', 'mean', 'max']
if name == 'max':
op = torch.finfo if torch.is_floating_point(src) else torch.iinfo
fill_value = op(src.dtype).min
else:
fill_value = 0
op = getattr(torch_scatter, 'scatter_{}'.format(name))
out = op(src, index, dim, None, dim_size)
if isinstance(out, tuple):
out = out[0]
if name == 'max':
out[out == fill_value] = 0
return out
class MessagePassing(torch.nn.Module):
def __init__(self, aggr='add', flow='source_to_target'):
super(MessagePassing, self).__init__()
self.aggr = aggr
assert self.aggr in ['add', 'mean', 'max']
self.flow = flow
assert self.flow in ['source_to_target', 'target_to_source']
self.__message_args__ = getargspec(self.message)[0][1:]
self.__special_args__ = [(i, arg)
for i, arg in enumerate(self.__message_args__)
if arg in special_args]
self.__message_args__ = [
arg for arg in self.__message_args__ if arg not in special_args
]
self.__update_args__ = getargspec(self.update)[0][2:]
def propagate(self, edge_index, size=None, **kwargs):
r"""The initial call to start propagating messages.
Args:
edge_index (Tensor): The indices of a general (sparse) assignment
matrix with shape :obj:`[N, M]` (can be directed or
undirected).
size (list or tuple, optional): The size :obj:`[N, M]` of the
assignment matrix. If set to :obj:`None`, the size is tried to
get automatically inferrred. (default: :obj:`None`)
**kwargs: Any additional data which is needed to construct messages
and to update node embeddings.
"""
size = [None, None] if size is None else list(size)
assert len(size) == 2
i, j = (0, 1) if self.flow == 'target_to_source' else (1, 0)
ij = {"_i": i, "_j": j}
message_args = []
for arg in self.__message_args__:
if arg[-2:] in ij.keys():
tmp = kwargs.get(arg[:-2], None)
if tmp is None: # pragma: no cover
message_args.append(tmp)
else:
idx = ij[arg[-2:]]
if isinstance(tmp, tuple) or isinstance(tmp, list):
assert len(tmp) == 2
if tmp[1 - idx] is not None:
if size[1 - idx] is None:
size[1 - idx] = tmp[1 - idx].size(0)
if size[1 - idx] != tmp[1 - idx].size(0):
raise ValueError(__size_error_msg__)
tmp = tmp[idx]
if size[idx] is None:
size[idx] = tmp.size(0)
if size[idx] != tmp.size(0):
raise ValueError(__size_error_msg__)
tmp = torch.index_select(tmp, 0, edge_index[idx])
message_args.append(tmp)
else:
message_args.append(kwargs.get(arg, None))
size[0] = size[1] if size[0] is None else size[0]
size[1] = size[0] if size[1] is None else size[1]
kwargs['edge_index'] = edge_index
kwargs['size'] = size
for (idx, arg) in self.__special_args__:
if arg[-2:] in ij.keys():
message_args.insert(idx, kwargs[arg[:-2]][ij[arg[-2:]]])
else:
message_args.insert(idx, kwargs[arg])
update_args = [kwargs[arg] for arg in self.__update_args__]
out = self.message(*message_args)
if self.aggr in ["add", "mean", "max"]:
out = scatter_(self.aggr, out, edge_index[i], dim_size=size[i])
else:
pass
out = self.update(out, *update_args)
return out
def message(self, x_j): # pragma: no cover
r"""Constructs messages in analogy to :math:`\phi_{\mathbf{\Theta}}`
for each edge in :math:`(i,j) \in \mathcal{E}`.
Can take any argument which was initially passed to :meth:`propagate`.
In addition, features can be lifted to the source node :math:`i` and
target node :math:`j` by appending :obj:`_i` or :obj:`_j` to the
variable name, *.e.g.* :obj:`x_i` and :obj:`x_j`."""
return x_j
def update(self, aggr_out): # pragma: no cover
r"""Updates node embeddings in analogy to
:math:`\gamma_{\mathbf{\Theta}}` for each node
:math:`i \in \mathcal{V}`.
Takes in the output of aggregation as first argument and any argument
which was initially passed to :meth:`propagate`."""
return aggr_out