Skip to content

Commit

Permalink
Further getting ready for training
Browse files Browse the repository at this point in the history
  • Loading branch information
mantasu committed Jan 5, 2024
1 parent e63ec07 commit 3ab1165
Show file tree
Hide file tree
Showing 13 changed files with 188 additions and 85 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ scipy
rarfile
pycocotools
torchvision
tensorboard
albumentations
pytorch_lightning
jsonargparse[signatures]
11 changes: 6 additions & 5 deletions scripts/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def create_wrapper_callback(
) -> pl.LightningModule:

# Get task and kind

task_and_kind = task.split("-", maxsplit=1)
task = task_and_kind[0]
kind = DEFAULT_KINDS[task] if len(task_and_kind) == 1 else task_and_kind[1]
Expand Down Expand Up @@ -173,12 +174,12 @@ def create_wrapper_callback(
kwargs["name_map_fn"] = {"masks": lambda x: f"{int(x[:5])}.jpg"}
wrapper_cls = BinarySegmenter

# Initialize model arch
model = model_cls(size)
# Initialize model architecture and load weights if needed
model = model_cls(kind=kind, size=size, pretrained=weights_path).model

if weights_path is not None:
# Load weights if the path is specified to them
model.load_state_dict(torch.load(weights_path))
# if weights_path is not None:
# # Load weights if the path is specified to them
# model.load_state_dict(torch.load(weights_path))

return wrapper_cls(model, *data_cls.create_loaders(**kwargs))

Expand Down
31 changes: 18 additions & 13 deletions src/glasses_detector/_data/classification_dataset.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,34 @@
import os
import torch
import random

from typing import Any
from collections import defaultdict
from typing import Any

import torch
from torch.utils.data import Dataset
from .mixins import ImageLoaderMixin, DataLoaderMixin

from .mixins import DataLoaderMixin, ImageLoaderMixin


class ImageClassificationDataset(Dataset, ImageLoaderMixin, DataLoaderMixin):
def __init__(
self,
root: str = '.',
self,
root: str = ".",
split_type: str = "train",
label_type: str | dict[str, Any] = "onehot", # enum, onehot, {} !vals must be immutable objects
label_type: str
| dict[str, Any] = "onehot", # enum, onehot, {} !vals must be immutable objects
seed: int = 0,
):
super().__init__()

# Init attributes and local vars
self.data, self.label2name = [], {}
cat2paths = defaultdict(lambda: [])

for dir in os.listdir(root):
for cat in os.scandir(os.path.join(root, dir, split_type)):
if not os.path.isdir(dir := os.path.join(root, dir, split_type)):
continue

for cat in os.scandir(dir):
# Add path to the image under category of the dir's name
cat2paths[cat.name].extend([f.path for f in os.scandir(cat.path)])

Expand All @@ -35,19 +40,19 @@ def __init__(
label = i
elif label_type == "onehot":
label = (int(i == j) for j in range(len(cat2paths)))

# Update mapping and data
self.label2name[label] = key
self.data.extend([(img_path, label) for img_path in val])

# Sort & shuffle
self.data.sort()
random.seed(seed)
random.shuffle(self.data)

# Create image augmentation pipeline based on split type
self.transform = self.create_transform(split_type=="train")
self.transform = self.create_transform(split_type == "train")

@property
def name2label(self):
return dict(zip(self.label2name.values(), self.label2name.keys()))
Expand Down
34 changes: 19 additions & 15 deletions src/glasses_detector/_data/segmentation_dataset.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
import os
import torch
import random

from typing import Iterable
from collections import defaultdict
from typing import Callable

import torch
from torch.utils.data import Dataset
from .mixins import ImageLoaderMixin, DataLoaderMixin

from .mixins import DataLoaderMixin, ImageLoaderMixin


class ImageSegmentationDataset(Dataset, ImageLoaderMixin, DataLoaderMixin):
def __init__(
self,
root: str = '.',
self,
root: str = ".",
split_type: str = "train",
img_dirname: str = "images",
name_map_fn: dict[str, callable] = {}, # maps mask name to image name
name_map_fn: dict[
str, Callable[[str], str]
] = {}, # maps mask name to image name
seed: int = 0,
):
super().__init__()
Expand All @@ -23,7 +26,10 @@ def __init__(
cat2paths = defaultdict(lambda: {"names": [], "paths": []})

for dir in os.listdir(root):
for cat in os.scandir(os.path.join(root, dir, split_type)):
if not os.path.isdir(dir := os.path.join(root, dir, split_type)):
continue

for cat in os.scandir(dir):
# Read the list of names and paths to images/masks
name_fn = name_map_fn.get(cat.name, lambda x: x)
names = list(map(name_fn, os.listdir(cat.path)))
Expand All @@ -43,7 +49,7 @@ def __init__(

for mask_dirname, names_and_paths in cat2paths.items():
if mask_dirname == img_dirname:
# Skip if it's image folder
# Skip if it's image folder
continue

if img_name in names_and_paths["names"]:
Expand All @@ -61,17 +67,16 @@ def __init__(
random.shuffle(self.data)

# Create image augmentation pipeline based on split type
self.transform = self.create_transform(split_type=="train")
self.transform = self.create_transform(split_type == "train")

@property
def name2idx(self):
return dict(zip(self.data[0].keys()), range(len(self.data[0])))

@property
def idx2name(self):
return dict(zip(range(len(self.data[0]), self.data[0].keys())))



def __getitem__(self, index):
# Load the image and the masks
image = self.data[index]["image"]
Expand All @@ -80,6 +85,5 @@ def __getitem__(self, index):

return image, torch.stack(masks, dim=0).to(torch.float32)


def __len__(self):
return len(self.data)
44 changes: 26 additions & 18 deletions src/glasses_detector/_wrappers/binary_classifier.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import torch
import numpy as np
import torchmetrics
import torch.nn as nn
import pytorch_lightning as pl

import torch
import torch.nn as nn
import torchmetrics
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch.optim.lr_scheduler import ReduceLROnPlateau


class BinaryClassifier(pl.LightningModule):
Expand All @@ -22,17 +21,20 @@ def __init__(self, model, train_loader=None, val_loader=None, test_loader=None):
self.criterion = nn.BCEWithLogitsLoss(pos_weight=self.pos_weight)

# Create F1 score and ROC-AUC metrics to monitor
self.metrics = torchmetrics.MetricCollection([
torchmetrics.F1Score(task="binary"),
torchmetrics.AUROC(task="binary")
])

self.metrics = torchmetrics.MetricCollection(
[
torchmetrics.F1Score(task="binary"),
torchmetrics.AUROC(task="binary"), # ROC-AUC
torchmetrics.AveragePrecision(task="binary"), # PR-AUC
]
)

@property
def pos_weight(self):
if self.train_loader is None:
# Not known
return None

# Calculate the positive weight to account for class imbalance
targets = np.array([y for _, y in iter(self.train_loader.dataset.data)])
pos_count = targets.sum()
Expand All @@ -50,8 +52,8 @@ def training_step(self, batch, batch_idx):
loss = self.criterion(y_hat, y)
self.log("train_loss", loss, prog_bar=True)
return loss
def eval_step(self, batch, prefix=''):

def eval_step(self, batch, prefix=""):
# Forward pass
x, y = batch
y_hat = self(x)
Expand All @@ -64,13 +66,14 @@ def eval_step(self, batch, prefix=''):
self.log(f"{prefix}_loss", loss, prog_bar=True)
self.log(f"{prefix}_f1", metrics["BinaryF1Score"], prog_bar=True)
self.log(f"{prefix}_roc_auc", metrics["BinaryAUROC"], prog_bar=True)

self.log(f"{prefix}_pr_auc", metrics["BinaryAveragePrecision"], prog_bar=True)

def validation_step(self, batch, batch_idx):
self.eval_step(batch, prefix="val")

def test_step(self, batch, batch_idx):
self.eval_step(batch, prefix="test")

def train_dataloader(self):
return self.train_loader

Expand All @@ -79,10 +82,15 @@ def val_dataloader(self):

def test_dataloader(self):
return self.test_loader

def configure_optimizers(self):
# Initialize AdamW optimizer and Cosine Annealing scheduler
optimizer = AdamW(self.parameters(), lr=1e-3, weight_decay=0.1)
scheduler = CosineAnnealingWarmRestarts(optimizer, 10, 2, 1e-6)
# scheduler = CosineAnnealingWarmRestarts(optimizer, 10, 2, 1e-6)
scheduler = ReduceLROnPlateau(optimizer, factor=0.1, patience=10)

return [optimizer], [scheduler]
return {
"optimizer": optimizer,
"lr_scheduler": scheduler,
"monitor": "val_loss",
}
Empty file.
43 changes: 26 additions & 17 deletions src/glasses_detector/_wrappers/binary_segmenter.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import torchmetrics
import torch.nn as nn
import pytorch_lightning as pl

import torch.nn as nn
import torchmetrics
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch.optim.lr_scheduler import ReduceLROnPlateau


class BinarySegmenter(pl.LightningModule):
def __init__(self, model, train_loader=None, val_loader=None, test_loader=None):
Expand All @@ -19,17 +19,20 @@ def __init__(self, model, train_loader=None, val_loader=None, test_loader=None):
self.criterion = nn.BCEWithLogitsLoss(pos_weight=self.pos_weight)

# Initialize some metrics to monitor the performance
self.metrics = torchmetrics.MetricCollection([
torchmetrics.F1Score(task="binary"),
torchmetrics.Dice()
])
self.metrics = torchmetrics.MetricCollection(
[
torchmetrics.F1Score(task="binary"),
torchmetrics.Dice(task="binary"),
torchmetrics.JaccardIndex(task="binary"), # IoU
]
)

@property
def pos_weight(self):
if self.train_loader is None:
# Not known
return None

# Init counts
pos, neg = 0, 0

Expand All @@ -51,7 +54,7 @@ def training_step(self, batch, batch_idx):
self.log("train_loss", loss, prog_bar=True)
return loss

def eval_step(self, batch, prefix=''):
def eval_step(self, batch, prefix=""):
# Forward pass
x, y = batch
y_hat = self(x)
Expand All @@ -63,14 +66,15 @@ def eval_step(self, batch, prefix=''):
# Log the loss and the metrics
self.log(f"{prefix}_loss", loss, prog_bar=True)
self.log(f"{prefix}_f1", metrics["BinaryF1Score"], prog_bar=True)
self.log(f"{prefix}_dice", metrics["Dice"], prog_bar=True)

self.log(f"{prefix}_dice", metrics["BinaryDice"], prog_bar=True)
self.log(f"{prefix}_iou", metrics["BinaryJaccardIndex"], prog_bar=True)

def validation_step(self, batch, batch_idx):
self.eval_step(batch, prefix="val")

def test_step(self, batch, batch_idx):
self.eval_step(batch, prefix="test")

def train_dataloader(self):
return self.train_loader

Expand All @@ -81,7 +85,12 @@ def test_dataloader(self):
return self.test_loader

def configure_optimizers(self):
optimizer = AdamW(self.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = CosineAnnealingWarmRestarts(optimizer, 10, 2, 1e-6)

return [optimizer], [scheduler]
optimizer = AdamW(self.parameters(), lr=1e-3, weight_decay=1e-2)
# scheduler = CosineAnnealingWarmRestarts(optimizer, 10, 2, 1e-6)
scheduler = ReduceLROnPlateau(optimizer, factor=0.1, patience=10)

return {
"optimizer": optimizer,
"lr_scheduler": scheduler,
"monitor": "val_loss",
}
6 changes: 3 additions & 3 deletions src/glasses_detector/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@

@dataclass
class GlassesClassifier(BaseGlassesModel):
"""Glasses classifier for specific glasses type."""
"""Glasses classifier for people wearing spectacles."""

task: str = field(default="classification", init=False)
kind: str = "anyglasses"
size: str = "normal"
pretrained: bool | str = field(default=True, repr=False)
size: str = "medium"
pretrained: bool | str | None = field(default=True, repr=False)

DEFAULT_SIZE_MAP: ClassVar[dict[str, dict[str, str]]] = {
"small": {"name": "tinyclsnet_v1", "version": "v1.0.0"},
Expand Down
2 changes: 1 addition & 1 deletion src/glasses_detector/components/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class BaseGlassesModel(nn.Module, PredInterface):
task: str
kind: str
size: str
pretrained: bool | str = field(default=False, repr=False)
pretrained: bool | str | None = field(default=False, repr=False)
device: str | torch.device = field(default="cpu", repr=False)
model: nn.Module = field(default_factory=lambda: None, init=False, repr=False)

Expand Down
Loading

0 comments on commit 3ab1165

Please sign in to comment.