From 9a05b7aac4ca9e2f040b3e16ae43dcd3e8720ee3 Mon Sep 17 00:00:00 2001
From: Ren Tianhe <48727989+rentainhe@users.noreply.github.com>
Date: Tue, 23 May 2023 13:17:51 +0800
Subject: [PATCH] Refine ImageBind-SAM playground (#273)
* update imagebind sam
* refine readme
---
README.md | 1 +
playground/ImageBind_SAM/README.md | 28 ++++-
.../ImageBind_SAM/audio_referring_seg_demo.py | 97 +++++++++++++++
.../ImageBind_SAM/image_referring_seg_demo.py | 110 ++++++++++++++++++
.../ImageBind_SAM/text_referring_seg_demo.py | 97 +++++++++++++++
5 files changed, 332 insertions(+), 1 deletion(-)
create mode 100644 playground/ImageBind_SAM/audio_referring_seg_demo.py
create mode 100644 playground/ImageBind_SAM/image_referring_seg_demo.py
create mode 100644 playground/ImageBind_SAM/text_referring_seg_demo.py
diff --git a/README.md b/README.md
index eb2b7b5f..e4ac64cf 100644
--- a/README.md
+++ b/README.md
@@ -14,6 +14,7 @@ We are very willing to **help everyone share and promote new projects** based on
The **core idea** behind this project is to **combine the strengths of different models in order to build a very powerful pipeline for solving complex problems**. And it's worth mentioning that this is a workflow for combining strong expert models, where **all parts can be used separately or in combination, and can be replaced with any similar but different models (like replacing Grounding DINO with GLIP or other detectors / replacing Stable-Diffusion with ControlNet or GLIGEN/ Combining with ChatGPT)**.
**🍇 Updates**
+- **`2023/05/23`**: Support `Image-Referring-Segment`, `Audio-Referring-Segment` and `Text-Referring-Segment` in [ImageBind-SAM](./playground/ImageBind_SAM/).
- **`2023/05/16`**: Release [ImageBind-SAM](./playground/ImageBind_SAM/) simple demo which aims to segment with different modalities.
- **`2023/05/15`**: Release [LaMa](./playground/LaMa/) and [RePaint](./playground/RePaint/) demo, thanks for nice tips by [Tao Yu](https://github.com/geekyutao).
- **`2023/05/14`**: Release [PaintByExample](./playground/PaintByExample/) demo with SAM.
diff --git a/playground/ImageBind_SAM/README.md b/playground/ImageBind_SAM/README.md
index 9f1c81ae..045b7962 100644
--- a/playground/ImageBind_SAM/README.md
+++ b/playground/ImageBind_SAM/README.md
@@ -12,6 +12,10 @@ This basic idea is followed with [IEA: Image Editing Anything](https://github.co
## Table of contents
- [Installation](#installation)
- [ImageBind-SAM Demo](#run-the-demo)
+- [Audio Referring Segment](#run-audio-referring-segment-demo)
+- [Text Referring Segment](#run-text-referring-segment-demo)
+- [Image Referring Segment](#run-image-referring-segmentation-demo)
+
## Installation
@@ -45,7 +49,29 @@ We implement `Text Seg` and `Audio Seg` in this demo, the generate masks will be
|:----:|:----:|:----:|
|  | [car audio](./.assets/car_audio.wav) |  |
|  | "A car" |  |
+|  |

|  |
+
-By setting different threshold may influence a lot on the final results.
\ No newline at end of file
+By setting different threshold may influence a lot on the final results.
+
+## Run image referring segmentation demo
+```bash
+# download the referring image
+cd .assets
+wget https://github.com/IDEA-Research/detrex-storage/releases/download/grounded-sam-storage/referring_car_image.jpg
+cd ..
+
+python image_referring_seg_demo.py
+```
+
+## Run audio referring segmentation demo
+```bash
+python audio_referring_seg_demo.py
+```
+
+## Run text referring segmentation demo
+```bash
+python text_referring_seg_demo.py
+```
\ No newline at end of file
diff --git a/playground/ImageBind_SAM/audio_referring_seg_demo.py b/playground/ImageBind_SAM/audio_referring_seg_demo.py
new file mode 100644
index 00000000..d5108aeb
--- /dev/null
+++ b/playground/ImageBind_SAM/audio_referring_seg_demo.py
@@ -0,0 +1,97 @@
+import data
+import cv2
+import torch
+from PIL import Image, ImageDraw
+from tqdm import tqdm
+from models import imagebind_model
+from models.imagebind_model import ModalityType
+
+from segment_anything import build_sam, SamAutomaticMaskGenerator
+
+from utils import (
+ segment_image,
+ convert_box_xywh_to_xyxy,
+ get_indices_of_values_above_threshold,
+)
+
+
+device = "cuda" if torch.cuda.is_available() else "cpu"
+
+
+"""
+Step 1: Instantiate model
+"""
+# Segment Anything
+mask_generator = SamAutomaticMaskGenerator(
+ build_sam(checkpoint=".checkpoints/sam_vit_h_4b8939.pth").to(device),
+ points_per_side=16,
+)
+
+# ImageBind
+bind_model = imagebind_model.imagebind_huge(pretrained=True)
+bind_model.eval()
+bind_model.to(device)
+
+
+"""
+Step 2: Generate auto masks with SAM
+"""
+image_path = ".assets/car_image.jpg"
+image = cv2.imread(image_path)
+image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
+masks = mask_generator.generate(image)
+
+
+"""
+Step 3: Get cropped images based on mask and box
+"""
+cropped_boxes = []
+image = Image.open(image_path)
+for mask in tqdm(masks):
+ cropped_boxes.append(segment_image(image, mask["segmentation"]).crop(convert_box_xywh_to_xyxy(mask["bbox"])))
+
+
+"""
+Step 4: Run ImageBind model to get similarity between cropped image and different modalities
+"""
+def retriev_vision_and_audio(elements, audio_list):
+ inputs = {
+ ModalityType.VISION: data.load_and_transform_vision_data_from_pil_image(elements, device),
+ ModalityType.AUDIO: data.load_and_transform_audio_data(audio_list, device),
+ }
+ with torch.no_grad():
+ embeddings = bind_model(inputs)
+ vision_audio = torch.softmax(embeddings[ModalityType.VISION] @ embeddings[ModalityType.AUDIO].T, dim=0),
+ return vision_audio
+
+vision_audio_result = retriev_vision_and_audio(cropped_boxes, [".assets/car_audio.wav"])
+
+
+"""
+Step 5: Merge the top similarity masks to get the final mask and save the merged mask
+
+This is the audio retrival result
+"""
+
+# get highest similar mask with threshold
+# result[0] shape: [113, 1]
+threshold = 0.025
+index = get_indices_of_values_above_threshold(vision_audio_result[0], threshold)
+
+segmentation_masks = []
+for seg_idx in index:
+ segmentation_mask_image = Image.fromarray(masks[seg_idx]["segmentation"].astype('uint8') * 255)
+ segmentation_masks.append(segmentation_mask_image)
+
+original_image = Image.open(image_path)
+overlay_image = Image.new('RGBA', image.size, (0, 0, 0, 255))
+overlay_color = (255, 255, 255, 0)
+
+draw = ImageDraw.Draw(overlay_image)
+for segmentation_mask_image in segmentation_masks:
+ draw.bitmap((0, 0), segmentation_mask_image, fill=overlay_color)
+
+# return Image.alpha_composite(original_image.convert('RGBA'), overlay_image)
+mask_image = overlay_image.convert("RGB")
+mask_image.save("./audio_sam_merged_mask.jpg")
+
diff --git a/playground/ImageBind_SAM/image_referring_seg_demo.py b/playground/ImageBind_SAM/image_referring_seg_demo.py
new file mode 100644
index 00000000..d7f03ef4
--- /dev/null
+++ b/playground/ImageBind_SAM/image_referring_seg_demo.py
@@ -0,0 +1,110 @@
+import data
+import cv2
+import torch
+from PIL import Image, ImageDraw
+from tqdm import tqdm
+from models import imagebind_model
+from models.imagebind_model import ModalityType
+
+from segment_anything import build_sam, SamAutomaticMaskGenerator
+
+from utils import (
+ segment_image,
+ convert_box_xywh_to_xyxy,
+ get_indices_of_values_above_threshold,
+)
+
+
+device = "cuda" if torch.cuda.is_available() else "cpu"
+
+
+"""
+Step 1: Instantiate model
+"""
+# Segment Anything
+mask_generator = SamAutomaticMaskGenerator(
+ build_sam(checkpoint=".checkpoints/sam_vit_h_4b8939.pth").to(device),
+ points_per_side=16,
+)
+
+# ImageBind
+bind_model = imagebind_model.imagebind_huge(pretrained=True)
+bind_model.eval()
+bind_model.to(device)
+
+
+"""
+Step 2: Generate auto masks with SAM
+"""
+image_path = ".assets/car_image.jpg"
+image = cv2.imread(image_path)
+image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
+masks = mask_generator.generate(image)
+
+
+"""
+Step 3: Get cropped images based on mask and box
+"""
+cropped_boxes = []
+image = Image.open(image_path)
+for mask in tqdm(masks):
+ cropped_boxes.append(segment_image(image, mask["segmentation"]).crop(convert_box_xywh_to_xyxy(mask["bbox"])))
+
+
+"""
+Step 4: Run ImageBind model to get similarity between cropped image and different modalities
+"""
+# load referring image
+referring_image_path = ".assets/referring_car_image.jpg"
+referring_image = Image.open(referring_image_path)
+
+image_list = []
+image_list += cropped_boxes
+image_list.append(referring_image)
+
+def retriev_vision_and_vision(elements):
+ inputs = {
+ ModalityType.VISION: data.load_and_transform_vision_data_from_pil_image(elements, device),
+ }
+ with torch.no_grad():
+ embeddings = bind_model(inputs)
+
+ # cropped box region embeddings
+ cropped_box_embeddings = embeddings[ModalityType.VISION][:-1, :]
+ referring_image_embeddings = embeddings[ModalityType.VISION][-1, :]
+
+ vision_referring_result = torch.softmax(cropped_box_embeddings @ referring_image_embeddings.T, dim=0),
+ return vision_referring_result # [113, 1]
+
+
+vision_referring_result = retriev_vision_and_vision(image_list)
+
+
+"""
+Step 5: Merge the top similarity masks to get the final mask and save the merged mask
+
+Image / Text mask
+"""
+
+# get highest similar mask with threshold
+# result[0] shape: [113, 1]
+threshold = 0.017
+index = get_indices_of_values_above_threshold(vision_referring_result[0], threshold)
+
+
+segmentation_masks = []
+for seg_idx in index:
+ segmentation_mask_image = Image.fromarray(masks[seg_idx]["segmentation"].astype('uint8') * 255)
+ segmentation_masks.append(segmentation_mask_image)
+
+original_image = Image.open(image_path)
+overlay_image = Image.new('RGBA', image.size, (0, 0, 0, 255))
+overlay_color = (255, 255, 255, 0)
+
+draw = ImageDraw.Draw(overlay_image)
+for segmentation_mask_image in segmentation_masks:
+ draw.bitmap((0, 0), segmentation_mask_image, fill=overlay_color)
+
+# return Image.alpha_composite(original_image.convert('RGBA'), overlay_image)
+mask_image = overlay_image.convert("RGB")
+mask_image.save("./image_referring_sam_merged_mask.jpg")
diff --git a/playground/ImageBind_SAM/text_referring_seg_demo.py b/playground/ImageBind_SAM/text_referring_seg_demo.py
new file mode 100644
index 00000000..17ee74a7
--- /dev/null
+++ b/playground/ImageBind_SAM/text_referring_seg_demo.py
@@ -0,0 +1,97 @@
+import data
+import cv2
+import torch
+from PIL import Image, ImageDraw
+from tqdm import tqdm
+from models import imagebind_model
+from models.imagebind_model import ModalityType
+
+from segment_anything import build_sam, SamAutomaticMaskGenerator
+
+from utils import (
+ segment_image,
+ convert_box_xywh_to_xyxy,
+ get_indices_of_values_above_threshold,
+)
+
+
+device = "cuda" if torch.cuda.is_available() else "cpu"
+
+
+"""
+Step 1: Instantiate model
+"""
+# Segment Anything
+mask_generator = SamAutomaticMaskGenerator(
+ build_sam(checkpoint=".checkpoints/sam_vit_h_4b8939.pth").to(device),
+ points_per_side=16,
+)
+
+# ImageBind
+bind_model = imagebind_model.imagebind_huge(pretrained=True)
+bind_model.eval()
+bind_model.to(device)
+
+
+"""
+Step 2: Generate auto masks with SAM
+"""
+image_path = ".assets/car_image.jpg"
+image = cv2.imread(image_path)
+image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
+masks = mask_generator.generate(image)
+
+
+"""
+Step 3: Get cropped images based on mask and box
+"""
+cropped_boxes = []
+image = Image.open(image_path)
+for mask in tqdm(masks):
+ cropped_boxes.append(segment_image(image, mask["segmentation"]).crop(convert_box_xywh_to_xyxy(mask["bbox"])))
+
+
+"""
+Step 4: Run ImageBind model to get similarity between cropped image and different modalities
+"""
+def retriev_vision_and_text(elements, text_list):
+ inputs = {
+ ModalityType.VISION: data.load_and_transform_vision_data_from_pil_image(elements, device),
+ ModalityType.TEXT: data.load_and_transform_text(text_list, device),
+ }
+ with torch.no_grad():
+ embeddings = bind_model(inputs)
+ vision_audio = torch.softmax(embeddings[ModalityType.VISION] @ embeddings[ModalityType.TEXT].T, dim=0),
+ return vision_audio # [113, 1]
+
+
+vision_text_result = retriev_vision_and_text(cropped_boxes, ["A car"] )
+
+
+"""
+Step 5: Merge the top similarity masks to get the final mask and save the merged mask
+
+Image / Text mask
+"""
+
+# get highest similar mask with threshold
+# result[0] shape: [113, 1]
+threshold = 0.05
+index = get_indices_of_values_above_threshold(vision_text_result[0], threshold)
+
+segmentation_masks = []
+for seg_idx in index:
+ segmentation_mask_image = Image.fromarray(masks[seg_idx]["segmentation"].astype('uint8') * 255)
+ segmentation_masks.append(segmentation_mask_image)
+
+original_image = Image.open(image_path)
+overlay_image = Image.new('RGBA', image.size, (0, 0, 0, 255))
+overlay_color = (255, 255, 255, 0)
+
+draw = ImageDraw.Draw(overlay_image)
+for segmentation_mask_image in segmentation_masks:
+ draw.bitmap((0, 0), segmentation_mask_image, fill=overlay_color)
+
+# return Image.alpha_composite(original_image.convert('RGBA'), overlay_image)
+mask_image = overlay_image.convert("RGB")
+mask_image.save("./text_sam_merged_mask.jpg")