-
Notifications
You must be signed in to change notification settings - Fork 119
/
Copy pathmi_gru_cell.py
64 lines (48 loc) · 2.63 KB
/
mi_gru_cell.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import tensorflow as tf
import numpy as np
class MiGRUCell(tf.nn.rnn_cell.RNNCell):
def __init__(self, num_units, input_size = None, activation = tf.tanh, reuse = None):
self.numUnits = num_units
self.activation = activation
self.reuse = reuse
@property
def state_size(self):
return self.numUnits
@property
def output_size(self):
return self.numUnits
def mulWeights(self, inp, inDim, outDim, name = ""):
with tf.variable_scope("weights" + name):
W = tf.get_variable("weights", shape = (inDim, outDim),
initializer = tf.contrib.layers.xavier_initializer())
output = tf.matmul(inp, W)
return output
def addBiases(self, inp1, inp2, dim, bInitial = 0, name = ""):
with tf.variable_scope("additiveBiases" + name):
b = tf.get_variable("biases", shape = (dim,),
initializer = tf.zeros_initializer()) + bInitial
with tf.variable_scope("multiplicativeBias" + name):
beta = tf.get_variable("biases", shape = (3 * dim,),
initializer = tf.ones_initializer())
Wx, Uh, inter = tf.split(beta * tf.concat([inp1, inp2, inp1 * inp2], axis = 1),
num_or_size_splits = 3, axis = 1)
output = Wx + Uh + inter + b
return output
def __call__(self, inputs, state, scope = None):
scope = scope or type(self).__name__
with tf.variable_scope(scope, reuse = self.reuse):
inputSize = int(inputs.shape[1])
Wxr = self.mulWeights(inputs, inputSize, self.numUnits, name = "Wxr")
Uhr = self.mulWeights(state, self.numUnits, self.numUnits, name = "Uhr")
r = tf.nn.sigmoid(self.addBiases(Wxr, Uhr, self.numUnits, bInitial = 1, name = "r"))
Wxu = self.mulWeights(inputs, inputSize, self.numUnits, name = "Wxu")
Uhu = self.mulWeights(state, self.numUnits, self.numUnits, name = "Uhu")
u = tf.nn.sigmoid(self.addBiases(Wxu, Uhu, self.numUnits, bInitial = 1, name = "u"))
# r, u = tf.split(gates, num_or_size_splits = 2, axis = 1)
Wx = self.mulWeights(inputs, inputSize, self.numUnits, name = "Wxl")
Urh = self.mulWeights(r * state, self.numUnits, self.numUnits, name = "Uhl")
c = self.activation(self.addBiases(Wx, Urh, self.numUnits, name = "2"))
newH = u * state + (1 - u) * c # switch u and 1-u?
return newH, newH
def zero_state(self, batchSize, dtype = tf.float32):
return tf.zeros((batchSize, self.numUnits), dtype = dtype)