-
Notifications
You must be signed in to change notification settings - Fork 2
/
seq2seq_model.py
169 lines (138 loc) · 7.86 KB
/
seq2seq_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
#coding=utf-8
__author__ = 'liyang54'
from distutils.version import LooseVersion
import tensorflow as tf
from tensorflow.python.layers.core import Dense
import numpy as np
import time
# encoder层
# decoder层
# seq2seq model
# Check TensorFlow Version
assert LooseVersion(tf.__version__) >= LooseVersion('1.1'), 'Please use TensorFlow version 1.1 or newer'
print('TensorFlow Version: {}'.format(tf.__version__))
# encode层输出结果
def get_encoder_layer(input_data, rnn_size, num_layers,
source_sequence_length, source_vocab_size,
encoding_embedding_size):
'''
构造Encoder层
参数说明:
- input_data: 输入tensor
- rnn_size: rnn隐层结点数量
- num_layers: 堆叠的rnn cell数量
- source_sequence_length: 源数据的序列长度
- source_vocab_size: 源数据的词典大小
- encoding_embedding_size: embedding的大小
'''
# Encoder embedding
encoder_embed_input = tf.contrib.layers.embed_sequence(input_data, source_vocab_size, encoding_embedding_size)
# RNN cell
def get_lstm_cell(rnn_size):
lstm_cell = tf.contrib.rnn.LSTMCell(rnn_size, initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=2))
return lstm_cell
# lstm_cell = tf.contrib.rnn.LSTMCell(rnn_size, initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=2))
cell = tf.contrib.rnn.MultiRNNCell([get_lstm_cell(rnn_size) for _ in range(num_layers)])
#encoder_state表示encoder最终状态
encoder_output, encoder_state = tf.nn.dynamic_rnn(cell, encoder_embed_input,
sequence_length=source_sequence_length, dtype=tf.float32)
# lstm_cell_fw = tf.nn.rnn_cell.BasicLSTMCell(encoding_embedding_size, forget_bias=1.0)
# input_batch_size = tf.shape(input_data)[0] + 0
# initial_state_sentence = lstm_cell_fw.zero_state(input_batch_size, dtype=tf.float32)
# encoder_output, encoder_state = tf.nn.dynamic_rnn(lstm_cell_fw, encoder_embed_input, dtype=tf.float32,
# initial_state=initial_state_sentence)
return encoder_output, encoder_state
def process_decoder_input(data, vocab_to_int, batch_size):
'''
补充<GO>,并移除最后一个字符
'''
# cut掉最后一个字符
ending = tf.strided_slice(data, [0, 0], [batch_size, -1], [1, 1])
decoder_input = tf.concat([tf.fill([batch_size, 1], vocab_to_int['<GO>']), ending], 1)
return decoder_input
def decoding_layer(target_letter_to_int, decoding_embedding_size, num_layers, rnn_size,
target_sequence_length, max_target_sequence_length, encoder_state, decoder_input, batch_size):
'''
构造Decoder层
参数:
- target_letter_to_int: target数据的映射表
- decoding_embedding_size: embed向量大小
- num_layers: 堆叠的RNN单元数量
- rnn_size: RNN单元的隐层结点数量
- target_sequence_length: target数据序列长度
- max_target_sequence_length: target数据序列最大长度
- encoder_state: encoder端编码的状态向量
- decoder_input: decoder端输入
'''
# 1. Embedding
target_vocab_size = len(target_letter_to_int)
decoder_embeddings = tf.Variable(tf.random_uniform([target_vocab_size, decoding_embedding_size]))
decoder_embed_input = tf.nn.embedding_lookup(decoder_embeddings, decoder_input)
# 2. 构造Decoder中的RNN单元
def get_decoder_cell(rnn_size):
decoder_cell = tf.contrib.rnn.LSTMCell(rnn_size,
initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=2))
return decoder_cell
cell = tf.contrib.rnn.MultiRNNCell([get_decoder_cell(rnn_size) for _ in range(num_layers)])
# 3. Output全连接层
output_layer = Dense(target_vocab_size,
kernel_initializer = tf.truncated_normal_initializer(mean = 0.0, stddev=0.1))
# 4. Training decoder
# TrainingHelper用于train阶段
# encoder_state表示encoder的final state
with tf.variable_scope("decode"):
# 得到help对象
training_helper = tf.contrib.seq2seq.TrainingHelper(inputs=decoder_embed_input,
sequence_length=target_sequence_length,
time_major=False)
# 构造decoder
training_decoder = tf.contrib.seq2seq.BasicDecoder(cell,
training_helper,
encoder_state,
output_layer)
training_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(training_decoder,
impute_finished=True,
maximum_iterations=max_target_sequence_length)
# 5. Predicting decoder
# 与training共享参数
# GreedyEmbeddingHelper用于test阶段
with tf.variable_scope("decode", reuse=True):
# 创建一个常量tensor并复制为batch_size的大小
start_tokens = tf.tile(tf.constant([target_letter_to_int['<GO>']], dtype=tf.int32), [batch_size],
name='start_tokens')
predicting_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(decoder_embeddings,
start_tokens,
target_letter_to_int['<EOS>'])
predicting_decoder = tf.contrib.seq2seq.BasicDecoder(cell,
predicting_helper,
encoder_state,
output_layer)
predicting_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(predicting_decoder,
impute_finished=True,
maximum_iterations=max_target_sequence_length)
return training_decoder_output, predicting_decoder_output
def seq2seq_model(input_data, targets, lr, target_sequence_length,
max_target_sequence_length, source_sequence_length,
source_vocab_size, target_vocab_size,
encoding_embedding_size, decoder_embedding_size,
rnn_size, num_layers, target_letter_to_int, batch_size, decoding_embedding_size):
# 获取encoder的状态输出
_, encoder_state = get_encoder_layer(input_data,
rnn_size,
num_layers,
source_sequence_length,
source_vocab_size,
encoding_embedding_size)
# 预处理后的decoder输入
decoder_input = process_decoder_input(targets, target_letter_to_int, batch_size)
# 将状态向量与输入传递给decoder
training_decoder_output, predicting_decoder_output = decoding_layer(target_letter_to_int,
decoding_embedding_size,
num_layers,
rnn_size,
target_sequence_length,
max_target_sequence_length,
encoder_state,
decoder_input,
batch_size)
return training_decoder_output, predicting_decoder_output