Skip to content

Commit

Permalink
fixed linting issues
Browse files Browse the repository at this point in the history
  • Loading branch information
dhritinaidu committed Jul 12, 2024
1 parent d913199 commit 298b190
Show file tree
Hide file tree
Showing 9 changed files with 86 additions and 43 deletions.
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ PYTHONPATH := ./torch:$(PYTHONPATH)
.PHONY: test lint setup dist

test:
PYTHONPATH=$(PYTHONPATH) pytest tests/
PYTHONPATH=$(PYTHONPATH) pytest src/
lint:
pylint --disable=E1101,W0719,C0202,R0801,W0613 src/
pylint --disable=E1101,W0719,C0202,R0801,W0613,C0411 src/
setup:
python3 -m pip install -r requirements.txt -U
dist/archive.tar.gz:
Expand Down
1 change: 0 additions & 1 deletion src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,4 @@ async def main():


if __name__ == "__main__":

asyncio.run(main())
5 changes: 2 additions & 3 deletions src/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ def __init__(
if size_mb > 500:
# pylint: disable=deprecated-method
LOGGER.warn(
"model file may be large for certain hardware (%s MB)",
size_mb
"model file may be large for certain hardware (%s MB)", size_mb
)
self.model = torch.load(path_to_serialized_file)
if not isinstance(self.model, nn.Module):
Expand All @@ -47,7 +46,7 @@ def __init__(
is of type collections.OrderedDict,
which suggests that the provided file
describes weights instead of a standalone model""",
path_to_serialized_file
path_to_serialized_file,
)
raise TypeError(
f"the model is of type {type(self.model)} instead of nn.Module type"
Expand Down
4 changes: 2 additions & 2 deletions src/model_inspector/input_size_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import Dict, Tuple
from viam.logging import getLogger
from torch import nn
from src.model_inspector.utils import is_defined_shape
from model_inspector.utils import is_defined_shape


LOGGER = getLogger(__name__)
Expand Down Expand Up @@ -312,6 +312,6 @@ def get_input_size(
"For layer type %s, with output shape: %s input shape found is %s",
type(layer).__name__,
output_shape,
input_shape
input_shape,
)
return input_shape
19 changes: 10 additions & 9 deletions src/model_inspector/input_tester.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,29 @@
"A class for testing input shapes on a PyTorch model."
from typing import List, Optional, Dict
from src.model_inspector.utils import is_defined_shape, output_to_shape_dict
from model_inspector.utils import is_defined_shape, output_to_shape_dict
import torch


class InputTester:
"""This class provides methods to test various input shapes
on a given PyTorch model and collect information
about working and non-working input sizes."""
"""This class provides methods to test various input shapes
on a given PyTorch model and collect information
about working and non-working input sizes."""

def __init__(self, model, input_candidate=None):
"""
A class for testing input shapes on a PyTorch model.
This class provides methods to test various input shapes
This class provides methods to test various input shapes
on a given PyTorch model and collect information
about working and non-working input sizes.
Args:
model (torch.nn.Module): The PyTorch model to be tested.
Note:
The try_image_input and try_audio_input methods test the
The try_image_input and try_audio_input methods test the
model with predefined input sizes for image-like and
audio-like data, respectively. The get_shapes method retrieves the final input
audio-like data, respectively. The get_shapes method retrieves the final input
and output shapes after testing various input sizes.
"""
self.model = model
Expand Down Expand Up @@ -123,7 +124,7 @@ def test_input_size(self, input_size):
output = None
try:
output = self.model.infer(input_tensor)
except Exception: #pylint: disable=(broad-exception-caught)
except Exception: # pylint: disable=(broad-exception-caught)
pass
if output is not None:
self.working_input_sizes["input"].append(input_size)
Expand All @@ -136,7 +137,7 @@ def test_input_size(self, input_size):
self.working_output_sizes[output] = [shape]

def try_inputs(self):
"""Test candidate input (if provided),
"""Test candidate input (if provided),
image-like inputs, and audio-like inputs."""
if self.input_candidate:
if is_defined_shape(self.input_candidate):
Expand Down
9 changes: 6 additions & 3 deletions src/model_inspector/inspector.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from viam.logging import getLogger
from viam.utils import dict_to_struct

from src.model_inspector.input_size_calculator import InputSizeCalculator
from src.model_inspector.input_tester import InputTester
from model_inspector.input_size_calculator import InputSizeCalculator
from model_inspector.input_tester import InputTester

from torch import nn
import torch
Expand All @@ -20,6 +20,7 @@

class Inspector:
"Inspector class for analyzing and gathering metadata from a PyTorch model."

def __init__(self, model: nn.Module) -> None:
# self.summary: ModelStatistics = summary(module, input_size=[1,3, 640,480])
self.model = model
Expand Down Expand Up @@ -85,7 +86,9 @@ def reverse_module(self):
output_shape = self.input_size_calculator.get_input_size(
module, input_shape
)
LOGGER.info("For module %s, the output shape is %s", module, output_shape)
LOGGER.info(
"For module %s, the output shape is %s", module, output_shape
)
else:
continue # sometimes some children are None

Expand Down
7 changes: 3 additions & 4 deletions src/model_inspector/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,20 @@
import torch



def is_valid_input_shape(model, input_shape, add_batch_dimension: bool = False):
"""
Check if the input shape is valid for a given PyTorch model.
Args:
model (torch.nn.Module): The PyTorch model to validate the input shape for.
input_shape (tuple): The shape of the input tensor.
input_shape (tuple): The shape of the input tensor.
It should be in the format (C, H, W) for image-like data,
where C is the number of channels, H is the height, and W is the width.
add_batch_dimension (bool, optional):
add_batch_dimension (bool, optional):
Whether to add a batch dimension to the input tensor. Default is False.
Returns:
list or None: A list representing the shape of the output tensor
list or None: A list representing the shape of the output tensor
if the input shape is valid for the model,
or None if an exception occurs during the model evaluation.
"""
Expand Down
65 changes: 53 additions & 12 deletions tests/test_local.py → src/test_local.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,52 @@
from google.protobuf.struct_pb2 import Struct
import torch
"Module for unit testing functionalities related to TorchModel and TorchMLModelModule."

from google.protobuf.struct_pb2 import Struct #pylint: disable=(no-name-in-module)

import unittest
from src.model.model import TorchModel
from src.model_inspector.inspector import Inspector
from src.torch_mlmodel_module import TorchMLModelModule
from viam.services.mlmodel import Metadata
from viam.proto.app.robot import ComponentConfig
from torchvision.models.detection.rpn import AnchorGenerator

from torchvision.models.detection import FasterRCNN

from torchvision.models import MobileNet_V2_Weights
import torchvision
import os
from torchvision.models.detection.rpn import AnchorGenerator
from typing import List, Iterable, Dict, Any, Mapping
from model.model import TorchModel
from model_inspector.inspector import Inspector
from torch_mlmodel_module import TorchMLModelModule
from viam.services.mlmodel import Metadata
from viam.proto.app.robot import ComponentConfig

from typing import Any, Mapping
import numpy as np
import os



import torch



def make_component_config(dictionary: Mapping[str, Any]) -> ComponentConfig:
" makes a mock config"
struct = Struct()
struct.update(dictionary)
return ComponentConfig(attributes=struct)


config = (
make_component_config({"model_path": "model path"}),
"received only one dimension attribute",
)


class TestInputs(unittest.IsolatedAsyncioTestCase):
"""
Unit tests for validating TorchModel and TorchMLModelModule functionalities.
"""
@staticmethod
def load_resnet_weights():
"""
Load ResNet weights from a serialized file.
"""
return TorchModel(
path_to_serialized_file=os.path.join(
"examples", "resnet_18", "resnet18-f37072fd.pth"
Expand All @@ -38,6 +55,9 @@ def load_resnet_weights():

@staticmethod
def load_standalone_resnet():
"""
Load a standalone ResNet model.
"""
return TorchModel(
path_to_serialized_file=os.path.join(
"examples", "resnet_18_scripted", "resnet-18.pt"
Expand All @@ -46,6 +66,9 @@ def load_standalone_resnet():

@staticmethod
def load_detector_from_torchvision():
"""
Load a detector model using torchvision.
"""
backbone = torchvision.models.mobilenet_v2(
weights=MobileNet_V2_Weights.DEFAULT
).features
Expand All @@ -66,28 +89,40 @@ def load_detector_from_torchvision():
model.eval()
return TorchModel(path_to_serialized_file=None, model=model)

def __init__(self, methodName: str = "runTest") -> None:
def __init__(self, methodName: str = "runTest") -> None: #pylint: disable=(useless-parent-delegation)
super().__init__(methodName)

async def test_validate(self):
"""
Test validation of configuration using TorchMLModelModule.
"""
response = TorchMLModelModule.validate_config(config=config[0])
self.assertEqual(response, [])

async def test_validate_empty_config(self):
"""
Test validation with an empty configuration.
"""
empty_config = make_component_config({})
with self.assertRaises(Exception) as excinfo:
await TorchMLModelModule.validate_config(config=empty_config)

self.assertIn(
"model_path can't be empty. model is required for torch mlmoded service module.",
str(excinfo.exception)
str(excinfo.exception),
)

def test_error_loading_weights(self):
"""
Test error handling when loading ResNet weights.
"""
with self.assertRaises(TypeError):
_ = self.load_resnet_weights()

def test_resnet_metadata(self):
"""
Test metadata retrieval for ResNet model.
"""
model: TorchModel = self.load_standalone_resnet()
x = torch.ones(3, 300, 400).unsqueeze(0)
output = model.infer({"any_input_name_you_want": x.numpy()})
Expand All @@ -107,6 +142,9 @@ def test_resnet_metadata(self):
self.assertTrue(output_checked)

def test_detector_metadata(self):
"""
Test metadata retrieval for detector model.
"""
model: TorchModel = self.load_detector_from_torchvision()
x = torch.ones(3, 300, 400).unsqueeze(0)
output = model.infer({"any_input_name_you_want": x.numpy()})
Expand All @@ -126,6 +164,9 @@ def test_detector_metadata(self):
self.assertTrue(output_checked)

def test_infer_method(self):
"""
Test inference method of the detector model.
"""
model: TorchModel = self.load_detector_from_torchvision()
x = torch.ones(3, 300, 400).unsqueeze(0)
output = model.infer({"input_name": x.numpy()})
Expand Down
15 changes: 8 additions & 7 deletions src/torch_mlmodel_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
from viam.resource.base import ResourceBase
from viam.utils import ValueTypes
from viam.logging import getLogger
from src.model.model import TorchModel
from src.model_inspector.inspector import Inspector
from model.model import TorchModel
from model_inspector.inspector import Inspector

LOGGER = getLogger(__name__)

Expand All @@ -27,6 +27,7 @@ class TorchMLModelModule(MLModel, Reconfigurable):
This class integrates a PyTorch model with Viam's MLModel and Reconfigurable interfaces,
providing functionality to create, configure, and use the model for inference.
"""

MODEL: ClassVar[Model] = Model(ModelFamily("viam", "mlmodel"), "torch-cpu")

def __init__(self, name: str):
Expand Down Expand Up @@ -60,6 +61,7 @@ def reconfigure(
self, config: ServiceConfig, dependencies: Mapping[ResourceName, ResourceBase]
):
"Reconfigure the service with the given configuration and dependencies."

# pylint: disable=too-many-return-statements
def get_attribute_from_config(attribute_name: str, default, of_type=None):
if attribute_name not in config.attributes.fields:
Expand Down Expand Up @@ -92,25 +94,24 @@ def get_attribute_from_config(attribute_name: str, default, of_type=None):
self.inspector = Inspector(self.torch_model)
self._metadata = self.inspector.find_metadata(label_file)


async def infer(
self, input_tensors: Dict[str, NDArray], *, timeout: Optional[float]
) -> Dict[str, NDArray]:
"""Take an already ordered input tensor as an array,
"""Take an already ordered input tensor as an array,
make an inference on the model, and return an output tensor map.
Args:
input_tensors (Dict[str, NDArray]):
input_tensors (Dict[str, NDArray]):
A dictionary of input flat tensors as specified in the metadata
Returns:
Dict[str, NDArray]:
Dict[str, NDArray]:
A dictionary of output flat tensors as specified in the metadata
"""
return self.torch_model.infer(input_tensors)

async def metadata(self, *, timeout: Optional[float]) -> Metadata:
"""Get the metadata (such as name, type, expected tensor/array shape,
"""Get the metadata (such as name, type, expected tensor/array shape,
inputs, and outputs) associated with the ML model.
Returns:
Expand Down

0 comments on commit 298b190

Please sign in to comment.