-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
291 additions
and
153 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,18 @@ | ||
from __future__ import annotations | ||
|
||
from .aimv2_annotator import AIMv2Annotator | ||
from .clip_annotator import CLIPAnnotator | ||
from .cls_annotator import ImgClassificationAnnotator | ||
from .image_annotator import BaseAnnotator, TaskList | ||
from .owlv2_annotator import OWLv2Annotator | ||
from .slimsam_annotator import SlimSAMAnnotator | ||
|
||
__all__ = [ | ||
"AIMv2Annotator", | ||
"BaseAnnotator", | ||
"TaskList", | ||
"OWLv2Annotator", | ||
"ImgClassificationAnnotator", | ||
"CLIPAnnotator", | ||
"SlimSAMAnnotator", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
"""This file uses pre-trained model derived from Apple's software, provided under the | ||
Apple Sample Code License license. The license is available at: | ||
https://developer.apple.com/support/downloads/terms/apple-sample-code/Apple-Sample-Code-License.pdf | ||
In addition, this file and other parts of the repository are licensed under the Apache 2.0 | ||
License. By using this file, you agree to comply with the terms of both licenses. | ||
""" | ||
from __future__ import annotations | ||
|
||
import logging | ||
|
||
import torch | ||
from PIL import Image | ||
from transformers import AutoModel, AutoProcessor | ||
|
||
from datadreamer.dataset_annotation.cls_annotator import ImgClassificationAnnotator | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class AIMv2Annotator(ImgClassificationAnnotator): | ||
"""A class for image annotation using the AIMv2 model, specializing in image | ||
classification. | ||
Attributes: | ||
model (AutoModel): The AIMv2 model for image-text similarity evaluation. | ||
processor (AutoProcessor): The processor for preparing inputs to the AIMv2 model. | ||
device (str): The device on which the model will run ('cuda' for GPU, 'cpu' for CPU). | ||
size (str): The size of the AIMv2 model to use ('base' or 'large'). | ||
Methods: | ||
_init_processor(): Initializes the AIMv2 processor. | ||
_init_model(): Initializes the AIMv2 model. | ||
annotate_batch(image, prompts, conf_threshold, use_tta, synonym_dict): Annotates the given image with bounding boxes and labels. | ||
release(empty_cuda_cache): Releases resources and optionally empties the CUDA cache. | ||
""" | ||
|
||
def _init_processor(self) -> AutoProcessor: | ||
"""Initializes the AIMv2 processor. | ||
Returns: | ||
AutoProcessor: The initialized AIMv2 processor. | ||
""" | ||
return AutoProcessor.from_pretrained("apple/aimv2-large-patch14-224-lit") | ||
|
||
def _init_model(self) -> AutoModel: | ||
"""Initializes the AIMv2 model. | ||
Returns: | ||
AutoModel: The initialized AIMv2 model. | ||
""" | ||
logger.info(f"Initializing AIMv2 {self.size} model...") | ||
return AutoModel.from_pretrained( | ||
"apple/aimv2-large-patch14-224-lit", trust_remote_code=True | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
import requests | ||
|
||
device = "cuda" if torch.cuda.is_available() else "cpu" | ||
url = "https://ultralytics.com/images/bus.jpg" | ||
im = Image.open(requests.get(url, stream=True).raw) | ||
annotator = AIMv2Annotator(device=device) | ||
labels = annotator.annotate_batch([im], ["bus", "people"]) | ||
print(labels) | ||
annotator.release() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
from __future__ import annotations | ||
|
||
import logging | ||
from typing import Dict, List | ||
|
||
import numpy as np | ||
import PIL | ||
import torch | ||
|
||
from datadreamer.dataset_annotation.image_annotator import BaseAnnotator, TaskList | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class ImgClassificationAnnotator(BaseAnnotator): | ||
"""Base class for image classification annotators using transformers models. | ||
Attributes: | ||
model: The model for image-text similarity evaluation. | ||
processor: The processor for preparing inputs to the model. | ||
device (str): The device on which the model will run ('cuda' for GPU, 'cpu' for CPU). | ||
size (str): The size of the model to use ('base' or 'large'). | ||
Methods: | ||
_init_processor(): Initializes the processor. | ||
_init_model(): Initializes the model. | ||
annotate_batch(image, prompts, conf_threshold, use_tta, synonym_dict): Annotates the given image with bounding boxes and labels. | ||
release(empty_cuda_cache): Releases resources and optionally empties the CUDA cache. | ||
""" | ||
|
||
def __init__( | ||
self, seed: float = 42, device: str = "cuda", size: str = "base" | ||
) -> None: | ||
"""Initializes the image classification annotator. | ||
Args: | ||
seed (float): Seed for reproducibility. Defaults to 42. | ||
device (str): The device to run the model on. Defaults to 'cuda'. | ||
size (str): The model size to use. | ||
""" | ||
super().__init__(seed, task_definition=TaskList.CLASSIFICATION) | ||
self.size = size | ||
self.device = device | ||
self.model = self._init_model() | ||
self.processor = self._init_processor() | ||
self.model.to(self.device) | ||
|
||
def _init_processor(self): | ||
"""Initializes the processor.""" | ||
raise NotImplementedError | ||
|
||
def _init_model(self): | ||
"""Initializes the model.""" | ||
raise NotImplementedError | ||
|
||
def annotate_batch( | ||
self, | ||
images: List[PIL.Image.Image], | ||
objects: List[str], | ||
conf_threshold: float = 0.1, | ||
synonym_dict: Dict[str, List[str]] | None = None, | ||
) -> List[np.ndarray]: | ||
"""Annotates images using the CLIP model. | ||
Args: | ||
images: The images to be annotated. | ||
objects: A list of objects (text) to test against the images. | ||
conf_threshold (float, optional): Confidence threshold for the annotations. Defaults to 0.1. | ||
synonym_dict (dict, optional): Dictionary for handling synonyms in labels. Defaults to None. | ||
Returns: | ||
List[np.ndarray]: A list of the annotations for each image. | ||
""" | ||
if synonym_dict is not None: | ||
objs_syn = set() | ||
for obj in objects: | ||
objs_syn.add(obj) | ||
for syn in synonym_dict[obj]: | ||
objs_syn.add(syn) | ||
objs_syn = list(objs_syn) | ||
# Make a dict to transform synonym ids to original ids | ||
synonym_dict_rev = {} | ||
for key, value in synonym_dict.items(): | ||
if key in objects: | ||
synonym_dict_rev[objs_syn.index(key)] = objects.index(key) | ||
for v in value: | ||
synonym_dict_rev[objs_syn.index(v)] = objects.index(key) | ||
objects = objs_syn | ||
|
||
inputs = self.processor( | ||
text=objects, images=images, return_tensors="pt", padding=True | ||
).to(self.device) | ||
|
||
outputs = self.model(**inputs) | ||
|
||
logits_per_image = outputs.logits_per_image # image-text similarity score | ||
probs = logits_per_image.softmax(dim=1).cpu() # label probabilities | ||
|
||
labels = [] | ||
# Get the labels for each image | ||
if synonym_dict is not None: | ||
for prob in probs: | ||
labels.append( | ||
np.unique( | ||
np.array( | ||
[ | ||
synonym_dict_rev[label.item()] | ||
for label in torch.where(prob > conf_threshold)[ | ||
0 | ||
].numpy() | ||
] | ||
) | ||
) | ||
) | ||
else: | ||
for prob in probs: | ||
labels.append(torch.where(prob > conf_threshold)[0].numpy()) | ||
|
||
return labels | ||
|
||
def release(self, empty_cuda_cache: bool = False) -> None: | ||
"""Releases the model and optionally empties the CUDA cache. | ||
Args: | ||
empty_cuda_cache (bool, optional): Whether to empty the CUDA cache. Defaults to False. | ||
""" | ||
self.model = self.model.to("cpu") | ||
if empty_cuda_cache: | ||
with torch.no_grad(): | ||
torch.cuda.empty_cache() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.