Skip to content

Commit

Permalink
Merge pull request #6 from dhritinaidu/unit-tests
Browse files Browse the repository at this point in the history
Unit tests
  • Loading branch information
dhritinaidu authored Jul 12, 2024
2 parents 9c0c409 + c5180ee commit 9ab097d
Show file tree
Hide file tree
Showing 18 changed files with 424 additions and 273 deletions.
18 changes: 18 additions & 0 deletions .github/workflows/publish.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
on:
push:
tags:
- '[0-9]+.[0-9]+.[0-9]+'
- '[0-9]+.[0-9]+.[0-9]+-rc'

jobs:
publish:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: viamrobotics/build-action@v4
with:
# note: you can replace this line with 'version: ""' if you want to test the build process without deploying
version: ${{ github.ref_name }}
ref: ${{ github.sha }}
key-id: ${{ secrets.viam_key_id }}
key-value: ${{ secrets.viam_key_value }}
10 changes: 8 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
name: Run lint
name: Run lint & unit tests


on:
push:
Expand All @@ -8,7 +9,8 @@ on:

jobs:
build:
name: "Run lint"

name: "Run unit tests"
runs-on: ubuntu-latest

steps:
Expand All @@ -20,3 +22,7 @@ jobs:
- name: Run lint
run: |
make lint
- name: Run unit tests
run: make test

25 changes: 25 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@

__pycache__
src/model_inspector/__pycache__
src/model/__pycache__

.venv

build

dist

lib

main.spec
pyvenv.cfg
bin/

tests/.pytest_cache

.pytest_cache

.DS_Store



12 changes: 12 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
PYTHONPATH := ./torch:$(PYTHONPATH)

.PHONY: test lint build dist

test:
PYTHONPATH=$(PYTHONPATH) pytest src/
lint:
pylint --disable=E1101,W0719,C0202,R0801,W0613,C0411 src/
build:
./build.sh
dist/archive.tar.gz:
tar -czvf dist/archive.tar.gz dist/main
23 changes: 23 additions & 0 deletions build.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#!/bin/bash
set -e
UNAME=$(uname -s)
if [ "$UNAME" = "Linux" ]
then
if dpkg -l python3-venv; then
echo "python3-venv is installed, skipping setup"
else
echo "Installing venv on Linux"
sudo apt-get install -y python3-venv
fi
fi
if [ "$UNAME" = "Darwin" ]
then
echo "Installing venv on Darwin"
brew install virtualenv
fi
source .env
python3 -m venv .venv .
source .venv/bin/activate
pip3 install -r requirements.txt
python3 -m PyInstaller --onefile --hidden-import="googleapiclient" src/main.py
tar -czvf dist/archive.tar.gz dist/main
18 changes: 18 additions & 0 deletions meta.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
{
"module_id": "viam:torch-cpu",
"visibility": "public",
"url": "https://github.com/viam-labs/torch",
"description": "Viam ML Module service serving PyTorch models.",
"models": [
{
"api": "rdk:service:mlmodel",
"model": "viam:mlmodel:torch-cpu"
}
],
"build": {
"build": "./build.sh",
"path": "dist/archive.tar.gz",
"arch": ["linux/arm64", "linux/amd64"]
},
"entrypoint": "dist/main"
}
8 changes: 6 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
viam-sdk
numpy
typing-extensions
numpy<2.0.0
pylint
pyinstaller
google-api-python-client
torch==2.2.1
pytest
torchvision
torch==2.2.1
9 changes: 5 additions & 4 deletions run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@ set -euxo pipefail

cd $(dirname $0)
exec dist/main $@
# source .env
# ./setup.sh

source .env
./setup.sh
source .venv/bin/activate
# # Be sure to use `exec` so that termination signals reach the python process,
# # or handle forwarding termination signals manually
# exec $PYTHON -m src.main $@
echo which python3
python3 -m src.main $@
5 changes: 4 additions & 1 deletion src/main.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
"""Main runner method of the module"""

import asyncio


from viam.module.module import Module
from viam.resource.registry import Registry, ResourceCreatorRegistration
from torch_mlmodel_module import TorchMLModelModule
from viam.services.mlmodel import MLModel

from torch_mlmodel_module import TorchMLModelModule


async def main():
"""
Expand Down
54 changes: 37 additions & 17 deletions src/model/model.py
Original file line number Diff line number Diff line change
@@ -1,73 +1,93 @@
import torch
"""
This module provides a class for loading and performing inference with a PyTorch model.
The TorchModel class handles loading a serialized model, preparing inputs, and wrapping outputs.
"""

import os
from typing import List, Iterable, Dict, Any
from numpy.typing import NDArray
import torch.nn as nn
from collections import OrderedDict

from numpy.typing import NDArray
from viam.logging import getLogger
import os

import torch
from torch import nn


LOGGER = getLogger(__name__)


class TorchModel:
"""
A class to load a PyTorch model from a serialized file or use a provided model,
prepare inputs for the model, perform inference, and wrap the outputs.
"""

def __init__(
self,
path_to_serialized_file: str,
model: nn.Module = None,
) -> None:
"Initializes the model by loading it from a serialized file or using a provided model."
if model is not None:
self.model = model
else:
sizeMB = os.stat(path_to_serialized_file).st_size / (1024 * 1024)
if sizeMB > 500:
size_mb = os.stat(path_to_serialized_file).st_size / (1024 * 1024)
if size_mb > 500:
# pylint: disable=deprecated-method
LOGGER.warn(
"model file may be large for certain hardware ("
+ str(sizeMB)
+ "MB)"
"model file may be large for certain hardware (%s MB)", size_mb
)
self.model = torch.load(path_to_serialized_file)
if not isinstance(self.model, nn.Module):
if isinstance(self.model, OrderedDict):
LOGGER.error(
f"the file {path_to_serialized_file} provided as model file is of type collections.OrderedDict, which suggests that the provided file describes weights instead of a standalone model"
"""the file %s provided as model file
is of type collections.OrderedDict,
which suggests that the provided file
describes weights instead of a standalone model""",
path_to_serialized_file,
)
raise TypeError(
f"the model is of type {type(self.model)} instead of nn.Module type"
)
self.model.eval()

def infer(self, input):
input = self.prepare_input(input)
def infer(self, input_data):
"Prepares the input data, performs inference using the model, and wraps the output."
input_data = self.prepare_input(input_data)
with torch.no_grad():
output = self.model(*input)
output = self.model(*input_data)
return self.wrap_output(output)

@staticmethod
def prepare_input(input_tensor: Dict[str, NDArray]) -> List[NDArray]:
"Converts a dictionary of NumPy arrays into a list of PyTorch tensors."
return [torch.from_numpy(tensor) for tensor in input_tensor.values()]

@staticmethod
def wrap_output(output: Any) -> Dict[str, NDArray]:
"Converts the output from a PyTorch model to a dictionary of NumPy arrays."
if isinstance(output, Iterable):
if len(output) == 1:
output = output[0] # unpack batched results

if isinstance(output, torch.Tensor):
return {"output_0": output.numpy()}

elif isinstance(output, dict):
if isinstance(output, dict):
for tensor_name, tensor in output.items():
if isinstance(tensor, torch.Tensor):
output[tensor_name] = tensor.numpy()

return output
elif isinstance(output, Iterable):

if isinstance(output, Iterable):
res = {}
count = 0
for out in output:
res[f"output_{count}"] = out
count += 1
return res

else:
raise TypeError(f"can't convert output of type {type(output)} to array")
raise TypeError(f"can't convert output of type {type(output)} to array")
Loading

0 comments on commit 9ab097d

Please sign in to comment.