diff --git a/.pylintrc b/.pylintrc index 9c86ada..ad34739 100644 --- a/.pylintrc +++ b/.pylintrc @@ -10,4 +10,5 @@ disable= too-few-public-methods, too-many-arguments, too-many-instance-attributes, + duplicate-code, invalid-name diff --git a/README.md b/README.md index d096855..324d462 100644 --- a/README.md +++ b/README.md @@ -176,6 +176,31 @@ python -m tf_bodypix \ --threshold=0.75 ``` +## TensorFlow Lite support (experimental) + +The model path may also point to a TensorFlow Lite model (`.tflite` extension). Whether that actually improves performance may depend on the platform and available hardware. + +You could convert one of the available TensorFlow JS models to TensorFlow Lite using the following command: + +```bash +python -m tf_bodypix \ + convert-to-tflite \ + --model-path \ + "https://storage.googleapis.com/tfjs-models/savedmodel/bodypix/mobilenet/float/075/model-stride16.json" \ + --optimize \ + --quantization-type=float16 \ + --output-model-file "./mobilenet-float16-stride16.tflite" +``` + +The above command is provided for convenience. +You may use alternative methods depending on your preference and requirements. + +Relevant links: + +* [TensorFlow Lite converter](https://www.tensorflow.org/lite/convert/) +* [TF Lite post_training_quantization](https://www.tensorflow.org/lite/performance/post_training_quantization) +* [TF GitHub #40183](https://github.com/tensorflow/tensorflow/issues/40183). + ## Acknowledgements * [Original TensorFlow JS Implementation of BodyPix](https://github.com/tensorflow/tfjs-models/tree/body-pix-v2.0.4/body-pix) diff --git a/tests/cli_test.py b/tests/cli_test.py index a02937b..a4d5d37 100644 --- a/tests/cli_test.py +++ b/tests/cli_test.py @@ -2,6 +2,7 @@ from pathlib import Path from tf_bodypix.download import BodyPixModelPaths +from tf_bodypix.model import ModelArchitectureNames from tf_bodypix.cli import main @@ -87,3 +88,22 @@ def test_should_list_all_default_model_urls(self, capsys): LOGGER.debug('output_urls: %s', output_urls) missing_urls = set(expected_urls) - set(output_urls) assert not missing_urls + + def test_should_be_able_to_convert_to_tflite_and_use_model(self, temp_dir: Path): + output_model_file = temp_dir / 'model.tflite' + main([ + 'convert-to-tflite', + '--model-path=%s' % BodyPixModelPaths.MOBILENET_FLOAT_75_STRIDE_16, + '--optimize', + '--quantization-type=int8', + '--output-model-file=%s' % output_model_file + ]) + output_image_path = temp_dir / 'mask.jpg' + main([ + 'draw-mask', + '--model-path=%s' % output_model_file, + '--model-architecture=%s' % ModelArchitectureNames.MOBILENET_V1, + '--output-stride=16', + '--source=%s' % EXAMPLE_IMAGE_URL, + '--output=%s' % output_image_path + ]) diff --git a/tf_bodypix/cli.py b/tf_bodypix/cli.py index c6f4c91..c4a90fa 100644 --- a/tf_bodypix/cli.py +++ b/tf_bodypix/cli.py @@ -5,6 +5,7 @@ from abc import ABC, abstractmethod from contextlib import ExitStack from itertools import cycle +from pathlib import Path from time import time from typing import Dict, List @@ -25,8 +26,10 @@ ) from tf_bodypix.utils.s3 import iter_s3_file_urls from tf_bodypix.download import download_model +from tf_bodypix.tflite import get_tflite_converter_for_model_path from tf_bodypix.model import ( load_model, + VALID_MODEL_ARCHITECTURE_NAMES, PART_CHANNELS, DEFAULT_RESIZE_METHOD, BodyPixModelWrapper, @@ -77,6 +80,14 @@ def add_model_arguments(parser: argparse.ArgumentParser): default=DEFAULT_MODEL_PATH, help="The path or URL to the bodypix model." ) + parser.add_argument( + "--model-architecture", + choices=VALID_MODEL_ARCHITECTURE_NAMES, + help=( + "The model architecture." + " It will be guessed from the model path if not specified." + ) + ) parser.add_argument( "--output-stride", type=int, @@ -219,7 +230,8 @@ def load_bodypix_model(args: argparse.Namespace) -> BodyPixModelWrapper: return load_model( local_model_path, internal_resolution=args.internal_resolution, - output_stride=args.output_stride + output_stride=args.output_stride, + architecture_name=args.model_architecture ) @@ -266,6 +278,52 @@ def run(self, args: argparse.Namespace): # pylint: disable=unused-argument print('\n'.join(bodypix_model_json_files)) +class ConvertToTFLiteSubCommand(SubCommand): + def __init__(self): + super().__init__("convert-to-tflite", "Converts the model to a tflite model") + + def add_arguments(self, parser: argparse.ArgumentParser): + add_common_arguments(parser) + parser.add_argument( + "--model-path", + default=DEFAULT_MODEL_PATH, + help="The path or URL to the bodypix model." + ) + parser.add_argument( + "--output-model-file", + required=True, + help="The path to the output file (tflite model)." + ) + parser.add_argument( + "--optimize", + action='store_true', + help="Enable optimization (quantization)." + ) + parser.add_argument( + "--quantization-type", + choices=['float16', 'float32', 'int8'], + help="The quantization type to use." + ) + + def run(self, args: argparse.Namespace): # pylint: disable=unused-argument + LOGGER.info('converting model: %s', args.model_path) + converter = get_tflite_converter_for_model_path(download_model( + args.model_path + )) + tflite_model = converter.convert() + if args.optimize: + LOGGER.info('enabled optimization') + converter.optimizations = [tf.lite.Optimize.DEFAULT] + if args.quantization_type: + LOGGER.info('quanization type: %s', args.quantization_type) + quantization_type = getattr(tf, args.quantization_type) + converter.target_spec.supported_types = [quantization_type] + converter.inference_input_type = quantization_type + converter.inference_output_type = quantization_type + LOGGER.info('saving tflite model to: %s', args.output_model_file) + Path(args.output_model_file).write_bytes(tflite_model) + + class AbstractWebcamFilterApp(ABC): def __init__(self, args: argparse.Namespace): self.args = args @@ -497,6 +555,7 @@ def get_app(self, args: argparse.Namespace) -> AbstractWebcamFilterApp: SUB_COMMANDS: List[SubCommand] = [ ListModelsSubCommand(), + ConvertToTFLiteSubCommand(), DrawMaskSubCommand(), BlurBackgroundSubCommand(), ReplaceBackgroundSubCommand() diff --git a/tf_bodypix/model.py b/tf_bodypix/model.py index bae5f40..b8e766e 100644 --- a/tf_bodypix/model.py +++ b/tf_bodypix/model.py @@ -341,6 +341,66 @@ def get_structured_output_names(structured_outputs: List[tf.Tensor]) -> List[str ] +def to_number_of_dimensions(data: np.ndarray, dimension_count: int) -> np.ndarray: + while len(data.shape) > dimension_count: + data = data[0] + while len(data.shape) < dimension_count: + data = np.expand_dims(data, axis=0) + return data + + +def load_tflite_model(model_path: str): + # Load TFLite model and allocate tensors. + interpreter = tf.lite.Interpreter(model_path=model_path) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + LOGGER.debug('input_details: %s', input_details) + input_names = [item['name'] for item in input_details] + LOGGER.debug('input_names: %s', input_names) + input_details_map = dict(zip(input_names, input_details)) + + output_details = interpreter.get_output_details() + LOGGER.debug('output_details: %s', output_details) + output_names = [item['name'] for item in output_details] + LOGGER.debug('output_names: %s', output_names) + + try: + image_input = input_details_map['image'] + except KeyError: + assert len(input_details_map) == 1 + image_input = list(input_details_map.values())[0] + input_shape = image_input['shape'] + LOGGER.debug('input_shape: %s', input_shape) + + def predict(image_data: np.ndarray): + nonlocal input_shape + image_data = to_number_of_dimensions(image_data, len(input_shape)) + LOGGER.debug('tflite predict, image_data.shape=%s (%s)', image_data.shape, image_data.dtype) + height, width, *_ = image_data.shape + if tuple(image_data.shape) != tuple(input_shape): + LOGGER.info('resizing input tensor: %s -> %s', tuple(input_shape), image_data.shape) + interpreter.resize_tensor_input(image_input['index'], list(image_data.shape)) + interpreter.allocate_tensors() + input_shape = image_data.shape + interpreter.set_tensor(image_input['index'], image_data) + if 'image_size' in input_details_map: + interpreter.set_tensor( + input_details_map['image_size']['index'], + np.array([height, width], dtype=np.float) + ) + + interpreter.invoke() + + # The function `get_tensor()` returns a copy of the tensor data. + # Use `tensor()` in order to get a pointer to the tensor. + return { + item['name']: interpreter.get_tensor(item['index']) + for item in output_details + } + return predict + + def load_using_saved_model_and_get_predict_function(model_path): loaded = tf.saved_model.load(model_path) LOGGER.debug('loaded: %s', loaded) @@ -366,6 +426,8 @@ def load_using_tfjs_graph_converter_and_get_predict_function( def load_model_and_get_predict_function( model_path: str ) -> Callable[[np.ndarray], dict]: + if model_path.endswith('.tflite'): + return load_tflite_model(model_path) try: return load_using_saved_model_and_get_predict_function(model_path) except OSError: @@ -373,17 +435,17 @@ def load_model_and_get_predict_function( def get_output_stride_from_model_path(model_path: str) -> int: - match = re.search(r'stride(\d+)', model_path) + match = re.search(r'stride(\d+)|_(\d+)_quant', model_path) if not match: raise ValueError('cannot extract output stride from model path: %r' % model_path) - return int(match.group(1)) + return int(match.group(1) or match.group(2)) def get_architecture_from_model_path(model_path: str) -> int: model_path_lower = model_path.lower() if 'mobilenet' in model_path_lower: return ModelArchitectureNames.MOBILENET_V1 - if 'resnet50' in model_path_lower: + if 'resnet' in model_path_lower: return ModelArchitectureNames.RESNET_50 raise ValueError('cannot extract model architecture from model path: %r' % model_path) diff --git a/tf_bodypix/tflite.py b/tf_bodypix/tflite.py new file mode 100644 index 0000000..6429f48 --- /dev/null +++ b/tf_bodypix/tflite.py @@ -0,0 +1,25 @@ +import logging + +import tensorflow as tf + +try: + import tfjs_graph_converter +except ImportError: + tfjs_graph_converter = None + + +LOGGER = logging.getLogger(__name__) + + +def get_tflite_converter_for_tfjs_model_path(model_path: str) -> tf.lite.TFLiteConverter: + if tfjs_graph_converter is None: + raise ImportError('tfjs_graph_converter required') + graph = tfjs_graph_converter.api.load_graph_model(model_path) + tf_fn = tfjs_graph_converter.api.graph_to_function_v2(graph) + return tf.lite.TFLiteConverter.from_concrete_functions([tf_fn]) + + +def get_tflite_converter_for_model_path(model_path: str) -> tf.lite.TFLiteConverter: + LOGGER.debug('converting model_path: %s', model_path) + # if model_path.endswith('.json'): + return get_tflite_converter_for_tfjs_model_path(model_path)