From 062290112a64d80a3375906766e947822d419bda Mon Sep 17 00:00:00 2001 From: Warley Vital Barbosa Date: Wed, 27 Mar 2024 14:52:15 -0300 Subject: [PATCH] refactor FaCoRNetLightning module Add FaCoRNet lightning config, among other things. --- ours/configs/facornet.yml | 39 ++++++++++ ours/datasets/facornet.py | 38 ++++++---- ours/datasets/fiw.py | 6 +- ours/datasets/utils.py | 2 +- ours/guild.yml | 6 +- ours/models/facornet.py | 150 ++++++++++++++++++++------------------ ours/tasks/facornet.py | 2 +- 7 files changed, 151 insertions(+), 92 deletions(-) create mode 100644 ours/configs/facornet.yml diff --git a/ours/configs/facornet.yml b/ours/configs/facornet.yml new file mode 100644 index 0000000..c3ae171 --- /dev/null +++ b/ours/configs/facornet.yml @@ -0,0 +1,39 @@ +seed_everything: 100 +trainer: + num_sanity_val_steps: 1 + log_every_n_steps: 10 + accelerator: "gpu" + deterministic: yes + fast_dev_run: no + max_epochs: 53 + limit_train_batches: 100 + callbacks: + class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + dirpath: ./ + filename: '{epoch}-{auc/val:.3f}-{auc/train:.3f}' + monitor: auc/val + verbose: no + save_last: yes + save_top_k: 1 + save_weights_only: no + auto_insert_metric_name: no + mode: max + + +data: + class_path: datasets.facornet.FaCoRNetDataModule + init_args: + batch_size: 20 + root_dir: data/facornet + +model: + class_path: models.facornet.FaCoRNetLightning + init_args: + lr: 1e-4 + momentum: 0.9 + weight_decay: 0 + weights_path: null + threshold: null + model: + class_path: models.facornet.FaCoR diff --git a/ours/datasets/facornet.py b/ours/datasets/facornet.py index 7cf6c83..8a1b9cc 100644 --- a/ours/datasets/facornet.py +++ b/ours/datasets/facornet.py @@ -1,8 +1,8 @@ from pathlib import Path -import lightning as pl +import lightning as L from torch.utils.data import DataLoader -from torchvision import transforms +from torchvision import transforms as T from .fiw import FIW @@ -16,7 +16,8 @@ class FIWFaCoRNet(FIW): # AdaFace uses BGR -- should I revert conversion read_image here? - def __init__(self, **kwargs): + def __init__(self, batch_size: int, **kwargs): + self.batch_size = batch_size super().__init__(**kwargs) def __getitem__(self, item): @@ -27,40 +28,49 @@ def __getitem__(self, item): return img1, img2, labels -class FaCoRNetDataModule(pl.LightningDataModule): +class FaCoRNetDataModule(L.LightningDataModule): def __init__(self, batch_size=20, root_dir=".", transform=None): super().__init__() self.batch_size = batch_size self.root_dir = root_dir - self.transform = transform or transforms.Compose([transforms.ToTensor()]) + self.transform = transform or T.Compose([T.ToTensor()]) def setup(self, stage=None): if stage == "fit" or stage is None: self.train_dataset = FIW( - root_dir=self.root_dir, sample_path=Path(FIWFaCoRNet.TRAIN_PAIRS), transform=self.transform + root_dir=self.root_dir, + sample_path=Path(FIWFaCoRNet.TRAIN_PAIRS), + batch_size=self.batch_size, + biased=True, + transform=self.transform, ) self.val_dataset = FIW( - root_dir=self.root_dir, sample_path=Path(FIWFaCoRNet.VAL_PAIRS_MODEL_SEL), transform=self.transform + root_dir=self.root_dir, + sample_path=Path(FIWFaCoRNet.VAL_PAIRS_MODEL_SEL), + batch_size=self.batch_size, + transform=self.transform, ) if stage == "validate" or stage is None: self.val_dataset = FIW( - root_dir=self.root_dir, sample_path=Path(FIWFaCoRNet.VAL_PAIRS_THRES_SEL), transform=self.transform + root_dir=self.root_dir, + sample_path=Path(FIWFaCoRNet.VAL_PAIRS_THRES_SEL), + batch_size=self.batch_size, + transform=self.transform, ) if stage == "test" or stage is None: self.test_dataset = FIW( - root_dir=self.root_dir, sample_path=Path(FIWFaCoRNet.TEST_PAIRS), transform=self.transform + root_dir=self.root_dir, + sample_path=Path(FIWFaCoRNet.TEST_PAIRS), + batch_size=self.batch_size, + transform=self.transform, ) print(f"Setup {stage} datasets") def train_dataloader(self): - return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=4, pin_memory=True) + return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=False, num_workers=4, pin_memory=True) def val_dataloader(self): return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=4, pin_memory=True) def test_dataloader(self): return DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=4, pin_memory=True) - - -if __name__ == "__main__": - fiw = FIW(root_dir="../../datasets/", sample_path=FIWFaCoRNet.TRAIN_PAIRS) diff --git a/ours/datasets/fiw.py b/ours/datasets/fiw.py index 8502179..ef25bc9 100644 --- a/ours/datasets/fiw.py +++ b/ours/datasets/fiw.py @@ -8,12 +8,14 @@ class FIW(Dataset): - def __init__(self, root_dir, sample_path, transform=None): + def __init__(self, root_dir: str, sample_path: Path, batch_size: int = 20, biased: bool = False, transform=None): self.root_dir = Path(root_dir) self.images_dir = "images" self.sample_path = sample_path + self.batch_size = batch_size self.transform = transform self.bias = 0 + self.biased = biased self.sample_cls = Sample self.sample_list = self.load_sample() print(f"Loaded {len(self.sample_list)} samples from {sample_path}") @@ -35,7 +37,7 @@ def load_sample(self): return sample_list def __len__(self): - return len(self.sample_list) + return len(self.sample_list) // self.batch_size if self.biased else len(self.sample_list) def read_image(self, path): # TODO: add to utils.py diff --git a/ours/datasets/utils.py b/ours/datasets/utils.py index f86618c..699fc46 100644 --- a/ours/datasets/utils.py +++ b/ours/datasets/utils.py @@ -7,7 +7,7 @@ class Sample: # TODO: move to utils.py NAME2LABEL = { - # "non-kin": 0, + "non-kin": 0, "md": 1, "ms": 2, "sibs": 3, diff --git a/ours/guild.yml b/ours/guild.yml index 1a79a7d..c5395b5 100644 --- a/ours/guild.yml +++ b/ours/guild.yml @@ -122,14 +122,16 @@ - datasets/utils.py - tasks/facornet.py flags-import: all - flags-dest: config:facornet.yml - output-scalars: '(\key):\s+(\value)' + flags-dest: config:configs/facornet.yml + output-scalars: '(\key)=(\value)' requires: - file: weights target-type: link - file: ../datasets/ target-type: link rename: data + - file: configs/facornet.yml + target-type: link val: description: Reproduction of Kinship Representation Learning with Face Componential Relation (2023) main: tasks.facornet validate diff --git a/ours/models/facornet.py b/ours/models/facornet.py index 02f9807..cc0c8c3 100644 --- a/ours/models/facornet.py +++ b/ours/models/facornet.py @@ -1,7 +1,7 @@ from collections import namedtuple from pathlib import Path -import lightning as pl +import lightning as L import numpy as np import torch import torch.nn as nn @@ -22,6 +22,7 @@ Sequential, Sigmoid, ) +from torchmetrics.utilities import dim_zero_cat # Assuming the necessary imports are done for FaCoR, facornet_contrastive_loss, FIW, and other utilities @@ -47,7 +48,7 @@ def load_pretrained_model(architecture="ir_101"): statedict = torch.load(adaface_models[architecture])["state_dict"] model_statedict = {key[6:]: val for key, val in statedict.items() if key.startswith("model.")} model.load_state_dict(model_statedict) - model.eval() + # model.eval() return model @@ -94,8 +95,7 @@ def __init__(self): self.channel_ca = ChannelCrossAttention(self.channel * 8) self.CCA = ChannelInteraction(1024) self.avg_pool = nn.AdaptiveAvgPool2d(1) - - self.task_kin = HeadKin(512, 12, 8) + # self.task_kin = HeadKin(512, 12, 8) def forward(self, imgs, aug=False): img1, img2 = imgs @@ -664,11 +664,9 @@ def __init__(self, **kwargs): self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum") self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") - def __call__(self, preds: torch.Tensor, target: torch.Tensor, threshold: torch.Tensor): - self.update(preds, target, threshold) - return self.compute() - def update(self, preds: torch.Tensor, target: torch.Tensor, threshold: torch.Tensor): + if preds.shape != target.shape: + raise ValueError("preds and target must have the same shape") preds_thresholded = preds >= threshold correct = torch.sum(preds_thresholded == target) self.correct += correct @@ -679,34 +677,25 @@ def compute(self): class CollectPreds(tm.Metric): - def __init__(self, **kwargs): + def __init__(self, name: str, **kwargs): + self.name = name super().__init__(**kwargs) - - self.add_state("predictions", default=[], dist_reduce_fx=None) + self.add_state("predictions", default=[], dist_reduce_fx="cat") def update(self, preds: torch.Tensor): - # Convert preds to the same device as the metric state - preds = preds.detach().to(self.predictions[0].device if self.predictions else preds.device) - # Append current batch predictions to the list of all predictions self.predictions.append(preds) def compute(self): # Concatenate the list of predictions into a single tensor - return torch.cat(self.predictions, dim=0) - - def reset(self): - # Reset the state (list of predictions) - self.predictions = [] + return dim_zero_cat(self.predictions) -class FaCoRNetLightning(pl.LightningModule): +class FaCoRNetLightning(L.LightningModule): def __init__( self, model: torch.nn.Module, lr=1e-4, momentum=0.9, weight_decay=0, weights_path=None, threshold=None ): super().__init__() - print(type(weights_path)) - self.save_hyperparameters() self.model = FaCoR() or model self.lr = lr @@ -716,9 +705,24 @@ def __init__( self.threshold = threshold - self.similarities = CollectPreds() # Custom metric to collect predictions - self.is_kin_labels = CollectPreds() # Custom metric to collect labels - self.kin_labels = CollectPreds() # Custom metric to collect labels + self.similarities = tm.MetricCollection( + { + "similarities/train": CollectPreds("similarities/train"), + "similarities/val": CollectPreds("similarities/val"), + } + ) + self.is_kin_labels = tm.MetricCollection( + { + "is_kin_labels/train": CollectPreds("is_kin_labels/train"), + "is_kin_labels/val": CollectPreds("is_kin_labels/val"), + } + ) + self.kin_labels = tm.MetricCollection( + { + "kin_labels/train": CollectPreds("kin_labels/train"), + "kin_labels/val": CollectPreds("kin_labels/val"), + } + ) # Metrics self.train_auc = tm.AUROC(task="binary") @@ -726,12 +730,14 @@ def __init__( self.train_acc = DynamicThresholdAccuracy() self.val_acc = DynamicThresholdAccuracy() self.train_acc_kin_relations = tm.MetricCollection( - {f"acc/{kin}/train": DynamicThresholdAccuracy() for kin in Sample.NAME2LABEL.keys()} + {f"accuracy/train/{kin}": DynamicThresholdAccuracy() for kin in Sample.NAME2LABEL.keys()} ) self.val_acc_kin_relations = tm.MetricCollection( - {f"acc/{kin}/val": DynamicThresholdAccuracy() for kin in Sample.NAME2LABEL.keys()} + {f"accuracy/val/{kin}": DynamicThresholdAccuracy() for kin in Sample.NAME2LABEL.keys()} ) + self.save_hyperparameters(ignore=("model")) + def setup(self, stage): # TODO: use checkpoint callback to load the weights if self.hparams.weights_path is not None: @@ -747,36 +753,33 @@ def setup(self, stage): print(f"Failed to load weights due to a runtime error: {e}") print("Model setup complete") - def forward(self, img1, img2): - return self.model([img1, img2]) + def forward(self, inputs): + return self.model(inputs) - def step(self, batch, stage="train"): + def _step(self, batch, stage="train"): img1, img2, labels = batch kin_relation, is_kin = labels - f1, f2, att = self.forward(img1, img2) + f1, f2, att = self((img1, img2)) loss = self.loss_fn(f1, f2, beta=att) sim = torch.cosine_similarity(f1, f2) - if stage == "train": - self.__compute_metrics(sim, is_kin.int(), kin_relation, stage) - else: - # Compute best threshold for training or validation - self.similarities(sim) - self.is_kin_labels(is_kin.int()) - self.kin_labels(kin_relation) + # Compute best threshold for training or validation + self.similarities[f"similarities/{stage}"](sim) + self.is_kin_labels[f"is_kin_labels/{stage}"](is_kin) + self.kin_labels[f"kin_labels/{stage}"](kin_relation) - self.log(f"loss/{stage}", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) + self.log(f"loss/{stage}", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) return loss def training_step(self, batch, batch_idx): - return self.step(batch, "train") + return self._step(batch, "train") def validation_step(self, batch, batch_idx): - self.step(batch, "val") + self._step(batch, "val") def test_step(self, batch, batch_idx): - self.step(batch, "test") + self._step(batch, "test") def configure_optimizers(self): optimizer = torch.optim.SGD( @@ -784,29 +787,35 @@ def configure_optimizers(self): ) return optimizer - def on_epoch_end(self): + def on_train_epoch_end(self): # Calculate the number of samples processed - use_sample = (self.current_epoch + 1) * self.trainer.datamodule.batch_size * self.trainer.limit_train_batches + use_sample = ( + (self.current_epoch + 1) * self.trainer.datamodule.batch_size * int(self.trainer.limit_train_batches) + ) # Update the dataset's bias or sampling strategy self.trainer.datamodule.train_dataset.set_bias(use_sample) - # Reset the metrics - self.similarities.reset() - self.is_kin_labels.reset() - self.kin_labels.reset() + print(f"Updated dataset bias to {use_sample}") + self._on_epoch_end("train") def on_validation_epoch_end(self): - similarities = self.similarities.compute() - is_kin_labels = self.is_kin_labels.compute() - kin_labels = self.kin_labels.compute() - self.__compute_metrics(similarities, is_kin_labels, kin_labels, stage="val") + self._on_epoch_end("val") def on_test_epoch_end(self): - similarities = self.similarities.compute() - is_kin_labels = self.is_kin_labels.compute() - kin_labels = self.kin_labels.compute() - self.__compute_metrics(similarities, is_kin_labels, kin_labels, stage="test") + self._on_epoch_end("test") + + def _on_epoch_end(self, stage): + # Compute predictions + similarities = self.similarities[f"similarities/{stage}"].compute() + is_kin_labels = self.is_kin_labels[f"is_kin_labels/{stage}"].compute() + kin_labels = self.kin_labels[f"kin_labels/{stage}"].compute() + self.__compute_metrics(similarities, is_kin_labels, kin_labels, stage=stage) + # Reset predictions + self.similarities[f"similarities/{stage}"].reset() + self.is_kin_labels[f"is_kin_labels/{stage}"].reset() + self.kin_labels[f"kin_labels/{stage}"].reset() def __compute_metrics(self, similarities, is_kin_labels, kin_labels, stage="train"): + # Compute best threshold if stage == "test" and self.threshold is None: raise ValueError("Threshold must be provided for test stage") elif stage == "test": @@ -814,34 +823,31 @@ def __compute_metrics(self, similarities, is_kin_labels, kin_labels, stage="trai else: # Compute best threshold for training or validation fpr, tpr, thresholds = tm.functional.roc(similarities, is_kin_labels, task="binary") best_threshold = compute_best_threshold(tpr, fpr, thresholds) - self.log(f"threshold/{stage}", best_threshold, on_epoch=True, prog_bar=True, logger=True) - # Log AUC and Accuracy + # Compute metrics auc_fn = self.train_auc if stage == "train" else self.val_auc acc_fn = self.train_acc if stage == "train" else self.val_acc auc = auc_fn(similarities, is_kin_labels) acc = acc_fn(similarities, is_kin_labels, best_threshold) - self.log(f"accuracy/{stage}", auc, on_epoch=True, prog_bar=True, logger=True) - self.log(f"auc/{stage}", acc, on_epoch=True, prog_bar=True, logger=True) # Accuracy for each kinship relation acc_kin_relations = self.train_acc_kin_relations if stage == "train" else self.val_acc_kin_relations - for kin, kin_id in Sample.NAME2LABEL.items(): + for kin, kin_id in Sample.NAME2LABEL.items(): # TODO: pass Sample class as argument mask = kin_labels == kin_id if torch.any(mask): - acc_kin_relations[f"acc/{kin}/{stage}"](similarities[mask], is_kin_labels[mask].int(), best_threshold) + acc_kin_relations[f"accuracy/{stage}/{kin}"]( + similarities[mask], is_kin_labels[mask].int(), best_threshold + ) self.log( - f"acc/{kin}/{stage}", - acc_kin_relations[f"acc/{kin}/{stage}"], + f"accuracy/{stage}/{kin}", + acc_kin_relations[f"accuracy/{stage}/{kin}"], + on_step=False, on_epoch=True, - prog_bar=True, + prog_bar=False, logger=True, ) - -if __name__ == "__main__": - model = IR_101((112, 112)) - print(model) - input = torch.rand(2, 3, 112, 112) - output = model(input) - print(output) + # Log metrics + self.log(f"threshold/{stage}", best_threshold, on_step=False, on_epoch=True, prog_bar=True, logger=True) + self.log(f"accuracy/{stage}", auc, on_step=False, on_epoch=True, prog_bar=True, logger=True) + self.log(f"auc/{stage}", acc, on_step=False, on_epoch=True, prog_bar=True, logger=True) diff --git a/ours/tasks/facornet.py b/ours/tasks/facornet.py index d55da4e..4807ed2 100644 --- a/ours/tasks/facornet.py +++ b/ours/tasks/facornet.py @@ -2,7 +2,7 @@ def main(args=None): - cli = LightningCLI(args=args, subclass_mode_model=True) # noqa: F841 + LightningCLI(args=args, subclass_mode_model=True) if __name__ == "__main__":