forked from IDEA-Research/Grounded-Segment-Anything
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refine ImageBind-SAM playground (IDEA-Research#273)
* update imagebind sam * refine readme
- Loading branch information
Showing
5 changed files
with
332 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
import data | ||
import cv2 | ||
import torch | ||
from PIL import Image, ImageDraw | ||
from tqdm import tqdm | ||
from models import imagebind_model | ||
from models.imagebind_model import ModalityType | ||
|
||
from segment_anything import build_sam, SamAutomaticMaskGenerator | ||
|
||
from utils import ( | ||
segment_image, | ||
convert_box_xywh_to_xyxy, | ||
get_indices_of_values_above_threshold, | ||
) | ||
|
||
|
||
device = "cuda" if torch.cuda.is_available() else "cpu" | ||
|
||
|
||
""" | ||
Step 1: Instantiate model | ||
""" | ||
# Segment Anything | ||
mask_generator = SamAutomaticMaskGenerator( | ||
build_sam(checkpoint=".checkpoints/sam_vit_h_4b8939.pth").to(device), | ||
points_per_side=16, | ||
) | ||
|
||
# ImageBind | ||
bind_model = imagebind_model.imagebind_huge(pretrained=True) | ||
bind_model.eval() | ||
bind_model.to(device) | ||
|
||
|
||
""" | ||
Step 2: Generate auto masks with SAM | ||
""" | ||
image_path = ".assets/car_image.jpg" | ||
image = cv2.imread(image_path) | ||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | ||
masks = mask_generator.generate(image) | ||
|
||
|
||
""" | ||
Step 3: Get cropped images based on mask and box | ||
""" | ||
cropped_boxes = [] | ||
image = Image.open(image_path) | ||
for mask in tqdm(masks): | ||
cropped_boxes.append(segment_image(image, mask["segmentation"]).crop(convert_box_xywh_to_xyxy(mask["bbox"]))) | ||
|
||
|
||
""" | ||
Step 4: Run ImageBind model to get similarity between cropped image and different modalities | ||
""" | ||
def retriev_vision_and_audio(elements, audio_list): | ||
inputs = { | ||
ModalityType.VISION: data.load_and_transform_vision_data_from_pil_image(elements, device), | ||
ModalityType.AUDIO: data.load_and_transform_audio_data(audio_list, device), | ||
} | ||
with torch.no_grad(): | ||
embeddings = bind_model(inputs) | ||
vision_audio = torch.softmax(embeddings[ModalityType.VISION] @ embeddings[ModalityType.AUDIO].T, dim=0), | ||
return vision_audio | ||
|
||
vision_audio_result = retriev_vision_and_audio(cropped_boxes, [".assets/car_audio.wav"]) | ||
|
||
|
||
""" | ||
Step 5: Merge the top similarity masks to get the final mask and save the merged mask | ||
This is the audio retrival result | ||
""" | ||
|
||
# get highest similar mask with threshold | ||
# result[0] shape: [113, 1] | ||
threshold = 0.025 | ||
index = get_indices_of_values_above_threshold(vision_audio_result[0], threshold) | ||
|
||
segmentation_masks = [] | ||
for seg_idx in index: | ||
segmentation_mask_image = Image.fromarray(masks[seg_idx]["segmentation"].astype('uint8') * 255) | ||
segmentation_masks.append(segmentation_mask_image) | ||
|
||
original_image = Image.open(image_path) | ||
overlay_image = Image.new('RGBA', image.size, (0, 0, 0, 255)) | ||
overlay_color = (255, 255, 255, 0) | ||
|
||
draw = ImageDraw.Draw(overlay_image) | ||
for segmentation_mask_image in segmentation_masks: | ||
draw.bitmap((0, 0), segmentation_mask_image, fill=overlay_color) | ||
|
||
# return Image.alpha_composite(original_image.convert('RGBA'), overlay_image) | ||
mask_image = overlay_image.convert("RGB") | ||
mask_image.save("./audio_sam_merged_mask.jpg") | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
import data | ||
import cv2 | ||
import torch | ||
from PIL import Image, ImageDraw | ||
from tqdm import tqdm | ||
from models import imagebind_model | ||
from models.imagebind_model import ModalityType | ||
|
||
from segment_anything import build_sam, SamAutomaticMaskGenerator | ||
|
||
from utils import ( | ||
segment_image, | ||
convert_box_xywh_to_xyxy, | ||
get_indices_of_values_above_threshold, | ||
) | ||
|
||
|
||
device = "cuda" if torch.cuda.is_available() else "cpu" | ||
|
||
|
||
""" | ||
Step 1: Instantiate model | ||
""" | ||
# Segment Anything | ||
mask_generator = SamAutomaticMaskGenerator( | ||
build_sam(checkpoint=".checkpoints/sam_vit_h_4b8939.pth").to(device), | ||
points_per_side=16, | ||
) | ||
|
||
# ImageBind | ||
bind_model = imagebind_model.imagebind_huge(pretrained=True) | ||
bind_model.eval() | ||
bind_model.to(device) | ||
|
||
|
||
""" | ||
Step 2: Generate auto masks with SAM | ||
""" | ||
image_path = ".assets/car_image.jpg" | ||
image = cv2.imread(image_path) | ||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | ||
masks = mask_generator.generate(image) | ||
|
||
|
||
""" | ||
Step 3: Get cropped images based on mask and box | ||
""" | ||
cropped_boxes = [] | ||
image = Image.open(image_path) | ||
for mask in tqdm(masks): | ||
cropped_boxes.append(segment_image(image, mask["segmentation"]).crop(convert_box_xywh_to_xyxy(mask["bbox"]))) | ||
|
||
|
||
""" | ||
Step 4: Run ImageBind model to get similarity between cropped image and different modalities | ||
""" | ||
# load referring image | ||
referring_image_path = ".assets/referring_car_image.jpg" | ||
referring_image = Image.open(referring_image_path) | ||
|
||
image_list = [] | ||
image_list += cropped_boxes | ||
image_list.append(referring_image) | ||
|
||
def retriev_vision_and_vision(elements): | ||
inputs = { | ||
ModalityType.VISION: data.load_and_transform_vision_data_from_pil_image(elements, device), | ||
} | ||
with torch.no_grad(): | ||
embeddings = bind_model(inputs) | ||
|
||
# cropped box region embeddings | ||
cropped_box_embeddings = embeddings[ModalityType.VISION][:-1, :] | ||
referring_image_embeddings = embeddings[ModalityType.VISION][-1, :] | ||
|
||
vision_referring_result = torch.softmax(cropped_box_embeddings @ referring_image_embeddings.T, dim=0), | ||
return vision_referring_result # [113, 1] | ||
|
||
|
||
vision_referring_result = retriev_vision_and_vision(image_list) | ||
|
||
|
||
""" | ||
Step 5: Merge the top similarity masks to get the final mask and save the merged mask | ||
Image / Text mask | ||
""" | ||
|
||
# get highest similar mask with threshold | ||
# result[0] shape: [113, 1] | ||
threshold = 0.017 | ||
index = get_indices_of_values_above_threshold(vision_referring_result[0], threshold) | ||
|
||
|
||
segmentation_masks = [] | ||
for seg_idx in index: | ||
segmentation_mask_image = Image.fromarray(masks[seg_idx]["segmentation"].astype('uint8') * 255) | ||
segmentation_masks.append(segmentation_mask_image) | ||
|
||
original_image = Image.open(image_path) | ||
overlay_image = Image.new('RGBA', image.size, (0, 0, 0, 255)) | ||
overlay_color = (255, 255, 255, 0) | ||
|
||
draw = ImageDraw.Draw(overlay_image) | ||
for segmentation_mask_image in segmentation_masks: | ||
draw.bitmap((0, 0), segmentation_mask_image, fill=overlay_color) | ||
|
||
# return Image.alpha_composite(original_image.convert('RGBA'), overlay_image) | ||
mask_image = overlay_image.convert("RGB") | ||
mask_image.save("./image_referring_sam_merged_mask.jpg") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
import data | ||
import cv2 | ||
import torch | ||
from PIL import Image, ImageDraw | ||
from tqdm import tqdm | ||
from models import imagebind_model | ||
from models.imagebind_model import ModalityType | ||
|
||
from segment_anything import build_sam, SamAutomaticMaskGenerator | ||
|
||
from utils import ( | ||
segment_image, | ||
convert_box_xywh_to_xyxy, | ||
get_indices_of_values_above_threshold, | ||
) | ||
|
||
|
||
device = "cuda" if torch.cuda.is_available() else "cpu" | ||
|
||
|
||
""" | ||
Step 1: Instantiate model | ||
""" | ||
# Segment Anything | ||
mask_generator = SamAutomaticMaskGenerator( | ||
build_sam(checkpoint=".checkpoints/sam_vit_h_4b8939.pth").to(device), | ||
points_per_side=16, | ||
) | ||
|
||
# ImageBind | ||
bind_model = imagebind_model.imagebind_huge(pretrained=True) | ||
bind_model.eval() | ||
bind_model.to(device) | ||
|
||
|
||
""" | ||
Step 2: Generate auto masks with SAM | ||
""" | ||
image_path = ".assets/car_image.jpg" | ||
image = cv2.imread(image_path) | ||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | ||
masks = mask_generator.generate(image) | ||
|
||
|
||
""" | ||
Step 3: Get cropped images based on mask and box | ||
""" | ||
cropped_boxes = [] | ||
image = Image.open(image_path) | ||
for mask in tqdm(masks): | ||
cropped_boxes.append(segment_image(image, mask["segmentation"]).crop(convert_box_xywh_to_xyxy(mask["bbox"]))) | ||
|
||
|
||
""" | ||
Step 4: Run ImageBind model to get similarity between cropped image and different modalities | ||
""" | ||
def retriev_vision_and_text(elements, text_list): | ||
inputs = { | ||
ModalityType.VISION: data.load_and_transform_vision_data_from_pil_image(elements, device), | ||
ModalityType.TEXT: data.load_and_transform_text(text_list, device), | ||
} | ||
with torch.no_grad(): | ||
embeddings = bind_model(inputs) | ||
vision_audio = torch.softmax(embeddings[ModalityType.VISION] @ embeddings[ModalityType.TEXT].T, dim=0), | ||
return vision_audio # [113, 1] | ||
|
||
|
||
vision_text_result = retriev_vision_and_text(cropped_boxes, ["A car"] ) | ||
|
||
|
||
""" | ||
Step 5: Merge the top similarity masks to get the final mask and save the merged mask | ||
Image / Text mask | ||
""" | ||
|
||
# get highest similar mask with threshold | ||
# result[0] shape: [113, 1] | ||
threshold = 0.05 | ||
index = get_indices_of_values_above_threshold(vision_text_result[0], threshold) | ||
|
||
segmentation_masks = [] | ||
for seg_idx in index: | ||
segmentation_mask_image = Image.fromarray(masks[seg_idx]["segmentation"].astype('uint8') * 255) | ||
segmentation_masks.append(segmentation_mask_image) | ||
|
||
original_image = Image.open(image_path) | ||
overlay_image = Image.new('RGBA', image.size, (0, 0, 0, 255)) | ||
overlay_color = (255, 255, 255, 0) | ||
|
||
draw = ImageDraw.Draw(overlay_image) | ||
for segmentation_mask_image in segmentation_masks: | ||
draw.bitmap((0, 0), segmentation_mask_image, fill=overlay_color) | ||
|
||
# return Image.alpha_composite(original_image.convert('RGBA'), overlay_image) | ||
mask_image = overlay_image.convert("RGB") | ||
mask_image.save("./text_sam_merged_mask.jpg") |