diff --git a/predict.py b/predict.py index ccfe2bc..a4aafee 100644 --- a/predict.py +++ b/predict.py @@ -8,7 +8,7 @@ def predict_image( image_path: str = typer.Argument(help="image path", show_default=True), - model_path: str = typer.Argument("checkpoint/checkpoint_notebook.pth", help="model path (pth)", show_default=True), + model_path: str = typer.Argument(None, help="path to your model (pth)", show_default=True), device: str = typer.Argument("cpu", help="use cuda if your device has cuda", show_default=True) ): predictor = ImageRecognition(model_path=model_path, device=device) diff --git a/src/onepiece_classify/infer/recognition.py b/src/onepiece_classify/infer/recognition.py index c11a987..263255e 100644 --- a/src/onepiece_classify/infer/recognition.py +++ b/src/onepiece_classify/infer/recognition.py @@ -1,3 +1,5 @@ +import os +import sys from pathlib import Path from typing import Dict, Optional, Tuple @@ -7,13 +9,23 @@ from onepiece_classify.models import image_recog from onepiece_classify.transforms import get_test_transforms +from onepiece_classify.utils import downloader from .base import BaseInference class ImageRecognition(BaseInference): - def __init__(self, model_path: str, device: str): - self.model_path = Path(model_path) + def __init__(self, device: str, model_path=None): + + path_to_save = str(self._get_cache_dir()) + "/model.pth" + if model_path is None: + downloader(path_to_save) + self.model_path = path_to_save + else: + self.model_path = Path(model_path) + if not self.model_path.exists(): + raise FileNotFoundError("Model does not exist, check your model location and read README for more information") + self.device = device self.class_dict = { 0: "Ace", @@ -44,6 +56,15 @@ def _build_model(self): model_backbone = image_recog(self.nclass) model_backbone.load_state_dict(state_dict) return model_backbone + + def _get_cache_dir(self): + if sys.platform.startswith("win"): + cache_dir = Path(os.getenv("LOCALAPPDATA", Path.home() / "AppData" / "Local")) / "OnepieceClassifyCache" + else: + cache_dir = Path.home() / ".cache" / "OnepieceClassifyCache" + + cache_dir.mkdir(parents=True, exist_ok=True) + return cache_dir def pre_process( self, image: Optional[str | np.ndarray | Image.Image] diff --git a/src/onepiece_classify/utils/__init__.py b/src/onepiece_classify/utils/__init__.py index e69de29..4be7d9c 100644 --- a/src/onepiece_classify/utils/__init__.py +++ b/src/onepiece_classify/utils/__init__.py @@ -0,0 +1 @@ +from .downloader import * \ No newline at end of file diff --git a/src/onepiece_classify/utils/downloader.py b/src/onepiece_classify/utils/downloader.py new file mode 100644 index 0000000..26f8009 --- /dev/null +++ b/src/onepiece_classify/utils/downloader.py @@ -0,0 +1,14 @@ +import gdown +from pathlib import Path + + +def downloader(output_file): + file_id = "1M1-1Hs198XDD6Xx-kSWLThv1elZBzJ0j" + prefix = 'https://drive.google.com/uc?/export=download&id=' + + url_download = prefix+file_id + + if not Path(output_file).exists(): + print("Downloading...") + gdown.download(url_download, output_file) + print("Download Finish...")