Skip to content

Commit

Permalink
Merge branch 'develop' into rs/check-hash-model
Browse files Browse the repository at this point in the history
  • Loading branch information
Ronan committed May 29, 2024
2 parents 257ea16 + c18de6b commit 0acd8d5
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 69 deletions.
48 changes: 1 addition & 47 deletions pyroengine/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.


import cv2 # type: ignore[import-untyped]
import numpy as np
from tqdm import tqdm # type: ignore[import-untyped]

__all__ = ["letterbox", "nms", "xywh2xyxy", "DownloadProgressBar"]
__all__ = ["nms", "xywh2xyxy", "DownloadProgressBar"]


def xywh2xyxy(x: np.ndarray):
Expand All @@ -20,51 +19,6 @@ def xywh2xyxy(x: np.ndarray):
return y


def letterbox(
im: np.ndarray, new_shape: tuple = (640, 640), color: tuple = (0, 0, 0), auto: bool = False, stride: int = 32
):
"""Letterbox image transform for yolo models
Args:
im (np.ndarray): Input image
new_shape (tuple, optional): Image size. Defaults to (640, 640).
color (tuple, optional): Pixel fill value for the area outside the transformed image.
Defaults to (0, 0, 0).
auto (bool, optional): auto padding. Defaults to True.
stride (int, optional): padding stride. Defaults to 32.
Returns:
np.ndarray: Output image
"""
# Resize and pad image while meeting stride-multiple constraints
im = np.array(im)
shape = im.shape[:2] # current shape [height, width]
if isinstance(new_shape, int):
new_shape = (new_shape, new_shape)

# Scale ratio (new / old)
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])

# Compute padding
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding

if auto: # minimum rectangle
dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding

dw /= 2 # divide padding into 2 sides
dh /= 2

if shape[::-1] != new_unpad: # resize
im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
# add border
h, w = im.shape[:2]
im_b = np.zeros((h + top + bottom, w + left + right, 3)) + color
im_b[top : top + h, left : left + w, :] = im

return im_b.astype("uint8"), (left, top)


def box_iou(box1: np.ndarray, box2: np.ndarray, eps: float = 1e-7):
"""
Calculate intersection-over-union (IoU) of boxes.
Expand Down
34 changes: 19 additions & 15 deletions pyroengine/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,19 @@
from typing import Optional, Tuple
from urllib.request import urlretrieve

import cv2 # type: ignore[import-untyped]
import numpy as np
import onnxruntime
from huggingface_hub import HfApi # type: ignore[import-untyped]
from PIL import Image

from .utils import DownloadProgressBar, letterbox, nms, xywh2xyxy
from .utils import DownloadProgressBar, nms, xywh2xyxy

__all__ = ["Classifier"]

MODEL_URL = "https://huggingface.co/pyronear/yolov8s/resolve/main/yolov8s.onnx"
MODEL_URL = "https://huggingface.co/pyronear/yolov8s/resolve/main/model.onnx"
MODEL_ID = "pyronear/yolov8s"
MODEL_NAME = "yolov8s.onnx"
MODEL_NAME = "model.onnx"
METADATA_PATH = "data/model_metadata.json"


Expand All @@ -40,7 +41,7 @@ class Classifier:
model_path: model path
"""

def __init__(self, model_path: Optional[str] = "data/model.onnx", img_size: tuple = (1024, 1024)) -> None:
def __init__(self, model_path: Optional[str] = "data/model.onnx", base_img_size: int = 1024) -> None:
if model_path is None:
model_path = "data/model.onnx"

Expand All @@ -66,7 +67,7 @@ def __init__(self, model_path: Optional[str] = "data/model.onnx", img_size: tupl
self.download_model(model_path, expected_sha256)

self.ort_session = onnxruntime.InferenceSession(model_path)
self.img_size = img_size
self.base_img_size = base_img_size

def get_sha(self, siblings):
# Extract the SHA256 hash from the model files metadata
Expand Down Expand Up @@ -98,7 +99,7 @@ def load_metadata(self, metadata_path):
return json.load(f)
return None

def preprocess_image(self, pil_img: Image.Image) -> Tuple[np.ndarray, Tuple[int, int]]:
def preprocess_image(self, pil_img: Image.Image, new_img_size: list) -> Tuple[np.ndarray, Tuple[int, int]]:
"""Preprocess an image for inference
Args:
Expand All @@ -109,15 +110,21 @@ def preprocess_image(self, pil_img: Image.Image) -> Tuple[np.ndarray, Tuple[int,
- The resized and normalized image of shape (1, C, H, W).
- Padding information as a tuple of integers (pad_height, pad_width).
"""
np_img, pad = letterbox(np.array(pil_img), self.img_size) # Applies letterbox resize with padding

np_img = cv2.resize(np.array(pil_img), new_img_size, interpolation=cv2.INTER_LINEAR)
np_img = np.expand_dims(np_img.astype("float"), axis=0) # Add batch dimension
np_img = np.ascontiguousarray(np_img.transpose((0, 3, 1, 2))) # Convert from BHWC to BCHW format
np_img = np_img.astype("float32") / 255 # Normalize to [0, 1]

return np_img, pad
return np_img

def __call__(self, pil_img: Image.Image, occlusion_mask: Optional[np.ndarray] = None) -> np.ndarray:
np_img, pad = self.preprocess_image(pil_img)

w, h = pil_img.size
ratio = self.base_img_size / max(w, h)
new_img_size = [int(ratio * w), int(ratio * h)]
new_img_size = [x - x % 32 for x in new_img_size] # size need to be a multiple of 32 to fit the model
np_img = self.preprocess_image(pil_img, new_img_size)

# ONNX inference
y = self.ort_session.run(["output0"], {"images": np_img})[0][0]
Expand All @@ -132,12 +139,9 @@ def __call__(self, pil_img: Image.Image, occlusion_mask: Optional[np.ndarray] =

# Normalize preds
if len(y) > 0:
# Remove padding
left_pad, top_pad = pad
y[:, :4:2] -= left_pad
y[:, 1:4:2] -= top_pad
y[:, :4:2] /= self.img_size[1] - 2 * left_pad
y[:, 1:4:2] /= self.img_size[0] - 2 * top_pad
# Normalize Output
y[:, :4:2] /= new_img_size[0]
y[:, 1:4:2] /= new_img_size[1]
else:
y = np.zeros((0, 5)) # normalize output

Expand Down
14 changes: 7 additions & 7 deletions tests/test_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

METADATA_PATH = "data/model_metadata.json"
model_path = "data/model.onnx"
sha = "12b9b5728dfa2e60502dcde2914bfdc4e9378caa57611c567a44cdd6228838c2"


def custom_isfile_false(path):
Expand All @@ -29,18 +30,17 @@ def test_classifier(mock_wildfire_image):
# Instantiate the ONNX model
model = Classifier()
# Check preprocessing
out, pad = model.preprocess_image(mock_wildfire_image)
out = model.preprocess_image(mock_wildfire_image, (1024, 576))
assert isinstance(out, np.ndarray) and out.dtype == np.float32
assert out.shape == (1, 3, 1024, 1024)
assert isinstance(pad, tuple)
assert out.shape == (1, 3, 576, 1024)
# Check inference
out = model(mock_wildfire_image)
assert out.shape == (1, 5)
conf = np.max(out[:, 4])
assert conf >= 0 and conf <= 1

# Test mask
mask = np.ones((1024, 640))
mask = np.ones((1024, 576))
out = model(mock_wildfire_image, mask)
assert out.shape == (1, 5)

Expand All @@ -54,7 +54,7 @@ def test_classifier(mock_wildfire_image):
# Test that the model is not loaded
def test_no_download():
print("test_no_download")
data = {"sha256": "00083a41dc6468e998a40d9f6f348c10e4c7c998a7bfec9f8dbf58db6bd3471d"}
data = {"sha256": sha}
with patch("os.path.isfile", side_effect=custom_isfile_true):
with patch("pyroengine.vision.Classifier.load_metadata", return_value=data):
with patch("onnxruntime.InferenceSession", return_value=None):
Expand All @@ -67,7 +67,7 @@ def test_no_download():
@patch("pyroengine.vision.DownloadProgressBar")
def test_sha_inequality(mock_download_progress, mock_urlretrieve):
print("test_sha_inequality")
data = {"sha256": "00083a41dc6468e998a40d9f6f348c10e4c7c998a7bfec9f8dbf58db6bd3471d"}
data = {"sha256": "falsesha"}

# Mock urlretrieve to create a fake file
def fake_urlretrieve(url, filename, reporthook=None):
Expand All @@ -83,7 +83,7 @@ def fake_urlretrieve(url, filename, reporthook=None):
with patch("pyroengine.vision.Classifier.load_metadata", return_value=data):
with patch(
"pyroengine.vision.Classifier.get_sha",
return_value="00083a41dc6468e998a40d9f6f348c10e4c7c998a7bfec9f8dbf58db6bd3471e",
return_value=sha,
):
with patch("onnxruntime.InferenceSession", return_value=None):
with patch("os.remove", return_value=True):
Expand Down

0 comments on commit 0acd8d5

Please sign in to comment.