-
Notifications
You must be signed in to change notification settings - Fork 2.6k
/
core_model.py
120 lines (102 loc) · 4.87 KB
/
core_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# pylint: disable=g-bad-file-header
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Core learned graph net model."""
import collections
import functools
import sonnet as snt
import tensorflow.compat.v1 as tf
EdgeSet = collections.namedtuple('EdgeSet', ['name', 'features', 'senders',
'receivers'])
MultiGraph = collections.namedtuple('Graph', ['node_features', 'edge_sets'])
class GraphNetBlock(snt.AbstractModule):
"""Multi-Edge Interaction Network with residual connections."""
def __init__(self, model_fn, name='GraphNetBlock'):
super(GraphNetBlock, self).__init__(name=name)
self._model_fn = model_fn
def _update_edge_features(self, node_features, edge_set):
"""Aggregrates node features, and applies edge function."""
sender_features = tf.gather(node_features, edge_set.senders)
receiver_features = tf.gather(node_features, edge_set.receivers)
features = [sender_features, receiver_features, edge_set.features]
with tf.variable_scope(edge_set.name+'_edge_fn'):
return self._model_fn()(tf.concat(features, axis=-1))
def _update_node_features(self, node_features, edge_sets):
"""Aggregrates edge features, and applies node function."""
num_nodes = tf.shape(node_features)[0]
features = [node_features]
for edge_set in edge_sets:
features.append(tf.math.unsorted_segment_sum(edge_set.features,
edge_set.receivers,
num_nodes))
with tf.variable_scope('node_fn'):
return self._model_fn()(tf.concat(features, axis=-1))
def _build(self, graph):
"""Applies GraphNetBlock and returns updated MultiGraph."""
# apply edge functions
new_edge_sets = []
for edge_set in graph.edge_sets:
updated_features = self._update_edge_features(graph.node_features,
edge_set)
new_edge_sets.append(edge_set._replace(features=updated_features))
# apply node function
new_node_features = self._update_node_features(graph.node_features,
new_edge_sets)
# add residual connections
new_node_features += graph.node_features
new_edge_sets = [es._replace(features=es.features + old_es.features)
for es, old_es in zip(new_edge_sets, graph.edge_sets)]
return MultiGraph(new_node_features, new_edge_sets)
class EncodeProcessDecode(snt.AbstractModule):
"""Encode-Process-Decode GraphNet model."""
def __init__(self,
output_size,
latent_size,
num_layers,
message_passing_steps,
name='EncodeProcessDecode'):
super(EncodeProcessDecode, self).__init__(name=name)
self._latent_size = latent_size
self._output_size = output_size
self._num_layers = num_layers
self._message_passing_steps = message_passing_steps
def _make_mlp(self, output_size, layer_norm=True):
"""Builds an MLP."""
widths = [self._latent_size] * self._num_layers + [output_size]
network = snt.nets.MLP(widths, activate_final=False)
if layer_norm:
network = snt.Sequential([network, snt.LayerNorm()])
return network
def _encoder(self, graph):
"""Encodes node and edge features into latent features."""
with tf.variable_scope('encoder'):
node_latents = self._make_mlp(self._latent_size)(graph.node_features)
new_edges_sets = []
for edge_set in graph.edge_sets:
latent = self._make_mlp(self._latent_size)(edge_set.features)
new_edges_sets.append(edge_set._replace(features=latent))
return MultiGraph(node_latents, new_edges_sets)
def _decoder(self, graph):
"""Decodes node features from graph."""
with tf.variable_scope('decoder'):
decoder = self._make_mlp(self._output_size, layer_norm=False)
return decoder(graph.node_features)
def _build(self, graph):
"""Encodes and processes a multigraph, and returns node features."""
model_fn = functools.partial(self._make_mlp, output_size=self._latent_size)
latent_graph = self._encoder(graph)
for _ in range(self._message_passing_steps):
latent_graph = GraphNetBlock(model_fn)(latent_graph)
return self._decoder(latent_graph)