Skip to content

Commit

Permalink
Update grounded_sam_demo.py (IDEA-Research#382)
Browse files Browse the repository at this point in the history
Enable the use of vit_l and vit_b with sam_hq in grounded_sam_demo.py.
  • Loading branch information
andoasaki authored Oct 16, 2023
1 parent aca56fe commit b579761
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions grounded_sam_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# segment anything
from segment_anything import (
sam_model_registry,
build_sam_hq,
sam_hq_model_registry,
SamPredictor
)
import cv2
Expand Down Expand Up @@ -99,7 +99,7 @@ def show_mask(mask, ax, random_color=False):
def show_box(box, ax, label):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
ax.text(x0, y0, label)


Expand Down Expand Up @@ -130,7 +130,7 @@ def save_mask_data(output_dir, mask_list, box_list, label_list):
})
with open(os.path.join(output_dir, 'mask.json'), 'w') as f:
json.dump(json_data, f)


if __name__ == "__main__":

Expand Down Expand Up @@ -194,7 +194,7 @@ def save_mask_data(output_dir, mask_list, box_list, label_list):

# initialize SAM
if use_sam_hq:
predictor = SamPredictor(build_sam_hq(checkpoint=sam_hq_checkpoint).to(device))
predictor = SamPredictor(sam_hq_model_registry[sam_version](checkpoint=sam_hq_checkpoint).to(device))
else:
predictor = SamPredictor(sam_model_registry[sam_version](checkpoint=sam_checkpoint).to(device))
image = cv2.imread(image_path)
Expand All @@ -217,7 +217,7 @@ def save_mask_data(output_dir, mask_list, box_list, label_list):
boxes = transformed_boxes.to(device),
multimask_output = False,
)

# draw output image
plt.figure(figsize=(10, 10))
plt.imshow(image)
Expand All @@ -228,9 +228,8 @@ def save_mask_data(output_dir, mask_list, box_list, label_list):

plt.axis('off')
plt.savefig(
os.path.join(output_dir, "grounded_sam_output.jpg"),
os.path.join(output_dir, "grounded_sam_output.jpg"),
bbox_inches="tight", dpi=300, pad_inches=0.0
)

save_mask_data(output_dir, masks, boxes_filt, pred_phrases)

0 comments on commit b579761

Please sign in to comment.