Skip to content

Commit

Permalink
up
Browse files Browse the repository at this point in the history
  • Loading branch information
lgc2333 committed Oct 29, 2024
1 parent 3148fda commit 3507754
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 15 deletions.
12 changes: 3 additions & 9 deletions nonebot_plugin_nailongremove/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,9 @@
Iterable,
Iterator,
List,
Optional,
Tuple,
TypeVar,
cast,
)
from typing_extensions import TypeAlias

import cv2
import numpy as np
Expand All @@ -28,7 +25,7 @@
from PIL import Image as PilImage

from .config import config
from .model import check_image
from .model import CheckResultTuple, check_image
from .uniapi import mute, recall

T = TypeVar("T")
Expand Down Expand Up @@ -115,11 +112,8 @@ async def nailong_rule(
)


CheckFrameResult: TypeAlias = Tuple[bool, Optional[np.ndarray]]


async def check_frames(frames: Iterator[np.ndarray]) -> CheckFrameResult:
async def worker() -> CheckFrameResult:
async def check_frames(frames: Iterator[np.ndarray]) -> CheckResultTuple:
async def worker() -> CheckResultTuple:
while True:
try:
frame = next(frames)
Expand Down
13 changes: 10 additions & 3 deletions nonebot_plugin_nailongremove/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Callable, NoReturn, Tuple, Union
from typing import Awaitable, Callable, Literal, NoReturn, Tuple, Union
from typing_extensions import TypeAlias

import numpy as np
from nonebot.utils import run_sync
Expand All @@ -14,7 +15,13 @@ def raise_extra_import_error(e: BaseException, group: str) -> NoReturn:
) from e


check_image_sync: Callable[[np.ndarray], Union[bool, Tuple[bool, np.ndarray]]]
CheckResultTuple: TypeAlias = Union[
Tuple[bool, None],
Tuple[Literal[True], np.ndarray],
]
CheckResult: TypeAlias = Union[bool, CheckResultTuple]

check_image_sync: Callable[[np.ndarray], CheckResult]

if config.nailong_model is ModelType.CLASSIFICATION:
from .classification import check_image as check_image_sync
Expand All @@ -29,4 +36,4 @@ def raise_extra_import_error(e: BaseException, group: str) -> NoReturn:
raise ValueError("Invalid model type")


check_image = run_sync(check_image_sync)
check_image: Callable[[np.ndarray], Awaitable[CheckResult]] = run_sync(check_image_sync)
7 changes: 5 additions & 2 deletions nonebot_plugin_nailongremove/model/classification.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any
from typing import TYPE_CHECKING, Any

import cv2
import numpy as np
Expand All @@ -8,6 +8,9 @@

from ..utils import ensure_model_from_github_repo

if TYPE_CHECKING:
from . import CheckResult

cuda_available = torch.cuda.is_available()
device = torch.device("cuda" if cuda_available else "cpu")
transform = transforms.Compose([transforms.ToTensor()])
Expand All @@ -30,7 +33,7 @@
model.cuda()


def check_image(image: np.ndarray):
def check_image(image: np.ndarray) -> "CheckResult":
if image.shape[0] < 224 or image.shape[1] < 224:
return False
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
Expand Down
7 changes: 6 additions & 1 deletion nonebot_plugin_nailongremove/model/target_detection.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
from typing import TYPE_CHECKING

import numpy as np
import onnxruntime

from ..config import config
from ..utils import ensure_model_from_github_release
from .yolox_utils import demo_postprocess, multiclass_nms, preprocess, vis

if TYPE_CHECKING:
from . import CheckResult

COCO_CLASSES = ("_background_", "nailong", "anime", "human", "emoji", "long", "other")

model_path = ensure_model_from_github_release(
Expand All @@ -18,7 +23,7 @@
input_shape = config.nailong_yolox_size


def check_image(image: np.ndarray):
def check_image(image: np.ndarray) -> "CheckResult":
img, ratio = preprocess(image, input_shape)
ort_inputs = {session.get_inputs()[0].name: img[None, :, :, :]}
output = session.run(None, ort_inputs)
Expand Down

0 comments on commit 3507754

Please sign in to comment.