Skip to content

Commit

Permalink
mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
MateoLostanlen committed May 28, 2024
1 parent 7683aaf commit 75d7466
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions pyroengine/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(self, model_path: Optional[str] = "data/model.onnx", base_img_size:
self.ort_session = onnxruntime.InferenceSession(model_path)
self.base_img_size = base_img_size

def preprocess_image(self, pil_img: Image.Image, new_img_size: tuple) -> Tuple[np.ndarray, Tuple[int, int]]:
def preprocess_image(self, pil_img: Image.Image, new_img_size: list) -> Tuple[np.ndarray, Tuple[int, int]]:
"""Preprocess an image for inference
Args:
Expand All @@ -67,8 +67,8 @@ def __call__(self, pil_img: Image.Image, occlusion_mask: Optional[np.ndarray] =

w, h = pil_img.size
ratio = self.base_img_size / max(w, h)
new_img_size = (int(ratio * w), int(ratio * h))
new_img_size = tuple([x - x % 32 for x in new_img_size]) # size need to be a multiple of 32 to fit the model
new_img_size = [int(ratio * w), int(ratio * h)]
new_img_size = [x - x % 32 for x in new_img_size] # size need to be a multiple of 32 to fit the model
np_img = self.preprocess_image(pil_img, new_img_size)

# ONNX inference
Expand Down

0 comments on commit 75d7466

Please sign in to comment.