Skip to content

Commit

Permalink
Use beam search to get better and multiple results
Browse files Browse the repository at this point in the history
  • Loading branch information
wb14123 committed Jul 8, 2021
1 parent 30bb57e commit 8687fcc
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 16 deletions.
9 changes: 6 additions & 3 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import reader
from os import path
import random
import numpy as np


class Model():
Expand Down Expand Up @@ -226,10 +227,12 @@ def infer(self, text):
in_seq = reader.encode_text(text.split(' ') + ['</s>',],
self.infer_vocab_indices)
in_seq_len = len(in_seq)
outputs = self.infer_session.run(self.infer_output,
(outputs, scores) = self.infer_session.run(self.infer_output,
feed_dict={
self.infer_in_seq: [in_seq],
self.infer_in_seq_len: [in_seq_len]})
output = outputs[0]
output_text = reader.decode_text(output, self.infer_vocabs)
return output_text
score = np.average(scores[0].T, axis=1)
output_text = reader.decode_multi_text(output, self.infer_vocabs)
output_without_space = [''.join(s.split(' ')) for s in output_text]
return (output_without_space, score)
14 changes: 14 additions & 0 deletions reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,20 @@ def decode_text(labels, vocabs, end_token = '</s>'):
return ' '.join(results)


def decode_multi_text(labels, vocabs, end_token = '</s>'):
all_results = []
(result_count, length) = labels.shape
for i in range(length):
results = []
for j in range(result_count):
word = vocabs[labels[j][i]]
if word == end_token:
all_results.append(' '.join(results))
break
results.append(word)
return all_results


def read_vocab(vocab_file):
f = open(vocab_file, 'rb')
vocabs = [line.decode('utf8')[:-1] for line in f]
Expand Down
37 changes: 28 additions & 9 deletions seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,29 +130,48 @@ def seq2seq(in_seq, in_seq_len, target_seq, target_seq_len, vocab_size,



decoder_cell = attention_decoder_cell(encoder_output, in_seq_len, num_units,
layers, input_keep_prob)
batch_size = tf.shape(in_seq_len)[0]
init_state = decoder_cell.zero_state(batch_size, tf.float32).clone(
if target_seq != None:
decoder_cell = attention_decoder_cell(encoder_output, in_seq_len, num_units,
layers, input_keep_prob)
init_state = decoder_cell.zero_state(batch_size, tf.float32).clone(
cell_state=encoder_state)

if target_seq != None:
embed_target = tf.nn.embedding_lookup(embedding, target_seq,
name='embed_target')
helper = tf.contrib.seq2seq.TrainingHelper(
embed_target, target_seq_len, time_major=False)
decoder = tf.contrib.seq2seq.BasicDecoder(decoder_cell, helper,
init_state, output_layer=projection_layer)
else:
# TODO: start tokens and end tokens are hard code
helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
embedding, tf.fill([batch_size], 0), 1)
decoder = tf.contrib.seq2seq.BasicDecoder(decoder_cell, helper,
init_state, output_layer=projection_layer)
beam_width = 10
tiled_encoder_output = tf.contrib.seq2seq.tile_batch(
encoder_output, multiplier=beam_width)
tiled_encoder_state = tf.contrib.seq2seq.tile_batch(
encoder_state, multiplier=beam_width)
tiled_in_seq_len = tf.contrib.seq2seq.tile_batch(
in_seq_len , multiplier=beam_width)
decoder_cell = attention_decoder_cell(tiled_encoder_output, tiled_in_seq_len, num_units,
layers, input_keep_prob)
init_state = decoder_cell.zero_state(batch_size * beam_width, tf.float32).clone(
cell_state=tiled_encoder_state)
decoder = tf.contrib.seq2seq.BeamSearchDecoder(
cell=decoder_cell,
embedding=embedding,
start_tokens=tf.fill([batch_size], 0),
end_token=1,
initial_state=init_state,
beam_width=beam_width,
output_layer=projection_layer,
length_penalty_weight=1.0)
outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder,
maximum_iterations=100)
if target_seq != None:
return outputs.rnn_output
else:
return outputs.sample_id
return (outputs.predicted_ids,
outputs.beam_search_decoder_output.scores)


def seq_loss(output, target, seq_len):
Expand Down
10 changes: 6 additions & 4 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,12 @@ def chat_couplet(in_str):
if len(in_str) == 0 or len(in_str) > 50:
output = u'您的输入太长了'
else:
output = m.infer(' '.join(in_str))
output = ''.join(output.split(' '))
logging.info('上联:%s;下联:%s' % (in_str, output))
return jsonify({'output': output})
result = {'text': []}
output, score = m.infer(' '.join(in_str))
score = score.tolist()
logging.info('上联:%s;下联:%s ; score: %s' % (
in_str, output, score))
return jsonify({'output': output, 'score': score})

http_server = WSGIServer(('', 5000), app)
http_server.serve_forever()

0 comments on commit 8687fcc

Please sign in to comment.