-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathactorPPOmiddR.py
69 lines (55 loc) · 3.42 KB
/
actorPPOmiddR.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
# -*- coding: utf-8 -*-
import tensorflow as tf
from tensorflow import keras
from keras import regularizers
class myModel(tf.keras.Model):
def __init__(self, hparams, hidden_init_actor, kernel_init_actor):
super(myModel, self).__init__()
self.hparams = hparams
# Define layers here
self.Message = tf.keras.models.Sequential()
self.Message.add(keras.layers.Dense(self.hparams['link_state_dim'],
kernel_initializer=hidden_init_actor,
activation=tf.nn.selu, name="FirstLayer"))
self.Update = tf.keras.layers.GRUCell(self.hparams['link_state_dim'], dtype=tf.float32)
self.Readout = tf.keras.models.Sequential()
self.Readout.add(keras.layers.Dense(self.hparams['readout_units'],
activation=tf.nn.selu,
kernel_initializer=hidden_init_actor,
kernel_regularizer=regularizers.l2(hparams['l2']),
name="Readout1"))
#self.Readout.add(keras.layers.Dropout(rate=hparams['dropout_rate']))
self.Readout.add(keras.layers.Dense(self.hparams['readout_units'],
activation=tf.nn.selu,
kernel_initializer=hidden_init_actor,
kernel_regularizer=regularizers.l2(hparams['l2']),
name="Readout2"))
#self.Readout.add(keras.layers.Dropout(rate=hparams['dropout_rate']))
self.Readout.add(keras.layers.Dense(1, kernel_initializer=kernel_init_actor, name="Readout3"))
def build(self, input_shape=None):
# Create the weights of the layer
self.Message.build(input_shape=tf.TensorShape([None, self.hparams['link_state_dim']*2]))
self.Update.build(input_shape=tf.TensorShape([None,self.hparams['link_state_dim']]))
self.Readout.build(input_shape=[None, self.hparams['link_state_dim']])
self.built = True
#@tf.function
def call(self, link_state, states_graph_ids, states_first, states_second, sates_num_edges, training=False):
# Execute T times
for _ in range(self.hparams['T']):
# We have the combination of the hidden states of the main edges with the neighbours
mainEdges = tf.gather(link_state, states_first)
neighEdges = tf.gather(link_state, states_second)
edgesConcat = tf.concat([mainEdges, neighEdges], axis=1)
### 1.a Message passing for link with all it's neighbours
outputs = self.Message(edgesConcat)
### 1.b Sum of output values according to link id index
edges_inputs = tf.math.unsorted_segment_sum(data=outputs, segment_ids=states_second,
num_segments=sates_num_edges)
### 2. Update for each link
# GRUcell needs a 3D tensor as state because there is a matmul: Wrap the link state
outputs, links_state_list = self.Update(edges_inputs, [link_state])
link_state = links_state_list[0]
# Perform sum of all hidden states
edges_combi_outputs = tf.math.segment_sum(link_state, states_graph_ids, name=None)
r = self.Readout(edges_combi_outputs,training=training)
return r