-
Notifications
You must be signed in to change notification settings - Fork 25
/
Copy pathmodel_seq2seq.py
71 lines (57 loc) · 2.78 KB
/
model_seq2seq.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from __future__ import print_function
import tensorflow as tf
from tensorflow.contrib import rnn
from trnn import *
from trnn_imply import *
def TLSTM(enc_inps, dec_inps, is_training, config):
def tlstm_cell():
return TensorLSTMCell(config.hidden_size, config.num_lags, config.rank_vals)
print('Training -->') if is_training else print('Testing -->')
cell= tlstm_cell()
#if is_training and config.keep_prob < 1:
# cell = tf.contrib.rnn.DropoutWrapper(
# cell, output_keep_prob=config.keep_prob)
cell = tf.contrib.rnn.MultiRNNCell(
[cell for _ in range(config.num_layers)])
with tf.variable_scope("Encoder", reuse=None):
print(' '*10+'Create Encoder ...')
enc_outs, enc_states = tensor_rnn_with_feed_prev(cell, enc_inps, True, config)
with tf.variable_scope("Decoder", reuse=None):
print(' '*10+'Create Decoder ...')
config.inp_steps = 0
dec_outs, dec_states = tensor_rnn_with_feed_prev(cell, dec_inps, is_training, config, enc_states)
return dec_outs
def RNN(enc_inps, dec_inps,is_training, config):
def rnn_cell():
return tf.contrib.rnn.BasicRNNCell(config.hidden_size)
if is_training and config.keep_prob < 1:
cell = tf.contrib.rnn.DropoutWrapper(
rnn_cell(), output_keep_prob=config.keep_prob)
cell = tf.contrib.rnn.MultiRNNCell(
[rnn_cell() for _ in range(config.num_layers)])
with tf.variable_scope("Encoder", reuse=None):
enc_outs, enc_states = rnn_with_feed_prev(cell, enc_inps, True, config)
with tf.variable_scope("Decoder", reuse=None):
config.inp_steps = 0
dec_outs, dec_states = rnn_with_feed_prev(cell, dec_inps, is_training, config, enc_states)
return dec_outs
def LSTM(enc_inps, dec_inps, is_training, config):
# Prepare data shape to match `rnn` function requirements
# Current data input shape: (batch_size, timesteps, n_input)
# Required shape: 'timesteps' tensors list of shape (batch_size, n_input)
# Define a lstm cell with tensorflow
def lstm_cell():
return tf.contrib.rnn.BasicLSTMCell(config.hidden_size,forget_bias=1.0, reuse=None)
#if is_training and config.keep_prob < 1:
# cell = tf.contrib.rnn.DropoutWrapper(
# lstm_cell(), output_keep_prob=config.keep_prob)
cell = tf.contrib.rnn.MultiRNNCell(
[lstm_cell() for _ in range(config.num_layers)])
# Get encoder output
with tf.variable_scope("Encoder", reuse=None):
enc_outs, enc_states = rnn_with_feed_prev(cell, enc_inps, True, config)
# Get decoder output
with tf.variable_scope("Decoder", reuse=None):
config.burn_in_steps = 0
dec_outs, dec_states = rnn_with_feed_prev(cell, dec_inps, is_training, config, enc_states)
return dec_outs