From 68ccd14dd7f03e54a967ab17f2cd0fa39ce2402b Mon Sep 17 00:00:00 2001 From: Mateo Date: Tue, 18 Jul 2023 08:44:43 +0200 Subject: [PATCH] keep all preds --- pyroengine/vision.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pyroengine/vision.py b/pyroengine/vision.py index 29084cb6..e9965eac 100644 --- a/pyroengine/vision.py +++ b/pyroengine/vision.py @@ -62,11 +62,13 @@ def __call__(self, pil_img: Image.Image) -> np.ndarray: # ONNX inference y = self.ort_session.run(["output0"], {"images": np_img})[0][0] - y = y[:, y[-1, :] > 0.1] + # Post processing y = np.transpose(y) y = xywh2xyxy(y) + # Sort by confidence y = y[y[:, 4].argsort()] y = NMS(y) + # Normalize preds if len(y) > 0: y[:, :4:2] /= self.img_size[1] y[:, 1:4:2] /= self.img_size[0]