Skip to content

Commit

Permalink
feat: add batch sampler (#80)
Browse files Browse the repository at this point in the history
  • Loading branch information
vitalwarley committed Jun 26, 2024
1 parent 686b227 commit 0269c52
Show file tree
Hide file tree
Showing 6 changed files with 326 additions and 26 deletions.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,6 @@ insightface/
weights/
outputs/
plots/
datasets/
**/models/*.pth
lightning_logs/

Expand Down
15 changes: 6 additions & 9 deletions ours/configs/facornet.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@ trainer:
accelerator: "gpu"
deterministic: yes
fast_dev_run: no
max_epochs: 43
limit_train_batches: 50
max_epochs: 100
default_root_dir: exp/
callbacks:
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
Expand All @@ -30,15 +29,13 @@ trainer:
mode: 'max' # Use 'min' if the metric should decrease

data:
class_path: datasets.facornet.FaCoRNetDataModule
class_path: datasets.sclff.SCLFFDataModule
init_args:
dataset: facornet
biased: false
batch_size: 20
root_dir: data/facornet

model:
class_path: models.facornet.FaCoRNetLightning
class_path: models.facornet.FaCoRNetBasic
init_args:
num_families: 0
loss_factor: 0
Expand All @@ -59,9 +56,9 @@ model:
anneal_strategy: cos
weights: null
model:
class_path: models.facornet.FaCoR
class_path: models.facornet.FaCoRV5
init_args:
model: adaface_ir_101
attention: models.attention.FaCoRAttention
attention: models.attention.FaCoRAttentionDummy
loss:
class_path: losses.FaCoRContrastiveLoss
class_path: losses.ContrastiveLoss
116 changes: 112 additions & 4 deletions ours/datasets/fiw.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections import defaultdict
from itertools import combinations, islice
from pathlib import Path

Expand Down Expand Up @@ -51,7 +52,8 @@ def __len__(self):

def read_image(self, path):
# TODO: add to utils.py
img = cv2.imread(f"{self.root_dir / self.images_dir}/{path}")
image_path = f"{self.root_dir / self.images_dir}/{path}"
img = cv2.imread(image_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (112, 112))
return img
Expand All @@ -61,11 +63,16 @@ def set_bias(self, bias):
self.bias = bias

def _process_images(self, sample):
img1, img2 = self.read_image(sample.f1), self.read_image(sample.f2)
if self.transform is not None:
img1, img2 = self.transform(img1), self.transform(img2)
img1 = self._process_one_image(sample.f1)
img2 = self._process_one_image(sample.f2)
return img1, img2

def _process_one_image(self, image_path):
image = self.read_image(image_path)
if self.transform is not None:
image = self.transform(image)
return image

def _process_labels(self, sample):
is_kin = torch.tensor(sample.is_kin)
kin_id = self.sample_cls.NAME2LABEL[sample.kin_relation]
Expand Down Expand Up @@ -154,6 +161,107 @@ def _process_labels(self, sample):
return labels


class FIWFamilyV3(FIW):
"""
To be used with the KinshipBatchSampler.
"""

# FaCoRNet dataset
TRAIN_PAIRS = "txt/train_sort_A2_m.txt"
VAL_PAIRS_MODEL_SEL = "txt/val_choose_A.txt"
VAL_PAIRS_THRES_SEL = "txt/val_A.txt"
TEST_PAIRS = "txt/test_A.txt"

def __init__(self, **kwargs):
super().__init__(**kwargs)
# Enconde all samples f1fid and f2fid to set of unique values
self.fids = []
self.filepaths = []
for sample in self.sample_list:
self.fids.append(sample.f1fid)
self.fids.append(sample.f2fid)
self.filepaths.append(sample.f1)
self.filepaths.append(sample.f2)
self.filepaths = set(self.filepaths)
# Map each fid to an index
self.fids = sorted(list(set(self.fids)))
print(f"Found {len(self.fids)} unique fids")
self.fid2idx = {fid: idx for idx, fid in enumerate(self.fids)}

whitelist_dir = "MID"
self.families = [
[
cur_person
for cur_person in cur_family.iterdir()
if cur_person.is_dir() and cur_person.name.startswith(whitelist_dir)
]
for cur_family in self.root_dir.iterdir()
if cur_family.is_dir()
]
self.fam2rel = defaultdict(list)
self.persons = []
self.persons2idx = {}
self.idx2persons = {}
self.person2family = {}
self.cache = {}
self.relationships = self._generate_relationships()

def _generate_relationships(self):
relationships = []
unique_relations = set()
persons = []
for sample_idx, sample in enumerate(self.sample_list):
# Only consider training set, therefore only positive samples
relation = (sample.f1fid,) + tuple(sorted([sample.f1mid, sample.f2mid]))
persons.append(sample.f1)
persons.append(sample.f2)
if relation not in unique_relations:
unique_relations.add(relation)
person1_images = list(Path(self.root_dir, self.images_dir, sample.f1).parent.glob("*.jpg"))
person2_images = list(Path(self.root_dir, self.images_dir, sample.f2).parent.glob("*.jpg"))
# Filter path relative to images_dir
person1_images = [str(img.relative_to(self.root_dir / self.images_dir)) for img in person1_images]
person2_images = [str(img.relative_to(self.root_dir / self.images_dir)) for img in person2_images]
# Filter relative to self.filepaths; apparently some images are missing in the train.csv
person1_images = [person for person in person1_images if person in self.filepaths]
person2_images = [person for person in person2_images if person in self.filepaths]
if person1_images and person2_images:
labels = (sample.f1mid, sample.f2mid, sample.f1fid)
relationships.append((person1_images, person2_images, labels))
self.fam2rel[sample.f1fid].append(len(relationships) - 1)

self.person2family = {person: int(person.split("/")[2][1:]) for person in persons}
self.persons = sorted(list(set(persons)))
self.persons2idx = {person: idx for idx, person in enumerate(self.persons)}
self.idx2persons = {idx: person for person, idx in self.persons2idx.items()}

print(f"Generated {len(relationships)} relationships")
print(f"Found {len(self.persons)} unique persons")

return relationships

def _process_one_image(self, image_path):
if image_path in self.cache:
return self.cache[image_path]
image = super()._process_one_image(image_path)
self.cache[image_path] = image
return image

def __getitem__(self, idx: list[tuple[int, int]]):
img1_idx, img2_idx = list(zip(*idx))
imgs1, imgs2 = [self.persons[idx] for idx in img1_idx], [self.persons[idx] for idx in img2_idx]
is_kin = [
int(self.person2family[person1] == self.person2family[person2]) for person1, person2 in zip(imgs1, imgs2)
]
imgs1 = [self._process_one_image(img) for img in imgs1]
imgs2 = [self._process_one_image(img) for img in imgs2]
sample = (imgs1, imgs2, is_kin) # collate!
return sample

def __len__(self):
return len(self.relationships)


class FIWPairs(FIW):

# FaCoRNet dataset
Expand Down
163 changes: 163 additions & 0 deletions ours/datasets/sclff.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
import random
from collections import defaultdict
from pathlib import Path

import lightning as L
from datasets.fiw import FIWFamilyV3
from datasets.utils import collate_fn_fiw_family_v3
from torch.utils.data import DataLoader
from torchvision import transforms as T

from datasets.facornet import FIWFaCoRNet


class KinshipBatchSampler:
def __init__(self, dataset, batch_size):
self.dataset = dataset
self.batch_size = batch_size
self.image_counters = defaultdict(int)
self.indices = list(range(len(self.dataset)))
self._shuffle_indices()

def _shuffle_indices(self):
random.shuffle(self.indices)

def _get_image_with_min_count(self, person_images):
min_count_image = min(person_images, key=lambda person: self.image_counters[person])
return min_count_image

def _replace_duplicates(self, sub_batch):
family_counts = defaultdict(int)
for pair in sub_batch:
fam = pair[2][2] # Label, Family ID
family_counts[fam] += 1

while any(count > 1 for count in family_counts.values()):
# print(f"Family counts: {family_counts}")
for i in range(len(sub_batch)):
current_fam = sub_batch[i][2][2]
if family_counts[current_fam] > 1:
# print(f"Checking pair {i + 1}: {current_fam}")
while (replacement_fam := random.choice(list(self.dataset.fam2rel.keys()))) in family_counts:
pass
replacement_pair_idx = random.choice(self.dataset.fam2rel[replacement_fam])
replacement_pair = self.dataset.relationships[replacement_pair_idx]
sub_batch[i] = replacement_pair
family_counts[current_fam] -= 1
family_counts[replacement_fam] += 1
# print(f"Replaced pair {i + 1} with a new pair")
return sub_batch

def __iter__(self):
for i in range(0, len(self.indices), self.batch_size):
sub_batch_indices = self.indices[i : i + self.batch_size]
sub_batch = [self.dataset.relationships[idx] for idx in sub_batch_indices]
sub_batch = self._replace_duplicates(sub_batch)
batch = []

for pair in sub_batch:
imgs1, imgs2, _ = pair
img1 = self._get_image_with_min_count(imgs1)
img2 = self._get_image_with_min_count(imgs2)
img1_id = self.dataset.persons2idx[img1]
img2_id = self.dataset.persons2idx[img2]
self.image_counters[img1] += 1
self.image_counters[img2] += 1
batch.append((img1_id, img2_id))

yield batch

def __len__(self):
return len(self.dataset) // self.batch_size


# Example usage:
# Assuming dataset is an instance of a Dataset class where __getitem__ returns (img1, img2, labels)
# batch_size = 32
# sampler = KinshipBatchSampler(dataset, batch_size)
# data_loader = torch.utils.data.DataLoader(dataset, batch_sampler=sampler)


class SCLFFDataModule(L.LightningDataModule):

def __init__(self, batch_size=20, root_dir=".", transform=None):
super().__init__()
self.batch_size = batch_size
self.root_dir = root_dir
self.transform = transform or T.Compose([T.ToTensor()])

def setup(self, stage=None):
if stage == "fit" or stage is None:
self.train_dataset = FIWFamilyV3(
root_dir=self.root_dir,
sample_path=Path(FIWFamilyV3.TRAIN_PAIRS),
batch_size=self.batch_size,
transform=self.transform,
)
self.val_dataset = FIWFaCoRNet(
root_dir=self.root_dir,
sample_path=Path(FIWFaCoRNet.VAL_PAIRS_MODEL_SEL),
batch_size=self.batch_size,
transform=self.transform,
)
if stage == "validate" or stage is None:
self.val_dataset = FIWFaCoRNet(
root_dir=self.root_dir,
sample_path=Path(FIWFaCoRNet.VAL_PAIRS_THRES_SEL),
batch_size=self.batch_size,
transform=self.transform,
)
if stage == "test" or stage is None:
self.test_dataset = FIWFaCoRNet(
root_dir=self.root_dir,
sample_path=Path(FIWFaCoRNet.TEST_PAIRS),
batch_size=self.batch_size,
transform=self.transform,
)
print(f"Setup {stage} datasets")

def train_dataloader(self):
sampler = KinshipBatchSampler(self.train_dataset, self.batch_size)
return DataLoader(
self.train_dataset,
num_workers=1,
pin_memory=True,
persistent_workers=True,
sampler=sampler,
collate_fn=collate_fn_fiw_family_v3,
)

def val_dataloader(self):
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=4,
pin_memory=True,
persistent_workers=True,
)

def test_dataloader(self):
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=4,
pin_memory=True,
persistent_workers=True,
)


if __name__ == "__main__":
batch_size = 20
dm = SCLFFDataModule(
dataset="fiw",
batch_size=batch_size,
root_dir="../datasets/facornet",
)
dm.setup("fit")
data_loader = dm.train_dataloader()

# Iterate over DataLoader
for i, batch in enumerate(data_loader):
print(f"Batch {i + 1}", batch[0].shape, batch[1].shape)
22 changes: 22 additions & 0 deletions ours/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,28 @@ def sr_collate_fn_v2(batch):
return (unique_probe_id, probe_images), (gallery_indexes, gallery_images)


def collate_fn_fiw_family_v3(batch):
imgs1_batch = [item[0] for item in batch]
imgs2_batch = [item[1] for item in batch]
is_kin = [item[2] for item in batch]

# Flatten the list of lists into a single list of tensors
imgs1_flat = [img for imgs in imgs1_batch for img in imgs]
imgs2_flat = [img for imgs in imgs2_batch for img in imgs]
is_kin_flat = [label for labels in is_kin for label in labels]

# Stack tensors along the batch dimension
imgs1_tensor = torch.stack(imgs1_flat)
imgs2_tensor = torch.stack(imgs2_flat)
is_kin_tensor = torch.tensor(is_kin_flat)

return imgs1_tensor, imgs2_tensor, is_kin_tensor


# Example usage:
# dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_fn)


class Sample:
# TODO: move to utils.py
NAME2LABEL = {
Expand Down
Loading

0 comments on commit 0269c52

Please sign in to comment.