Skip to content

Commit

Permalink
Added args for grounded_mobile_sam (IDEA-Research#364)
Browse files Browse the repository at this point in the history
* trial_1

* added args and option to save binary mask

* updated grounded_mobile_sam usage
  • Loading branch information
tripathiarpan20 authored Aug 27, 2023
1 parent a4d76a2 commit f34e1a3
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 100 deletions.
4 changes: 2 additions & 2 deletions EfficientSAM/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ python EfficientSAM/grounded_fast_sam.py --model_path "./FastSAM-x.pt" --img_pat
```bash
cd Grounded-Segment-Anything

python EfficientSAM/grounded_mobile_sam.py
python EfficientSAM/grounded_mobile_sam.py --MOBILE_SAM_CHECKPOINT_PATH "./EfficientSAM/mobile_sam.pt" --SOURCE_IMAGE_PATH "./assets/demo2.jpg" --CAPTION "the running dog"
```

- And the result will be saved as `./gronded_mobile_sam_anontated_image.jpg` as:
Expand Down Expand Up @@ -103,4 +103,4 @@ python EfficientSAM/grounded_light_hqsam.py
|:---:|:---:|:---:|
|![](/EfficientSAM/LightHQSAM/example_light_hqsam.png) | "Bench" | ![](/EfficientSAM/LightHQSAM/grounded_light_hqsam_annotated_image.jpg) |

</div>
</div>
233 changes: 135 additions & 98 deletions EfficientSAM/grounded_mobile_sam.py
Original file line number Diff line number Diff line change
@@ -1,108 +1,145 @@
import cv2
import numpy as np
import supervision as sv

import argparse
import torch
import torchvision

from groundingdino.util.inference import Model
from segment_anything import SamPredictor
from MobileSAM.setup_mobile_sam import setup_model

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# GroundingDINO config and checkpoint
GROUNDING_DINO_CONFIG_PATH = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
GROUNDING_DINO_CHECKPOINT_PATH = "./groundingdino_swint_ogc.pth"

# Building GroundingDINO inference model
grounding_dino_model = Model(model_config_path=GROUNDING_DINO_CONFIG_PATH, model_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH)

# Building MobileSAM predictor
MOBILE_SAM_CHECKPOINT_PATH = "./EfficientSAM/mobile_sam.pt"
checkpoint = torch.load(MOBILE_SAM_CHECKPOINT_PATH)
mobile_sam = setup_model()
mobile_sam.load_state_dict(checkpoint, strict=True)
mobile_sam.to(device=DEVICE)

sam_predictor = SamPredictor(mobile_sam)


# Predict classes and hyper-param for GroundingDINO
SOURCE_IMAGE_PATH = "./assets/demo2.jpg"
CLASSES = ["The running dog"]
BOX_THRESHOLD = 0.25
TEXT_THRESHOLD = 0.25
NMS_THRESHOLD = 0.8


# load image
image = cv2.imread(SOURCE_IMAGE_PATH)

# detect objects
detections = grounding_dino_model.predict_with_classes(
image=image,
classes=CLASSES,
box_threshold=BOX_THRESHOLD,
text_threshold=BOX_THRESHOLD
)

# annotate image with detections
box_annotator = sv.BoxAnnotator()
labels = [
f"{CLASSES[class_id]} {confidence:0.2f}"
for _, _, confidence, class_id, _
in detections]
annotated_frame = box_annotator.annotate(scene=image.copy(), detections=detections, labels=labels)

# save the annotated grounding dino image
cv2.imwrite("groundingdino_annotated_image.jpg", annotated_frame)


# NMS post process
print(f"Before NMS: {len(detections.xyxy)} boxes")
nms_idx = torchvision.ops.nms(
torch.from_numpy(detections.xyxy),
torch.from_numpy(detections.confidence),
NMS_THRESHOLD
).numpy().tolist()

detections.xyxy = detections.xyxy[nms_idx]
detections.confidence = detections.confidence[nms_idx]
detections.class_id = detections.class_id[nms_idx]

print(f"After NMS: {len(detections.xyxy)} boxes")

# Prompting SAM with detected boxes
def segment(sam_predictor: SamPredictor, image: np.ndarray, xyxy: np.ndarray) -> np.ndarray:
sam_predictor.set_image(image)
result_masks = []
for box in xyxy:
masks, scores, logits = sam_predictor.predict(
box=box,
multimask_output=True
)
index = np.argmax(scores)
result_masks.append(masks[index])
return np.array(result_masks)


# convert detections to masks
detections.mask = segment(
sam_predictor=sam_predictor,
image=cv2.cvtColor(image, cv2.COLOR_BGR2RGB),
xyxy=detections.xyxy
)

# annotate image with detections
box_annotator = sv.BoxAnnotator()
mask_annotator = sv.MaskAnnotator()
labels = [
f"{CLASSES[class_id]} {confidence:0.2f}"
for _, _, confidence, class_id, _
in detections]
annotated_image = mask_annotator.annotate(scene=image.copy(), detections=detections)
annotated_image = box_annotator.annotate(scene=annotated_image, detections=detections, labels=labels)

# save the annotated grounded-sam image
cv2.imwrite("grounded_mobile_sam_annotated_image.jpg", annotated_image)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--MOBILE_SAM_CHECKPOINT_PATH", type=str, default="./EfficientSAM/mobile_sam.pt", help="model"
)
parser.add_argument(
"--SOURCE_IMAGE_PATH", type=str, default="./assets/demo2.jpg", help="path to image file"
)
parser.add_argument(
"--CAPTION", type=str, default="The running dog", help="text prompt for GroundingDINO"
)
parser.add_argument(
"--OUT_FILE_BOX", type=str, default="groundingdino_annotated_image.jpg", help="the output filename"
)
parser.add_argument(
"--OUT_FILE_SEG", type=str, default="grounded_mobile_sam_annotated_image.jpg", help="the output filename"
)
parser.add_argument(
"--OUT_FILE_BIN_MASK", type=str, default="grounded_mobile_sam_bin_mask.jpg", help="the output filename"
)
parser.add_argument("--BOX_THRESHOLD", type=float, default=0.25, help="")
parser.add_argument("--TEXT_THRESHOLD", type=float, default=0.25, help="")
parser.add_argument("--NMS_THRESHOLD", type=float, default=0.8, help="")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
parser.add_argument(
"--DEVICE", type=str, default=device, help="cuda:[0,1,2,3,4] or cpu"
)
return parser.parse_args()

def main(args):
DEVICE = args.DEVICE

# GroundingDINO config and checkpoint
GROUNDING_DINO_CONFIG_PATH = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
GROUNDING_DINO_CHECKPOINT_PATH = "./groundingdino_swint_ogc.pth"

# Building GroundingDINO inference model
grounding_dino_model = Model(model_config_path=GROUNDING_DINO_CONFIG_PATH, model_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH)

# Building MobileSAM predictor
MOBILE_SAM_CHECKPOINT_PATH = args.MOBILE_SAM_CHECKPOINT_PATH
checkpoint = torch.load(MOBILE_SAM_CHECKPOINT_PATH)
mobile_sam = setup_model()
mobile_sam.load_state_dict(checkpoint, strict=True)
mobile_sam.to(device=DEVICE)

sam_predictor = SamPredictor(mobile_sam)


# Predict classes and hyper-param for GroundingDINO
SOURCE_IMAGE_PATH = args.SOURCE_IMAGE_PATH
CLASSES = [args.CAPTION]
BOX_THRESHOLD = args.BOX_THRESHOLD
TEXT_THRESHOLD = args.TEXT_THRESHOLD
NMS_THRESHOLD = args.NMS_THRESHOLD


# load image
image = cv2.imread(SOURCE_IMAGE_PATH)

# detect objects
detections = grounding_dino_model.predict_with_classes(
image=image,
classes=CLASSES,
box_threshold=BOX_THRESHOLD,
text_threshold=TEXT_THRESHOLD
)

# annotate image with detections
box_annotator = sv.BoxAnnotator()
labels = [
f"{CLASSES[class_id]} {confidence:0.2f}"
for _, _, confidence, class_id, _
in detections]
annotated_frame = box_annotator.annotate(scene=image.copy(), detections=detections, labels=labels)

# save the annotated grounding dino image
cv2.imwrite(args.OUT_FILE_BOX, annotated_frame)


# NMS post process
print(f"Before NMS: {len(detections.xyxy)} boxes")
nms_idx = torchvision.ops.nms(
torch.from_numpy(detections.xyxy),
torch.from_numpy(detections.confidence),
NMS_THRESHOLD
).numpy().tolist()

detections.xyxy = detections.xyxy[nms_idx]
detections.confidence = detections.confidence[nms_idx]
detections.class_id = detections.class_id[nms_idx]

print(f"After NMS: {len(detections.xyxy)} boxes")

# Prompting SAM with detected boxes
def segment(sam_predictor: SamPredictor, image: np.ndarray, xyxy: np.ndarray) -> np.ndarray:
sam_predictor.set_image(image)
result_masks = []
for box in xyxy:
masks, scores, logits = sam_predictor.predict(
box=box,
multimask_output=True
)
index = np.argmax(scores)
result_masks.append(masks[index])
return np.array(result_masks)


# convert detections to masks
detections.mask = segment(
sam_predictor=sam_predictor,
image=cv2.cvtColor(image, cv2.COLOR_BGR2RGB),
xyxy=detections.xyxy
)

binary_mask = detections.mask[0].astype(np.uint8)*255
cv2.imwrite(args.OUT_FILE_BIN_MASK, binary_mask)

# annotate image with detections
box_annotator = sv.BoxAnnotator()
mask_annotator = sv.MaskAnnotator()
labels = [
f"{CLASSES[class_id]} {confidence:0.2f}"
for _, _, confidence, class_id, _
in detections]
annotated_image = mask_annotator.annotate(scene=image.copy(), detections=detections)
annotated_image = box_annotator.annotate(scene=annotated_image, detections=detections, labels=labels)
# save the annotated grounded-sam image
cv2.imwrite(args.OUT_FILE_SEG, annotated_image)

if __name__ == "__main__":
args = parse_args()
main(args)

0 comments on commit f34e1a3

Please sign in to comment.