Skip to content

Commit

Permalink
Merge pull request #2146 from mozilla/refactor-model-impls
Browse files Browse the repository at this point in the history
Refactor TF and TFLite implementations into their own classes/files and fix concurrent/interleaved stream bugs by tracking LSTM state in StreamingState
  • Loading branch information
reuben committed Jun 20, 2019
2 parents 4b29b78 + f12ea5e commit a2306cf
Show file tree
Hide file tree
Showing 18 changed files with 1,020 additions and 711 deletions.
94 changes: 37 additions & 57 deletions DeepSpeech.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,12 +574,8 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
# no state management since n_step is expected to be dynamic too (see below)
previous_state = previous_state_c = previous_state_h = None
else:
if tflite:
previous_state_c = tf.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_c')
previous_state_h = tf.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_h')
else:
previous_state_c = variable_on_cpu('previous_state_c', [batch_size, Config.n_cell_dim], initializer=None)
previous_state_h = variable_on_cpu('previous_state_h', [batch_size, Config.n_cell_dim], initializer=None)
previous_state_c = tf.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_c')
previous_state_h = tf.placeholder(tf.float32, [batch_size, Config.n_cell_dim], name='previous_state_h')

previous_state = tf.contrib.rnn.LSTMStateTuple(previous_state_c, previous_state_h)

Expand All @@ -592,7 +588,7 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
rnn_impl = rnn_impl_lstmblockfusedcell

logits, layers = create_model(batch_x=input_tensor,
seq_length=seq_length if FLAGS.use_seq_length else None,
seq_length=seq_length if not FLAGS.export_tflite else None,
dropout=no_dropout,
previous_state=previous_state,
overlap=False,
Expand All @@ -605,7 +601,7 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
logits = tf.squeeze(logits, [1])

# Apply softmax for CTC decoder
logits = tf.nn.softmax(logits)
logits = tf.nn.softmax(logits, name='logits')

if batch_size <= 0:
if tflite:
Expand All @@ -618,51 +614,31 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
'input_lengths': seq_length,
},
{
'outputs': tf.identity(logits, name='logits'),
'outputs': logits,
},
layers
)

new_state_c, new_state_h = layers['rnn_output_state']
if tflite:
logits = tf.identity(logits, name='logits')
new_state_c = tf.identity(new_state_c, name='new_state_c')
new_state_h = tf.identity(new_state_h, name='new_state_h')

inputs = {
'input': input_tensor,
'previous_state_c': previous_state_c,
'previous_state_h': previous_state_h,
'input_samples': input_samples,
}

if FLAGS.use_seq_length:
inputs.update({'input_lengths': seq_length})

outputs = {
'outputs': logits,
'new_state_c': new_state_c,
'new_state_h': new_state_h,
'mfccs': mfccs,
}
else:
zero_state = tf.zeros([batch_size, Config.n_cell_dim], tf.float32)
initialize_c = tf.assign(previous_state_c, zero_state)
initialize_h = tf.assign(previous_state_h, zero_state)
initialize_state = tf.group(initialize_c, initialize_h, name='initialize_state')
with tf.control_dependencies([tf.assign(previous_state_c, new_state_c), tf.assign(previous_state_h, new_state_h)]):
logits = tf.identity(logits, name='logits')

inputs = {
'input': input_tensor,
'input_lengths': seq_length,
'input_samples': input_samples,
}
outputs = {
'outputs': logits,
'initialize_state': initialize_state,
'mfccs': mfccs,
}
new_state_c = tf.identity(new_state_c, name='new_state_c')
new_state_h = tf.identity(new_state_h, name='new_state_h')

inputs = {
'input': input_tensor,
'previous_state_c': previous_state_c,
'previous_state_h': previous_state_h,
'input_samples': input_samples,
}

if not FLAGS.export_tflite:
inputs.update({'input_lengths': seq_length})

outputs = {
'outputs': logits,
'new_state_c': new_state_c,
'new_state_h': new_state_h,
'mfccs': mfccs,
}

return inputs, outputs, layers

Expand All @@ -682,10 +658,12 @@ def export():
output_names_ops = [op.name for op in outputs.values() if isinstance(op, Operation)]
output_names = ",".join(output_names_tensors + output_names_ops)

if not FLAGS.export_tflite:
mapping = {v.op.name: v for v in tf.global_variables() if not v.op.name.startswith('previous_state_')}
else:
mapping = None
if FLAGS.export_tflite:
# Create a saver using variables from the above newly created graph
# Training graph uses LSTMFusedCell, but the TFLite inference graph uses
# a static RNN with a normal cell, so we need to rewrite the names to
# match the training weights when restoring.
def fixup(name):
if name.startswith('rnn/lstm_cell/'):
return name.replace('rnn/lstm_cell/', 'lstm_fused_cell/')
Expand All @@ -710,7 +688,7 @@ def fixup(name):
if not os.path.isdir(FLAGS.export_dir):
os.makedirs(FLAGS.export_dir)

def do_graph_freeze(output_file=None, output_node_names=None, variables_blacklist=None):
def do_graph_freeze(output_file=None, output_node_names=None, variables_blacklist=''):
frozen = freeze_graph.freeze_graph_with_def_protos(
input_graph_def=tf.get_default_graph().as_graph_def(),
input_saver_def=saver.as_saver_def(),
Expand All @@ -731,7 +709,7 @@ def do_graph_freeze(output_file=None, output_node_names=None, variables_blacklis
placeholder_type_enum=tf.float32.as_datatype_enum)

if not FLAGS.export_tflite:
frozen_graph = do_graph_freeze(output_node_names=output_names, variables_blacklist='previous_state_c,previous_state_h')
frozen_graph = do_graph_freeze(output_node_names=output_names)
frozen_graph.version = int(file_relative_read('GRAPH_VERSION').strip())

# Add a no-op node to the graph with metadata information to be loaded by the native client
Expand All @@ -747,7 +725,7 @@ def do_graph_freeze(output_file=None, output_node_names=None, variables_blacklis
with open(output_graph_path, 'wb') as fout:
fout.write(frozen_graph.SerializeToString())
else:
frozen_graph = do_graph_freeze(output_node_names=output_names, variables_blacklist='')
frozen_graph = do_graph_freeze(output_node_names=output_names)
output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite'))

converter = tf.lite.TFLiteConverter(frozen_graph, input_tensors=inputs.values(), output_tensors=outputs.values())
Expand All @@ -771,8 +749,7 @@ def do_single_file_inference(input_file_path):
inputs, outputs, _ = create_inference_graph(batch_size=1, n_steps=-1)

# Create a saver using variables from the above newly created graph
mapping = {v.op.name: v for v in tf.global_variables() if not v.op.name.startswith('previous_state_')}
saver = tf.train.Saver(mapping)
saver = tf.train.Saver()

# Restore variables from training checkpoint
# TODO: This restores the most recent checkpoint, but if we use validation to counteract
Expand All @@ -784,9 +761,10 @@ def do_single_file_inference(input_file_path):

checkpoint_path = checkpoint.model_checkpoint_path
saver.restore(session, checkpoint_path)
session.run(outputs['initialize_state'])

features, features_len = audiofile_to_features(input_file_path)
previous_state_c = np.zeros([1, Config.n_cell_dim])
previous_state_h = np.zeros([1, Config.n_cell_dim])

# Add batch dimension
features = tf.expand_dims(features, 0)
Expand All @@ -799,6 +777,8 @@ def do_single_file_inference(input_file_path):
logits = outputs['outputs'].eval(feed_dict={
inputs['input']: features,
inputs['input_lengths']: features_len,
inputs['previous_state_c']: previous_state_c,
inputs['previous_state_h']: previous_state_h,
}, session=session)

logits = np.squeeze(logits)
Expand Down
2 changes: 1 addition & 1 deletion GRAPH_VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1
2
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ Refer to the corresponding [README.md](native_client/README.md) for information

### Exporting a model for TFLite

If you want to experiment with the TF Lite engine, you need to export a model that is compatible with it, then use the `--nouse_seq_length --export_tflite` flags. If you already have a trained model, you can re-export it for TFLite by running `DeepSpeech.py` again and specifying the same `checkpoint_dir` that you used for training, as well as passing `--nouse_seq_length --export_tflite --export_dir /model/export/destination`.
If you want to experiment with the TF Lite engine, you need to export a model that is compatible with it, then use the `--export_tflite` flags. If you already have a trained model, you can re-export it for TFLite by running `DeepSpeech.py` again and specifying the same `checkpoint_dir` that you used for training, as well as passing `--export_tflite --export_dir /model/export/destination`.

### Making a mmap-able model for inference

Expand Down
2 changes: 1 addition & 1 deletion bin/run-tc-ldc93s1_tflite.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ python -u DeepSpeech.py --noshow_progressbar \
--export_dir '/tmp/train_tflite' \
--lm_binary_path 'data/smoke_test/vocab.pruned.lm' \
--lm_trie_path 'data/smoke_test/vocab.trie' \
--export_tflite --nouse_seq_length
--export_tflite
Binary file added data/smoke_test/new-home-in-the-stars-16k.wav
Binary file not shown.
27 changes: 15 additions & 12 deletions native_client/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,20 @@ tf_cc_shared_object(
srcs = ["deepspeech.cc",
"deepspeech.h",
"alphabet.h",
"modelstate.h",
"modelstate.cc",
"ds_version.h",
"ds_graph_version.h"] +
DECODER_SOURCES,
DECODER_SOURCES +
select({
"//native_client:tflite": [
"tflitemodelstate.h",
"tflitemodelstate.cc"
],
"//conditions:default": [
"tfmodelstate.h",
"tfmodelstate.cc"
]}),
copts = select({
# -fvisibility=hidden is not required on Windows, MSCV hides all declarations by default
"//tensorflow:windows": ["/w"],
Expand Down Expand Up @@ -103,34 +114,26 @@ tf_cc_shared_object(
### => Trying to be more fine-grained
### Use bin/ops_in_graph.py to list all the ops used by a frozen graph.
### CPU only build, libdeepspeech.so file size reduced by ~50%
"//tensorflow/core/kernels:dense_update_ops", # Assign
"//tensorflow/core/kernels:constant_op", # Const
"//tensorflow/core/kernels:immutable_constant_op", # ImmutableConst
"//tensorflow/core/kernels:dense_update_ops", # Assign (remove once prod model no longer depends on it)
"//tensorflow/core/kernels:constant_op", # Placeholder
"//tensorflow/core/kernels:immutable_constant_op", # ImmutableConst (used in memmapped models)
"//tensorflow/core/kernels:identity_op", # Identity
"//tensorflow/core/kernels:softmax_op", # Softmax
"//tensorflow/core/kernels:transpose_op", # Transpose
"//tensorflow/core/kernels:reshape_op", # Reshape
"//tensorflow/core/kernels:shape_ops", # Shape
"//tensorflow/core/kernels:concat_op", # ConcatV2
"//tensorflow/core/kernels:split_op", # Split
"//tensorflow/core/kernels:variable_ops", # VariableV2
"//tensorflow/core/kernels:relu_op", # Relu
"//tensorflow/core/kernels:bias_op", # BiasAdd
"//tensorflow/core/kernels:math", # Range, MatMul
"//tensorflow/core/kernels:control_flow_ops", # Enter
"//tensorflow/core/kernels:tile_ops", # Tile
"//tensorflow/core/kernels:gather_op", # Gather
"//tensorflow/core/kernels:mfcc_op", # Mfcc
"//tensorflow/core/kernels:spectrogram_op", # AudioSpectrogram
"//tensorflow/core/kernels:strided_slice_op", # StridedSlice
"//tensorflow/core/kernels:slice_op", # Slice, needed by StridedSlice
"//tensorflow/contrib/rnn:lstm_ops_kernels", # BlockLSTM
"//tensorflow/core/kernels:random_ops", # RandomGammaGrad
"//tensorflow/core/kernels:pack_op", # Pack
"//tensorflow/core/kernels:gather_nd_op", # GatherNd
#### Needed by production model produced without "--use_seq_length False"
#"//tensorflow/core/kernels:logging_ops", # Assert
#"//tensorflow/core/kernels:reverse_sequence_op", # ReverseSequence
],
}) + if_cuda([
"//tensorflow/core:core",
Expand Down
Loading

0 comments on commit a2306cf

Please sign in to comment.