Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New model #195

Merged
merged 6 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ dependencies = [
"pyroclient @ git+https://github.com/pyronear/pyro-api.git@main#egg=pkg&subdirectory=client",
"requests>=2.20.0,<3.0.0",
"opencv-python==4.5.5.64",
"tqdm>=4.62.0",
]

[project.optional-dependencies]
Expand Down
14 changes: 11 additions & 3 deletions pyroengine/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@

import cv2 # type: ignore[import-untyped]
import numpy as np
from tqdm import tqdm

__all__ = ["letterbox", "nms", "xywh2xyxy"]
__all__ = ["letterbox", "nms", "xywh2xyxy", "DownloadProgressBar"]


def xywh2xyxy(x: np.ndarray):
Expand All @@ -20,14 +21,14 @@ def xywh2xyxy(x: np.ndarray):


def letterbox(
im: np.ndarray, new_shape: tuple = (640, 640), color: tuple = (114, 114, 114), auto: bool = False, stride: int = 32
im: np.ndarray, new_shape: tuple = (640, 640), color: tuple = (0, 0, 0), auto: bool = False, stride: int = 32
):
"""Letterbox image transform for yolo models
Args:
im (np.ndarray): Input image
new_shape (tuple, optional): Image size. Defaults to (640, 640).
color (tuple, optional): Pixel fill value for the area outside the transformed image.
Defaults to (114, 114, 114).
Defaults to (0, 0, 0).
auto (bool, optional): auto padding. Defaults to True.
stride (int, optional): padding stride. Defaults to 32.
Returns:
Expand Down Expand Up @@ -109,3 +110,10 @@ def nms(boxes: np.ndarray, overlapThresh: int = 0):
indices = indices[indices != i]

return boxes[indices]


class DownloadProgressBar(tqdm):
def update_to(self, b=1, bsize=1, tsize=None):
if tsize is not None:
self.total = tsize
self.update(b * bsize - self.n)
10 changes: 6 additions & 4 deletions pyroengine/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
import onnxruntime
from PIL import Image

from .utils import letterbox, nms, xywh2xyxy
from .utils import DownloadProgressBar, letterbox, nms, xywh2xyxy

__all__ = ["Classifier"]

MODEL_URL = "https://github.com/pyronear/pyro-vision/releases/download/v0.2.0/yolov8s_v001.onnx"
MODEL_URL = "https://huggingface.co/pyronear/yolov8s/resolve/main/yolov8s.onnx"
RonanMorgan marked this conversation as resolved.
Show resolved Hide resolved


class Classifier:
Expand All @@ -29,14 +29,16 @@
model_path: model path
"""

def __init__(self, model_path: Optional[str] = "data/model.onnx", img_size: tuple = (384, 640)) -> None:
def __init__(self, model_path: Optional[str] = "data/model.onnx", img_size: tuple = (1024, 1024)) -> None:
if model_path is None:
model_path = "data/model.onnx"

if not os.path.isfile(model_path):
os.makedirs(os.path.split(model_path)[0], exist_ok=True)
print(f"Downloading model from {MODEL_URL} ...")
urlretrieve(MODEL_URL, model_path)
with DownloadProgressBar(unit="B", unit_scale=True, miniters=1, desc=model_path) as t:
urlretrieve(MODEL_URL, model_path, reporthook=t.update_to)

Check warning on line 40 in pyroengine/vision.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

pyroengine/vision.py#L40

Audit url open for permitted schemes. Allowing use of file:/ or custom schemes is often unexpected.
print("Model downloaded!")

self.ort_session = onnxruntime.InferenceSession(model_path)
self.img_size = img_size
Expand Down
6 changes: 3 additions & 3 deletions tests/test_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def test_classifier(mock_wildfire_image):
# Check preprocessing
out, pad = model.preprocess_image(mock_wildfire_image)
assert isinstance(out, np.ndarray) and out.dtype == np.float32
assert out.shape == (1, 3, 384, 640)
assert out.shape == (1, 3, 1024, 1024)
assert isinstance(pad, tuple)
# Check inference
out = model(mock_wildfire_image)
Expand All @@ -18,10 +18,10 @@ def test_classifier(mock_wildfire_image):
assert conf >= 0 and conf <= 1

# Test mask
mask = np.ones((384, 640))
mask = np.ones((1024, 640))
out = model(mock_wildfire_image, mask)
assert out.shape == (1, 5)

mask = np.zeros((384, 640))
mask = np.zeros((1024, 1024))
out = model(mock_wildfire_image, mask)
assert out.shape == (0, 5)
Loading