diff --git a/tfjs-converter/python/tensorflowjs/converters/common.py b/tfjs-converter/python/tensorflowjs/converters/common.py index 2f76019d8b5..388b59e26a8 100644 --- a/tfjs-converter/python/tensorflowjs/converters/common.py +++ b/tfjs-converter/python/tensorflowjs/converters/common.py @@ -20,6 +20,7 @@ # JSON string keys for fields of the indexing JSON. ARTIFACT_MODEL_TOPOLOGY_KEY = 'modelTopology' +ARTIFACT_MODEL_INITIALIZER = 'modelInitializer' ARTIFACT_WEIGHTS_MANIFEST_KEY = 'weightsManifest' FORMAT_KEY = 'format' diff --git a/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2.py b/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2.py index f1d8f6e69ef..f52f7ebd9b9 100644 --- a/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2.py +++ b/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2.py @@ -112,7 +112,9 @@ def _run_grappler(config, graph_def, graph, signature_def): def optimize_graph(graph, signature_def, output_graph, tf_version, quantization_dtype_map=None, skip_op_check=False, strip_debug_ops=False, - weight_shard_size_bytes=1024 * 1024 * 4, experiments=False): + weight_shard_size_bytes=1024 * 1024 * 4, + experiments=False, + initializer_graph=None): """Takes a Python Graph object and optimizes the graph. Args: @@ -127,6 +129,7 @@ def optimize_graph(graph, signature_def, output_graph, strip_debug_ops: Bool whether to strip debug ops. weight_shard_size_bytes: Shard size (in bytes) of the weight files. The size of each weight file will be <= this value. + initializer_graph: The frozen graph for initializers. """ # Add a collection 'train_op' so that Grappler knows the outputs. @@ -173,6 +176,7 @@ def optimize_graph(graph, signature_def, output_graph, ] optimized_graph = _run_grappler(config, optimized_graph, graph, signature_def) + optimized_graph = _remove_unused_control_flow_inputs(optimized_graph) # Because TF break the Prelu op into 6 ops, for performance we are @@ -194,9 +198,15 @@ def optimize_graph(graph, signature_def, output_graph, raise ValueError('Unsupported Ops in the model after optimization\n' + ', '.join(unsupported)) + initializer_graph_def = None + if initializer_graph: + initializer_graph_def = initializer_graph.as_graph_def() + extract_weights( optimized_graph, output_graph, tf_version, - signature_def, quantization_dtype_map, weight_shard_size_bytes) + signature_def, quantization_dtype_map, weight_shard_size_bytes, + initializer_graph_def) + return optimize_graph @@ -235,7 +245,8 @@ def extract_weights(graph_def, tf_version, signature_def, quantization_dtype_map=None, - weight_shard_size_bytes=1024 * 1024 * 4): + weight_shard_size_bytes=1024 * 1024 * 4, + initializer_graph_def=None): """Takes a Python GraphDef object and extract the weights. Args: @@ -249,6 +260,7 @@ def extract_weights(graph_def, supports wildcard substitution. weight_shard_size_bytes: Shard size (in bytes) of the weight files. The size of each weight file will be <= this value. + initializer_graph_def: tf.GraphDef proto object for initializer graph. """ global_manifest = extract_const_nodes(graph_def.node) @@ -260,15 +272,21 @@ def extract_weights(graph_def, func.node_def.extend(nodes) function_manifests += extract_const_nodes(func.node_def) + initializer_manifests = [] + if initializer_graph_def: + initializer_manifests = extract_const_nodes(initializer_graph_def.node) + print('Writing weight file ' + output_graph + '...') write_artifacts(MessageToDict(graph_def), - [global_manifest + function_manifests], + [global_manifest + + function_manifests + + initializer_manifests], output_graph, tf_version, signature_def, quantization_dtype_map=quantization_dtype_map, - weight_shard_size_bytes=weight_shard_size_bytes) - + weight_shard_size_bytes=weight_shard_size_bytes, + initializer_graph_def=initializer_graph_def) def write_artifacts(topology, weights, @@ -276,7 +294,8 @@ def write_artifacts(topology, tf_version, signature_def, quantization_dtype_map=None, - weight_shard_size_bytes=1024 * 1024 * 4): + weight_shard_size_bytes=1024 * 1024 * 4, + initializer_graph_def=None): """Writes weights and topology to the output_dir. If `topology` is Falsy (e.g., `None`), only emit weights to output_dir. @@ -293,8 +312,8 @@ def write_artifacts(topology, supports wildcard substitution. weight_shard_size_bytes: Shard size (in bytes) of the weight files. The size of each weight file will be <= this value. + initializer_graph_def: tf.GraphDef proto object for initializer graph. """ - model_json = { common.FORMAT_KEY: common.TFJS_GRAPH_MODEL_FORMAT, # TODO(piyu): Add tensorflow version below by using `meta_info_def`. @@ -305,6 +324,11 @@ def write_artifacts(topology, } } model_json[common.ARTIFACT_MODEL_TOPOLOGY_KEY] = topology or None + + if initializer_graph_def: + model_json[common.ARTIFACT_MODEL_INITIALIZER] = MessageToDict( + initializer_graph_def) + weights_manifest = write_weights.write_weights( weights, os.path.dirname(output_graph), write_manifest=False, quantization_dtype_map=quantization_dtype_map, @@ -337,18 +361,50 @@ def _check_signature_in_model(saved_model, signature_name): def _freeze_saved_model_v1(saved_model_dir, saved_model_tags, output_node_names): + """Freeze the graph by converting variables to constants for 1.x saved model. + + Args: + saved_model_dir: dir where saved model files are stored. + saved_model_tags: inference graph tag. + output_node_names: List of name strings for the result nodes of the graph. + + Returns: + A freezed and optimized graph. + Nullable. A freezed and optimized initializer graph. + Nullable. A list of output node names of initializer. + """ g = tf.Graph() with g.as_default(): with tf.compat.v1.Session() as sess: - loader.load(sess, saved_model_tags, saved_model_dir) + meta_graph = loader.load(sess, saved_model_tags, saved_model_dir) + + meta_graph_def = g.as_graph_def() + frozen_graph_def = tf.compat.v1.graph_util.convert_variables_to_constants( - sess, g.as_graph_def(), output_node_names) + sess, meta_graph_def, output_node_names) frozen_graph = tf.Graph() with frozen_graph.as_default(): tf.import_graph_def(frozen_graph_def, name='') - return frozen_graph + frozen_initializer_graph = None + initializer_output_names = None + # Only support table initializers for now. + if meta_graph.collection_def and meta_graph.collection_def[ + 'table_initializer']: + initializer_output_names = meta_graph.collection_def[ + 'table_initializer'].node_list.value + # This will use grappler to extract a subgraph with the + # table initializer ops as the outputs. + frozen_initializer_graph_def = (tf.compat.v1.graph_util + .convert_variables_to_constants( + sess, meta_graph_def, + initializer_output_names)) + frozen_initializer_graph = tf.Graph() + with frozen_initializer_graph.as_default(): + tf.import_graph_def(frozen_initializer_graph_def, name='') + + return frozen_graph, frozen_initializer_graph def _freeze_saved_model_v2(concrete_func, control_flow_v2=False): if tf.__version__ < '2.2.0': @@ -439,7 +495,8 @@ def convert_tf_saved_model(saved_model_dir, skip_op_check=False, strip_debug_ops=False, weight_shard_size_bytes=1024 * 1024 * 4, - control_flow_v2=False, experiments=False): + control_flow_v2=False, + experiments=False): """Freeze the SavedModel and check the model compatibility with Tensorflow.js. Optimize and convert the model to Tensorflow.js format, when the model passes @@ -483,18 +540,22 @@ def convert_tf_saved_model(saved_model_dir, _check_signature_in_model(model, signature_def) concrete_func = model.signatures[signature_def] + output_node_names = [] for output_tensor in concrete_func.outputs: output_node_names.append(output_tensor.name.split(':')[0]) - # TensorFlow doesn't encode the saved model version in the graph in a reliable - # way. Try to freeze the graph using V2 utils. If that fails, freeze the - # graph using V1 utils. + # TensorFlow doesn't encode the saved model version in the graph in a + # reliable way. Try to freeze the graph using V2 utils. If that fails, freeze + # the graph using V1 utils. + frozen_initializer_graph = None try: frozen_graph = _freeze_saved_model_v2(concrete_func, control_flow_v2) except BaseException: - frozen_graph = _freeze_saved_model_v1(saved_model_dir, saved_model_tags, - output_node_names) + (frozen_graph, + frozen_initializer_graph) = _freeze_saved_model_v1(saved_model_dir, + saved_model_tags, + output_node_names) inputs = [x for x in concrete_func.inputs if not x.dtype == 'resource'] signature = _build_signature_def( @@ -549,7 +610,8 @@ def _strip_unused_nodes(frozen_graph, concrete_func, output_node_names): skip_op_check=skip_op_check, strip_debug_ops=strip_debug_ops, weight_shard_size_bytes=weight_shard_size_bytes, - experiments=experiments) + experiments=experiments, + initializer_graph=frozen_initializer_graph) def load_and_initialize_hub_module(module_path, signature='default'): """Loads graph of a TF-Hub module and initializes it into a session. diff --git a/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2_test.py b/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2_test.py index c152312160c..e72f72e81a3 100644 --- a/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2_test.py +++ b/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2_test.py @@ -360,7 +360,11 @@ def test_convert_saved_model_v1_with_hashtable(self): expected_weights_manifest = [{ 'paths': ['group1-shard1of1.bin'], - 'weights': [{'dtype': 'float32', 'name': 'w', 'shape': [2, 2]}]}] + 'weights': [ + {'dtype': 'float32', 'name': 'w', 'shape': [2, 2]}, + {'dtype': 'string', 'name': 'Const', 'shape': [1]}, + {'dtype': 'int32', 'name': 'Const_1', 'shape': [1]} + ]}] tfjs_path = os.path.join(self._tmp_dir, SAVED_MODEL_DIR, 'js') # Check model.json and weights manifest. @@ -372,7 +376,7 @@ def test_convert_saved_model_v1_with_hashtable(self): self.assertIsNot(signature, None) self.assertIsNot(signature['inputs'], None) self.assertIsNot(signature['outputs'], None) - + self.assertTrue(model_json['modelInitializer']) weights_manifest = model_json['weightsManifest'] self.assertEqual(weights_manifest, expected_weights_manifest)