Skip to content

Commit

Permalink
Donwload model only when needed (#200)
Browse files Browse the repository at this point in the history
* feat: use hash to check if a model should be downloaded

* fix: add hugging face dependency

* fix:make style

* fix: make quality

* fix: don't understand why there is an error locally but not in the gitaction

* linting

* error in headers

* feat: add tests
  • Loading branch information
RonanMorgan authored May 29, 2024
1 parent c18de6b commit 347acca
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 28 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ dependencies = [
"pyroclient @ git+https://github.com/pyronear/pyro-api.git@main#egg=pkg&subdirectory=client",
"requests>=2.20.0,<3.0.0",
"opencv-python==4.5.5.64",
"tqdm>=4.62.0",
"tqdm>=4.62.0",
"huggingface_hub==0.23.1",
]

[project.optional-dependencies]
Expand Down
69 changes: 62 additions & 7 deletions pyroengine/vision.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,33 @@
# Copyright (C) 2022-2024, Pyronear.
# Copyright (C) 2023-2024, Pyronear.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

import json
import os
from typing import Optional, Tuple
from urllib.request import urlretrieve

import cv2 # type: ignore[import-untyped]
import numpy as np
import onnxruntime
from huggingface_hub import HfApi # type: ignore[import-untyped]
from PIL import Image

from .utils import DownloadProgressBar, nms, xywh2xyxy

__all__ = ["Classifier"]

MODEL_URL = "https://huggingface.co/pyronear/yolov8s/resolve/main/model.onnx"
MODEL_ID = "pyronear/yolov8s"
MODEL_NAME = "model.onnx"
METADATA_PATH = "data/model_metadata.json"


# Utility function to save metadata
def save_metadata(metadata_path, metadata):
with open(metadata_path, "w") as f:
json.dump(metadata, f)


class Classifier:
Expand All @@ -34,16 +45,60 @@ def __init__(self, model_path: Optional[str] = "data/model.onnx", base_img_size:
if model_path is None:
model_path = "data/model.onnx"

if not os.path.isfile(model_path):
os.makedirs(os.path.split(model_path)[0], exist_ok=True)
print(f"Downloading model from {MODEL_URL} ...")
with DownloadProgressBar(unit="B", unit_scale=True, miniters=1, desc=model_path) as t:
urlretrieve(MODEL_URL, model_path, reporthook=t.update_to)
print("Model downloaded!")
# Get the expected SHA256 from Hugging Face
api = HfApi()
model_info = api.model_info(MODEL_ID, files_metadata=True)
expected_sha256 = self.get_sha(model_info.siblings)

if not expected_sha256:
raise ValueError("SHA256 hash for the model file not found in the Hugging Face model metadata.")

# Check if the model file exists
if os.path.isfile(model_path):
# Load existing metadata
metadata = self.load_metadata(METADATA_PATH)
if metadata and metadata.get("sha256") == expected_sha256:
print("Model already exists and the SHA256 hash matches. No download needed.")
else:
print("Model exists but the SHA256 hash does not match or the file doesn't exist.")
os.remove(model_path)
self.download_model(model_path, expected_sha256)
else:
self.download_model(model_path, expected_sha256)

self.ort_session = onnxruntime.InferenceSession(model_path)
self.base_img_size = base_img_size

def get_sha(self, siblings):
# Extract the SHA256 hash from the model files metadata
for file in siblings:
if file.rfilename == os.path.basename(MODEL_NAME):
expected_sha256 = file.lfs.sha256
break
return expected_sha256

def download_model(self, model_path, expected_sha256):
# Ensure the directory exists
os.makedirs(os.path.split(model_path)[0], exist_ok=True)

# Download the model
print(f"Downloading model from {MODEL_URL} ...")
with DownloadProgressBar(unit="B", unit_scale=True, miniters=1, desc=model_path) as t:
urlretrieve(MODEL_URL, model_path, reporthook=t.update_to)
print("Model downloaded!")

# Save the metadata
metadata = {"sha256": expected_sha256}
save_metadata(METADATA_PATH, metadata)
print("Metadata saved!")

# Utility function to load metadata
def load_metadata(self, metadata_path):
if os.path.exists(metadata_path):
with open(metadata_path, "r") as f:
return json.load(f)
return None

def preprocess_image(self, pil_img: Image.Image, new_img_size: list) -> Tuple[np.ndarray, Tuple[int, int]]:
"""Preprocess an image for inference
Expand Down
119 changes: 99 additions & 20 deletions tests/test_vision.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,105 @@
import os
from unittest.mock import MagicMock, patch

import numpy as np
import pytest

from pyroengine.vision import Classifier

METADATA_PATH = "data/model_metadata.json"
model_path = "data/model.onnx"
sha = "12b9b5728dfa2e60502dcde2914bfdc4e9378caa57611c567a44cdd6228838c2"


def custom_isfile_false(path):
if path == model_path:
return False # or True based on your test case
return True # Default behavior for other paths


def custom_isfile_true(path):
if path == model_path:
return True # or True based on your test case
return True # Default behavior for other paths


# Test for the case : the model doesn't exist
def test_classifier(mock_wildfire_image):
# Instantiate the ONNX model
model = Classifier()
# Check preprocessing
out = model.preprocess_image(mock_wildfire_image, (1024, 576))
assert isinstance(out, np.ndarray) and out.dtype == np.float32
assert out.shape == (1, 3, 576, 1024)
# Check inference
out = model(mock_wildfire_image)
assert out.shape == (1, 5)
conf = np.max(out[:, 4])
assert conf >= 0 and conf <= 1

# Test mask
mask = np.ones((1024, 576))
out = model(mock_wildfire_image, mask)
assert out.shape == (1, 5)

mask = np.zeros((1024, 1024))
out = model(mock_wildfire_image, mask)
assert out.shape == (0, 5)
print("test_classifier")
with patch("os.path.isfile", side_effect=custom_isfile_false):
# Instantiate the ONNX model
model = Classifier()
# Check preprocessing
out = model.preprocess_image(mock_wildfire_image, (1024, 576))
assert isinstance(out, np.ndarray) and out.dtype == np.float32
assert out.shape == (1, 3, 576, 1024)
# Check inference
out = model(mock_wildfire_image)
assert out.shape == (1, 5)
conf = np.max(out[:, 4])
assert conf >= 0 and conf <= 1

# Test mask
mask = np.ones((1024, 576))
out = model(mock_wildfire_image, mask)
assert out.shape == (1, 5)

mask = np.zeros((1024, 1024))
out = model(mock_wildfire_image, mask)
assert out.shape == (0, 5)
os.remove(model_path)
os.remove(METADATA_PATH)


# Test that the model is not loaded
def test_no_download():
print("test_no_download")
data = {"sha256": sha}
with patch("os.path.isfile", side_effect=custom_isfile_true):
with patch("pyroengine.vision.Classifier.load_metadata", return_value=data):
with patch("onnxruntime.InferenceSession", return_value=None):
Classifier()
assert os.path.isfile(model_path) is False


# Test if sha are not the same
@patch("pyroengine.vision.urlretrieve")
@patch("pyroengine.vision.DownloadProgressBar")
def test_sha_inequality(mock_download_progress, mock_urlretrieve):
print("test_sha_inequality")
data = {"sha256": "falsesha"}

# Mock urlretrieve to create a fake file
def fake_urlretrieve(url, filename, reporthook=None):
with open(filename, "w") as f:
f.write("fake model content")

mock_urlretrieve.side_effect = fake_urlretrieve
# Mock the DownloadProgressBar context manager
mock_progress_bar_instance = MagicMock()
mock_download_progress.return_value.__enter__.return_value = mock_progress_bar_instance

with patch("os.path.isfile", side_effect=custom_isfile_true):
with patch("pyroengine.vision.Classifier.load_metadata", return_value=data):
with patch(
"pyroengine.vision.Classifier.get_sha",
return_value=sha,
):
with patch("onnxruntime.InferenceSession", return_value=None):
with patch("os.remove", return_value=True):
model = Classifier()

assert os.path.isfile(model_path) is True
assert model.load_metadata("non_existent_metadata.json") is None
os.remove(model_path)
os.remove(METADATA_PATH)


# Test for raising ValueError if expected_sha256 is not found
def test_raise_value_error_if_no_sha256():
print("test_raise_value_error_if_no_sha256")
with patch("pyroengine.vision.Classifier.get_sha", return_value=""):
with pytest.raises(
ValueError, match="SHA256 hash for the model file not found in the Hugging Face model metadata."
):
Classifier(model_path="non_existent_model.onnx")

0 comments on commit 347acca

Please sign in to comment.