diff --git a/grounded_sam_demo.py b/grounded_sam_demo.py index 93064c82..631b3627 100644 --- a/grounded_sam_demo.py +++ b/grounded_sam_demo.py @@ -17,7 +17,7 @@ # segment anything from segment_anything import ( sam_model_registry, - build_sam_hq, + sam_hq_model_registry, SamPredictor ) import cv2 @@ -99,7 +99,7 @@ def show_mask(mask, ax, random_color=False): def show_box(box, ax, label): x0, y0 = box[0], box[1] w, h = box[2] - box[0], box[3] - box[1] - ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2)) + ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2)) ax.text(x0, y0, label) @@ -130,7 +130,7 @@ def save_mask_data(output_dir, mask_list, box_list, label_list): }) with open(os.path.join(output_dir, 'mask.json'), 'w') as f: json.dump(json_data, f) - + if __name__ == "__main__": @@ -194,7 +194,7 @@ def save_mask_data(output_dir, mask_list, box_list, label_list): # initialize SAM if use_sam_hq: - predictor = SamPredictor(build_sam_hq(checkpoint=sam_hq_checkpoint).to(device)) + predictor = SamPredictor(sam_hq_model_registry[sam_version](checkpoint=sam_hq_checkpoint).to(device)) else: predictor = SamPredictor(sam_model_registry[sam_version](checkpoint=sam_checkpoint).to(device)) image = cv2.imread(image_path) @@ -217,7 +217,7 @@ def save_mask_data(output_dir, mask_list, box_list, label_list): boxes = transformed_boxes.to(device), multimask_output = False, ) - + # draw output image plt.figure(figsize=(10, 10)) plt.imshow(image) @@ -228,9 +228,8 @@ def save_mask_data(output_dir, mask_list, box_list, label_list): plt.axis('off') plt.savefig( - os.path.join(output_dir, "grounded_sam_output.jpg"), + os.path.join(output_dir, "grounded_sam_output.jpg"), bbox_inches="tight", dpi=300, pad_inches=0.0 ) save_mask_data(output_dir, masks, boxes_filt, pred_phrases) -