Skip to content

Commit

Permalink
refactor FaCoRNetLightning module
Browse files Browse the repository at this point in the history
Add FaCoRNet lightning config, among other things.
  • Loading branch information
vitalwarley committed Mar 27, 2024
1 parent e0fbebb commit 0622901
Show file tree
Hide file tree
Showing 7 changed files with 151 additions and 92 deletions.
39 changes: 39 additions & 0 deletions ours/configs/facornet.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
seed_everything: 100
trainer:
num_sanity_val_steps: 1
log_every_n_steps: 10
accelerator: "gpu"
deterministic: yes
fast_dev_run: no
max_epochs: 53
limit_train_batches: 100
callbacks:
class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
dirpath: ./
filename: '{epoch}-{auc/val:.3f}-{auc/train:.3f}'
monitor: auc/val
verbose: no
save_last: yes
save_top_k: 1
save_weights_only: no
auto_insert_metric_name: no
mode: max


data:
class_path: datasets.facornet.FaCoRNetDataModule
init_args:
batch_size: 20
root_dir: data/facornet

model:
class_path: models.facornet.FaCoRNetLightning
init_args:
lr: 1e-4
momentum: 0.9
weight_decay: 0
weights_path: null
threshold: null
model:
class_path: models.facornet.FaCoR
38 changes: 24 additions & 14 deletions ours/datasets/facornet.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from pathlib import Path

import lightning as pl
import lightning as L
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import transforms as T

from .fiw import FIW

Expand All @@ -16,7 +16,8 @@ class FIWFaCoRNet(FIW):

# AdaFace uses BGR -- should I revert conversion read_image here?

def __init__(self, **kwargs):
def __init__(self, batch_size: int, **kwargs):
self.batch_size = batch_size
super().__init__(**kwargs)

def __getitem__(self, item):
Expand All @@ -27,40 +28,49 @@ def __getitem__(self, item):
return img1, img2, labels


class FaCoRNetDataModule(pl.LightningDataModule):
class FaCoRNetDataModule(L.LightningDataModule):
def __init__(self, batch_size=20, root_dir=".", transform=None):
super().__init__()
self.batch_size = batch_size
self.root_dir = root_dir
self.transform = transform or transforms.Compose([transforms.ToTensor()])
self.transform = transform or T.Compose([T.ToTensor()])

def setup(self, stage=None):
if stage == "fit" or stage is None:
self.train_dataset = FIW(
root_dir=self.root_dir, sample_path=Path(FIWFaCoRNet.TRAIN_PAIRS), transform=self.transform
root_dir=self.root_dir,
sample_path=Path(FIWFaCoRNet.TRAIN_PAIRS),
batch_size=self.batch_size,
biased=True,
transform=self.transform,
)
self.val_dataset = FIW(
root_dir=self.root_dir, sample_path=Path(FIWFaCoRNet.VAL_PAIRS_MODEL_SEL), transform=self.transform
root_dir=self.root_dir,
sample_path=Path(FIWFaCoRNet.VAL_PAIRS_MODEL_SEL),
batch_size=self.batch_size,
transform=self.transform,
)
if stage == "validate" or stage is None:
self.val_dataset = FIW(
root_dir=self.root_dir, sample_path=Path(FIWFaCoRNet.VAL_PAIRS_THRES_SEL), transform=self.transform
root_dir=self.root_dir,
sample_path=Path(FIWFaCoRNet.VAL_PAIRS_THRES_SEL),
batch_size=self.batch_size,
transform=self.transform,
)
if stage == "test" or stage is None:
self.test_dataset = FIW(
root_dir=self.root_dir, sample_path=Path(FIWFaCoRNet.TEST_PAIRS), transform=self.transform
root_dir=self.root_dir,
sample_path=Path(FIWFaCoRNet.TEST_PAIRS),
batch_size=self.batch_size,
transform=self.transform,
)
print(f"Setup {stage} datasets")

def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=4, pin_memory=True)
return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=False, num_workers=4, pin_memory=True)

def val_dataloader(self):
return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=4, pin_memory=True)

def test_dataloader(self):
return DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=4, pin_memory=True)


if __name__ == "__main__":
fiw = FIW(root_dir="../../datasets/", sample_path=FIWFaCoRNet.TRAIN_PAIRS)
6 changes: 4 additions & 2 deletions ours/datasets/fiw.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@


class FIW(Dataset):
def __init__(self, root_dir, sample_path, transform=None):
def __init__(self, root_dir: str, sample_path: Path, batch_size: int = 20, biased: bool = False, transform=None):
self.root_dir = Path(root_dir)
self.images_dir = "images"
self.sample_path = sample_path
self.batch_size = batch_size
self.transform = transform
self.bias = 0
self.biased = biased
self.sample_cls = Sample
self.sample_list = self.load_sample()
print(f"Loaded {len(self.sample_list)} samples from {sample_path}")
Expand All @@ -35,7 +37,7 @@ def load_sample(self):
return sample_list

def __len__(self):
return len(self.sample_list)
return len(self.sample_list) // self.batch_size if self.biased else len(self.sample_list)

def read_image(self, path):
# TODO: add to utils.py
Expand Down
2 changes: 1 addition & 1 deletion ours/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
class Sample:
# TODO: move to utils.py
NAME2LABEL = {
# "non-kin": 0,
"non-kin": 0,
"md": 1,
"ms": 2,
"sibs": 3,
Expand Down
6 changes: 4 additions & 2 deletions ours/guild.yml
Original file line number Diff line number Diff line change
Expand Up @@ -122,14 +122,16 @@
- datasets/utils.py
- tasks/facornet.py
flags-import: all
flags-dest: config:facornet.yml
output-scalars: '(\key):\s+(\value)'
flags-dest: config:configs/facornet.yml
output-scalars: '(\key)=(\value)'
requires:
- file: weights
target-type: link
- file: ../datasets/
target-type: link
rename: data
- file: configs/facornet.yml
target-type: link
val:
description: Reproduction of Kinship Representation Learning with Face Componential Relation (2023)
main: tasks.facornet validate
Expand Down
Loading

0 comments on commit 0622901

Please sign in to comment.