Skip to content

Commit

Permalink
Support SAM-HQ (IDEA-Research#300)
Browse files Browse the repository at this point in the history
* support sam-hq

* add sam_hq and update README
  • Loading branch information
rentainhe authored Jun 12, 2023
1 parent 07f0afc commit 9daeb4b
Show file tree
Hide file tree
Showing 9 changed files with 441 additions and 30 deletions.
43 changes: 37 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,11 @@ We are very willing to **help everyone share and promote new projects** based on
The **core idea** behind this project is to **combine the strengths of different models in order to build a very powerful pipeline for solving complex problems**. And it's worth mentioning that this is a workflow for combining strong expert models, where **all parts can be used separately or in combination, and can be replaced with any similar but different models (like replacing Grounding DINO with GLIP or other detectors / replacing Stable-Diffusion with ControlNet or GLIGEN/ Combining with ChatGPT)**.

**🍇 Updates**
- **`2023/06/13`** Support [SAM-HQ](https://github.com/SysCV/sam-hq) in [Grounded-SAM Demo](#running_man-grounded-sam-detect-and-segment-everything-with-text-prompt) for higher quality prediction.
- **`2023/06/12`** Support [RAM-Grounded-SAM](#label-grounded-sam-with-ram-or-tag2text-for-automatic-labeling) for strong automatic labeling pipeline! Thanks for [Recognize-Anything](https://github.com/xinyu1205/Recognize_Anything-Tag2Text).
- **`2023/06/01`** Our Grounded-SAM has been accepted to present a **demo** at **ICCV 2023**! See you in Paris!
- **`2023/05/23`**: Support `Image-Referring-Segment`, `Audio-Referring-Segment` and `Text-Referring-Segment` in [ImageBind-SAM](./playground/ImageBind_SAM/).
- **`2023/05/16`**: Release [ImageBind-SAM](./playground/ImageBind_SAM/) simple demo which aims to segment with different modalities.
- **`2023/05/15`**: Release [LaMa](./playground/LaMa/) and [RePaint](./playground/RePaint/) demo, thanks for nice tips by [Tao Yu](https://github.com/geekyutao).
- **`2023/05/14`**: Release [PaintByExample](./playground/PaintByExample/) demo with SAM.
- **`2023/05/11`**: We decide to share more interesting demo in [playground](./playground/) and we've already tested the [DeepFloyd](./playground/DeepFloyd/) for image generation and style transfering and share some notes about using IF.
- **`2023/05/05`**: Release a simpler code for automatic labeling (combined with `Tag2Text` model): please see [automatic_label_simple_demo.py](./automatic_label_simple_demo.py)
- **`2023/05/03`**: Checkout the [Automated Dataset Annotation and Evaluation with GroundingDINO and SAM](https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/automated-dataset-annotation-and-evaluation-with-grounding-dino-and-sam.ipynb) which is an amazing tutorial on automatic labeling! Thanks a lot for [Piotr Skalski](https://github.com/SkalskiP) and [Roboflow](https://github.com/roboflow/notebooks)!


Expand Down Expand Up @@ -312,7 +310,41 @@ The annotated results will be saved in `./outputs` as follows

</div>

**Step 3: Runing the updated grounded-sam demo (optional)**
**Step 3: Runing grounded-sam demo with sam-hq**
- Download the demo image
```bash
wget https://github.com/IDEA-Research/detrex-storage/releases/download/grounded-sam-storage/sam_hq_demo_image.png
```

- Download SAM-HQ checkpoint [here](https://github.com/SysCV/sam-hq#model-checkpoints)

- Running grounded-sam-hq demo as follows:
```python
export CUDA_VISIBLE_DEVICES=0
python grounded_sam_demo.py \
--config GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py \
--grounded_checkpoint groundingdino_swint_ogc.pth \
--sam_hq_checkpoint ./sam_hq_vit_h.pth \ # path to sam-hq checkpoint
--use_sam_hq \ # set to use sam-hq model
--input_image sam_hq_demo_image.png \
--output_dir "outputs" \
--box_threshold 0.3 \
--text_threshold 0.25 \
--text_prompt "chair." \
--device "cuda"
```

The annotated results will be saved in `./outputs` as follows

<div align="center">

| Input Image | SAM Output | SAM-HQ Output |
|:----:|:----:|:----:|
| ![](https://github.com/IDEA-Research/detrex-storage/blob/main/assets/grounded_sam/sam_hq/sam_hq_demo.png?raw=true) | ![](https://github.com/IDEA-Research/detrex-storage/blob/main/assets/grounded_sam/sam_hq/sam_output.jpg?raw=true) | ![](https://github.com/IDEA-Research/detrex-storage/blob/main/assets/grounded_sam/sam_hq/sam_hq_output.jpg?raw=true) |

</div>

**Step 4: Runing the updated grounded-sam demo (optional)**
Note that this demo is almost same as the original demo, but **with more elegant code**.

```python
Expand All @@ -330,7 +362,6 @@ The annotated results will be saved as `./groundingdino_annotated_image.jpg` and

</div>


### :skier: Grounded-SAM with Inpainting: Detect, Segment and Generate Everything with Text Prompt

**Step 1: Download the pretrained weights**
Expand Down
21 changes: 18 additions & 3 deletions grounded_sam_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@
from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap

# segment anything
from segment_anything import build_sam, SamPredictor
from segment_anything import (
build_sam,
build_sam_hq,
SamPredictor
)
import cv2
import numpy as np
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -136,7 +140,13 @@ def save_mask_data(output_dir, mask_list, box_list, label_list):
"--grounded_checkpoint", type=str, required=True, help="path to checkpoint file"
)
parser.add_argument(
"--sam_checkpoint", type=str, required=True, help="path to checkpoint file"
"--sam_checkpoint", type=str, required=False, help="path to sam checkpoint file"
)
parser.add_argument(
"--sam_hq_checkpoint", type=str, default=None, help="path to sam-hq checkpoint file"
)
parser.add_argument(
"--use_sam_hq", action="store_true", help="using sam-hq for prediction"
)
parser.add_argument("--input_image", type=str, required=True, help="path to image file")
parser.add_argument("--text_prompt", type=str, required=True, help="text prompt")
Expand All @@ -154,6 +164,8 @@ def save_mask_data(output_dir, mask_list, box_list, label_list):
config_file = args.config # change the path of the model config file
grounded_checkpoint = args.grounded_checkpoint # change the path of the model
sam_checkpoint = args.sam_checkpoint
sam_hq_checkpoint = args.sam_hq_checkpoint
use_sam_hq = args.use_sam_hq
image_path = args.input_image
text_prompt = args.text_prompt
output_dir = args.output_dir
Expand All @@ -177,7 +189,10 @@ def save_mask_data(output_dir, mask_list, box_list, label_list):
)

# initialize SAM
predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint).to(device))
if use_sam_hq:
predictor = SamPredictor(build_sam_hq(checkpoint=sam_hq_checkpoint).to(device))
else:
predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint).to(device))
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
predictor.set_image(image)
Expand Down
7 changes: 7 additions & 0 deletions segment_anything/segment_anything/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,12 @@
build_sam_vit_b,
sam_model_registry,
)
from .build_sam_hq import (
build_sam_hq,
build_sam_hq_vit_h,
build_sam_hq_vit_l,
build_sam_hq_vit_b,
sam_hq_model_registry,
)
from .predictor import SamPredictor
from .automatic_mask_generator import SamAutomaticMaskGenerator
113 changes: 113 additions & 0 deletions segment_anything/segment_anything/build_sam_hq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import torch

from functools import partial

from .modeling import ImageEncoderViT, MaskDecoderHQ, PromptEncoder, Sam, TwoWayTransformer


def build_sam_hq_vit_h(checkpoint=None):
return _build_sam(
encoder_embed_dim=1280,
encoder_depth=32,
encoder_num_heads=16,
encoder_global_attn_indexes=[7, 15, 23, 31],
checkpoint=checkpoint,
)


build_sam_hq = build_sam_hq_vit_h


def build_sam_hq_vit_l(checkpoint=None):
return _build_sam(
encoder_embed_dim=1024,
encoder_depth=24,
encoder_num_heads=16,
encoder_global_attn_indexes=[5, 11, 17, 23],
checkpoint=checkpoint,
)


def build_sam_hq_vit_b(checkpoint=None):
return _build_sam(
encoder_embed_dim=768,
encoder_depth=12,
encoder_num_heads=12,
encoder_global_attn_indexes=[2, 5, 8, 11],
checkpoint=checkpoint,
)


sam_hq_model_registry = {
"default": build_sam_hq_vit_h,
"vit_h": build_sam_hq_vit_h,
"vit_l": build_sam_hq_vit_l,
"vit_b": build_sam_hq_vit_b,
}


def _build_sam(
encoder_embed_dim,
encoder_depth,
encoder_num_heads,
encoder_global_attn_indexes,
checkpoint=None,
):
prompt_embed_dim = 256
image_size = 1024
vit_patch_size = 16
image_embedding_size = image_size // vit_patch_size
sam = Sam(
image_encoder=ImageEncoderViT(
depth=encoder_depth,
embed_dim=encoder_embed_dim,
img_size=image_size,
mlp_ratio=4,
norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
num_heads=encoder_num_heads,
patch_size=vit_patch_size,
qkv_bias=True,
use_rel_pos=True,
global_attn_indexes=encoder_global_attn_indexes,
window_size=14,
out_chans=prompt_embed_dim,
),
prompt_encoder=PromptEncoder(
embed_dim=prompt_embed_dim,
image_embedding_size=(image_embedding_size, image_embedding_size),
input_image_size=(image_size, image_size),
mask_in_chans=16,
),
mask_decoder=MaskDecoderHQ(
num_multimask_outputs=3,
transformer=TwoWayTransformer(
depth=2,
embedding_dim=prompt_embed_dim,
mlp_dim=2048,
num_heads=8,
),
transformer_dim=prompt_embed_dim,
iou_head_depth=3,
iou_head_hidden_dim=256,
vit_dim=encoder_embed_dim,
),
pixel_mean=[123.675, 116.28, 103.53],
pixel_std=[58.395, 57.12, 57.375],
)
# sam.eval()
if checkpoint is not None:
with open(checkpoint, "rb") as f:
state_dict = torch.load(f)
info = sam.load_state_dict(state_dict, strict=False)
print(info)
for n, p in sam.named_parameters():
if 'hf_token' not in n and 'hf_mlp' not in n and 'compress_vit_feat' not in n and 'embedding_encoder' not in n and 'embedding_maskfeature' not in n:
p.requires_grad = False

return sam
1 change: 1 addition & 0 deletions segment_anything/segment_anything/modeling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from .sam import Sam
from .image_encoder import ImageEncoderViT
from .mask_decoder_hq import MaskDecoderHQ
from .mask_decoder import MaskDecoder
from .prompt_encoder import PromptEncoder
from .transformer import TwoWayTransformer
21 changes: 12 additions & 9 deletions segment_anything/segment_anything/modeling/image_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.pos_embed is not None:
x = x + self.pos_embed

interm_embeddings=[]
for blk in self.blocks:
x = blk(x)
if blk.window_size == 0:
interm_embeddings.append(x)

x = self.neck(x.permute(0, 3, 1, 2))

return x
return x, interm_embeddings


class Block(nn.Module):
Expand Down Expand Up @@ -144,8 +147,8 @@ def __init__(
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
window_size (int): Window size for window attention blocks. If it equals 0, then
use global attention.
input_size (int or None): Input resolution for calculating the relative positional
parameter size.
input_size (tuple(int, int) or None): Input resolution for calculating the relative
positional parameter size.
"""
super().__init__()
self.norm1 = norm_layer(dim)
Expand Down Expand Up @@ -198,11 +201,11 @@ def __init__(
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
qkv_bias (bool: If True, add a learnable bias to query, key, value.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
input_size (int or None): Input resolution for calculating the relative positional
parameter size.
input_size (tuple(int, int) or None): Input resolution for calculating the relative
positional parameter size.
"""
super().__init__()
self.num_heads = num_heads
Expand Down Expand Up @@ -270,7 +273,7 @@ def window_unpartition(
"""
Window unpartition into original sequences and removing padding.
Args:
x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
window_size (int): window size.
pad_hw (Tuple): padded height and width (Hp, Wp).
hw (Tuple): original height and width (H, W) before padding.
Expand Down Expand Up @@ -380,7 +383,7 @@ def __init__(
stride (Tuple): stride of the projection layer.
padding (Tuple): padding size of the projection layer.
in_chans (int): Number of input image channels.
embed_dim (int): embed_dim (int): Patch embedding dimension.
embed_dim (int): Patch embedding dimension.
"""
super().__init__()

Expand All @@ -392,4 +395,4 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
# B C H W -> B H W C
x = x.permute(0, 2, 3, 1)
return x
return x
8 changes: 5 additions & 3 deletions segment_anything/segment_anything/modeling/mask_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(
) -> None:
"""
Predicts masks given an image and prompt embeddings, using a
tranformer architecture.
transformer architecture.
Arguments:
transformer_dim (int): the channel dimension of the transformer
Expand Down Expand Up @@ -75,6 +75,8 @@ def forward(
sparse_prompt_embeddings: torch.Tensor,
dense_prompt_embeddings: torch.Tensor,
multimask_output: bool,
hq_token_only: bool,
interm_embeddings: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Predict masks given image and prompt embeddings.
Expand All @@ -98,7 +100,7 @@ def forward(
dense_prompt_embeddings=dense_prompt_embeddings,
)

# Select the correct mask or masks for outptu
# Select the correct mask or masks for output
if multimask_output:
mask_slice = slice(1, None)
else:
Expand Down Expand Up @@ -173,4 +175,4 @@ def forward(self, x):
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
if self.sigmoid_output:
x = F.sigmoid(x)
return x
return x
Loading

0 comments on commit 9daeb4b

Please sign in to comment.