diff --git a/pyroengine/vision.py b/pyroengine/vision.py index 8a3270e..17e5445 100644 --- a/pyroengine/vision.py +++ b/pyroengine/vision.py @@ -23,17 +23,12 @@ MODEL_URL_FOLDER = "https://huggingface.co/pyronear/yolov8s/resolve/main/" MODEL_ID = "pyronear/yolov8s" MODEL_NAME = "yolov8s.pt" -METADATA_PATH = "data/model_metadata.json" +METADATA_NAME = "model_metadata.json" logging.basicConfig(format="%(asctime)s | %(levelname)s: %(message)s", level=logging.INFO, force=True) -def is_arm_architecture(): - # Check for ARM architecture - return platform.machine().startswith("arm") or platform.machine().startswith("aarch") - - # Utility function to save metadata def save_metadata(metadata_path, metadata): with open(metadata_path, "w") as f: @@ -54,7 +49,7 @@ class Classifier: def __init__(self, model_folder="data", imgsz=1024, conf=0.15, iou=0.05, format="ncnn", model_path=None) -> None: if model_path is None: if format == "ncnn": - if is_arm_architecture(): + if self.is_arm_architecture(): model = "yolov8s_ncnn_model.zip" else: logging.info("NCNN format is optimized for arm architecture only, switching to onnx") @@ -63,6 +58,7 @@ def __init__(self, model_folder="data", imgsz=1024, conf=0.15, iou=0.05, format= model = f"yolov8s.{format}" model_path = os.path.join(model_folder, model) + metadata_path = os.path.join(model_folder, METADATA_NAME) model_url = MODEL_URL_FOLDER + model # Get the expected SHA256 from Hugging Face @@ -76,15 +72,15 @@ def __init__(self, model_folder="data", imgsz=1024, conf=0.15, iou=0.05, format= # Check if the model file exists if os.path.isfile(model_path): # Load existing metadata - metadata = self.load_metadata(METADATA_PATH) + metadata = self.load_metadata(metadata_path) if metadata and metadata.get("sha256") == expected_sha256: logging.info("Model already exists and the SHA256 hash matches. No download needed.") else: logging.info("Model exists but the SHA256 hash does not match or the file doesn't exist.") os.remove(model_path) - self.download_model(model_url, model_path, expected_sha256) + self.download_model(model_url, model_path, expected_sha256, metadata_path) else: - self.download_model(model_url, model_path, expected_sha256) + self.download_model(model_url, model_path, expected_sha256, metadata_path) file_name, ext = os.path.splitext(model_path) if ext == ".zip": @@ -97,6 +93,10 @@ def __init__(self, model_folder="data", imgsz=1024, conf=0.15, iou=0.05, format= self.conf = conf self.iou = iou + def is_arm_architecture(self): + # Check for ARM architecture + return platform.machine().startswith("arm") or platform.machine().startswith("aarch") + def get_sha(self, siblings): # Extract the SHA256 hash from the model files metadata for file in siblings: @@ -104,7 +104,7 @@ def get_sha(self, siblings): return file.lfs["sha256"] return None - def download_model(self, model_url, model_path, expected_sha256): + def download_model(self, model_url, model_path, expected_sha256, metadata_path): # Ensure the directory exists os.makedirs(os.path.split(model_path)[0], exist_ok=True) @@ -116,7 +116,7 @@ def download_model(self, model_url, model_path, expected_sha256): # Save the metadata metadata = {"sha256": expected_sha256} - save_metadata(METADATA_PATH, metadata) + save_metadata(metadata_path, metadata) logging.info("Metadata saved!") # Utility function to load metadata diff --git a/tests/test_vision.py b/tests/test_vision.py index 5c84708..5840f26 100644 --- a/tests/test_vision.py +++ b/tests/test_vision.py @@ -1,10 +1,27 @@ +import datetime +import os +from unittest.mock import patch + import numpy as np from pyroengine.vision import Classifier -METADATA_PATH = "data/model_metadata.json" -model_path = "data/yolov8s.onnx" -sha = "9f1b1c2654d98bbed91e514ce20ea73a0a5fbd1111880f230d516ed40ea2dc58" + +def get_creation_date(file_path): + if os.path.exists(file_path): + + # For Unix-like systems + stat = os.stat(file_path) + try: + creation_time = stat.st_birthtime + except AttributeError: + # On Unix, use the last modification time as a fallback + creation_time = stat.st_mtime + + creation_date = datetime.datetime.fromtimestamp(creation_time) + return creation_date + else: + return None def test_classifier(tmpdir_factory, mock_wildfire_image): @@ -19,7 +36,10 @@ def test_classifier(tmpdir_factory, mock_wildfire_image): conf = np.max(out[:, 4]) assert 0 <= conf <= 1 + # Test onnx model model = Classifier(model_folder=folder, format="onnx") + model_path = os.path.join(folder, "yolov8s.onnx") + assert os.path.isfile(model_path) # Test mask mask = np.ones((384, 640)) @@ -29,3 +49,44 @@ def test_classifier(tmpdir_factory, mock_wildfire_image): mask = np.zeros((384, 640)) out = model(mock_wildfire_image, mask) assert out.shape == (0, 5) + + # Test dl pt model + _ = Classifier(model_folder=folder, format="pt") + model_path = os.path.join(folder, "yolov8s.pt") + assert os.path.isfile(model_path) + + # Test dl ncnn model + with patch.object(Classifier, "is_arm_architecture", return_value=True): + _ = Classifier(model_folder=folder) + model_path = os.path.join(folder, "yolov8s_ncnn_model") + assert os.path.isdir(model_path) + + +def test_download(tmpdir_factory): + print("test_classifier") + folder = str(tmpdir_factory.mktemp("engine_cache")) + + # Instantiate the ONNX model + model = Classifier(model_folder=folder) + + model_path = os.path.join(folder, "yolov8s.onnx") + model_creation_date = get_creation_date(model_path) + + # No download if exist + _ = Classifier(model_folder=folder) + model_creation_date2 = get_creation_date(model_path) + assert model_creation_date == model_creation_date2 + + # Download if does not exist + os.remove(model_path) + _ = Classifier(model_folder=folder) + model_creation_date3 = get_creation_date(model_path) + print(model_creation_date, model_creation_date3) + assert model_creation_date != model_creation_date3 + + # Download if sha is not the same + with patch.object(Classifier, "get_sha", return_value="sha12"): + _ = Classifier(model_folder=folder) + model_creation_date4 = get_creation_date(model_path) + print(model_creation_date, model_creation_date3) + assert model_creation_date4 != model_creation_date3