Skip to content

Commit

Permalink
[converter]Support table initializer in converter. (#3958)
Browse files Browse the repository at this point in the history
  • Loading branch information
lina128 authored Oct 6, 2020
1 parent 81117f2 commit 7bec2d5
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 20 deletions.
1 change: 1 addition & 0 deletions tfjs-converter/python/tensorflowjs/converters/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

# JSON string keys for fields of the indexing JSON.
ARTIFACT_MODEL_TOPOLOGY_KEY = 'modelTopology'
ARTIFACT_MODEL_INITIALIZER = 'modelInitializer'
ARTIFACT_WEIGHTS_MANIFEST_KEY = 'weightsManifest'

FORMAT_KEY = 'format'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@ def _run_grappler(config, graph_def, graph, signature_def):
def optimize_graph(graph, signature_def, output_graph,
tf_version, quantization_dtype_map=None,
skip_op_check=False, strip_debug_ops=False,
weight_shard_size_bytes=1024 * 1024 * 4, experiments=False):
weight_shard_size_bytes=1024 * 1024 * 4,
experiments=False,
initializer_graph=None):
"""Takes a Python Graph object and optimizes the graph.
Args:
Expand All @@ -127,6 +129,7 @@ def optimize_graph(graph, signature_def, output_graph,
strip_debug_ops: Bool whether to strip debug ops.
weight_shard_size_bytes: Shard size (in bytes) of the weight files.
The size of each weight file will be <= this value.
initializer_graph: The frozen graph for initializers.
"""

# Add a collection 'train_op' so that Grappler knows the outputs.
Expand Down Expand Up @@ -173,6 +176,7 @@ def optimize_graph(graph, signature_def, output_graph,
]

optimized_graph = _run_grappler(config, optimized_graph, graph, signature_def)

optimized_graph = _remove_unused_control_flow_inputs(optimized_graph)

# Because TF break the Prelu op into 6 ops, for performance we are
Expand All @@ -194,9 +198,15 @@ def optimize_graph(graph, signature_def, output_graph,
raise ValueError('Unsupported Ops in the model after optimization\n' +
', '.join(unsupported))

initializer_graph_def = None
if initializer_graph:
initializer_graph_def = initializer_graph.as_graph_def()

extract_weights(
optimized_graph, output_graph, tf_version,
signature_def, quantization_dtype_map, weight_shard_size_bytes)
signature_def, quantization_dtype_map, weight_shard_size_bytes,
initializer_graph_def)

return optimize_graph


Expand Down Expand Up @@ -235,7 +245,8 @@ def extract_weights(graph_def,
tf_version,
signature_def,
quantization_dtype_map=None,
weight_shard_size_bytes=1024 * 1024 * 4):
weight_shard_size_bytes=1024 * 1024 * 4,
initializer_graph_def=None):
"""Takes a Python GraphDef object and extract the weights.
Args:
Expand All @@ -249,6 +260,7 @@ def extract_weights(graph_def,
supports wildcard substitution.
weight_shard_size_bytes: Shard size (in bytes) of the weight files.
The size of each weight file will be <= this value.
initializer_graph_def: tf.GraphDef proto object for initializer graph.
"""
global_manifest = extract_const_nodes(graph_def.node)

Expand All @@ -260,23 +272,30 @@ def extract_weights(graph_def,
func.node_def.extend(nodes)
function_manifests += extract_const_nodes(func.node_def)

initializer_manifests = []
if initializer_graph_def:
initializer_manifests = extract_const_nodes(initializer_graph_def.node)

print('Writing weight file ' + output_graph + '...')

write_artifacts(MessageToDict(graph_def),
[global_manifest + function_manifests],
[global_manifest +
function_manifests +
initializer_manifests],
output_graph,
tf_version, signature_def,
quantization_dtype_map=quantization_dtype_map,
weight_shard_size_bytes=weight_shard_size_bytes)

weight_shard_size_bytes=weight_shard_size_bytes,
initializer_graph_def=initializer_graph_def)

def write_artifacts(topology,
weights,
output_graph,
tf_version,
signature_def,
quantization_dtype_map=None,
weight_shard_size_bytes=1024 * 1024 * 4):
weight_shard_size_bytes=1024 * 1024 * 4,
initializer_graph_def=None):
"""Writes weights and topology to the output_dir.
If `topology` is Falsy (e.g., `None`), only emit weights to output_dir.
Expand All @@ -293,8 +312,8 @@ def write_artifacts(topology,
supports wildcard substitution.
weight_shard_size_bytes: Shard size (in bytes) of the weight files.
The size of each weight file will be <= this value.
initializer_graph_def: tf.GraphDef proto object for initializer graph.
"""

model_json = {
common.FORMAT_KEY: common.TFJS_GRAPH_MODEL_FORMAT,
# TODO(piyu): Add tensorflow version below by using `meta_info_def`.
Expand All @@ -305,6 +324,11 @@ def write_artifacts(topology,
}
}
model_json[common.ARTIFACT_MODEL_TOPOLOGY_KEY] = topology or None

if initializer_graph_def:
model_json[common.ARTIFACT_MODEL_INITIALIZER] = MessageToDict(
initializer_graph_def)

weights_manifest = write_weights.write_weights(
weights, os.path.dirname(output_graph), write_manifest=False,
quantization_dtype_map=quantization_dtype_map,
Expand Down Expand Up @@ -337,18 +361,50 @@ def _check_signature_in_model(saved_model, signature_name):

def _freeze_saved_model_v1(saved_model_dir, saved_model_tags,
output_node_names):
"""Freeze the graph by converting variables to constants for 1.x saved model.
Args:
saved_model_dir: dir where saved model files are stored.
saved_model_tags: inference graph tag.
output_node_names: List of name strings for the result nodes of the graph.
Returns:
A freezed and optimized graph.
Nullable. A freezed and optimized initializer graph.
Nullable. A list of output node names of initializer.
"""
g = tf.Graph()
with g.as_default():
with tf.compat.v1.Session() as sess:
loader.load(sess, saved_model_tags, saved_model_dir)
meta_graph = loader.load(sess, saved_model_tags, saved_model_dir)

meta_graph_def = g.as_graph_def()

frozen_graph_def = tf.compat.v1.graph_util.convert_variables_to_constants(
sess, g.as_graph_def(), output_node_names)
sess, meta_graph_def, output_node_names)

frozen_graph = tf.Graph()
with frozen_graph.as_default():
tf.import_graph_def(frozen_graph_def, name='')

return frozen_graph
frozen_initializer_graph = None
initializer_output_names = None
# Only support table initializers for now.
if meta_graph.collection_def and meta_graph.collection_def[
'table_initializer']:
initializer_output_names = meta_graph.collection_def[
'table_initializer'].node_list.value
# This will use grappler to extract a subgraph with the
# table initializer ops as the outputs.
frozen_initializer_graph_def = (tf.compat.v1.graph_util
.convert_variables_to_constants(
sess, meta_graph_def,
initializer_output_names))
frozen_initializer_graph = tf.Graph()
with frozen_initializer_graph.as_default():
tf.import_graph_def(frozen_initializer_graph_def, name='')

return frozen_graph, frozen_initializer_graph

def _freeze_saved_model_v2(concrete_func, control_flow_v2=False):
if tf.__version__ < '2.2.0':
Expand Down Expand Up @@ -439,7 +495,8 @@ def convert_tf_saved_model(saved_model_dir,
skip_op_check=False,
strip_debug_ops=False,
weight_shard_size_bytes=1024 * 1024 * 4,
control_flow_v2=False, experiments=False):
control_flow_v2=False,
experiments=False):
"""Freeze the SavedModel and check the model compatibility with Tensorflow.js.
Optimize and convert the model to Tensorflow.js format, when the model passes
Expand Down Expand Up @@ -483,18 +540,22 @@ def convert_tf_saved_model(saved_model_dir,
_check_signature_in_model(model, signature_def)

concrete_func = model.signatures[signature_def]

output_node_names = []
for output_tensor in concrete_func.outputs:
output_node_names.append(output_tensor.name.split(':')[0])

# TensorFlow doesn't encode the saved model version in the graph in a reliable
# way. Try to freeze the graph using V2 utils. If that fails, freeze the
# graph using V1 utils.
# TensorFlow doesn't encode the saved model version in the graph in a
# reliable way. Try to freeze the graph using V2 utils. If that fails, freeze
# the graph using V1 utils.
frozen_initializer_graph = None
try:
frozen_graph = _freeze_saved_model_v2(concrete_func, control_flow_v2)
except BaseException:
frozen_graph = _freeze_saved_model_v1(saved_model_dir, saved_model_tags,
output_node_names)
(frozen_graph,
frozen_initializer_graph) = _freeze_saved_model_v1(saved_model_dir,
saved_model_tags,
output_node_names)

inputs = [x for x in concrete_func.inputs if not x.dtype == 'resource']
signature = _build_signature_def(
Expand Down Expand Up @@ -549,7 +610,8 @@ def _strip_unused_nodes(frozen_graph, concrete_func, output_node_names):
skip_op_check=skip_op_check,
strip_debug_ops=strip_debug_ops,
weight_shard_size_bytes=weight_shard_size_bytes,
experiments=experiments)
experiments=experiments,
initializer_graph=frozen_initializer_graph)

def load_and_initialize_hub_module(module_path, signature='default'):
"""Loads graph of a TF-Hub module and initializes it into a session.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,11 @@ def test_convert_saved_model_v1_with_hashtable(self):

expected_weights_manifest = [{
'paths': ['group1-shard1of1.bin'],
'weights': [{'dtype': 'float32', 'name': 'w', 'shape': [2, 2]}]}]
'weights': [
{'dtype': 'float32', 'name': 'w', 'shape': [2, 2]},
{'dtype': 'string', 'name': 'Const', 'shape': [1]},
{'dtype': 'int32', 'name': 'Const_1', 'shape': [1]}
]}]

tfjs_path = os.path.join(self._tmp_dir, SAVED_MODEL_DIR, 'js')
# Check model.json and weights manifest.
Expand All @@ -372,7 +376,7 @@ def test_convert_saved_model_v1_with_hashtable(self):
self.assertIsNot(signature, None)
self.assertIsNot(signature['inputs'], None)
self.assertIsNot(signature['outputs'], None)

self.assertTrue(model_json['modelInitializer'])

weights_manifest = model_json['weightsManifest']
self.assertEqual(weights_manifest, expected_weights_manifest)
Expand Down

0 comments on commit 7bec2d5

Please sign in to comment.