From 347accae782af3188b6cc4fdcbefd1489e205997 Mon Sep 17 00:00:00 2001 From: RonanMorgan <49660557+RonanMorgan@users.noreply.github.com> Date: Wed, 29 May 2024 11:41:31 +0200 Subject: [PATCH] Donwload model only when needed (#200) * feat: use hash to check if a model should be downloaded * fix: add hugging face dependency * fix:make style * fix: make quality * fix: don't understand why there is an error locally but not in the gitaction * linting * error in headers * feat: add tests --- pyproject.toml | 3 +- pyroengine/vision.py | 69 ++++++++++++++++++++++--- tests/test_vision.py | 119 +++++++++++++++++++++++++++++++++++-------- 3 files changed, 163 insertions(+), 28 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 53f69c24..5db7610e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,8 @@ dependencies = [ "pyroclient @ git+https://github.com/pyronear/pyro-api.git@main#egg=pkg&subdirectory=client", "requests>=2.20.0,<3.0.0", "opencv-python==4.5.5.64", - "tqdm>=4.62.0", + "tqdm>=4.62.0", + "huggingface_hub==0.23.1", ] [project.optional-dependencies] diff --git a/pyroengine/vision.py b/pyroengine/vision.py index 674fe4ba..17198088 100644 --- a/pyroengine/vision.py +++ b/pyroengine/vision.py @@ -1,8 +1,9 @@ -# Copyright (C) 2022-2024, Pyronear. +# Copyright (C) 2023-2024, Pyronear. # This program is licensed under the Apache License 2.0. # See LICENSE or go to for full license details. +import json import os from typing import Optional, Tuple from urllib.request import urlretrieve @@ -10,6 +11,7 @@ import cv2 # type: ignore[import-untyped] import numpy as np import onnxruntime +from huggingface_hub import HfApi # type: ignore[import-untyped] from PIL import Image from .utils import DownloadProgressBar, nms, xywh2xyxy @@ -17,6 +19,15 @@ __all__ = ["Classifier"] MODEL_URL = "https://huggingface.co/pyronear/yolov8s/resolve/main/model.onnx" +MODEL_ID = "pyronear/yolov8s" +MODEL_NAME = "model.onnx" +METADATA_PATH = "data/model_metadata.json" + + +# Utility function to save metadata +def save_metadata(metadata_path, metadata): + with open(metadata_path, "w") as f: + json.dump(metadata, f) class Classifier: @@ -34,16 +45,60 @@ def __init__(self, model_path: Optional[str] = "data/model.onnx", base_img_size: if model_path is None: model_path = "data/model.onnx" - if not os.path.isfile(model_path): - os.makedirs(os.path.split(model_path)[0], exist_ok=True) - print(f"Downloading model from {MODEL_URL} ...") - with DownloadProgressBar(unit="B", unit_scale=True, miniters=1, desc=model_path) as t: - urlretrieve(MODEL_URL, model_path, reporthook=t.update_to) - print("Model downloaded!") + # Get the expected SHA256 from Hugging Face + api = HfApi() + model_info = api.model_info(MODEL_ID, files_metadata=True) + expected_sha256 = self.get_sha(model_info.siblings) + + if not expected_sha256: + raise ValueError("SHA256 hash for the model file not found in the Hugging Face model metadata.") + + # Check if the model file exists + if os.path.isfile(model_path): + # Load existing metadata + metadata = self.load_metadata(METADATA_PATH) + if metadata and metadata.get("sha256") == expected_sha256: + print("Model already exists and the SHA256 hash matches. No download needed.") + else: + print("Model exists but the SHA256 hash does not match or the file doesn't exist.") + os.remove(model_path) + self.download_model(model_path, expected_sha256) + else: + self.download_model(model_path, expected_sha256) self.ort_session = onnxruntime.InferenceSession(model_path) self.base_img_size = base_img_size + def get_sha(self, siblings): + # Extract the SHA256 hash from the model files metadata + for file in siblings: + if file.rfilename == os.path.basename(MODEL_NAME): + expected_sha256 = file.lfs.sha256 + break + return expected_sha256 + + def download_model(self, model_path, expected_sha256): + # Ensure the directory exists + os.makedirs(os.path.split(model_path)[0], exist_ok=True) + + # Download the model + print(f"Downloading model from {MODEL_URL} ...") + with DownloadProgressBar(unit="B", unit_scale=True, miniters=1, desc=model_path) as t: + urlretrieve(MODEL_URL, model_path, reporthook=t.update_to) + print("Model downloaded!") + + # Save the metadata + metadata = {"sha256": expected_sha256} + save_metadata(METADATA_PATH, metadata) + print("Metadata saved!") + + # Utility function to load metadata + def load_metadata(self, metadata_path): + if os.path.exists(metadata_path): + with open(metadata_path, "r") as f: + return json.load(f) + return None + def preprocess_image(self, pil_img: Image.Image, new_img_size: list) -> Tuple[np.ndarray, Tuple[int, int]]: """Preprocess an image for inference diff --git a/tests/test_vision.py b/tests/test_vision.py index ac2809c2..84c8a2ce 100644 --- a/tests/test_vision.py +++ b/tests/test_vision.py @@ -1,26 +1,105 @@ +import os +from unittest.mock import MagicMock, patch + import numpy as np +import pytest from pyroengine.vision import Classifier +METADATA_PATH = "data/model_metadata.json" +model_path = "data/model.onnx" +sha = "12b9b5728dfa2e60502dcde2914bfdc4e9378caa57611c567a44cdd6228838c2" + + +def custom_isfile_false(path): + if path == model_path: + return False # or True based on your test case + return True # Default behavior for other paths + + +def custom_isfile_true(path): + if path == model_path: + return True # or True based on your test case + return True # Default behavior for other paths + +# Test for the case : the model doesn't exist def test_classifier(mock_wildfire_image): - # Instantiate the ONNX model - model = Classifier() - # Check preprocessing - out = model.preprocess_image(mock_wildfire_image, (1024, 576)) - assert isinstance(out, np.ndarray) and out.dtype == np.float32 - assert out.shape == (1, 3, 576, 1024) - # Check inference - out = model(mock_wildfire_image) - assert out.shape == (1, 5) - conf = np.max(out[:, 4]) - assert conf >= 0 and conf <= 1 - - # Test mask - mask = np.ones((1024, 576)) - out = model(mock_wildfire_image, mask) - assert out.shape == (1, 5) - - mask = np.zeros((1024, 1024)) - out = model(mock_wildfire_image, mask) - assert out.shape == (0, 5) + print("test_classifier") + with patch("os.path.isfile", side_effect=custom_isfile_false): + # Instantiate the ONNX model + model = Classifier() + # Check preprocessing + out = model.preprocess_image(mock_wildfire_image, (1024, 576)) + assert isinstance(out, np.ndarray) and out.dtype == np.float32 + assert out.shape == (1, 3, 576, 1024) + # Check inference + out = model(mock_wildfire_image) + assert out.shape == (1, 5) + conf = np.max(out[:, 4]) + assert conf >= 0 and conf <= 1 + + # Test mask + mask = np.ones((1024, 576)) + out = model(mock_wildfire_image, mask) + assert out.shape == (1, 5) + + mask = np.zeros((1024, 1024)) + out = model(mock_wildfire_image, mask) + assert out.shape == (0, 5) + os.remove(model_path) + os.remove(METADATA_PATH) + + +# Test that the model is not loaded +def test_no_download(): + print("test_no_download") + data = {"sha256": sha} + with patch("os.path.isfile", side_effect=custom_isfile_true): + with patch("pyroengine.vision.Classifier.load_metadata", return_value=data): + with patch("onnxruntime.InferenceSession", return_value=None): + Classifier() + assert os.path.isfile(model_path) is False + + +# Test if sha are not the same +@patch("pyroengine.vision.urlretrieve") +@patch("pyroengine.vision.DownloadProgressBar") +def test_sha_inequality(mock_download_progress, mock_urlretrieve): + print("test_sha_inequality") + data = {"sha256": "falsesha"} + + # Mock urlretrieve to create a fake file + def fake_urlretrieve(url, filename, reporthook=None): + with open(filename, "w") as f: + f.write("fake model content") + + mock_urlretrieve.side_effect = fake_urlretrieve + # Mock the DownloadProgressBar context manager + mock_progress_bar_instance = MagicMock() + mock_download_progress.return_value.__enter__.return_value = mock_progress_bar_instance + + with patch("os.path.isfile", side_effect=custom_isfile_true): + with patch("pyroengine.vision.Classifier.load_metadata", return_value=data): + with patch( + "pyroengine.vision.Classifier.get_sha", + return_value=sha, + ): + with patch("onnxruntime.InferenceSession", return_value=None): + with patch("os.remove", return_value=True): + model = Classifier() + + assert os.path.isfile(model_path) is True + assert model.load_metadata("non_existent_metadata.json") is None + os.remove(model_path) + os.remove(METADATA_PATH) + + +# Test for raising ValueError if expected_sha256 is not found +def test_raise_value_error_if_no_sha256(): + print("test_raise_value_error_if_no_sha256") + with patch("pyroengine.vision.Classifier.get_sha", return_value=""): + with pytest.raises( + ValueError, match="SHA256 hash for the model file not found in the Hugging Face model metadata." + ): + Classifier(model_path="non_existent_model.onnx")