From 298b1903d779a8e3b015ca6072b7d3d930599679 Mon Sep 17 00:00:00 2001 From: dhritinaidu Date: Fri, 12 Jul 2024 16:15:32 -0400 Subject: [PATCH] fixed linting issues --- Makefile | 4 +- src/main.py | 1 - src/model/model.py | 5 +- src/model_inspector/input_size_calculator.py | 4 +- src/model_inspector/input_tester.py | 19 +++--- src/model_inspector/inspector.py | 9 ++- src/model_inspector/utils.py | 7 +-- {tests => src}/test_local.py | 65 ++++++++++++++++---- src/torch_mlmodel_module.py | 15 ++--- 9 files changed, 86 insertions(+), 43 deletions(-) rename {tests => src}/test_local.py (78%) diff --git a/Makefile b/Makefile index 31f9b93..b0e6834 100644 --- a/Makefile +++ b/Makefile @@ -3,9 +3,9 @@ PYTHONPATH := ./torch:$(PYTHONPATH) .PHONY: test lint setup dist test: - PYTHONPATH=$(PYTHONPATH) pytest tests/ + PYTHONPATH=$(PYTHONPATH) pytest src/ lint: - pylint --disable=E1101,W0719,C0202,R0801,W0613 src/ + pylint --disable=E1101,W0719,C0202,R0801,W0613,C0411 src/ setup: python3 -m pip install -r requirements.txt -U dist/archive.tar.gz: diff --git a/src/main.py b/src/main.py index c1cca5e..e22981f 100644 --- a/src/main.py +++ b/src/main.py @@ -30,5 +30,4 @@ async def main(): if __name__ == "__main__": - asyncio.run(main()) diff --git a/src/model/model.py b/src/model/model.py index 69fb950..f550b57 100644 --- a/src/model/model.py +++ b/src/model/model.py @@ -36,8 +36,7 @@ def __init__( if size_mb > 500: # pylint: disable=deprecated-method LOGGER.warn( - "model file may be large for certain hardware (%s MB)", - size_mb + "model file may be large for certain hardware (%s MB)", size_mb ) self.model = torch.load(path_to_serialized_file) if not isinstance(self.model, nn.Module): @@ -47,7 +46,7 @@ def __init__( is of type collections.OrderedDict, which suggests that the provided file describes weights instead of a standalone model""", - path_to_serialized_file + path_to_serialized_file, ) raise TypeError( f"the model is of type {type(self.model)} instead of nn.Module type" diff --git a/src/model_inspector/input_size_calculator.py b/src/model_inspector/input_size_calculator.py index e094591..4ad7147 100644 --- a/src/model_inspector/input_size_calculator.py +++ b/src/model_inspector/input_size_calculator.py @@ -10,7 +10,7 @@ from typing import Dict, Tuple from viam.logging import getLogger from torch import nn -from src.model_inspector.utils import is_defined_shape +from model_inspector.utils import is_defined_shape LOGGER = getLogger(__name__) @@ -312,6 +312,6 @@ def get_input_size( "For layer type %s, with output shape: %s input shape found is %s", type(layer).__name__, output_shape, - input_shape + input_shape, ) return input_shape diff --git a/src/model_inspector/input_tester.py b/src/model_inspector/input_tester.py index 051cbab..0995c3f 100644 --- a/src/model_inspector/input_tester.py +++ b/src/model_inspector/input_tester.py @@ -1,18 +1,19 @@ "A class for testing input shapes on a PyTorch model." from typing import List, Optional, Dict -from src.model_inspector.utils import is_defined_shape, output_to_shape_dict +from model_inspector.utils import is_defined_shape, output_to_shape_dict import torch class InputTester: - """This class provides methods to test various input shapes - on a given PyTorch model and collect information - about working and non-working input sizes.""" + """This class provides methods to test various input shapes + on a given PyTorch model and collect information + about working and non-working input sizes.""" + def __init__(self, model, input_candidate=None): """ A class for testing input shapes on a PyTorch model. - This class provides methods to test various input shapes + This class provides methods to test various input shapes on a given PyTorch model and collect information about working and non-working input sizes. @@ -20,9 +21,9 @@ def __init__(self, model, input_candidate=None): model (torch.nn.Module): The PyTorch model to be tested. Note: - The try_image_input and try_audio_input methods test the + The try_image_input and try_audio_input methods test the model with predefined input sizes for image-like and - audio-like data, respectively. The get_shapes method retrieves the final input + audio-like data, respectively. The get_shapes method retrieves the final input and output shapes after testing various input sizes. """ self.model = model @@ -123,7 +124,7 @@ def test_input_size(self, input_size): output = None try: output = self.model.infer(input_tensor) - except Exception: #pylint: disable=(broad-exception-caught) + except Exception: # pylint: disable=(broad-exception-caught) pass if output is not None: self.working_input_sizes["input"].append(input_size) @@ -136,7 +137,7 @@ def test_input_size(self, input_size): self.working_output_sizes[output] = [shape] def try_inputs(self): - """Test candidate input (if provided), + """Test candidate input (if provided), image-like inputs, and audio-like inputs.""" if self.input_candidate: if is_defined_shape(self.input_candidate): diff --git a/src/model_inspector/inspector.py b/src/model_inspector/inspector.py index dfc8ebf..f1a2fa6 100644 --- a/src/model_inspector/inspector.py +++ b/src/model_inspector/inspector.py @@ -8,8 +8,8 @@ from viam.logging import getLogger from viam.utils import dict_to_struct -from src.model_inspector.input_size_calculator import InputSizeCalculator -from src.model_inspector.input_tester import InputTester +from model_inspector.input_size_calculator import InputSizeCalculator +from model_inspector.input_tester import InputTester from torch import nn import torch @@ -20,6 +20,7 @@ class Inspector: "Inspector class for analyzing and gathering metadata from a PyTorch model." + def __init__(self, model: nn.Module) -> None: # self.summary: ModelStatistics = summary(module, input_size=[1,3, 640,480]) self.model = model @@ -85,7 +86,9 @@ def reverse_module(self): output_shape = self.input_size_calculator.get_input_size( module, input_shape ) - LOGGER.info("For module %s, the output shape is %s", module, output_shape) + LOGGER.info( + "For module %s, the output shape is %s", module, output_shape + ) else: continue # sometimes some children are None diff --git a/src/model_inspector/utils.py b/src/model_inspector/utils.py index fb9d5af..b8bf56c 100644 --- a/src/model_inspector/utils.py +++ b/src/model_inspector/utils.py @@ -6,21 +6,20 @@ import torch - def is_valid_input_shape(model, input_shape, add_batch_dimension: bool = False): """ Check if the input shape is valid for a given PyTorch model. Args: model (torch.nn.Module): The PyTorch model to validate the input shape for. - input_shape (tuple): The shape of the input tensor. + input_shape (tuple): The shape of the input tensor. It should be in the format (C, H, W) for image-like data, where C is the number of channels, H is the height, and W is the width. - add_batch_dimension (bool, optional): + add_batch_dimension (bool, optional): Whether to add a batch dimension to the input tensor. Default is False. Returns: - list or None: A list representing the shape of the output tensor + list or None: A list representing the shape of the output tensor if the input shape is valid for the model, or None if an exception occurs during the model evaluation. """ diff --git a/tests/test_local.py b/src/test_local.py similarity index 78% rename from tests/test_local.py rename to src/test_local.py index 8156453..bdf7bb1 100644 --- a/tests/test_local.py +++ b/src/test_local.py @@ -1,26 +1,37 @@ -from google.protobuf.struct_pb2 import Struct -import torch +"Module for unit testing functionalities related to TorchModel and TorchMLModelModule." + +from google.protobuf.struct_pb2 import Struct #pylint: disable=(no-name-in-module) + import unittest -from src.model.model import TorchModel -from src.model_inspector.inspector import Inspector -from src.torch_mlmodel_module import TorchMLModelModule -from viam.services.mlmodel import Metadata -from viam.proto.app.robot import ComponentConfig +from torchvision.models.detection.rpn import AnchorGenerator + from torchvision.models.detection import FasterRCNN from torchvision.models import MobileNet_V2_Weights import torchvision -import os -from torchvision.models.detection.rpn import AnchorGenerator -from typing import List, Iterable, Dict, Any, Mapping +from model.model import TorchModel +from model_inspector.inspector import Inspector +from torch_mlmodel_module import TorchMLModelModule +from viam.services.mlmodel import Metadata +from viam.proto.app.robot import ComponentConfig + +from typing import Any, Mapping import numpy as np +import os + + + +import torch + def make_component_config(dictionary: Mapping[str, Any]) -> ComponentConfig: + " makes a mock config" struct = Struct() struct.update(dictionary) return ComponentConfig(attributes=struct) + config = ( make_component_config({"model_path": "model path"}), "received only one dimension attribute", @@ -28,8 +39,14 @@ def make_component_config(dictionary: Mapping[str, Any]) -> ComponentConfig: class TestInputs(unittest.IsolatedAsyncioTestCase): + """ + Unit tests for validating TorchModel and TorchMLModelModule functionalities. + """ @staticmethod def load_resnet_weights(): + """ + Load ResNet weights from a serialized file. + """ return TorchModel( path_to_serialized_file=os.path.join( "examples", "resnet_18", "resnet18-f37072fd.pth" @@ -38,6 +55,9 @@ def load_resnet_weights(): @staticmethod def load_standalone_resnet(): + """ + Load a standalone ResNet model. + """ return TorchModel( path_to_serialized_file=os.path.join( "examples", "resnet_18_scripted", "resnet-18.pt" @@ -46,6 +66,9 @@ def load_standalone_resnet(): @staticmethod def load_detector_from_torchvision(): + """ + Load a detector model using torchvision. + """ backbone = torchvision.models.mobilenet_v2( weights=MobileNet_V2_Weights.DEFAULT ).features @@ -66,28 +89,40 @@ def load_detector_from_torchvision(): model.eval() return TorchModel(path_to_serialized_file=None, model=model) - def __init__(self, methodName: str = "runTest") -> None: + def __init__(self, methodName: str = "runTest") -> None: #pylint: disable=(useless-parent-delegation) super().__init__(methodName) async def test_validate(self): + """ + Test validation of configuration using TorchMLModelModule. + """ response = TorchMLModelModule.validate_config(config=config[0]) self.assertEqual(response, []) async def test_validate_empty_config(self): + """ + Test validation with an empty configuration. + """ empty_config = make_component_config({}) with self.assertRaises(Exception) as excinfo: await TorchMLModelModule.validate_config(config=empty_config) self.assertIn( "model_path can't be empty. model is required for torch mlmoded service module.", - str(excinfo.exception) + str(excinfo.exception), ) def test_error_loading_weights(self): + """ + Test error handling when loading ResNet weights. + """ with self.assertRaises(TypeError): _ = self.load_resnet_weights() def test_resnet_metadata(self): + """ + Test metadata retrieval for ResNet model. + """ model: TorchModel = self.load_standalone_resnet() x = torch.ones(3, 300, 400).unsqueeze(0) output = model.infer({"any_input_name_you_want": x.numpy()}) @@ -107,6 +142,9 @@ def test_resnet_metadata(self): self.assertTrue(output_checked) def test_detector_metadata(self): + """ + Test metadata retrieval for detector model. + """ model: TorchModel = self.load_detector_from_torchvision() x = torch.ones(3, 300, 400).unsqueeze(0) output = model.infer({"any_input_name_you_want": x.numpy()}) @@ -126,6 +164,9 @@ def test_detector_metadata(self): self.assertTrue(output_checked) def test_infer_method(self): + """ + Test inference method of the detector model. + """ model: TorchModel = self.load_detector_from_torchvision() x = torch.ones(3, 300, 400).unsqueeze(0) output = model.infer({"input_name": x.numpy()}) diff --git a/src/torch_mlmodel_module.py b/src/torch_mlmodel_module.py index a50766c..ff8b90b 100644 --- a/src/torch_mlmodel_module.py +++ b/src/torch_mlmodel_module.py @@ -16,8 +16,8 @@ from viam.resource.base import ResourceBase from viam.utils import ValueTypes from viam.logging import getLogger -from src.model.model import TorchModel -from src.model_inspector.inspector import Inspector +from model.model import TorchModel +from model_inspector.inspector import Inspector LOGGER = getLogger(__name__) @@ -27,6 +27,7 @@ class TorchMLModelModule(MLModel, Reconfigurable): This class integrates a PyTorch model with Viam's MLModel and Reconfigurable interfaces, providing functionality to create, configure, and use the model for inference. """ + MODEL: ClassVar[Model] = Model(ModelFamily("viam", "mlmodel"), "torch-cpu") def __init__(self, name: str): @@ -60,6 +61,7 @@ def reconfigure( self, config: ServiceConfig, dependencies: Mapping[ResourceName, ResourceBase] ): "Reconfigure the service with the given configuration and dependencies." + # pylint: disable=too-many-return-statements def get_attribute_from_config(attribute_name: str, default, of_type=None): if attribute_name not in config.attributes.fields: @@ -92,25 +94,24 @@ def get_attribute_from_config(attribute_name: str, default, of_type=None): self.inspector = Inspector(self.torch_model) self._metadata = self.inspector.find_metadata(label_file) - async def infer( self, input_tensors: Dict[str, NDArray], *, timeout: Optional[float] ) -> Dict[str, NDArray]: - """Take an already ordered input tensor as an array, + """Take an already ordered input tensor as an array, make an inference on the model, and return an output tensor map. Args: - input_tensors (Dict[str, NDArray]): + input_tensors (Dict[str, NDArray]): A dictionary of input flat tensors as specified in the metadata Returns: - Dict[str, NDArray]: + Dict[str, NDArray]: A dictionary of output flat tensors as specified in the metadata """ return self.torch_model.infer(input_tensors) async def metadata(self, *, timeout: Optional[float]) -> Metadata: - """Get the metadata (such as name, type, expected tensor/array shape, + """Get the metadata (such as name, type, expected tensor/array shape, inputs, and outputs) associated with the ML model. Returns: