Skip to content

Commit

Permalink
Merge pull request #19 from lombokai/feature/downloader
Browse files Browse the repository at this point in the history
Feature/downloader
  • Loading branch information
nunenuh authored May 22, 2024
2 parents b98c4eb + 9a1f0c5 commit b9d347c
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 3 deletions.
2 changes: 1 addition & 1 deletion predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

def predict_image(
image_path: str = typer.Argument(help="image path", show_default=True),
model_path: str = typer.Argument("checkpoint/checkpoint_notebook.pth", help="model path (pth)", show_default=True),
model_path: str = typer.Argument(None, help="path to your model (pth)", show_default=True),
device: str = typer.Argument("cpu", help="use cuda if your device has cuda", show_default=True)
):
predictor = ImageRecognition(model_path=model_path, device=device)
Expand Down
25 changes: 23 additions & 2 deletions src/onepiece_classify/infer/recognition.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os
import sys
from pathlib import Path
from typing import Dict, Optional, Tuple

Expand All @@ -7,13 +9,23 @@

from onepiece_classify.models import image_recog
from onepiece_classify.transforms import get_test_transforms
from onepiece_classify.utils import downloader

from .base import BaseInference


class ImageRecognition(BaseInference):
def __init__(self, model_path: str, device: str):
self.model_path = Path(model_path)
def __init__(self, device: str, model_path=None):

path_to_save = str(self._get_cache_dir()) + "/model.pth"
if model_path is None:
downloader(path_to_save)
self.model_path = path_to_save
else:
self.model_path = Path(model_path)
if not self.model_path.exists():
raise FileNotFoundError("Model does not exist, check your model location and read README for more information")

self.device = device
self.class_dict = {
0: "Ace",
Expand Down Expand Up @@ -44,6 +56,15 @@ def _build_model(self):
model_backbone = image_recog(self.nclass)
model_backbone.load_state_dict(state_dict)
return model_backbone

def _get_cache_dir(self):
if sys.platform.startswith("win"):
cache_dir = Path(os.getenv("LOCALAPPDATA", Path.home() / "AppData" / "Local")) / "OnepieceClassifyCache"
else:
cache_dir = Path.home() / ".cache" / "OnepieceClassifyCache"

cache_dir.mkdir(parents=True, exist_ok=True)
return cache_dir

def pre_process(
self, image: Optional[str | np.ndarray | Image.Image]
Expand Down
1 change: 1 addition & 0 deletions src/onepiece_classify/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .downloader import *
14 changes: 14 additions & 0 deletions src/onepiece_classify/utils/downloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import gdown
from pathlib import Path


def downloader(output_file):
file_id = "1M1-1Hs198XDD6Xx-kSWLThv1elZBzJ0j"
prefix = 'https://drive.google.com/uc?/export=download&id='

url_download = prefix+file_id

if not Path(output_file).exists():
print("Downloading...")
gdown.download(url_download, output_file)
print("Download Finish...")

0 comments on commit b9d347c

Please sign in to comment.