From c6a2e33bb73595551cf9f9f6b94354626f3dd848 Mon Sep 17 00:00:00 2001 From: prernadh Date: Thu, 19 Sep 2024 10:58:53 -0700 Subject: [PATCH] Adding scores --- fiftyone/utils/sam.py | 9 +++++++-- fiftyone/utils/sam2.py | 5 ++++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/fiftyone/utils/sam.py b/fiftyone/utils/sam.py index 32afa9334c..82bfe92536 100644 --- a/fiftyone/utils/sam.py +++ b/fiftyone/utils/sam.py @@ -271,14 +271,19 @@ def _forward_pass_boxes(self, imgs): device=sam_predictor.device, ) - masks, _, _ = sam_predictor.predict_torch( + masks, scores, _ = sam_predictor.predict_torch( point_coords=None, point_labels=None, boxes=transformed_boxes, multimask_output=False, ) outputs.append( - {"boxes": input_boxes, "labels": labels, "masks": masks} + { + "boxes": input_boxes, + "labels": labels, + "masks": masks, + "scores": scores, + } ) return outputs diff --git a/fiftyone/utils/sam2.py b/fiftyone/utils/sam2.py index c0bba8806c..d9ed379566 100644 --- a/fiftyone/utils/sam2.py +++ b/fiftyone/utils/sam2.py @@ -201,7 +201,7 @@ def _forward_pass_boxes(self, imgs): device=sam2_predictor.device, ) - masks, _, _ = sam2_predictor.predict( + masks, scores, _ = sam2_predictor.predict( point_coords=None, point_labels=None, box=sam_boxes[None, :], @@ -214,6 +214,9 @@ def _forward_pass_boxes(self, imgs): "boxes": input_boxes, "labels": labels, "masks": torch.tensor(masks, device=sam2_predictor.device), + "scores": torch.tensor( + scores, device=sam2_predictor.device + ), } )