From 33575d06fedf766dd37b473fefa858aac8004f48 Mon Sep 17 00:00:00 2001 From: AnnurAfgoni Date: Wed, 22 May 2024 14:08:16 +0800 Subject: [PATCH 1/6] feat: add downloader module in utils dir #18 --- src/onepiece_classify/utils/downloader.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) create mode 100644 src/onepiece_classify/utils/downloader.py diff --git a/src/onepiece_classify/utils/downloader.py b/src/onepiece_classify/utils/downloader.py new file mode 100644 index 0000000..4127e19 --- /dev/null +++ b/src/onepiece_classify/utils/downloader.py @@ -0,0 +1,16 @@ +import gdown +from pathlib import Path + + +def downloader(output_file): + file_id = "1M1-1Hs198XDD6Xx-kSWLThv1elZBzJ0j" + prefix = 'https://drive.google.com/uc?/export=download&id=' + + url_download = prefix+file_id + filename = "checkpoint_notebook.pth" + if Path(output_file).joinpath(filename).exists(): + print("The model has been downloaded") + else: + print("Downloading...") + gdown.download(url_download, output_file) + print("Download Finish...") From 5988b2dd89a50faf0732bb2b1ca26e5451f78ec7 Mon Sep 17 00:00:00 2001 From: AnnurAfgoni Date: Wed, 22 May 2024 14:10:04 +0800 Subject: [PATCH 2/6] feat: edit recognition module to adopt new schema --- src/onepiece_classify/infer/recognition.py | 27 ++++++++++++++++++++-- src/onepiece_classify/utils/__init__.py | 1 + 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/src/onepiece_classify/infer/recognition.py b/src/onepiece_classify/infer/recognition.py index c11a987..ea136bb 100644 --- a/src/onepiece_classify/infer/recognition.py +++ b/src/onepiece_classify/infer/recognition.py @@ -1,3 +1,5 @@ +import os +import sys from pathlib import Path from typing import Dict, Optional, Tuple @@ -7,13 +9,25 @@ from onepiece_classify.models import image_recog from onepiece_classify.transforms import get_test_transforms +from onepiece_classify.utils import downloader from .base import BaseInference class ImageRecognition(BaseInference): - def __init__(self, model_path: str, device: str): - self.model_path = Path(model_path) + def __init__(self, device: str, download: bool = True): + # self.model_path = Path(model_path) + self.model_path = self._get_cache_dir() + + if download: + downloader(str(self.model_path) + "/") + + filename = "checkpoint_notebook.pth" + if (self.model_path.joinpath(filename)).exists(): + self.model_path = self.model_path.joinpath("checkpoint_notebook.pth") + else: + raise FileNotFoundError("Model does not exist, set download parameter to True and read README for more information") + self.device = device self.class_dict = { 0: "Ace", @@ -44,6 +58,15 @@ def _build_model(self): model_backbone = image_recog(self.nclass) model_backbone.load_state_dict(state_dict) return model_backbone + + def _get_cache_dir(self): + if sys.platform.startswith("win"): + cache_dir = Path(os.getenv("LOCALAPPDATA", Path.home() / "AppData" / "Local")) / "OnepieceClassifyCache" + else: + cache_dir = Path.home() / ".cache" / "OnepieceClassifyCache" + + cache_dir.mkdir(parents=True, exist_ok=True) + return cache_dir def pre_process( self, image: Optional[str | np.ndarray | Image.Image] diff --git a/src/onepiece_classify/utils/__init__.py b/src/onepiece_classify/utils/__init__.py index e69de29..4be7d9c 100644 --- a/src/onepiece_classify/utils/__init__.py +++ b/src/onepiece_classify/utils/__init__.py @@ -0,0 +1 @@ +from .downloader import * \ No newline at end of file From 0897c69ad29d992957c8f05685e5ec73245ed1af Mon Sep 17 00:00:00 2001 From: AnnurAfgoni Date: Wed, 22 May 2024 14:11:55 +0800 Subject: [PATCH 3/6] feat: edit predict.py to test inference module new schema --- predict.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/predict.py b/predict.py index ccfe2bc..9e01659 100644 --- a/predict.py +++ b/predict.py @@ -8,10 +8,11 @@ def predict_image( image_path: str = typer.Argument(help="image path", show_default=True), - model_path: str = typer.Argument("checkpoint/checkpoint_notebook.pth", help="model path (pth)", show_default=True), + # model_path: str = typer.Argument("checkpoint_notebook.pth", help="model path (pth)", show_default=True), + download: bool = typer.Argument(True, help="True for download the model automatically", show_default=True), device: str = typer.Argument("cpu", help="use cuda if your device has cuda", show_default=True) ): - predictor = ImageRecognition(model_path=model_path, device=device) + predictor = ImageRecognition(download=download, device=device) result = predictor.predict(image=image_path) typer.echo(f"Prediction: {result}") From 675454ee44e8288fa97524aac6a5e9514143207d Mon Sep 17 00:00:00 2001 From: AnnurAfgoni Date: Wed, 22 May 2024 16:02:46 +0800 Subject: [PATCH 4/6] feat: remove download parameter, change logic to auto download #18 --- src/onepiece_classify/infer/recognition.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/src/onepiece_classify/infer/recognition.py b/src/onepiece_classify/infer/recognition.py index ea136bb..263255e 100644 --- a/src/onepiece_classify/infer/recognition.py +++ b/src/onepiece_classify/infer/recognition.py @@ -15,19 +15,17 @@ class ImageRecognition(BaseInference): - def __init__(self, device: str, download: bool = True): - # self.model_path = Path(model_path) - self.model_path = self._get_cache_dir() - - if download: - downloader(str(self.model_path) + "/") - - filename = "checkpoint_notebook.pth" - if (self.model_path.joinpath(filename)).exists(): - self.model_path = self.model_path.joinpath("checkpoint_notebook.pth") + def __init__(self, device: str, model_path=None): + + path_to_save = str(self._get_cache_dir()) + "/model.pth" + if model_path is None: + downloader(path_to_save) + self.model_path = path_to_save else: - raise FileNotFoundError("Model does not exist, set download parameter to True and read README for more information") - + self.model_path = Path(model_path) + if not self.model_path.exists(): + raise FileNotFoundError("Model does not exist, check your model location and read README for more information") + self.device = device self.class_dict = { 0: "Ace", From af7fb439b48dd0d49552dab42c8e46056e808b90 Mon Sep 17 00:00:00 2001 From: AnnurAfgoni Date: Wed, 22 May 2024 16:03:34 +0800 Subject: [PATCH 5/6] feat: edit downloader module, make it simpler --- src/onepiece_classify/utils/downloader.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/onepiece_classify/utils/downloader.py b/src/onepiece_classify/utils/downloader.py index 4127e19..26f8009 100644 --- a/src/onepiece_classify/utils/downloader.py +++ b/src/onepiece_classify/utils/downloader.py @@ -7,10 +7,8 @@ def downloader(output_file): prefix = 'https://drive.google.com/uc?/export=download&id=' url_download = prefix+file_id - filename = "checkpoint_notebook.pth" - if Path(output_file).joinpath(filename).exists(): - print("The model has been downloaded") - else: + + if not Path(output_file).exists(): print("Downloading...") gdown.download(url_download, output_file) print("Download Finish...") From 9a1f0c5179cbd62a9b5e75157adcada23451fb6d Mon Sep 17 00:00:00 2001 From: AnnurAfgoni Date: Wed, 22 May 2024 16:06:21 +0800 Subject: [PATCH 6/6] feat: edit predict.py, keep model path for flexibility load the model #18 --- predict.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/predict.py b/predict.py index 9e01659..a4aafee 100644 --- a/predict.py +++ b/predict.py @@ -8,11 +8,10 @@ def predict_image( image_path: str = typer.Argument(help="image path", show_default=True), - # model_path: str = typer.Argument("checkpoint_notebook.pth", help="model path (pth)", show_default=True), - download: bool = typer.Argument(True, help="True for download the model automatically", show_default=True), + model_path: str = typer.Argument(None, help="path to your model (pth)", show_default=True), device: str = typer.Argument("cpu", help="use cuda if your device has cuda", show_default=True) ): - predictor = ImageRecognition(download=download, device=device) + predictor = ImageRecognition(model_path=model_path, device=device) result = predictor.predict(image=image_path) typer.echo(f"Prediction: {result}")