-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
first commit of char-aware lang model
- Loading branch information
asd
authored and
asd
committed
Sep 24, 2017
1 parent
59de402
commit bb48e2b
Showing
21 changed files
with
51,727 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
MIT License | ||
|
||
Copyright (c) 2016 Mike Kroutikov | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
|
||
The above copyright notice and this permission notice shall be included in all | ||
copies or substantial portions of the Software. | ||
|
||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
SOFTWARE. |
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
from __future__ import print_function | ||
from __future__ import division | ||
|
||
import os | ||
import codecs | ||
import collections | ||
import numpy as np | ||
|
||
|
||
class Vocab: | ||
|
||
def __init__(self, token2index=None, index2token=None): | ||
self._token2index = token2index or {} | ||
self._index2token = index2token or [] | ||
|
||
def feed(self, token): | ||
if token not in self._token2index: | ||
# allocate new index for this token | ||
index = len(self._token2index) | ||
self._token2index[token] = index | ||
self._index2token.append(token) | ||
|
||
return self._token2index[token] | ||
|
||
@property | ||
def size(self): | ||
return len(self._token2index) | ||
|
||
def token(self, index): | ||
return self._index2token[index] | ||
|
||
def __getitem__(self, token): | ||
index = self.get(token) | ||
if index is None: | ||
raise KeyError(token) | ||
return index | ||
|
||
def get(self, token, default=None): | ||
return self._token2index.get(token, default) | ||
|
||
def save(self, filename): | ||
with open(filename, 'wb') as f: | ||
pickle.dump((self._token2index, self._index2token), f, pickle.HIGHEST_PROTOCOL) | ||
|
||
@classmethod | ||
def load(cls, filename): | ||
with open(filename, 'rb') as f: | ||
token2index, index2token = pickle.load(f) | ||
|
||
return cls(token2index, index2token) | ||
|
||
|
||
def load_data(data_dir, max_word_length, eos='+'): | ||
|
||
char_vocab = Vocab() | ||
char_vocab.feed(' ') # blank is at index 0 in char vocab | ||
char_vocab.feed('{') # start is at index 1 in char vocab | ||
char_vocab.feed('}') # end is at index 2 in char vocab | ||
|
||
word_vocab = Vocab() | ||
word_vocab.feed('|') # <unk> is at index 0 in word vocab | ||
|
||
actual_max_word_length = 0 | ||
|
||
word_tokens = collections.defaultdict(list) | ||
char_tokens = collections.defaultdict(list) | ||
|
||
for fname in ('train', 'valid', 'test'): | ||
print('reading', fname) | ||
with codecs.open(os.path.join(data_dir, fname + '.txt'), 'r', 'utf-8') as f: | ||
for line in f: | ||
line = line.strip() | ||
line = line.replace('}', '').replace('{', '').replace('|', '') | ||
line = line.replace('<unk>', ' | ') | ||
if eos: | ||
line = line.replace(eos, '') | ||
|
||
for word in line.split(): | ||
if len(word) > max_word_length - 2: # space for 'start' and 'end' chars | ||
word = word[:max_word_length-2] | ||
|
||
word_tokens[fname].append(word_vocab.feed(word)) | ||
|
||
char_array = [char_vocab.feed(c) for c in '{' + word + '}'] | ||
char_tokens[fname].append(char_array) | ||
|
||
actual_max_word_length = max(actual_max_word_length, len(char_array)) | ||
|
||
if eos: | ||
word_tokens[fname].append(word_vocab.feed(eos)) | ||
|
||
char_array = [char_vocab.feed(c) for c in '{' + eos + '}'] | ||
char_tokens[fname].append(char_array) | ||
|
||
assert actual_max_word_length <= max_word_length | ||
|
||
print() | ||
print('actual longest token length is:', actual_max_word_length) | ||
print('size of word vocabulary:', word_vocab.size) | ||
print('size of char vocabulary:', char_vocab.size) | ||
print('number of tokens in train:', len(word_tokens['train'])) | ||
print('number of tokens in valid:', len(word_tokens['valid'])) | ||
print('number of tokens in test:', len(word_tokens['test'])) | ||
|
||
# now we know the sizes, create tensors | ||
word_tensors = {} | ||
char_tensors = {} | ||
for fname in ('train', 'valid', 'test'): | ||
assert len(char_tokens[fname]) == len(word_tokens[fname]) | ||
|
||
word_tensors[fname] = np.array(word_tokens[fname], dtype=np.int32) | ||
char_tensors[fname] = np.zeros([len(char_tokens[fname]), actual_max_word_length], dtype=np.int32) | ||
|
||
for i, char_array in enumerate(char_tokens[fname]): | ||
char_tensors[fname] [i,:len(char_array)] = char_array | ||
|
||
return word_vocab, char_vocab, word_tensors, char_tensors, actual_max_word_length | ||
|
||
|
||
class DataReader: | ||
|
||
def __init__(self, word_tensor, char_tensor, batch_size, num_unroll_steps): | ||
|
||
length = word_tensor.shape[0] | ||
assert char_tensor.shape[0] == length | ||
|
||
max_word_length = char_tensor.shape[1] | ||
|
||
# round down length to whole number of slices | ||
reduced_length = (length // (batch_size * num_unroll_steps)) * batch_size * num_unroll_steps | ||
word_tensor = word_tensor[:reduced_length] | ||
char_tensor = char_tensor[:reduced_length, :] | ||
|
||
ydata = np.zeros_like(word_tensor) | ||
ydata[:-1] = word_tensor[1:].copy() | ||
ydata[-1] = word_tensor[0].copy() | ||
|
||
x_batches = char_tensor.reshape([batch_size, -1, num_unroll_steps, max_word_length]) | ||
y_batches = ydata.reshape([batch_size, -1, num_unroll_steps]) | ||
|
||
x_batches = np.transpose(x_batches, axes=(1, 0, 2, 3)) | ||
y_batches = np.transpose(y_batches, axes=(1, 0, 2)) | ||
|
||
self._x_batches = list(x_batches) | ||
self._y_batches = list(y_batches) | ||
assert len(self._x_batches) == len(self._y_batches) | ||
self.length = len(self._y_batches) | ||
self.batch_size = batch_size | ||
self.num_unroll_steps = num_unroll_steps | ||
|
||
def iter(self): | ||
|
||
for x, y in zip(self._x_batches, self._y_batches): | ||
yield x, y | ||
|
||
|
||
if __name__ == '__main__': | ||
|
||
_, _, wt, ct, _ = load_data('data', 65) | ||
print(wt.keys()) | ||
|
||
count = 0 | ||
for x, y in DataReader(wt['valid'], ct['valid'], 20, 35).iter(): | ||
count += 1 | ||
print(x, y) | ||
if count > 0: | ||
break |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import os | ||
import time | ||
import numpy as np | ||
import tensorflow as tf | ||
|
||
import model | ||
from data_reader import load_data, DataReader | ||
|
||
|
||
flags = tf.flags | ||
|
||
# data | ||
flags.DEFINE_string('data_dir', 'data', 'data directory. Should contain train.txt/valid.txt/test.txt with input data') | ||
flags.DEFINE_string('train_dir', 'cv', 'training directory (models and summaries are saved there periodically)') | ||
flags.DEFINE_string('load_model', None, '(optional) filename of the model to load. Useful for re-starting training from a checkpoint') | ||
|
||
# model params | ||
flags.DEFINE_integer('rnn_size', 650, 'size of LSTM internal state') | ||
flags.DEFINE_integer('highway_layers', 2, 'number of highway layers') | ||
flags.DEFINE_integer('char_embed_size', 15, 'dimensionality of character embeddings') | ||
flags.DEFINE_string ('kernels', '[1,2,3,4,5,6,7]', 'CNN kernel widths') | ||
flags.DEFINE_string ('kernel_features', '[50,100,150,200,200,200,200]', 'number of features in the CNN kernel') | ||
flags.DEFINE_integer('rnn_layers', 2, 'number of layers in the LSTM') | ||
flags.DEFINE_float ('dropout', 0.5, 'dropout. 0 = no dropout') | ||
|
||
# optimization | ||
flags.DEFINE_integer('num_unroll_steps', 35, 'number of timesteps to unroll for') | ||
flags.DEFINE_integer('batch_size', 20, 'number of sequences to train on in parallel') | ||
flags.DEFINE_integer('max_word_length', 65, 'maximum word length') | ||
|
||
# bookkeeping | ||
flags.DEFINE_integer('seed', 3435, 'random number generator seed') | ||
flags.DEFINE_string ('EOS', '+', '<EOS> symbol. should be a single unused character (like +) for PTB and blank for others') | ||
|
||
FLAGS = flags.FLAGS | ||
|
||
|
||
def run_test(session, m, data, batch_size, num_steps): | ||
"""Runs the model on the given data.""" | ||
|
||
costs = 0.0 | ||
iters = 0 | ||
state = session.run(m.initial_state) | ||
|
||
for step, (x, y) in enumerate(reader.dataset_iterator(data, batch_size, num_steps)): | ||
cost, state = session.run([m.cost, m.final_state], { | ||
m.input_data: x, | ||
m.targets: y, | ||
m.initial_state: state | ||
}) | ||
|
||
costs += cost | ||
iters += 1 | ||
|
||
return costs / iters | ||
|
||
|
||
def main(_): | ||
''' Loads trained model and evaluates it on test split ''' | ||
|
||
if FLAGS.load_model is None: | ||
print('Please specify checkpoint file to load model from') | ||
return -1 | ||
|
||
if not os.path.exists(FLAGS.load_model + ".index"): | ||
print('Checkpoint file not found', FLAGS.load_model) | ||
return -1 | ||
|
||
word_vocab, char_vocab, word_tensors, char_tensors, max_word_length = \ | ||
load_data(FLAGS.data_dir, FLAGS.max_word_length, eos=FLAGS.EOS) | ||
|
||
test_reader = DataReader(word_tensors['test'], char_tensors['test'], | ||
FLAGS.batch_size, FLAGS.num_unroll_steps) | ||
|
||
print('initialized test dataset reader') | ||
|
||
with tf.Graph().as_default(), tf.Session() as session: | ||
|
||
# tensorflow seed must be inside graph | ||
tf.set_random_seed(FLAGS.seed) | ||
np.random.seed(seed=FLAGS.seed) | ||
|
||
''' build inference graph ''' | ||
with tf.variable_scope("Model"): | ||
m = model.inference_graph( | ||
char_vocab_size=char_vocab.size, | ||
word_vocab_size=word_vocab.size, | ||
char_embed_size=FLAGS.char_embed_size, | ||
batch_size=FLAGS.batch_size, | ||
num_highway_layers=FLAGS.highway_layers, | ||
num_rnn_layers=FLAGS.rnn_layers, | ||
rnn_size=FLAGS.rnn_size, | ||
max_word_length=max_word_length, | ||
kernels=eval(FLAGS.kernels), | ||
kernel_features=eval(FLAGS.kernel_features), | ||
num_unroll_steps=FLAGS.num_unroll_steps, | ||
dropout=0) | ||
m.update(model.loss_graph(m.logits, FLAGS.batch_size, FLAGS.num_unroll_steps)) | ||
|
||
global_step = tf.Variable(0, dtype=tf.int32, name='global_step') | ||
|
||
saver = tf.train.Saver() | ||
saver.restore(session, FLAGS.load_model) | ||
print('Loaded model from', FLAGS.load_model, 'saved at global step', global_step.eval()) | ||
|
||
''' training starts here ''' | ||
rnn_state = session.run(m.initial_rnn_state) | ||
count = 0 | ||
avg_loss = 0 | ||
start_time = time.time() | ||
for x, y in test_reader.iter(): | ||
count += 1 | ||
loss, rnn_state = session.run([ | ||
m.loss, | ||
m.final_rnn_state | ||
], { | ||
m.input : x, | ||
m.targets: y, | ||
m.initial_rnn_state: rnn_state | ||
}) | ||
|
||
avg_loss += loss | ||
|
||
avg_loss /= count | ||
time_elapsed = time.time() - start_time | ||
|
||
print("test loss = %6.8f, perplexity = %6.8f" % (avg_loss, np.exp(avg_loss))) | ||
print("test samples:", count*FLAGS.batch_size, "time elapsed:", time_elapsed, "time per one batch:", time_elapsed/count) | ||
|
||
|
||
if __name__ == "__main__": | ||
tf.app.run() |
Oops, something went wrong.