forked from IDEA-Research/Grounded-Segment-Anything
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added args for grounded_mobile_sam (IDEA-Research#364)
* trial_1 * added args and option to save binary mask * updated grounded_mobile_sam usage
- Loading branch information
1 parent
a4d76a2
commit f34e1a3
Showing
2 changed files
with
137 additions
and
100 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,108 +1,145 @@ | ||
import cv2 | ||
import numpy as np | ||
import supervision as sv | ||
|
||
import argparse | ||
import torch | ||
import torchvision | ||
|
||
from groundingdino.util.inference import Model | ||
from segment_anything import SamPredictor | ||
from MobileSAM.setup_mobile_sam import setup_model | ||
|
||
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | ||
|
||
# GroundingDINO config and checkpoint | ||
GROUNDING_DINO_CONFIG_PATH = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py" | ||
GROUNDING_DINO_CHECKPOINT_PATH = "./groundingdino_swint_ogc.pth" | ||
|
||
# Building GroundingDINO inference model | ||
grounding_dino_model = Model(model_config_path=GROUNDING_DINO_CONFIG_PATH, model_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH) | ||
|
||
# Building MobileSAM predictor | ||
MOBILE_SAM_CHECKPOINT_PATH = "./EfficientSAM/mobile_sam.pt" | ||
checkpoint = torch.load(MOBILE_SAM_CHECKPOINT_PATH) | ||
mobile_sam = setup_model() | ||
mobile_sam.load_state_dict(checkpoint, strict=True) | ||
mobile_sam.to(device=DEVICE) | ||
|
||
sam_predictor = SamPredictor(mobile_sam) | ||
|
||
|
||
# Predict classes and hyper-param for GroundingDINO | ||
SOURCE_IMAGE_PATH = "./assets/demo2.jpg" | ||
CLASSES = ["The running dog"] | ||
BOX_THRESHOLD = 0.25 | ||
TEXT_THRESHOLD = 0.25 | ||
NMS_THRESHOLD = 0.8 | ||
|
||
|
||
# load image | ||
image = cv2.imread(SOURCE_IMAGE_PATH) | ||
|
||
# detect objects | ||
detections = grounding_dino_model.predict_with_classes( | ||
image=image, | ||
classes=CLASSES, | ||
box_threshold=BOX_THRESHOLD, | ||
text_threshold=BOX_THRESHOLD | ||
) | ||
|
||
# annotate image with detections | ||
box_annotator = sv.BoxAnnotator() | ||
labels = [ | ||
f"{CLASSES[class_id]} {confidence:0.2f}" | ||
for _, _, confidence, class_id, _ | ||
in detections] | ||
annotated_frame = box_annotator.annotate(scene=image.copy(), detections=detections, labels=labels) | ||
|
||
# save the annotated grounding dino image | ||
cv2.imwrite("groundingdino_annotated_image.jpg", annotated_frame) | ||
|
||
|
||
# NMS post process | ||
print(f"Before NMS: {len(detections.xyxy)} boxes") | ||
nms_idx = torchvision.ops.nms( | ||
torch.from_numpy(detections.xyxy), | ||
torch.from_numpy(detections.confidence), | ||
NMS_THRESHOLD | ||
).numpy().tolist() | ||
|
||
detections.xyxy = detections.xyxy[nms_idx] | ||
detections.confidence = detections.confidence[nms_idx] | ||
detections.class_id = detections.class_id[nms_idx] | ||
|
||
print(f"After NMS: {len(detections.xyxy)} boxes") | ||
|
||
# Prompting SAM with detected boxes | ||
def segment(sam_predictor: SamPredictor, image: np.ndarray, xyxy: np.ndarray) -> np.ndarray: | ||
sam_predictor.set_image(image) | ||
result_masks = [] | ||
for box in xyxy: | ||
masks, scores, logits = sam_predictor.predict( | ||
box=box, | ||
multimask_output=True | ||
) | ||
index = np.argmax(scores) | ||
result_masks.append(masks[index]) | ||
return np.array(result_masks) | ||
|
||
|
||
# convert detections to masks | ||
detections.mask = segment( | ||
sam_predictor=sam_predictor, | ||
image=cv2.cvtColor(image, cv2.COLOR_BGR2RGB), | ||
xyxy=detections.xyxy | ||
) | ||
|
||
# annotate image with detections | ||
box_annotator = sv.BoxAnnotator() | ||
mask_annotator = sv.MaskAnnotator() | ||
labels = [ | ||
f"{CLASSES[class_id]} {confidence:0.2f}" | ||
for _, _, confidence, class_id, _ | ||
in detections] | ||
annotated_image = mask_annotator.annotate(scene=image.copy(), detections=detections) | ||
annotated_image = box_annotator.annotate(scene=annotated_image, detections=detections, labels=labels) | ||
|
||
# save the annotated grounded-sam image | ||
cv2.imwrite("grounded_mobile_sam_annotated_image.jpg", annotated_image) | ||
def parse_args(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--MOBILE_SAM_CHECKPOINT_PATH", type=str, default="./EfficientSAM/mobile_sam.pt", help="model" | ||
) | ||
parser.add_argument( | ||
"--SOURCE_IMAGE_PATH", type=str, default="./assets/demo2.jpg", help="path to image file" | ||
) | ||
parser.add_argument( | ||
"--CAPTION", type=str, default="The running dog", help="text prompt for GroundingDINO" | ||
) | ||
parser.add_argument( | ||
"--OUT_FILE_BOX", type=str, default="groundingdino_annotated_image.jpg", help="the output filename" | ||
) | ||
parser.add_argument( | ||
"--OUT_FILE_SEG", type=str, default="grounded_mobile_sam_annotated_image.jpg", help="the output filename" | ||
) | ||
parser.add_argument( | ||
"--OUT_FILE_BIN_MASK", type=str, default="grounded_mobile_sam_bin_mask.jpg", help="the output filename" | ||
) | ||
parser.add_argument("--BOX_THRESHOLD", type=float, default=0.25, help="") | ||
parser.add_argument("--TEXT_THRESHOLD", type=float, default=0.25, help="") | ||
parser.add_argument("--NMS_THRESHOLD", type=float, default=0.8, help="") | ||
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
parser.add_argument( | ||
"--DEVICE", type=str, default=device, help="cuda:[0,1,2,3,4] or cpu" | ||
) | ||
return parser.parse_args() | ||
|
||
def main(args): | ||
DEVICE = args.DEVICE | ||
|
||
# GroundingDINO config and checkpoint | ||
GROUNDING_DINO_CONFIG_PATH = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py" | ||
GROUNDING_DINO_CHECKPOINT_PATH = "./groundingdino_swint_ogc.pth" | ||
|
||
# Building GroundingDINO inference model | ||
grounding_dino_model = Model(model_config_path=GROUNDING_DINO_CONFIG_PATH, model_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH) | ||
|
||
# Building MobileSAM predictor | ||
MOBILE_SAM_CHECKPOINT_PATH = args.MOBILE_SAM_CHECKPOINT_PATH | ||
checkpoint = torch.load(MOBILE_SAM_CHECKPOINT_PATH) | ||
mobile_sam = setup_model() | ||
mobile_sam.load_state_dict(checkpoint, strict=True) | ||
mobile_sam.to(device=DEVICE) | ||
|
||
sam_predictor = SamPredictor(mobile_sam) | ||
|
||
|
||
# Predict classes and hyper-param for GroundingDINO | ||
SOURCE_IMAGE_PATH = args.SOURCE_IMAGE_PATH | ||
CLASSES = [args.CAPTION] | ||
BOX_THRESHOLD = args.BOX_THRESHOLD | ||
TEXT_THRESHOLD = args.TEXT_THRESHOLD | ||
NMS_THRESHOLD = args.NMS_THRESHOLD | ||
|
||
|
||
# load image | ||
image = cv2.imread(SOURCE_IMAGE_PATH) | ||
|
||
# detect objects | ||
detections = grounding_dino_model.predict_with_classes( | ||
image=image, | ||
classes=CLASSES, | ||
box_threshold=BOX_THRESHOLD, | ||
text_threshold=TEXT_THRESHOLD | ||
) | ||
|
||
# annotate image with detections | ||
box_annotator = sv.BoxAnnotator() | ||
labels = [ | ||
f"{CLASSES[class_id]} {confidence:0.2f}" | ||
for _, _, confidence, class_id, _ | ||
in detections] | ||
annotated_frame = box_annotator.annotate(scene=image.copy(), detections=detections, labels=labels) | ||
|
||
# save the annotated grounding dino image | ||
cv2.imwrite(args.OUT_FILE_BOX, annotated_frame) | ||
|
||
|
||
# NMS post process | ||
print(f"Before NMS: {len(detections.xyxy)} boxes") | ||
nms_idx = torchvision.ops.nms( | ||
torch.from_numpy(detections.xyxy), | ||
torch.from_numpy(detections.confidence), | ||
NMS_THRESHOLD | ||
).numpy().tolist() | ||
|
||
detections.xyxy = detections.xyxy[nms_idx] | ||
detections.confidence = detections.confidence[nms_idx] | ||
detections.class_id = detections.class_id[nms_idx] | ||
|
||
print(f"After NMS: {len(detections.xyxy)} boxes") | ||
|
||
# Prompting SAM with detected boxes | ||
def segment(sam_predictor: SamPredictor, image: np.ndarray, xyxy: np.ndarray) -> np.ndarray: | ||
sam_predictor.set_image(image) | ||
result_masks = [] | ||
for box in xyxy: | ||
masks, scores, logits = sam_predictor.predict( | ||
box=box, | ||
multimask_output=True | ||
) | ||
index = np.argmax(scores) | ||
result_masks.append(masks[index]) | ||
return np.array(result_masks) | ||
|
||
|
||
# convert detections to masks | ||
detections.mask = segment( | ||
sam_predictor=sam_predictor, | ||
image=cv2.cvtColor(image, cv2.COLOR_BGR2RGB), | ||
xyxy=detections.xyxy | ||
) | ||
|
||
binary_mask = detections.mask[0].astype(np.uint8)*255 | ||
cv2.imwrite(args.OUT_FILE_BIN_MASK, binary_mask) | ||
|
||
# annotate image with detections | ||
box_annotator = sv.BoxAnnotator() | ||
mask_annotator = sv.MaskAnnotator() | ||
labels = [ | ||
f"{CLASSES[class_id]} {confidence:0.2f}" | ||
for _, _, confidence, class_id, _ | ||
in detections] | ||
annotated_image = mask_annotator.annotate(scene=image.copy(), detections=detections) | ||
annotated_image = box_annotator.annotate(scene=annotated_image, detections=detections, labels=labels) | ||
# save the annotated grounded-sam image | ||
cv2.imwrite(args.OUT_FILE_SEG, annotated_image) | ||
|
||
if __name__ == "__main__": | ||
args = parse_args() | ||
main(args) |