Skip to content

Commit

Permalink
Revert "feat: add image-driven object detection for OWLv2"
Browse files Browse the repository at this point in the history
This reverts commit 133cfc8.
  • Loading branch information
sokovninn committed Jan 31, 2025
1 parent 133cfc8 commit 48bb90c
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 159 deletions.
130 changes: 10 additions & 120 deletions datadreamer/dataset_annotation/owlv2_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import logging
from typing import Dict, List, Tuple

import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np
import PIL
import torch
Expand Down Expand Up @@ -82,116 +80,43 @@ def _init_processor(self) -> Owlv2Processor:
"google/owlv2-base-patch16-ensemble", do_pad=False, do_resize=False
)

def _generate_annotations_from_text(
def _generate_annotations(
self,
images: List[PIL.Image.Image],
prompts: List[str],
conf_threshold: float = 0.1,
) -> List[Dict[str, torch.Tensor]]:
"""Generates annotations for the given images and text prompts.
"""Generates annotations for the given images and prompts.
Args:
images: The images to be annotated.
prompts: The text prompts to guide the annotation.
prompts: Prompts to guide the annotation.
conf_threshold (float, optional): Confidence threshold for the annotations. Defaults to 0.1.
Returns:
List[Dict[str, torch.Tensor]]: The annotations for the given images and text prompts.
List[Dict[str, torch.Tensor]]: The annotations for the given images and prompts.
"""

batched_prompts = [prompts] * len(images)
n = len(images)
batched_prompts = [prompts] * n
target_sizes = torch.Tensor([img.size[::-1] for img in images]).to(self.device)

# resize the images to the model's input size
img_size = (1008, 1008) if self.size == "large" else (960, 960)
images = [img.resize(img_size) for img in images]

images = [images[i].resize(img_size) for i in range(n)]
inputs = self.processor(
images=images,
text=batched_prompts,
images=images,
return_tensors="pt",
padding="max_length",
).to(self.device)

with torch.no_grad():
outputs = self.model(**inputs)

preds = self.processor.post_process_object_detection(
outputs=outputs, target_sizes=target_sizes, threshold=conf_threshold
)

return preds

def _generate_annotations_from_image(
self,
images: List[PIL.Image.Image],
query_images: List[PIL.Image.Image],
conf_threshold: float = 0.1,
) -> List[Dict[str, torch.Tensor]]:
"""Generates annotations for the given images and query images.
Args:
images: The images to be annotated.
query_images: The query images to guide the annotation. One query image is expected per target image to be queried.
conf_threshold (float, optional): Confidence threshold for the annotations. Defaults to 0.1.
Returns:
List[Dict[str, torch.Tensor]]: The annotations for the given images and query images.
"""

if len(query_images) != len(images) and len(query_images) != 1:
raise ValueError(
"The number of query images must be either 1 or the same as the number of target images."
)

target_sizes = torch.Tensor([img.size[::-1] for img in images]).to(self.device)

inputs = self.processor(
images=images,
query_images=query_images,
return_tensors="pt",
do_resize=True,
).to(self.device)

with torch.no_grad():
outputs = self.model.image_guided_detection(**inputs)

preds = self.processor.post_process_image_guided_detection(
outputs=outputs,
target_sizes=target_sizes,
threshold=conf_threshold,
)

return preds

def _generate_annotations(
self,
images: List[PIL.Image.Image],
prompts: List[str] | List[PIL.Image.Image],
conf_threshold: float = 0.1,
) -> List[Dict[str, torch.Tensor]]:
"""Generates annotations for the given images and prompts.
Args:
images: The images to be annotated.
prompts: Either text prompts (List[str]) or a list of query images (List[PIL.Image.Image]).
conf_threshold (float, optional): Confidence threshold for the annotations. Defaults to 0.1.
Returns:
List[Dict[str, torch.Tensor]]: The annotations for the given images and prompts.
"""
if isinstance(prompts, list) and all(isinstance(p, str) for p in prompts):
return self._generate_annotations_from_text(images, prompts, conf_threshold)
elif isinstance(prompts, list) and all(
isinstance(p, PIL.Image.Image) for p in prompts
):
return self._generate_annotations_from_image(
images, prompts, conf_threshold
)
else:
raise ValueError(
"Invalid prompts: Expected List[str] or List[PIL.Image.Image]"
)

def _get_annotations(
self,
pred: Dict[str, torch.Tensor],
Expand Down Expand Up @@ -232,9 +157,6 @@ def _get_annotations(
if use_tta:
boxes[:, [0, 2]] = img_width - boxes[:, [2, 0]]

if labels is None:
labels = torch.zeros(scores.shape, dtype=torch.int64)

return boxes, scores, labels

def _correct_bboxes_misalignment(
Expand Down Expand Up @@ -264,7 +186,7 @@ def _correct_bboxes_misalignment(
def annotate_batch(
self,
images: List[PIL.Image.Image],
prompts: List[str] | List[PIL.Image.Image],
prompts: List[str],
conf_threshold: float = 0.1,
iou_threshold: float = 0.2,
use_tta: bool = False,
Expand Down Expand Up @@ -402,42 +324,10 @@ def release(self, empty_cuda_cache: bool = False) -> None:
import requests
from PIL import Image

# Text-driven annotation
url = "https://ultralytics.com/images/bus.jpg"
im = Image.open(requests.get(url, stream=True).raw)
annotator = OWLv2Annotator(device="cpu", size="base")
final_boxes, final_scores, final_labels = annotator.annotate_batch(
[im], ["bus", "person"]
)

# Image-driven annotation
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
im = Image.open(requests.get(url, stream=True).raw)
query_url = "http://images.cocodataset.org/val2017/000000058111.jpg"
query_image = Image.open(requests.get(query_url, stream=True).raw)

final_boxes, final_scores, final_labels = annotator.annotate_batch(
[im], [query_image], conf_threshold=0.9
)
print(final_boxes, final_scores, final_labels)

fig, ax = plt.subplots(1)
ax.imshow(im)
for box, score, label in zip(final_boxes[0], final_scores[0], final_labels[0]):
x1, y1, x2, y2 = box
width, height = x2 - x1, y2 - y1
rect = patches.Rectangle(
(x1, y1), width, height, linewidth=2, edgecolor="r", facecolor="none"
)
ax.add_patch(rect)

plt.text(
x1,
y1,
f"{label} {score:.2f}",
bbox=dict(facecolor="yellow", alpha=0.5),
)

plt.savefig("test_image_guided.png")

annotator.release()
49 changes: 10 additions & 39 deletions tests/core_tests/unittests/test_annotators.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,13 @@
total_disk_space = psutil.disk_usage("/").total / (1024**3)


def _check_owlv2_annotator(
device: str, size: str = "base", use_text_prompts: bool = True
):
def _check_owlv2_annotator(device: str, size: str = "base"):
url = "https://ultralytics.com/images/bus.jpg"
im = Image.open(requests.get(url, stream=True).raw)
annotator = OWLv2Annotator(device=device, size=size)

if use_text_prompts:
url = "https://ultralytics.com/images/bus.jpg"
im = Image.open(requests.get(url, stream=True).raw)
final_boxes, final_scores, final_labels = annotator.annotate_batch(
[im], ["bus", "people"]
)
else:
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
im = Image.open(requests.get(url, stream=True).raw)
query_url = "http://images.cocodataset.org/val2017/000000058111.jpg"
query_image = Image.open(requests.get(query_url, stream=True).raw)
annotator = OWLv2Annotator(device=device, size=size)
final_boxes, final_scores, final_labels = annotator.annotate_batch(
[im], [query_image], conf_threshold=0.9
)
final_boxes, final_scores, final_labels = annotator.annotate_batch(
[im], ["bus", "people"]
)
# Assert that the boxes, scores and labels are tensors
assert isinstance(final_boxes, list) and len(final_boxes) == 1
assert isinstance(final_scores, list) and len(final_scores) == 1
Expand All @@ -58,32 +45,16 @@ def _check_owlv2_annotator(
not torch.cuda.is_available() or total_disk_space < 16,
reason="Test requires GPU and 16GB of HDD",
)
def test_cuda_owlv2_annotator_text():
_check_owlv2_annotator("cuda", use_text_prompts=True)


@pytest.mark.skipif(
total_disk_space < 16,
reason="Test requires at least 16GB of HDD",
)
def test_cpu_owlv2_annotator_text():
_check_owlv2_annotator("cpu", use_text_prompts=True)


@pytest.mark.skipif(
not torch.cuda.is_available() or total_disk_space < 16,
reason="Test requires GPU and 16GB of HDD",
)
def test_cuda_owlv2_annotator_image():
_check_owlv2_annotator("cuda", use_text_prompts=False)
def test_cuda_owlv2_annotator():
_check_owlv2_annotator("cuda")


@pytest.mark.skipif(
total_disk_space < 16,
reason="Test requires at least 16GB of HDD",
)
def test_cpu_owlv2_annotator_image():
_check_owlv2_annotator("cpu", use_text_prompts=False)
def test_cpu_owlv2_annotator():
_check_owlv2_annotator("cpu")


def _check_aimv2_annotator(device: str):
Expand Down

0 comments on commit 48bb90c

Please sign in to comment.