From 9a05b7aac4ca9e2f040b3e16ae43dcd3e8720ee3 Mon Sep 17 00:00:00 2001 From: Ren Tianhe <48727989+rentainhe@users.noreply.github.com> Date: Tue, 23 May 2023 13:17:51 +0800 Subject: [PATCH] Refine ImageBind-SAM playground (#273) * update imagebind sam * refine readme --- README.md | 1 + playground/ImageBind_SAM/README.md | 28 ++++- .../ImageBind_SAM/audio_referring_seg_demo.py | 97 +++++++++++++++ .../ImageBind_SAM/image_referring_seg_demo.py | 110 ++++++++++++++++++ .../ImageBind_SAM/text_referring_seg_demo.py | 97 +++++++++++++++ 5 files changed, 332 insertions(+), 1 deletion(-) create mode 100644 playground/ImageBind_SAM/audio_referring_seg_demo.py create mode 100644 playground/ImageBind_SAM/image_referring_seg_demo.py create mode 100644 playground/ImageBind_SAM/text_referring_seg_demo.py diff --git a/README.md b/README.md index eb2b7b5f..e4ac64cf 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,7 @@ We are very willing to **help everyone share and promote new projects** based on The **core idea** behind this project is to **combine the strengths of different models in order to build a very powerful pipeline for solving complex problems**. And it's worth mentioning that this is a workflow for combining strong expert models, where **all parts can be used separately or in combination, and can be replaced with any similar but different models (like replacing Grounding DINO with GLIP or other detectors / replacing Stable-Diffusion with ControlNet or GLIGEN/ Combining with ChatGPT)**. **🍇 Updates** +- **`2023/05/23`**: Support `Image-Referring-Segment`, `Audio-Referring-Segment` and `Text-Referring-Segment` in [ImageBind-SAM](./playground/ImageBind_SAM/). - **`2023/05/16`**: Release [ImageBind-SAM](./playground/ImageBind_SAM/) simple demo which aims to segment with different modalities. - **`2023/05/15`**: Release [LaMa](./playground/LaMa/) and [RePaint](./playground/RePaint/) demo, thanks for nice tips by [Tao Yu](https://github.com/geekyutao). - **`2023/05/14`**: Release [PaintByExample](./playground/PaintByExample/) demo with SAM. diff --git a/playground/ImageBind_SAM/README.md b/playground/ImageBind_SAM/README.md index 9f1c81ae..045b7962 100644 --- a/playground/ImageBind_SAM/README.md +++ b/playground/ImageBind_SAM/README.md @@ -12,6 +12,10 @@ This basic idea is followed with [IEA: Image Editing Anything](https://github.co ## Table of contents - [Installation](#installation) - [ImageBind-SAM Demo](#run-the-demo) +- [Audio Referring Segment](#run-audio-referring-segment-demo) +- [Text Referring Segment](#run-text-referring-segment-demo) +- [Image Referring Segment](#run-image-referring-segmentation-demo) + ## Installation @@ -45,7 +49,29 @@ We implement `Text Seg` and `Audio Seg` in this demo, the generate masks will be |:----:|:----:|:----:| | ![](./.assets/car_image.jpg) | [car audio](./.assets/car_audio.wav) | ![](https://github.com/IDEA-Research/detrex-storage/blob/main/assets/grounded_sam/imagebind_sam/audio_sam_merged_mask_new.jpg?raw=true) | | ![](./.assets/car_image.jpg) | "A car" | ![](https://github.com/IDEA-Research/detrex-storage/blob/main/assets/grounded_sam/imagebind_sam/text_sam_merged_mask.jpg?raw=true) | +| ![](./.assets/car_image.jpg) |
| ![](https://github.com/IDEA-Research/detrex-storage/blob/main/assets/grounded_sam/imagebind_sam/image_referring_sam_merged_mask.jpg?raw=true) | + -By setting different threshold may influence a lot on the final results. \ No newline at end of file +By setting different threshold may influence a lot on the final results. + +## Run image referring segmentation demo +```bash +# download the referring image +cd .assets +wget https://github.com/IDEA-Research/detrex-storage/releases/download/grounded-sam-storage/referring_car_image.jpg +cd .. + +python image_referring_seg_demo.py +``` + +## Run audio referring segmentation demo +```bash +python audio_referring_seg_demo.py +``` + +## Run text referring segmentation demo +```bash +python text_referring_seg_demo.py +``` \ No newline at end of file diff --git a/playground/ImageBind_SAM/audio_referring_seg_demo.py b/playground/ImageBind_SAM/audio_referring_seg_demo.py new file mode 100644 index 00000000..d5108aeb --- /dev/null +++ b/playground/ImageBind_SAM/audio_referring_seg_demo.py @@ -0,0 +1,97 @@ +import data +import cv2 +import torch +from PIL import Image, ImageDraw +from tqdm import tqdm +from models import imagebind_model +from models.imagebind_model import ModalityType + +from segment_anything import build_sam, SamAutomaticMaskGenerator + +from utils import ( + segment_image, + convert_box_xywh_to_xyxy, + get_indices_of_values_above_threshold, +) + + +device = "cuda" if torch.cuda.is_available() else "cpu" + + +""" +Step 1: Instantiate model +""" +# Segment Anything +mask_generator = SamAutomaticMaskGenerator( + build_sam(checkpoint=".checkpoints/sam_vit_h_4b8939.pth").to(device), + points_per_side=16, +) + +# ImageBind +bind_model = imagebind_model.imagebind_huge(pretrained=True) +bind_model.eval() +bind_model.to(device) + + +""" +Step 2: Generate auto masks with SAM +""" +image_path = ".assets/car_image.jpg" +image = cv2.imread(image_path) +image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) +masks = mask_generator.generate(image) + + +""" +Step 3: Get cropped images based on mask and box +""" +cropped_boxes = [] +image = Image.open(image_path) +for mask in tqdm(masks): + cropped_boxes.append(segment_image(image, mask["segmentation"]).crop(convert_box_xywh_to_xyxy(mask["bbox"]))) + + +""" +Step 4: Run ImageBind model to get similarity between cropped image and different modalities +""" +def retriev_vision_and_audio(elements, audio_list): + inputs = { + ModalityType.VISION: data.load_and_transform_vision_data_from_pil_image(elements, device), + ModalityType.AUDIO: data.load_and_transform_audio_data(audio_list, device), + } + with torch.no_grad(): + embeddings = bind_model(inputs) + vision_audio = torch.softmax(embeddings[ModalityType.VISION] @ embeddings[ModalityType.AUDIO].T, dim=0), + return vision_audio + +vision_audio_result = retriev_vision_and_audio(cropped_boxes, [".assets/car_audio.wav"]) + + +""" +Step 5: Merge the top similarity masks to get the final mask and save the merged mask + +This is the audio retrival result +""" + +# get highest similar mask with threshold +# result[0] shape: [113, 1] +threshold = 0.025 +index = get_indices_of_values_above_threshold(vision_audio_result[0], threshold) + +segmentation_masks = [] +for seg_idx in index: + segmentation_mask_image = Image.fromarray(masks[seg_idx]["segmentation"].astype('uint8') * 255) + segmentation_masks.append(segmentation_mask_image) + +original_image = Image.open(image_path) +overlay_image = Image.new('RGBA', image.size, (0, 0, 0, 255)) +overlay_color = (255, 255, 255, 0) + +draw = ImageDraw.Draw(overlay_image) +for segmentation_mask_image in segmentation_masks: + draw.bitmap((0, 0), segmentation_mask_image, fill=overlay_color) + +# return Image.alpha_composite(original_image.convert('RGBA'), overlay_image) +mask_image = overlay_image.convert("RGB") +mask_image.save("./audio_sam_merged_mask.jpg") + diff --git a/playground/ImageBind_SAM/image_referring_seg_demo.py b/playground/ImageBind_SAM/image_referring_seg_demo.py new file mode 100644 index 00000000..d7f03ef4 --- /dev/null +++ b/playground/ImageBind_SAM/image_referring_seg_demo.py @@ -0,0 +1,110 @@ +import data +import cv2 +import torch +from PIL import Image, ImageDraw +from tqdm import tqdm +from models import imagebind_model +from models.imagebind_model import ModalityType + +from segment_anything import build_sam, SamAutomaticMaskGenerator + +from utils import ( + segment_image, + convert_box_xywh_to_xyxy, + get_indices_of_values_above_threshold, +) + + +device = "cuda" if torch.cuda.is_available() else "cpu" + + +""" +Step 1: Instantiate model +""" +# Segment Anything +mask_generator = SamAutomaticMaskGenerator( + build_sam(checkpoint=".checkpoints/sam_vit_h_4b8939.pth").to(device), + points_per_side=16, +) + +# ImageBind +bind_model = imagebind_model.imagebind_huge(pretrained=True) +bind_model.eval() +bind_model.to(device) + + +""" +Step 2: Generate auto masks with SAM +""" +image_path = ".assets/car_image.jpg" +image = cv2.imread(image_path) +image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) +masks = mask_generator.generate(image) + + +""" +Step 3: Get cropped images based on mask and box +""" +cropped_boxes = [] +image = Image.open(image_path) +for mask in tqdm(masks): + cropped_boxes.append(segment_image(image, mask["segmentation"]).crop(convert_box_xywh_to_xyxy(mask["bbox"]))) + + +""" +Step 4: Run ImageBind model to get similarity between cropped image and different modalities +""" +# load referring image +referring_image_path = ".assets/referring_car_image.jpg" +referring_image = Image.open(referring_image_path) + +image_list = [] +image_list += cropped_boxes +image_list.append(referring_image) + +def retriev_vision_and_vision(elements): + inputs = { + ModalityType.VISION: data.load_and_transform_vision_data_from_pil_image(elements, device), + } + with torch.no_grad(): + embeddings = bind_model(inputs) + + # cropped box region embeddings + cropped_box_embeddings = embeddings[ModalityType.VISION][:-1, :] + referring_image_embeddings = embeddings[ModalityType.VISION][-1, :] + + vision_referring_result = torch.softmax(cropped_box_embeddings @ referring_image_embeddings.T, dim=0), + return vision_referring_result # [113, 1] + + +vision_referring_result = retriev_vision_and_vision(image_list) + + +""" +Step 5: Merge the top similarity masks to get the final mask and save the merged mask + +Image / Text mask +""" + +# get highest similar mask with threshold +# result[0] shape: [113, 1] +threshold = 0.017 +index = get_indices_of_values_above_threshold(vision_referring_result[0], threshold) + + +segmentation_masks = [] +for seg_idx in index: + segmentation_mask_image = Image.fromarray(masks[seg_idx]["segmentation"].astype('uint8') * 255) + segmentation_masks.append(segmentation_mask_image) + +original_image = Image.open(image_path) +overlay_image = Image.new('RGBA', image.size, (0, 0, 0, 255)) +overlay_color = (255, 255, 255, 0) + +draw = ImageDraw.Draw(overlay_image) +for segmentation_mask_image in segmentation_masks: + draw.bitmap((0, 0), segmentation_mask_image, fill=overlay_color) + +# return Image.alpha_composite(original_image.convert('RGBA'), overlay_image) +mask_image = overlay_image.convert("RGB") +mask_image.save("./image_referring_sam_merged_mask.jpg") diff --git a/playground/ImageBind_SAM/text_referring_seg_demo.py b/playground/ImageBind_SAM/text_referring_seg_demo.py new file mode 100644 index 00000000..17ee74a7 --- /dev/null +++ b/playground/ImageBind_SAM/text_referring_seg_demo.py @@ -0,0 +1,97 @@ +import data +import cv2 +import torch +from PIL import Image, ImageDraw +from tqdm import tqdm +from models import imagebind_model +from models.imagebind_model import ModalityType + +from segment_anything import build_sam, SamAutomaticMaskGenerator + +from utils import ( + segment_image, + convert_box_xywh_to_xyxy, + get_indices_of_values_above_threshold, +) + + +device = "cuda" if torch.cuda.is_available() else "cpu" + + +""" +Step 1: Instantiate model +""" +# Segment Anything +mask_generator = SamAutomaticMaskGenerator( + build_sam(checkpoint=".checkpoints/sam_vit_h_4b8939.pth").to(device), + points_per_side=16, +) + +# ImageBind +bind_model = imagebind_model.imagebind_huge(pretrained=True) +bind_model.eval() +bind_model.to(device) + + +""" +Step 2: Generate auto masks with SAM +""" +image_path = ".assets/car_image.jpg" +image = cv2.imread(image_path) +image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) +masks = mask_generator.generate(image) + + +""" +Step 3: Get cropped images based on mask and box +""" +cropped_boxes = [] +image = Image.open(image_path) +for mask in tqdm(masks): + cropped_boxes.append(segment_image(image, mask["segmentation"]).crop(convert_box_xywh_to_xyxy(mask["bbox"]))) + + +""" +Step 4: Run ImageBind model to get similarity between cropped image and different modalities +""" +def retriev_vision_and_text(elements, text_list): + inputs = { + ModalityType.VISION: data.load_and_transform_vision_data_from_pil_image(elements, device), + ModalityType.TEXT: data.load_and_transform_text(text_list, device), + } + with torch.no_grad(): + embeddings = bind_model(inputs) + vision_audio = torch.softmax(embeddings[ModalityType.VISION] @ embeddings[ModalityType.TEXT].T, dim=0), + return vision_audio # [113, 1] + + +vision_text_result = retriev_vision_and_text(cropped_boxes, ["A car"] ) + + +""" +Step 5: Merge the top similarity masks to get the final mask and save the merged mask + +Image / Text mask +""" + +# get highest similar mask with threshold +# result[0] shape: [113, 1] +threshold = 0.05 +index = get_indices_of_values_above_threshold(vision_text_result[0], threshold) + +segmentation_masks = [] +for seg_idx in index: + segmentation_mask_image = Image.fromarray(masks[seg_idx]["segmentation"].astype('uint8') * 255) + segmentation_masks.append(segmentation_mask_image) + +original_image = Image.open(image_path) +overlay_image = Image.new('RGBA', image.size, (0, 0, 0, 255)) +overlay_color = (255, 255, 255, 0) + +draw = ImageDraw.Draw(overlay_image) +for segmentation_mask_image in segmentation_masks: + draw.bitmap((0, 0), segmentation_mask_image, fill=overlay_color) + +# return Image.alpha_composite(original_image.convert('RGBA'), overlay_image) +mask_image = overlay_image.convert("RGB") +mask_image.save("./text_sam_merged_mask.jpg")