Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Addition of SAM2.1 Annotator for Instance Segmentation #79

Merged
merged 5 commits into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ datadreamer --config <path-to-config>
- `--num_objects_range`: Range of objects in a prompt. Default is 1 to 3.
- `--prompt_generator`: Choose between `simple`, `lm` (Mistral-7B), `tiny` (tiny LM), and `qwen2` (Qwen2.5 LM). Default is `qwen2`.
- `--image_generator`: Choose image generator, e.g., `sdxl`, `sdxl-turbo`, `sdxl-lightning` or `shuttle-3`. Default is `sdxl-turbo`.
- `--image_annotator`: Specify the image annotator, like `owlv2` for object detection or `aimv2` or `clip` for image classification or `owlv2-slimsam` for instance segmentation. Default is `owlv2`.
- `--image_annotator`: Specify the image annotator, like `owlv2` for object detection or `aimv2` or `clip` for image classification or `owlv2-slimsam` and `owlv2-sam2` for instance segmentation. Default is `owlv2`.
- `--conf_threshold`: Confidence threshold for annotation. Default is `0.15`.
- `--annotation_iou_threshold`: Intersection over Union (IoU) threshold for annotation. Default is `0.2`.
- `--prompt_prefix`: Prefix to add to every image generation prompt. Default is `""`.
Expand Down Expand Up @@ -221,6 +221,7 @@ datadreamer --config <path-to-config>
| | [CLIP](https://huggingface.co/openai/clip-vit-base-patch32) | Zero-shot-image-classification |
| | [AIMv2](https://huggingface.co/apple/aimv2-large-patch14-224-lit) | Zero-shot-image-classification |
| | [SlimSAM](https://huggingface.co/Zigeng/SlimSAM-uniform-50) | Zero-shot-instance-segmentation |
| | [SAM2.1](https://huggingface.co/facebook/sam2.1-hiera-large) | Zero-shot-instance-segmentation |

<a name="example"></a>

Expand Down
2 changes: 2 additions & 0 deletions datadreamer/dataset_annotation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .cls_annotator import ImgClassificationAnnotator
from .image_annotator import BaseAnnotator, TaskList
from .owlv2_annotator import OWLv2Annotator
from .sam2_annotator import SAM2Annotator
from .slimsam_annotator import SlimSAMAnnotator

__all__ = [
Expand All @@ -14,5 +15,6 @@
"OWLv2Annotator",
"ImgClassificationAnnotator",
"CLIPAnnotator",
"SAM2Annotator",
"SlimSAMAnnotator",
]
137 changes: 137 additions & 0 deletions datadreamer/dataset_annotation/sam2_annotator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
from __future__ import annotations

import logging
from typing import List

import numpy as np
import PIL
import torch
from sam2.sam2_image_predictor import SAM2ImagePredictor

from datadreamer.dataset_annotation.image_annotator import BaseAnnotator
from datadreamer.dataset_annotation.utils import mask_to_polygon

logger = logging.getLogger(__name__)


class SAM2Annotator(BaseAnnotator):
"""A class for image annotation using the SAM2.1 model, specializing in instance
segmentation.

Attributes:
model (SAM2ImagePredictor): The SAM2.1 model for instance segmentation.
device (str): The device on which the model will run ('cuda' for GPU, 'cpu' for CPU).
size (str): The size of the SAM model to use ('base' or 'large').

Methods:
_init_model(): Initializes the SAM2.1 model.
annotate_batch(image, prompts, conf_threshold, use_tta, synonym_dict): Annotates the given image with bounding boxes and labels.
release(empty_cuda_cache): Releases resources and optionally empties the CUDA cache.
"""

def __init__(
self,
seed: float = 42,
device: str = "cuda",
size: str = "base",
) -> None:
"""Initializes the SAMAnnotator with a specific seed and device.

Args:
seed (float): Seed for reproducibility. Defaults to 42.
device (str): The device to run the model on. Defaults to 'cuda'.
"""
super().__init__(seed)
self.size = size
self.device = device
self.model = self._init_model(device=device)
self.dtype = torch.bfloat16 if self.device == "cuda" else torch.float16

def _init_model(self, device: str) -> SAM2ImagePredictor:
"""Initializes the SAM2.1 model for object detection.

Returns:
SAM2ImagePredictor: The initialized SAM2.1 model.
"""
logger.info(f"Initializing SAM2.1 {self.size} model...")
if self.size == "large":
return SAM2ImagePredictor.from_pretrained(
"facebook/sam2-hiera-large", device=device
)
return SAM2ImagePredictor.from_pretrained(
"facebook/sam2.1-hiera-base-plus", device=device
)

def annotate_batch(
self,
images: List[PIL.Image.Image],
boxes_batch: List[np.ndarray],
iou_threshold: float = 0.2,
) -> List[List[List[float]]]:
"""Annotates images for the task of instance segmentation using the SAM2.1
model.

Args:
images: The images to be annotated.
boxes_batch: The bounding boxes of found objects.
iou_threshold (float, optional): Intersection over union threshold for non-maximum suppression. Defaults to 0.2.

Returns:
List: A list containing the final segment masks represented as a polygon.
"""
final_segments = []

image_batch = [np.array(img.convert("RGB")) for img in images]
bboxes_batch = [None if len(boxes) == 0 else boxes for boxes in boxes_batch]

with torch.inference_mode(), torch.autocast(self.device, dtype=self.dtype):
self.model.set_image_batch(image_batch)
masks_batch, scores_batch, _ = self.model.predict_batch(
box_batch=bboxes_batch,
multimask_output=False,
)

n = len(images)

for i in range(n):
boxes = boxes_batch[i].tolist()
if boxes is None:
final_segments.append([])
continue

image_masks = []
for j in range(len(boxes)):
mask, score = masks_batch[i][j].astype(np.uint8), scores_batch[i][j]
if score < iou_threshold:
image_masks.append([])
continue
if len(mask.shape) == 3:
mask = mask.squeeze(0)
polygon = mask_to_polygon(mask)
image_masks.append(polygon if len(polygon) != 0 else [])

final_segments.append(image_masks)

return final_segments

def release(self, empty_cuda_cache: bool = False) -> None:
"""Releases the model and optionally empties the CUDA cache.

Args:
empty_cuda_cache (bool, optional): Whether to empty the CUDA cache. Defaults to False.
"""
if empty_cuda_cache:
with torch.no_grad():
torch.cuda.empty_cache()


if __name__ == "__main__":
import requests
from PIL import Image

url = "https://ultralytics.com/images/bus.jpg"
im = Image.open(requests.get(url, stream=True).raw)
annotator = SAM2Annotator(device="cpu", size="large")
final_segments = annotator.annotate_batch([im], [np.array([[3, 229, 559, 650]])])
print(len(final_segments), len(final_segments[0]))
print(final_segments[0][0][:5])
12 changes: 7 additions & 5 deletions datadreamer/pipelines/generate_dataset_from_scratch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
AIMv2Annotator,
CLIPAnnotator,
OWLv2Annotator,
SAM2Annotator,
SlimSAMAnnotator,
)
from datadreamer.image_generation import (
Expand Down Expand Up @@ -61,8 +62,8 @@

det_annotators = {"owlv2": OWLv2Annotator}
clf_annotators = {"clip": CLIPAnnotator, "aimv2": AIMv2Annotator}
inst_seg_annotators = {"owlv2-slimsam": SlimSAMAnnotator}
inst_seg_detectors = {"owlv2-slimsam": OWLv2Annotator}
inst_seg_annotators = {"owlv2-slimsam": SlimSAMAnnotator, "owlv2-sam2": SAM2Annotator}
inst_seg_detectors = {"owlv2-slimsam": OWLv2Annotator, "owlv2-sam2": OWLv2Annotator}

setup_logging(use_rich=True)

Expand Down Expand Up @@ -125,7 +126,7 @@ def parse_args():
parser.add_argument(
"--image_annotator",
type=str,
choices=["owlv2", "clip", "owlv2-slimsam", "aimv2"],
choices=["owlv2", "clip", "owlv2-slimsam", "aimv2", "owlv2-sam2"],
help="Image annotator to use",
)

Expand Down Expand Up @@ -668,9 +669,10 @@ def read_image_batch(image_batch, batch_num, batch_size):
if args.task == "instance-segmentation":
if k < len(masks_batch[j]):
mask = masks_batch[j][k]
x_points, y_points = zip(*mask)
if len(mask) > 0:
x_points, y_points = zip(*mask)

ax.fill(x_points, y_points, label, alpha=0.5)
ax.fill(x_points, y_points, label, alpha=0.5)

labels.append(label)
x1, y1, x2, y2 = box
Expand Down
4 changes: 3 additions & 1 deletion datadreamer/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ class Config(LuxonisConfig):
# Profanity filter arguments
disable_lm_filter: bool = False
# Annotation arguments
image_annotator: Literal["owlv2", "aimv2", "clip", "owlv2-slimsam"] = "owlv2"
image_annotator: Literal[
"owlv2", "aimv2", "clip", "owlv2-slimsam", "owlv2-sam2"
] = "owlv2"
conf_threshold: float = 0.15
annotation_iou_threshold: float = 0.2
use_tta: bool = False
Expand Down
7 changes: 3 additions & 4 deletions examples/generate_dataset_and_train_yolo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"id": "11adc87f",
"metadata": {},
"source": [
"<img src=\"https://docs.luxonis.com/images/depthai_logo.png\" width=\"500\">\n",
"<img src=\"https://www.luxonis.com/logo.svg\" width=\"400\">\n",
"\n",
"# DataDreamer Tutorial: Generating a dataset for object detection, training a model, and deploying it to the OAK (optional)"
]
Expand Down Expand Up @@ -85,7 +85,7 @@
"- `--num_objects_range`: Range of objects in a prompt. Default is 1 to 3.\n",
"- `--prompt_generator`: Choose between `simple`, `lm` (Mistral-7B), `tiny` (tiny LM), and `qwen2` (Qwen2.5 LM). Default is `qwen2`.\n",
"- `--image_generator`: Choose image generator, e.g., `sdxl`, `sdxl-turbo`, `sdxl-lightning` or `shuttle-3`. Default is `sdxl-turbo`.\n",
"- `--image_annotator`: Specify the image annotator, like `owlv2` for object detection or `aimv2` or `clip` for image classification or `owlv2-slimsam` for instance segmentation. Default is `owlv2`.\n",
"- `--image_annotator`: Specify the image annotator, like `owlv2` for object detection or `aimv2` or `clip` for image classification or `owlv2-slimsam` and `owlv2-sam2` for instance segmentation. Default is `owlv2`.\n",
"- `--conf_threshold`: Confidence threshold for annotation. Default is `0.15`.\n",
"- `--annotation_iou_threshold`: Intersection over Union (IoU) threshold for annotation. Default is `0.2`.\n",
"- `--prompt_prefix`: Prefix to add to every image generation prompt. Default is `\"\"`.\n",
Expand All @@ -104,8 +104,7 @@
"- `--batch_size_image`: Batch size for image generation. Default is `1`.\n",
"- `--device`: Choose between `cuda` and `cpu`. Default is `cuda`.\n",
"- `--seed`: Set a random seed for image and prompt generation. Default is `42`.\n",
"- `--config`: A path to an optional `.yaml` config file specifying the pipeline's arguments.\n",
""
"- `--config`: A path to an optional `.yaml` config file specifying the pipeline's arguments.\n"
]
},
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"id": "8ce1517f-7258-406d-9139-9adadb1a1570"
},
"source": [
"<img src=\"https://docs.luxonis.com/images/depthai_logo.png\" width=\"500\">\n",
"<img src=\"https://www.luxonis.com/logo.svg\" width=\"400\">\n",
"\n",
"# DataDreamer Tutorial: Generating a dataset for instance segmentation, training a model, and deploying it to the OAK (optional)"
]
Expand Down Expand Up @@ -99,7 +99,7 @@
"- `--num_objects_range`: Range of objects in a prompt. Default is 1 to 3.\n",
"- `--prompt_generator`: Choose between `simple`, `lm` (Mistral-7B), `tiny` (tiny LM), and `qwen2` (Qwen2.5 LM). Default is `qwen2`.\n",
"- `--image_generator`: Choose image generator, e.g., `sdxl`, `sdxl-turbo`, `sdxl-lightning` or `shuttle-3`. Default is `sdxl-turbo`.\n",
"- `--image_annotator`: Specify the image annotator, like `owlv2` for object detection or `aimv2` or `clip` for image classification or `owlv2-slimsam` for instance segmentation. Default is `owlv2`.\n",
"- `--image_annotator`: Specify the image annotator, like `owlv2` for object detection or `aimv2` or `clip` for image classification or `owlv2-slimsam` and `owlv2-sam2` for instance segmentation. Default is `owlv2`.\n",
"- `--conf_threshold`: Confidence threshold for annotation. Default is `0.15`.\n",
"- `--annotation_iou_threshold`: Intersection over Union (IoU) threshold for annotation. Default is `0.2`.\n",
"- `--prompt_prefix`: Prefix to add to every image generation prompt. Default is `\"\"`.\n",
Expand Down
Loading
Loading