Skip to content

Commit b3b4bdc

Browse files
committed
Fix model save issue
1 parent 8f98a67 commit b3b4bdc

File tree

4 files changed

+75
-62
lines changed

4 files changed

+75
-62
lines changed

README.md

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# 自然语言处理算法与实战
2-
本书主要是介绍了一些基础的入门知识和概念,同时偏重于实战,这里是代码的主要介绍:
2+
本书主要是面向初学者,介绍了一些基础的入门知识和概念,同时提供一些偏向于实战的代码供给读者练习,这里是代码的主要介绍:
33
* chapter-3 中文分词技术
44
* chapter-4 词性标注与命名实体识别
55
* chapter-5 关键词提取
@@ -8,3 +8,5 @@
88
* chapter-8 情感分析
99
* chapter-9 NLP中用到的机器学习算法
1010
* chapter-10 基于深度学习的NLP算法
11+
12+
**由于是初版,还存在不少小的问题,欢迎大家提issue,我们会积极地回复和改进,非常感谢大家。**

chapter-10/seq2seq/dynamic_seq2seq_model.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def _init_placeholders(self):
9393
dtype=tf.int32,
9494
name='encoder_inputs',
9595
)
96-
#self.encoder_inputs = tf.Variable(np.ones((10, 50)).astype(np.int32))
96+
# self.encoder_inputs = tf.Variable(np.ones((10, 50)).astype(np.int32))
9797
self.encoder_inputs_length = tf.placeholder(
9898
shape=(None,),
9999
dtype=tf.int32,
@@ -115,15 +115,15 @@ def _init_decoder_train_connectors(self):
115115
with tf.name_scope('DecoderTrainFeeds'):
116116
sequence_size, batch_size = tf.unstack(
117117
tf.shape(self.decoder_targets))
118-
#batch_size, sequence_size = tf.unstack(tf.shape(self.decoder_targets))
118+
# batch_size, sequence_size = tf.unstack(tf.shape(self.decoder_targets))
119119

120120
EOS_SLICE = tf.ones([1, batch_size], dtype=tf.int32) * self.EOS
121121
PAD_SLICE = tf.ones([1, batch_size], dtype=tf.int32) * self.PAD
122122

123123
self.decoder_train_inputs = tf.concat(
124124
[EOS_SLICE, self.decoder_targets], axis=0)
125125
self.decoder_train_length = self.decoder_targets_length + 1
126-
#self.decoder_train_length = self.decoder_targets_length
126+
# self.decoder_train_length = self.decoder_targets_length
127127

128128
decoder_train_targets = tf.concat(
129129
[self.decoder_targets, PAD_SLICE], axis=0)
@@ -148,7 +148,6 @@ def _init_decoder_train_connectors(self):
148148

149149
def _init_embeddings(self):
150150
with tf.variable_scope("embedding") as scope:
151-
152151
sqrt3 = math.sqrt(3)
153152
initializer = tf.random_uniform_initializer(-sqrt3, sqrt3)
154153

chapter-10/seq2seq/main.py

+7-8
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,8 @@ def train(self):
167167
self.clearModel()
168168
total_time = 0
169169
for i, (e_in, dt_pred) in enumerate(zip(
170-
fd[self.model.decoder_targets].T,
171-
sess.run(self.model.decoder_prediction_train, fd).T
170+
fd[self.model.decoder_targets].T,
171+
sess.run(self.model.decoder_prediction_train, fd).T
172172
)):
173173
print(' sample {}:'.format(i + 1))
174174
print(' dec targets > {}'.format(e_in))
@@ -233,8 +233,8 @@ def onlinelearning(self, input_strs, target_strs):
233233
sess, checkpoint_path, global_step=self.model.global_step)
234234

235235
for i, (e_in, dt_pred) in enumerate(zip(
236-
fd[self.model.decoder_targets].T,
237-
sess.run(self.model.decoder_prediction_train, fd).T
236+
fd[self.model.decoder_targets].T,
237+
sess.run(self.model.decoder_prediction_train, fd).T
238238
)):
239239
print(' sample {}:'.format(i + 1))
240240
print(' dec targets > {}'.format(e_in))
@@ -282,7 +282,7 @@ def predict(self):
282282

283283
action = False
284284
segements = self.segement(inputs_strs)
285-
#inputs_vec = [enc_vocab.get(i) for i in segements]
285+
# inputs_vec = [enc_vocab.get(i) for i in segements]
286286
inputs_vec = []
287287
for i in segements:
288288
inputs_vec.append(self.enc_vocab.get(i, self.model.UNK))
@@ -349,8 +349,8 @@ def test(self):
349349
sess.run(self.model.loss, fd)))
350350

351351
for i, (e_in, dt_pred) in enumerate(zip(
352-
fd[self.model.decoder_targets].T,
353-
sess.run(self.model.decoder_prediction_train, fd).T
352+
fd[self.model.decoder_targets].T,
353+
sess.run(self.model.decoder_prediction_train, fd).T
354354
)):
355355
print(' sample {}:'.format(i + 1))
356356
print(' dec targets > {}'.format(e_in))
@@ -366,4 +366,3 @@ def test(self):
366366
seq_obj.train()
367367
elif sys.argv[1] == 'infer':
368368
seq_obj.predict()
369-

chapter-8/sentiment-analysis/main.py

+62-49
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# encoding:utf-8
22

33
import numpy as np
4+
45
wordsList = np.load('wordsList.npy')
56
print('载入word列表')
67
wordsList = wordsList.tolist()
@@ -9,45 +10,45 @@
910
wordVectors = np.load('wordVectors.npy')
1011
print('载入文本向量')
1112

12-
1313
print(len(wordsList))
1414
print(wordVectors.shape)
1515

1616
import os
1717
from os.path import isfile, join
18+
1819
pos_files = ['pos/' + f for f in os.listdir(
1920
'pos/') if isfile(join('pos/', f))]
2021
neg_files = ['neg/' + f for f in os.listdir(
2122
'neg/') if isfile(join('neg/', f))]
2223
num_words = []
2324
for pf in pos_files:
24-
with open(pf, "r", encoding='utf-8') as f:
25-
line = f.readline()
26-
counter = len(line.split())
27-
num_words.append(counter)
25+
with open(pf, "r", encoding='utf-8') as f:
26+
line = f.readline()
27+
counter = len(line.split())
28+
num_words.append(counter)
2829
print('正面评价完结')
2930

3031
for nf in neg_files:
31-
with open(nf, "r", encoding='utf-8') as f:
32-
line = f.readline()
33-
counter = len(line.split())
34-
num_words.append(counter)
32+
with open(nf, "r", encoding='utf-8') as f:
33+
line = f.readline()
34+
counter = len(line.split())
35+
num_words.append(counter)
3536
print('负面评价完结')
3637

3738
num_files = len(num_words)
3839
print('文件总数', num_files)
3940
print('所有的词的数量', sum(num_words))
4041
print('平均文件词的长度', sum(num_words) / len(num_words))
4142

42-
4343
import re
44+
4445
strip_special_chars = re.compile("[^A-Za-z0-9 ]+")
4546
num_dimensions = 300 # Dimensions for each word vector
4647

4748

4849
def cleanSentences(string):
49-
string = string.lower().replace("<br />", " ")
50-
return re.sub(strip_special_chars, "", string.lower())
50+
string = string.lower().replace("<br />", " ")
51+
return re.sub(strip_special_chars, "", string.lower())
5152

5253

5354
max_seq_num = 250
@@ -94,38 +95,40 @@ def cleanSentences(string):
9495
batch_size = 24
9596
lstm_units = 64
9697
num_labels = 2
97-
iterations = 100000
98+
iterations = 100
99+
lr = 0.001
98100
ids = np.load('idsMatrix.npy')
99101

100102

101103
def get_train_batch():
102-
labels = []
103-
arr = np.zeros([batch_size, max_seq_num])
104-
for i in range(batch_size):
105-
if (i % 2 == 0):
106-
num = randint(1, 11499)
107-
labels.append([1, 0])
108-
else:
109-
num = randint(13499, 24999)
110-
labels.append([0, 1])
111-
arr[i] = ids[num - 1:num]
112-
return arr, labels
104+
labels = []
105+
arr = np.zeros([batch_size, max_seq_num])
106+
for i in range(batch_size):
107+
if (i % 2 == 0):
108+
num = randint(1, 11499)
109+
labels.append([1, 0])
110+
else:
111+
num = randint(13499, 24999)
112+
labels.append([0, 1])
113+
arr[i] = ids[num - 1:num]
114+
return arr, labels
113115

114116

115117
def get_test_batch():
116-
labels = []
117-
arr = np.zeros([batch_size, max_seq_num])
118-
for i in range(batch_size):
119-
num = randint(11499, 13499)
120-
if (num <= 12499):
121-
labels.append([1, 0])
122-
else:
123-
labels.append([0, 1])
124-
arr[i] = ids[num - 1:num]
125-
return arr, labels
118+
labels = []
119+
arr = np.zeros([batch_size, max_seq_num])
120+
for i in range(batch_size):
121+
num = randint(11499, 13499)
122+
if (num <= 12499):
123+
labels.append([1, 0])
124+
else:
125+
labels.append([0, 1])
126+
arr[i] = ids[num - 1:num]
127+
return arr, labels
126128

127129

128130
import tensorflow as tf
131+
129132
tf.reset_default_graph()
130133

131134
labels = tf.placeholder(tf.float32, [batch_size, num_labels])
@@ -134,33 +137,43 @@ def get_test_batch():
134137
tf.zeros([batch_size, max_seq_num, num_dimensions]), dtype=tf.float32)
135138
data = tf.nn.embedding_lookup(wordVectors, input_data)
136139

137-
138140
lstmCell = tf.contrib.rnn.BasicLSTMCell(lstm_units)
139-
lstmCell = tf.contrib.rnn.DropoutWrapper(cell=lstmCell, output_keep_prob=0.75)
141+
lstmCell = tf.contrib.rnn.DropoutWrapper(cell=lstmCell, output_keep_prob=0.5)
140142
value, _ = tf.nn.dynamic_rnn(lstmCell, data, dtype=tf.float32)
141143

142-
143144
weight = tf.Variable(tf.truncated_normal([lstm_units, num_labels]))
144145
bias = tf.Variable(tf.constant(0.1, shape=[num_labels]))
145146
value = tf.transpose(value, [1, 0, 2])
146147
last = tf.gather(value, int(value.get_shape()[0]) - 1)
147148
prediction = (tf.matmul(last, weight) + bias)
148149

149-
correctPred = tf.equal(tf.argmax(prediction, 1), tf.argmax(labels, 1))
150-
accuracy = tf.reduce_mean(tf.cast(correctPred, tf.float32))
151-
150+
correct_pred = tf.equal(tf.argmax(prediction, 1), tf.argmax(labels, 1))
151+
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
152152

153153
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
154154
logits=prediction, labels=labels))
155-
optimizer = tf.train.AdamOptimizer().minimize(loss)
155+
optimizer = tf.train.AdamOptimizer(lr).minimize(loss)
156156

157-
158-
sess = tf.InteractiveSession()
159157
saver = tf.train.Saver()
160-
saver.restore(sess, tf.train.latest_checkpoint('models'))
161158

162-
iterations = 10
163-
for i in range(iterations):
164-
next_batch, next_batch_labels = get_test_batch()
165-
print("正确率:", (sess.run(
166-
accuracy, {input_data: next_batch, labels: next_batch_labels})) * 100)
159+
with tf.Session() as sess:
160+
if os.path.exists("models") and os.path.exists("models/checkpoint"):
161+
saver.restore(sess, tf.train.latest_checkpoint('models'))
162+
else:
163+
if int((tf.__version__).split('.')[1]) < 12 and int((tf.__version__).split('.')[0]) < 1:
164+
init = tf.initialize_all_variables()
165+
else:
166+
init = tf.global_variables_initializer()
167+
sess.run(init)
168+
169+
iterations = 100
170+
for step in range(iterations):
171+
next_batch, next_batch_labels = get_test_batch()
172+
if step % 20 == 0:
173+
print("step:", step, " 正确率:", (sess.run(
174+
accuracy, {input_data: next_batch, labels: next_batch_labels})) * 100)
175+
176+
if not os.path.exists("models"):
177+
os.mkdir("models")
178+
save_path = saver.save(sess, "models/model.ckpt")
179+
print("Model saved in path: %s" % save_path)

0 commit comments

Comments
 (0)