Skip to content

Commit

Permalink
refactor facornet to use pytorch lightning (#71)
Browse files Browse the repository at this point in the history
  • Loading branch information
vitalwarley committed Mar 20, 2024
1 parent def26ca commit c850a42
Show file tree
Hide file tree
Showing 5 changed files with 260 additions and 259 deletions.
47 changes: 47 additions & 0 deletions ours/datasets/facornet.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
from pathlib import Path

import lightning as pl
from torch.utils.data import DataLoader
from torchvision import transforms

from .fiw import FIW


Expand All @@ -13,6 +19,47 @@ class FIWFaCoRNet(FIW):
def __init__(self, **kwargs):
super().__init__(**kwargs)

def __getitem__(self, item):
(img1, img2, labels) = super().__getitem__(item)
# Convert img1 and img2 to BGR - they're tensors (C, H, W)
# img1 = img1[[2, 1, 0], :, :]
# img2 = img2[[2, 1, 0], :, :]
return img1, img2, labels


class FaCoRNetDataModule(pl.LightningDataModule):
def __init__(self, batch_size=20, root_dir=".", transform=None):
super().__init__()
self.batch_size = batch_size
self.root_dir = root_dir
self.transform = transform or transforms.Compose([transforms.ToTensor()])

def setup(self, stage=None):
if stage == "fit" or stage is None:
self.train_dataset = FIW(
root_dir=self.root_dir, sample_path=Path(FIWFaCoRNet.TRAIN_PAIRS), transform=self.transform
)
self.val_dataset = FIW(
root_dir=self.root_dir, sample_path=Path(FIWFaCoRNet.VAL_PAIRS_MODEL_SEL), transform=self.transform
)
if stage == "val" or stage is None:
self.val_dataset = FIW(
root_dir=self.root_dir, sample_path=Path(FIWFaCoRNet.VAL_PAIRS_THRES_SEL), transform=self.transform
)
if stage == "test" or stage is None:
self.test_dataset = FIW(
root_dir=self.root_dir, sample_path=Path(FIWFaCoRNet.TEST_PAIRS), transform=self.transform
)

def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=4, pin_memory=True)

def val_dataloader(self):
return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=4, pin_memory=True)

def test_dataloader(self):
return DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=4, pin_memory=True)


if __name__ == "__main__":
fiw = FIW(root_dir="../../datasets/", sample_path=FIWFaCoRNet.TRAIN_PAIRS)
7 changes: 7 additions & 0 deletions ours/guild-config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,10 @@ remotes:
user: warley
guild-env: ~/.virtualenvs/research
guild-home: ~/.guild
rig-2-facor:
type: ssh
description: RIG2
host: rig2
user: warley
guild-env: ~/.virtualenvs/facor
guild-home: ~/.guild
4 changes: 2 additions & 2 deletions ours/guild.yml
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@
operations:
train:
description: Reproduction of Kinship Representation Learning with Face Componential Relation (2023)
main: tasks.facornet train
main: tasks.facornet fit
sourcecode:
- utils.py
- losses.py
Expand All @@ -132,7 +132,7 @@
rename: data
val:
description: Reproduction of Kinship Representation Learning with Face Componential Relation (2023)
main: tasks.facornet val
main: tasks.facornet validate
sourcecode:
- utils.py
- losses.py
Expand Down
186 changes: 186 additions & 0 deletions ours/models/facornet.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
from collections import namedtuple
from pathlib import Path

import lightning as pl
import numpy as np
import torch
import torch.nn as nn
import torchmetrics as tm
from datasets.utils import Sample
from losses import facornet_contrastive_loss
from models.utils import compute_best_threshold
from torch.nn import (
BatchNorm1d,
BatchNorm2d,
Expand All @@ -18,6 +23,8 @@
Sigmoid,
)

# Assuming the necessary imports are done for FaCoR, facornet_contrastive_loss, FIW, and other utilities

HERE = Path(__file__).parent

adaface_models = {
Expand Down Expand Up @@ -636,6 +643,185 @@ def IR_SE_200(input_size):
return model


class DynamicThresholdAccuracy(tm.Metric):
def __init__(self, compute_on_step=True, dist_sync_on_step=False):
super().__init__(compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step)
self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")

def __call__(self, preds: torch.Tensor, target: torch.Tensor, threshold: torch.Tensor):
self.update(preds, target, threshold)
return self.compute()

def update(self, preds: torch.Tensor, target: torch.Tensor, threshold: torch.Tensor):
preds_thresholded = preds >= threshold.unsqueeze(
1
) # Assuming threshold is a 1D tensor with the same batch size as preds
correct = torch.sum(preds_thresholded == target)
self.correct += correct
self.total += target.numel()

def compute(self):
return self.correct.float() / self.total


class CollectPreds(tm.Metric):
def __init__(self, compute_on_step=False, dist_sync_on_step=False):
super().__init__(compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step)

self.add_state("predictions", default=[], dist_reduce_fx=None)

def update(self, preds: torch.Tensor):
# Convert preds to the same device as the metric state
preds = preds.detach().to(self.predictions[0].device if self.predictions else preds.device)

# Append current batch predictions to the list of all predictions
self.predictions.append(preds)

def compute(self):
# Concatenate the list of predictions into a single tensor
return torch.cat(self.predictions, dim=0)

def reset(self):
# Reset the state (list of predictions)
self.predictions = []


class FaCoRNetLightning(pl.LightningModule):
def __init__(self, lr=1e-4, momentum=0.9, weight_decay=0, weights_path=None, threshold=None, **kwargs):
super().__init__()
self.save_hyperparameters()
self.model = FaCoR()

self.lr = lr
self.momentum = momentum
self.weight_decay = weight_decay
self.loss_fn = facornet_contrastive_loss

self.threshold = threshold

self.similarities = CollectPreds() # Custom metric to collect predictions
self.is_kin_labels = CollectPreds() # Custom metric to collect labels
self.kin_labels = CollectPreds() # Custom metric to collect labels

# Metrics
self.train_auc = tm.AUROC(task="binary")
self.val_auc = tm.AUROC(task="binary")
self.train_acc = DynamicThresholdAccuracy()
self.val_acc = DynamicThresholdAccuracy()
self.train_acc_kin_relations = tm.MetricCollection(
{f"train/acc_{kin}": DynamicThresholdAccuracy() for kin in Sample.NAME2LABEL.values()}
)
self.val_acc_kin_relations = tm.MetricCollection(
{f"val/acc_{kin}": DynamicThresholdAccuracy() for kin in Sample.NAME2LABEL.values()}
)

def setup(self, stage):
# TODO: use checkpoint callback to load the weights
if self.hparams.weights_path is not None:
map_location = "cuda" if torch.cuda.is_available() else "cpu"
try:
# Load the weights
state_dict = torch.load(self.hparams.weights_path, map_location=map_location)
self.model.load_state_dict(state_dict)
print(f"Loaded weights from {self.hparams.weights_path}")
except FileNotFoundError:
print(f"Failed to load weights from {self.hparams.weights_path}. File does not exist.")
except RuntimeError as e:
print(f"Failed to load weights due to a runtime error: {e}")

def forward(self, img1, img2):
return self.model([img1, img2])

def step(self, batch, stage="train"):
img1, img2, labels = batch
kin_relation, is_kin = labels
f1, f2, att = self.forward(img1, img2)
loss = self.loss_fn(f1, f2, beta=att)
sim = torch.cosine_similarity(f1, f2)

if stage == "train":
self.__compute_metrics(sim, is_kin.int(), kin_relation, stage)
else:
# Compute best threshold for training or validation
self.similarities(sim)
self.is_kin_labels(is_kin.int())
self.kin_labels(kin_relation)

self.log(f"{stage}/loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)

return loss

def training_step(self, batch, batch_idx):
return self.step(batch, "train")

def validation_step(self, batch, batch_idx):
self.step(batch, "val")

def test_step(self, batch, batch_idx):
self.step(batch, "test")

def configure_optimizers(self):
optimizer = torch.optim.SGD(
self.parameters(), lr=self.lr, momentum=self.momentum, weight_decay=self.weight_decay
)
return optimizer

def on_epoch_end(self):
# Calculate the number of samples processed
use_sample = (self.current_epoch + 1) * self.trainer.datamodule.batch_size * self.trainer.limit_train_batches
# Update the dataset's bias or sampling strategy
self.trainer.datamodule.train_dataset.set_bias(use_sample)
# Reset the metrics
self.similarities.reset()
self.is_kin_labels.reset()
self.kin_labels.reset()

def on_validation_epoch_end(self, outputs):
similarities = self.similarities.compute()
is_kin_labels = self.is_kin_labels.compute()
kin_labels = self.kin_labels.compute()
self.__compute_metrics(similarities, is_kin_labels, kin_labels, stage="val")

def on_test_epoch_end(self):
similarities = self.similarities.compute()
is_kin_labels = self.is_kin_labels.compute()
kin_labels = self.kin_labels.compute()
self.__compute_metrics(similarities, is_kin_labels, kin_labels, stage="test")

def __compute_metrics(self, similarities, is_kin_labels, kin_labels, stage="train"):
if stage == "test" and self.threshold is None:
raise ValueError("Threshold must be provided for test stage")
elif stage == "test":
best_threshold = self.threshold
else: # Compute best threshold for training or validation
fpr, tpr, thresholds = tm.functional.roc(similarities, is_kin_labels, task="binary")
best_threshold = compute_best_threshold(tpr, fpr, thresholds)
self.log(f"{stage}/threshold", best_threshold, on_epoch=True, prog_bar=True, logger=True)

# Log AUC and Accuracy
auc_fn = self.train_auc if stage == "train" else self.val_auc
acc_fn = self.train_acc if stage == "train" else self.val_acc
auc = auc_fn(similarities, is_kin_labels, best_threshold)
acc = acc_fn(similarities, is_kin_labels, best_threshold)
self.log(f"{stage}/auc", auc, on_epoch=True, prog_bar=True, logger=True)
self.log(f"{stage}/acc", acc, on_epoch=True, prog_bar=True, logger=True)

# Accuracy for each kinship relation
acc_kin_relations = self.train_acc_kin_relations if stage == "train" else self.val_acc_kin_relations
for kin_id in Sample.NAME2LABEL.values():
mask = kin_labels == kin_id
if torch.any(mask):
acc_kin_relations[f"val/acc_{kin_id}"](similarities[mask], is_kin_labels[mask].int(), best_threshold)
self.log(
f"{stage}/acc_{kin_id}",
acc_kin_relations[f"val/acc_{kin_id}"],
on_epoch=True,
prog_bar=True,
logger=True,
)


if __name__ == "__main__":
model = IR_101((112, 112))
print(model)
Expand Down
Loading

0 comments on commit c850a42

Please sign in to comment.