Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RSDK-8003: Unit tests #6

Merged
merged 12 commits into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions .github/workflows/publish.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
on:
push:
tags:
- '[0-9]+.[0-9]+.[0-9]+'
- '[0-9]+.[0-9]+.[0-9]+-rc'

jobs:
publish:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: viamrobotics/build-action@v4
with:
# note: you can replace this line with 'version: ""' if you want to test the build process without deploying
version: ${{ github.ref_name }}
ref: ${{ github.sha }}
key-id: ${{ secrets.viam_key_id }}
key-value: ${{ secrets.viam_key_value }}
10 changes: 8 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
name: Run lint
name: Run lint & unit tests


on:
push:
Expand All @@ -8,7 +9,8 @@ on:

jobs:
build:
name: "Run lint"

name: "Run unit tests"
runs-on: ubuntu-latest

steps:
Expand All @@ -20,3 +22,7 @@ jobs:
- name: Run lint
run: |
make lint

- name: Run unit tests
run: make test

25 changes: 25 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@

__pycache__
src/model_inspector/__pycache__
src/model/__pycache__

.venv

build

dist

lib

main.spec
pyvenv.cfg
bin/

tests/.pytest_cache

.pytest_cache

.DS_Store



12 changes: 12 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
PYTHONPATH := ./torch:$(PYTHONPATH)

.PHONY: test lint build dist

test:
PYTHONPATH=$(PYTHONPATH) pytest src/
lint:
pylint --disable=E1101,W0719,C0202,R0801,W0613,C0411 src/
build:
./build.sh
dist/archive.tar.gz:
tar -czvf dist/archive.tar.gz dist/__main__
23 changes: 23 additions & 0 deletions build.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#!/bin/bash
set -e
UNAME=$(uname -s)
if [ "$UNAME" = "Linux" ]
then
if dpkg -l python3-venv; then
echo "python3-venv is installed, skipping setup"
else
echo "Installing venv on Linux"
sudo apt-get install -y python3-venv
fi
fi
if [ "$UNAME" = "Darwin" ]
then
echo "Installing venv on Darwin"
brew install virtualenv
fi
source .env
python3 -m venv .venv .
source .venv/bin/activate
pip3 install -r requirements.txt
python3 -m PyInstaller --onefile --hidden-import="googleapiclient" src/main.py
tar -czvf dist/archive.tar.gz dist/main
18 changes: 18 additions & 0 deletions meta.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
{
"module_id": "viam:torch-cpu",
"visibility": "public",
"url": "https://github.com/viam-labs/torch",
"description": "Viam ML Module service serving PyTorch models.",
"models": [
{
"api": "rdk:service:mlmodel",
"model": "viam:mlmodel:torch-cpu"
}
],
"build": {
"build": "./build.sh",
"path": "dist/archive.tar.gz",
"arch": ["linux/arm64", "linux/amd64"]
},
"entrypoint": "dist/main"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the entrypoint dist/__main__ or dist/main ? it seems like the the Makefile and the meta.json do different things, they should be consistent

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's my bad, fixing this

}
8 changes: 6 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
viam-sdk
numpy
typing-extensions
numpy<2.0.0
pylint
pyinstaller
google-api-python-client
torch==2.2.1
pytest
torchvision
torch==2.2.1
9 changes: 5 additions & 4 deletions run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@ set -euxo pipefail

cd $(dirname $0)
exec dist/main $@
# source .env
# ./setup.sh

source .env
./setup.sh
source .venv/bin/activate
# # Be sure to use `exec` so that termination signals reach the python process,
# # or handle forwarding termination signals manually
# exec $PYTHON -m src.main $@
echo which python3
python3 -m src.main $@
5 changes: 4 additions & 1 deletion src/main.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
"""Main runner method of the module"""

import asyncio


from viam.module.module import Module
from viam.resource.registry import Registry, ResourceCreatorRegistration
from torch_mlmodel_module import TorchMLModelModule
from viam.services.mlmodel import MLModel

from torch_mlmodel_module import TorchMLModelModule


async def main():
"""
Expand Down
54 changes: 37 additions & 17 deletions src/model/model.py
Original file line number Diff line number Diff line change
@@ -1,73 +1,93 @@
import torch
"""
This module provides a class for loading and performing inference with a PyTorch model.
The TorchModel class handles loading a serialized model, preparing inputs, and wrapping outputs.
"""

import os
from typing import List, Iterable, Dict, Any
from numpy.typing import NDArray
import torch.nn as nn
from collections import OrderedDict

from numpy.typing import NDArray
from viam.logging import getLogger
import os

import torch
from torch import nn


LOGGER = getLogger(__name__)


class TorchModel:
"""
A class to load a PyTorch model from a serialized file or use a provided model,
prepare inputs for the model, perform inference, and wrap the outputs.
"""

def __init__(
self,
path_to_serialized_file: str,
model: nn.Module = None,
) -> None:
"Initializes the model by loading it from a serialized file or using a provided model."
if model is not None:
self.model = model
else:
sizeMB = os.stat(path_to_serialized_file).st_size / (1024 * 1024)
if sizeMB > 500:
size_mb = os.stat(path_to_serialized_file).st_size / (1024 * 1024)
if size_mb > 500:
# pylint: disable=deprecated-method
LOGGER.warn(
"model file may be large for certain hardware ("
+ str(sizeMB)
+ "MB)"
"model file may be large for certain hardware (%s MB)", size_mb
)
self.model = torch.load(path_to_serialized_file)
if not isinstance(self.model, nn.Module):
if isinstance(self.model, OrderedDict):
LOGGER.error(
f"the file {path_to_serialized_file} provided as model file is of type collections.OrderedDict, which suggests that the provided file describes weights instead of a standalone model"
"""the file %s provided as model file
is of type collections.OrderedDict,
which suggests that the provided file
describes weights instead of a standalone model""",
path_to_serialized_file,
)
raise TypeError(
f"the model is of type {type(self.model)} instead of nn.Module type"
)
self.model.eval()

def infer(self, input):
input = self.prepare_input(input)
def infer(self, input_data):
"Prepares the input data, performs inference using the model, and wraps the output."
input_data = self.prepare_input(input_data)
with torch.no_grad():
output = self.model(*input)
output = self.model(*input_data)
return self.wrap_output(output)

@staticmethod
def prepare_input(input_tensor: Dict[str, NDArray]) -> List[NDArray]:
"Converts a dictionary of NumPy arrays into a list of PyTorch tensors."
return [torch.from_numpy(tensor) for tensor in input_tensor.values()]

@staticmethod
def wrap_output(output: Any) -> Dict[str, NDArray]:
"Converts the output from a PyTorch model to a dictionary of NumPy arrays."
if isinstance(output, Iterable):
if len(output) == 1:
output = output[0] # unpack batched results

if isinstance(output, torch.Tensor):
return {"output_0": output.numpy()}

elif isinstance(output, dict):
if isinstance(output, dict):
for tensor_name, tensor in output.items():
if isinstance(tensor, torch.Tensor):
output[tensor_name] = tensor.numpy()

return output
elif isinstance(output, Iterable):

if isinstance(output, Iterable):
res = {}
count = 0
for out in output:
res[f"output_{count}"] = out
count += 1
return res

else:
raise TypeError(f"can't convert output of type {type(output)} to array")
raise TypeError(f"can't convert output of type {type(output)} to array")
Loading