Skip to content

Commit

Permalink
Add SAM2.1 annotator
Browse files Browse the repository at this point in the history
  • Loading branch information
HonzaCuhel committed Jan 31, 2025
1 parent 48bb90c commit 32514d1
Show file tree
Hide file tree
Showing 11 changed files with 223 additions and 22 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ datadreamer --config <path-to-config>
- `--num_objects_range`: Range of objects in a prompt. Default is 1 to 3.
- `--prompt_generator`: Choose between `simple`, `lm` (Mistral-7B), `tiny` (tiny LM), and `qwen2` (Qwen2.5 LM). Default is `qwen2`.
- `--image_generator`: Choose image generator, e.g., `sdxl`, `sdxl-turbo`, `sdxl-lightning` or `shuttle-3`. Default is `sdxl-turbo`.
- `--image_annotator`: Specify the image annotator, like `owlv2` for object detection or `aimv2` or `clip` for image classification or `owlv2-slimsam` for instance segmentation. Default is `owlv2`.
- `--image_annotator`: Specify the image annotator, like `owlv2` for object detection or `aimv2` or `clip` for image classification or `owlv2-slimsam` and `owlv2-sam2` for instance segmentation. Default is `owlv2`.
- `--conf_threshold`: Confidence threshold for annotation. Default is `0.15`.
- `--annotation_iou_threshold`: Intersection over Union (IoU) threshold for annotation. Default is `0.2`.
- `--prompt_prefix`: Prefix to add to every image generation prompt. Default is `""`.
Expand Down Expand Up @@ -221,6 +221,7 @@ datadreamer --config <path-to-config>
| | [CLIP](https://huggingface.co/openai/clip-vit-base-patch32) | Zero-shot-image-classification |
| | [AIMv2](https://huggingface.co/apple/aimv2-large-patch14-224-lit) | Zero-shot-image-classification |
| | [SlimSAM](https://huggingface.co/Zigeng/SlimSAM-uniform-50) | Zero-shot-instance-segmentation |
| | [SAM2.1](https://huggingface.co/facebook/sam2-hiera-tiny) | Zero-shot-instance-segmentation |

<a name="example"></a>

Expand Down
2 changes: 2 additions & 0 deletions datadreamer/dataset_annotation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .cls_annotator import ImgClassificationAnnotator
from .image_annotator import BaseAnnotator, TaskList
from .owlv2_annotator import OWLv2Annotator
from .sam2_annotator import SAM2Annotator
from .slimsam_annotator import SlimSAMAnnotator

__all__ = [
Expand All @@ -14,5 +15,6 @@
"OWLv2Annotator",
"ImgClassificationAnnotator",
"CLIPAnnotator",
"SAM2Annotator",
"SlimSAMAnnotator",
]
136 changes: 136 additions & 0 deletions datadreamer/dataset_annotation/sam2_annotator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
from __future__ import annotations

import logging
from typing import List

import numpy as np
import PIL
import torch
from sam2.sam2_image_predictor import SAM2ImagePredictor

from datadreamer.dataset_annotation.image_annotator import BaseAnnotator
from datadreamer.dataset_annotation.utils import mask_to_polygon

logger = logging.getLogger(__name__)


class SAM2Annotator(BaseAnnotator):
"""A class for image annotation using the SAM2.1 model, specializing in instance
segmentation.
Attributes:
model (SAM2ImagePredictor): The SAM2.1 model for instance segmentation.
device (str): The device on which the model will run ('cuda' for GPU, 'cpu' for CPU).
size (str): The size of the SAM model to use ('base' or 'large').
Methods:
_init_model(): Initializes the SAM2.1 model.
annotate_batch(image, prompts, conf_threshold, use_tta, synonym_dict): Annotates the given image with bounding boxes and labels.
release(empty_cuda_cache): Releases resources and optionally empties the CUDA cache.
"""

def __init__(
self,
seed: float = 42,
device: str = "cuda",
size: str = "base",
) -> None:
"""Initializes the SAMAnnotator with a specific seed and device.
Args:
seed (float): Seed for reproducibility. Defaults to 42.
device (str): The device to run the model on. Defaults to 'cuda'.
"""
super().__init__(seed)
self.size = size
self.device = device
self.model = self._init_model(device=device)
self.dtype = torch.bfloat16 if self.device == "cuda" else torch.float16

def _init_model(self, device: str) -> SAM2ImagePredictor:
"""Initializes the SAM2.1 model for object detection.
Returns:
SAM2ImagePredictor: The initialized SAM2.1 model.
"""
logger.info(f"Initializing SAM2.1 {self.size} model...")
if self.size == "large":
return SAM2ImagePredictor.from_pretrained(
"facebook/sam2.1-hiera-base-plus", device=device
)
return SAM2ImagePredictor.from_pretrained(
"facebook/sam2-hiera-tiny", device=device
)

def annotate_batch(
self,
images: List[PIL.Image.Image],
boxes_batch: List[np.ndarray],
iou_threshold: float = 0.2,
) -> List[List[List[float]]]:
"""Annotates images for the task of instance segmentation using the SAM2.1
model.
Args:
images: The images to be annotated.
boxes_batch: The bounding boxes of found objects.
iou_threshold (float, optional): Intersection over union threshold for non-maximum suppression. Defaults to 0.2.
Returns:
List: A list containing the final segment masks represented as a polygon.
"""
final_segments = []

image_batch = [np.array(img.convert("RGB")) for img in images]
bboxes_batch = [None if len(boxes) == 0 else boxes for boxes in boxes_batch]

with torch.inference_mode(), torch.autocast(self.device, dtype=self.dtype):
self.model.set_image_batch(image_batch)
masks_batch, scores_batch, _ = self.model.predict_batch(
box_batch=bboxes_batch,
multimask_output=False,
)

n = len(images)

for i in range(n):
boxes = boxes_batch[i].tolist()
if boxes is None:
final_segments.append([])
continue

image_masks = []
for j in range(len(boxes)):
mask, score = masks_batch[i][j], scores_batch[i][j]
if score < iou_threshold:
image_masks.append([])
continue
mask = mask.astype(np.uint8)
polygon = mask_to_polygon(mask)
image_masks.append(polygon if len(polygon) != 0 else [])

final_segments.append(image_masks)

return final_segments

def release(self, empty_cuda_cache: bool = False) -> None:
"""Releases the model and optionally empties the CUDA cache.
Args:
empty_cuda_cache (bool, optional): Whether to empty the CUDA cache. Defaults to False.
"""
if empty_cuda_cache:
with torch.no_grad():
torch.cuda.empty_cache()


if __name__ == "__main__":
import requests
from PIL import Image

url = "https://ultralytics.com/images/bus.jpg"
im = Image.open(requests.get(url, stream=True).raw)
annotator = SAM2Annotator(device="cpu", size="base")
final_segments = annotator.annotate_batch([im], [np.array([[3, 229, 559, 650]])])
print(len(final_segments), len(final_segments[0]))
print(final_segments[0][0][:5])
12 changes: 7 additions & 5 deletions datadreamer/pipelines/generate_dataset_from_scratch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
AIMv2Annotator,
CLIPAnnotator,
OWLv2Annotator,
SAM2Annotator,
SlimSAMAnnotator,
)
from datadreamer.image_generation import (
Expand Down Expand Up @@ -61,8 +62,8 @@

det_annotators = {"owlv2": OWLv2Annotator}
clf_annotators = {"clip": CLIPAnnotator, "aimv2": AIMv2Annotator}
inst_seg_annotators = {"owlv2-slimsam": SlimSAMAnnotator}
inst_seg_detectors = {"owlv2-slimsam": OWLv2Annotator}
inst_seg_annotators = {"owlv2-slimsam": SlimSAMAnnotator, "owlv2-sam2": SAM2Annotator}
inst_seg_detectors = {"owlv2-slimsam": OWLv2Annotator, "owlv2-sam2": OWLv2Annotator}

setup_logging(use_rich=True)

Expand Down Expand Up @@ -125,7 +126,7 @@ def parse_args():
parser.add_argument(
"--image_annotator",
type=str,
choices=["owlv2", "clip", "owlv2-slimsam", "aimv2"],
choices=["owlv2", "clip", "owlv2-slimsam", "aimv2", "owlv2-sam2"],
help="Image annotator to use",
)

Expand Down Expand Up @@ -668,9 +669,10 @@ def read_image_batch(image_batch, batch_num, batch_size):
if args.task == "instance-segmentation":
if k < len(masks_batch[j]):
mask = masks_batch[j][k]
x_points, y_points = zip(*mask)
if len(mask) > 0:
x_points, y_points = zip(*mask)

ax.fill(x_points, y_points, label, alpha=0.5)
ax.fill(x_points, y_points, label, alpha=0.5)

labels.append(label)
x1, y1, x2, y2 = box
Expand Down
4 changes: 3 additions & 1 deletion datadreamer/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ class Config(LuxonisConfig):
# Profanity filter arguments
disable_lm_filter: bool = False
# Annotation arguments
image_annotator: Literal["owlv2", "aimv2", "clip", "owlv2-slimsam"] = "owlv2"
image_annotator: Literal[
"owlv2", "aimv2", "clip", "owlv2-slimsam", "owlv2-sam2"
] = "owlv2"
conf_threshold: float = 0.15
annotation_iou_threshold: float = 0.2
use_tta: bool = False
Expand Down
7 changes: 3 additions & 4 deletions examples/generate_dataset_and_train_yolo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"id": "11adc87f",
"metadata": {},
"source": [
"<img src=\"https://docs.luxonis.com/images/depthai_logo.png\" width=\"500\">\n",
"<img src=\"https://www.luxonis.com/logo.svg\" width=\"400\">\n",
"\n",
"# DataDreamer Tutorial: Generating a dataset for object detection, training a model, and deploying it to the OAK (optional)"
]
Expand Down Expand Up @@ -85,7 +85,7 @@
"- `--num_objects_range`: Range of objects in a prompt. Default is 1 to 3.\n",
"- `--prompt_generator`: Choose between `simple`, `lm` (Mistral-7B), `tiny` (tiny LM), and `qwen2` (Qwen2.5 LM). Default is `qwen2`.\n",
"- `--image_generator`: Choose image generator, e.g., `sdxl`, `sdxl-turbo`, `sdxl-lightning` or `shuttle-3`. Default is `sdxl-turbo`.\n",
"- `--image_annotator`: Specify the image annotator, like `owlv2` for object detection or `aimv2` or `clip` for image classification or `owlv2-slimsam` for instance segmentation. Default is `owlv2`.\n",
"- `--image_annotator`: Specify the image annotator, like `owlv2` for object detection or `aimv2` or `clip` for image classification or `owlv2-slimsam` and `owlv2-sam2` for instance segmentation. Default is `owlv2`.\n",
"- `--conf_threshold`: Confidence threshold for annotation. Default is `0.15`.\n",
"- `--annotation_iou_threshold`: Intersection over Union (IoU) threshold for annotation. Default is `0.2`.\n",
"- `--prompt_prefix`: Prefix to add to every image generation prompt. Default is `\"\"`.\n",
Expand All @@ -104,8 +104,7 @@
"- `--batch_size_image`: Batch size for image generation. Default is `1`.\n",
"- `--device`: Choose between `cuda` and `cpu`. Default is `cuda`.\n",
"- `--seed`: Set a random seed for image and prompt generation. Default is `42`.\n",
"- `--config`: A path to an optional `.yaml` config file specifying the pipeline's arguments.\n",
""
"- `--config`: A path to an optional `.yaml` config file specifying the pipeline's arguments.\n"
]
},
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"id": "8ce1517f-7258-406d-9139-9adadb1a1570"
},
"source": [
"<img src=\"https://docs.luxonis.com/images/depthai_logo.png\" width=\"500\">\n",
"<img src=\"https://www.luxonis.com/logo.svg\" width=\"400\">\n",
"\n",
"# DataDreamer Tutorial: Generating a dataset for instance segmentation, training a model, and deploying it to the OAK (optional)"
]
Expand Down Expand Up @@ -99,7 +99,7 @@
"- `--num_objects_range`: Range of objects in a prompt. Default is 1 to 3.\n",
"- `--prompt_generator`: Choose between `simple`, `lm` (Mistral-7B), `tiny` (tiny LM), and `qwen2` (Qwen2.5 LM). Default is `qwen2`.\n",
"- `--image_generator`: Choose image generator, e.g., `sdxl`, `sdxl-turbo`, `sdxl-lightning` or `shuttle-3`. Default is `sdxl-turbo`.\n",
"- `--image_annotator`: Specify the image annotator, like `owlv2` for object detection or `aimv2` or `clip` for image classification or `owlv2-slimsam` for instance segmentation. Default is `owlv2`.\n",
"- `--image_annotator`: Specify the image annotator, like `owlv2` for object detection or `aimv2` or `clip` for image classification or `owlv2-slimsam` and `owlv2-sam2` for instance segmentation. Default is `owlv2`.\n",
"- `--conf_threshold`: Confidence threshold for annotation. Default is `0.15`.\n",
"- `--annotation_iou_threshold`: Intersection over Union (IoU) threshold for annotation. Default is `0.2`.\n",
"- `--prompt_prefix`: Prefix to add to every image generation prompt. Default is `\"\"`.\n",
Expand Down
Loading

0 comments on commit 32514d1

Please sign in to comment.