Skip to content

Commit

Permalink
Prepared for detection training
Browse files Browse the repository at this point in the history
  • Loading branch information
mantasu committed Jan 7, 2024
1 parent 3ab1165 commit 54995fa
Show file tree
Hide file tree
Showing 10 changed files with 407 additions and 46 deletions.
16 changes: 11 additions & 5 deletions scripts/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,18 +395,24 @@ def parse_coco_json(
img.save(os.path.join(path, img_info["file_name"]))
elif "detection" in path_splits:
# Normalize bbox (x_center, y_center, width, height)
x = (ann["bbox"][0] + ann["bbox"][2] / 2) / img_info["width"]
y = (ann["bbox"][1] + ann["bbox"][3] / 2) / img_info["height"]
w = ann["bbox"][2] / img_info["width"]
h = ann["bbox"][3] / img_info["height"]
# x = (ann["bbox"][0] + ann["bbox"][2] / 2) / img_info["width"]
# y = (ann["bbox"][1] + ann["bbox"][3] / 2) / img_info["height"]
# w = ann["bbox"][2] / img_info["width"]
# h = ann["bbox"][3] / img_info["height"]

# Convert to pascal_voc format (with resized bbox)
x1 = int(ann["bbox"][0] * size[0] / img_info["width"])
y1 = int(ann["bbox"][1] * size[1] / img_info["height"])
x2 = x1 + int(ann["bbox"][2] * size[0] / img_info["width"])
y2 = y1 + int(ann["bbox"][3] * size[1] / img_info["height"])

# Copy the image and create .txt annotation filename
img.save(os.path.join(path, "images", img_info["file_name"]))
txt = img_info["file_name"].rsplit(".", 1)[0] + ".txt"

with open(os.path.join(path, "annotations", txt), "w") as f:
# Write the bounding box
f.write(f"{x} {y} {w} {h}")
f.write(f"{x1} {y1} {x2} {y2}")
elif "segmentation" in path_splits:
# Update the mask for the current class
mask = masks[class_name]
Expand Down
14 changes: 8 additions & 6 deletions scripts/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,12 @@
torch.set_float32_matmul_precision("medium")

from glasses_detector import GlassesClassifier, GlassesDetector, GlassesSegmenter
from glasses_detector._data import ImageClassificationDataset, ImageSegmentationDataset
from glasses_detector._wrappers import BinaryClassifier, BinarySegmenter
from glasses_detector._data import (
ImageClassificationDataset,
ImageDetectionDataset,
ImageSegmentationDataset,
)
from glasses_detector._wrappers import BinaryClassifier, BinaryDetector, BinarySegmenter


class RunCLI(LightningCLI):
Expand Down Expand Up @@ -152,7 +156,7 @@ def create_wrapper_callback(
# Get model and dataset classes
model_cls, data_cls = {
"classification": (GlassesClassifier, ImageClassificationDataset),
"detection": (GlassesDetector, None),
"detection": (GlassesDetector, ImageDetectionDataset),
"segmentation": (GlassesSegmenter, ImageSegmentationDataset),
}[task]

Expand All @@ -168,10 +172,8 @@ def create_wrapper_callback(
kwargs["label_type"] = {kind: 1, "no_" + kind: 0}
wrapper_cls = BinaryClassifier
elif task == "detection":
raise NotImplementedError("Detection is not implemented yet!")
wrapper_cls = BinaryDetector
elif task == "segmentation":
kwargs["img_dirname"] = "images"
kwargs["name_map_fn"] = {"masks": lambda x: f"{int(x[:5])}.jpg"}
wrapper_cls = BinarySegmenter

# Initialize model architecture and load weights if needed
Expand Down
3 changes: 2 additions & 1 deletion src/glasses_detector/_data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .mixins import ImageLoaderMixin, DataLoaderMixin
from .classification_dataset import ImageClassificationDataset
from .detection_dataset import ImageDetectionDataset
from .mixins import DataLoaderMixin, ImageLoaderMixin
from .segmentation_dataset import ImageSegmentationDataset
107 changes: 107 additions & 0 deletions src/glasses_detector/_data/detection_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import os
import random
from collections import defaultdict
from typing import Callable

import albumentations as A
import torch
from torch.utils.data import Dataset

from .mixins import DataLoaderMixin, ImageLoaderMixin


class ImageDetectionDataset(Dataset, ImageLoaderMixin, DataLoaderMixin):
# It's more efficient to implement a specific dataset for each task
# And it is very unlikely that multiple tasks will be considered at
# once, meaning a generic dataset is not needed
def __init__(
self,
root: str = ".",
split_type: str = "train",
img_folder: str = "images",
ann2img_fn: dict[str, Callable[[str], str]] = {},
# for each annotation folder name, a function that maps annotation file name to the image file name it belongs
seed: int = 0,
):
super().__init__()

self.data = []
cat2paths = defaultdict(lambda: {"names": [], "paths": []})

for dataset in os.listdir(root):
if not os.path.isdir(p := os.path.join(root, dataset, split_type)):
continue

for cat in os.scandir(p):
# Read the list of names and paths to images/masks
name_fn = ann2img_fn.get(cat.name, lambda x: x.replace(".txt", ".jpg"))
names = list(map(name_fn, os.listdir(cat.path)))
paths = [f.path for f in os.scandir(cat.path)]

# Extend the lists of image/annot names + paths
cat2paths[cat.name]["names"].extend(names)
cat2paths[cat.name]["paths"].extend(paths)

# Pop the non-category folder (get image names and paths)
img_names, img_paths = cat2paths.pop(img_folder).values()

for img_name, img_path in zip(img_names, img_paths):
# Add the default image entry
self.data.append({"image": img_path})

for cat_dirname, names_and_paths in cat2paths.items():
if img_name in names_and_paths["names"]:
# Get the index of corresponding annotation
i = names_and_paths["names"].index(img_name)
annotation_path = names_and_paths["paths"][i]
self.data[-1][cat_dirname] = annotation_path
else:
# No annotation but add for equally sized batches
self.data[-1][cat_dirname] = None

# Sort & shuffle
self.data.sort(key=lambda x: x["image"])
random.seed(seed)
random.shuffle(self.data)

# Create image augmentation pipeline based on split type
p = A.BboxParams(format="pascal_voc", label_fields=["classes"])
self.transform = self.create_transform(split_type == "train", bbox_params=p)

@property
def name2idx(self):
return dict(zip(self.data[0].keys()), range(len(self.data[0])))

@property
def idx2name(self):
return dict(zip(range(len(self.data[0]), self.data[0].keys())))

def __getitem__(self, index):
# Load the image, bboxes and classes
image = self.data[index]["image"]
bboxes = list(self.data[index].values())[1:]
labels = [1] * len(bboxes)
# labels = [self.cat2label(k) for k in list(self.data[index].keys())[1:]]

(image, bboxes, labels) = self.load_image(
image=image,
bboxes=bboxes,
classes=labels,
transform=self.transform,
)

# TODO: create cat2label map and map class names to labels
# TODO: there may be more bboxes read than classes after loading
# the transformed image so consider adding either a max_bbox
# argument or implement a custom collate function for dataloader

if len(bboxes) == 0:
bboxes = torch.tensor([[0, 0, 1, 1]], dtype=torch.float32)
labels = torch.tensor([0], dtype=torch.int64)

annotations = {"boxes": bboxes, "labels": labels}

return image, annotations

def __len__(self):
return len(self.data)
117 changes: 105 additions & 12 deletions src/glasses_detector/_data/mixins.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
from typing import Any

import albumentations as A
import numpy
import PIL.Image as Image
import skimage.transform as st
import torch
from albumentations.pytorch import ToTensorV2
from torch.utils.data import DataLoader


class ImageLoaderMixin:
@staticmethod
def create_transform(is_train: bool = False) -> A.Compose:
def create_transform(is_train: bool = False, **kwargs) -> A.Compose:
# Default augmentation
transform = [
A.VerticalFlip(),
Expand All @@ -18,10 +21,10 @@ def create_transform(is_train: bool = False) -> A.Compose:
A.OneOf(
[
A.RandomResizedCrop(256, 256, p=0.5),
A.GridDistortion(),
A.OpticalDistortion(distort_limit=0.1, shift_limit=0.1),
A.PiecewiseAffine(),
A.Perspective(),
A.GridDistortion(),
]
),
A.OneOf(
Expand All @@ -45,21 +48,26 @@ def create_transform(is_train: bool = False) -> A.Compose:
A.GaussNoise(),
]
),
A.CoarseDropout(max_holes=5, p=0.3),
A.Normalize(),
ToTensorV2(),
]

if "bbox_params" not in kwargs:
transform.insert(-2, A.CoarseDropout(max_holes=5, p=0.3))

if not is_train:
# Only keep the last two
transform = transform[-2:]

return A.Compose(transform)
return A.Compose(transform, **kwargs)

@staticmethod
def load_image(
image: str | Image.Image | numpy.ndarray,
masks: list[str | Image.Image | numpy.ndarray] = [],
bboxes: list[str | list[int | float | str]] = [], # x_min, y_min, x_max, y_max
classes: list[Any] = [], # one for each bbox
resize: tuple[int, int] | None = None,
transform: A.Compose | bool = False,
) -> torch.Tensor:
def open_image_file(image_file, is_mask=False):
Expand All @@ -81,24 +89,109 @@ def open_image_file(image_file, is_mask=False):
# Image is not a mask, so convert it to RGB
image_file = numpy.stack([image_file] * 3, axis=-1)

if resize is not None:
# Resize image to new (w, h)
size = resize[1], resize[0]
image_file = st.resize(image_file, size)

return image_file

def open_bbox_files(bbox_files, classes):
# Init new
_bboxes, _classes = [], []

for i, bbox_file in enumerate(bbox_files):
if isinstance(bbox_file, str):
with open(bbox_file, "r") as f:
# Each line is bbox: "x_min y_min x_max y_max"
batch = [xyxy.strip().split() for xyxy in f.readlines()]
else:
# bbox_file is a single bbox (list[str | int | float])
batch = [bbox_file]

batch = [list(map(float, xyxy)) for xyxy in batch]

for i, xyxy in enumerate(batch):
if xyxy[2] <= xyxy[0]:
batch[i][0] = min(xyxy[0], image.shape[1] - 1)
batch[i][2] = batch[i][0] + 1

if xyxy[3] <= xyxy[1]:
batch[i][1] = min(xyxy[1], image.shape[0] - 1)
batch[i][3] = batch[i][1] + 1

if resize is not None:
# Get old and new width, height
old_h, old_w = image.shape[:2]
new_w, new_h = resize

# Convert bboxes to new (w, h)
batch = [
[
xyxy[0] * new_w / old_w,
xyxy[1] * new_h / old_h,
xyxy[2] * new_w / old_w,
xyxy[3] * new_h / old_h,
]
for xyxy in batch
]

# Add to list
_bboxes.extend(batch)

if classes != []:
# If classes are provided, add them
_classes.extend([classes[i]] * len(batch))

return _bboxes, _classes

kwargs = {}

if isinstance(transform, bool):
if bboxes != []:
kwargs.update(
{
"bbox_params": A.BboxParams(
format="pascal_voc",
label_fields=["classes"] if classes != [] else None,
)
}
)

# Load transform (train/test is based on bool)
transform = ImageLoaderMixin.create_transform(transform)
transform = ImageLoaderMixin.create_transform(transform, **kwargs)

# Load image and mask files
# Load image, mask, bbox files
image = open_image_file(image)
masks = [open_image_file(m, True) for m in masks]
bboxes, classes = open_bbox_files(bboxes, classes)

# Create transform kwargs
kwargs["image"] = image
kwargs.update({"masks": masks} if masks != [] else {})
kwargs.update({"bboxes": bboxes} if bboxes != [] else {})
kwargs.update({"classes": classes} if classes != [] else {})

# Transform everything, init returns
transformed = transform(**kwargs)
return_list = [transformed["image"]]

if masks != []:
# TODO: check if transformation is converted to a tensor
return_list.append(transformed["masks"])

if bboxes != []:
bboxes = torch.tensor(transformed["bboxes"], dtype=torch.float32)
return_list.append(bboxes)

if masks == []:
return transform(image=image)["image"]
if classes != []:
classes = torch.tensor(transformed["classes"], dtype=torch.int64)
return_list.append(classes)

# Transform the image and masks
transformed = transform(image=image, masks=masks)
image, masks = transformed["image"], transformed["masks"]
if len(return_list) == 1:
return return_list[0]

return image, masks
return tuple(return_list)


class DataLoaderMixin:
Expand Down
3 changes: 2 additions & 1 deletion src/glasses_detector/_wrappers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .binary_classifier import BinaryClassifier
from .binary_segmenter import BinarySegmenter
from .binary_detector import BinaryDetector
from .binary_segmenter import BinarySegmenter
Loading

0 comments on commit 54995fa

Please sign in to comment.