-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
686b227
commit 0269c52
Showing
6 changed files
with
326 additions
and
26 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -134,7 +134,6 @@ insightface/ | |
weights/ | ||
outputs/ | ||
plots/ | ||
datasets/ | ||
**/models/*.pth | ||
lightning_logs/ | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
import random | ||
from collections import defaultdict | ||
from pathlib import Path | ||
|
||
import lightning as L | ||
from datasets.fiw import FIWFamilyV3 | ||
from datasets.utils import collate_fn_fiw_family_v3 | ||
from torch.utils.data import DataLoader | ||
from torchvision import transforms as T | ||
|
||
from datasets.facornet import FIWFaCoRNet | ||
|
||
|
||
class KinshipBatchSampler: | ||
def __init__(self, dataset, batch_size): | ||
self.dataset = dataset | ||
self.batch_size = batch_size | ||
self.image_counters = defaultdict(int) | ||
self.indices = list(range(len(self.dataset))) | ||
self._shuffle_indices() | ||
|
||
def _shuffle_indices(self): | ||
random.shuffle(self.indices) | ||
|
||
def _get_image_with_min_count(self, person_images): | ||
min_count_image = min(person_images, key=lambda person: self.image_counters[person]) | ||
return min_count_image | ||
|
||
def _replace_duplicates(self, sub_batch): | ||
family_counts = defaultdict(int) | ||
for pair in sub_batch: | ||
fam = pair[2][2] # Label, Family ID | ||
family_counts[fam] += 1 | ||
|
||
while any(count > 1 for count in family_counts.values()): | ||
# print(f"Family counts: {family_counts}") | ||
for i in range(len(sub_batch)): | ||
current_fam = sub_batch[i][2][2] | ||
if family_counts[current_fam] > 1: | ||
# print(f"Checking pair {i + 1}: {current_fam}") | ||
while (replacement_fam := random.choice(list(self.dataset.fam2rel.keys()))) in family_counts: | ||
pass | ||
replacement_pair_idx = random.choice(self.dataset.fam2rel[replacement_fam]) | ||
replacement_pair = self.dataset.relationships[replacement_pair_idx] | ||
sub_batch[i] = replacement_pair | ||
family_counts[current_fam] -= 1 | ||
family_counts[replacement_fam] += 1 | ||
# print(f"Replaced pair {i + 1} with a new pair") | ||
return sub_batch | ||
|
||
def __iter__(self): | ||
for i in range(0, len(self.indices), self.batch_size): | ||
sub_batch_indices = self.indices[i : i + self.batch_size] | ||
sub_batch = [self.dataset.relationships[idx] for idx in sub_batch_indices] | ||
sub_batch = self._replace_duplicates(sub_batch) | ||
batch = [] | ||
|
||
for pair in sub_batch: | ||
imgs1, imgs2, _ = pair | ||
img1 = self._get_image_with_min_count(imgs1) | ||
img2 = self._get_image_with_min_count(imgs2) | ||
img1_id = self.dataset.persons2idx[img1] | ||
img2_id = self.dataset.persons2idx[img2] | ||
self.image_counters[img1] += 1 | ||
self.image_counters[img2] += 1 | ||
batch.append((img1_id, img2_id)) | ||
|
||
yield batch | ||
|
||
def __len__(self): | ||
return len(self.dataset) // self.batch_size | ||
|
||
|
||
# Example usage: | ||
# Assuming dataset is an instance of a Dataset class where __getitem__ returns (img1, img2, labels) | ||
# batch_size = 32 | ||
# sampler = KinshipBatchSampler(dataset, batch_size) | ||
# data_loader = torch.utils.data.DataLoader(dataset, batch_sampler=sampler) | ||
|
||
|
||
class SCLFFDataModule(L.LightningDataModule): | ||
|
||
def __init__(self, batch_size=20, root_dir=".", transform=None): | ||
super().__init__() | ||
self.batch_size = batch_size | ||
self.root_dir = root_dir | ||
self.transform = transform or T.Compose([T.ToTensor()]) | ||
|
||
def setup(self, stage=None): | ||
if stage == "fit" or stage is None: | ||
self.train_dataset = FIWFamilyV3( | ||
root_dir=self.root_dir, | ||
sample_path=Path(FIWFamilyV3.TRAIN_PAIRS), | ||
batch_size=self.batch_size, | ||
transform=self.transform, | ||
) | ||
self.val_dataset = FIWFaCoRNet( | ||
root_dir=self.root_dir, | ||
sample_path=Path(FIWFaCoRNet.VAL_PAIRS_MODEL_SEL), | ||
batch_size=self.batch_size, | ||
transform=self.transform, | ||
) | ||
if stage == "validate" or stage is None: | ||
self.val_dataset = FIWFaCoRNet( | ||
root_dir=self.root_dir, | ||
sample_path=Path(FIWFaCoRNet.VAL_PAIRS_THRES_SEL), | ||
batch_size=self.batch_size, | ||
transform=self.transform, | ||
) | ||
if stage == "test" or stage is None: | ||
self.test_dataset = FIWFaCoRNet( | ||
root_dir=self.root_dir, | ||
sample_path=Path(FIWFaCoRNet.TEST_PAIRS), | ||
batch_size=self.batch_size, | ||
transform=self.transform, | ||
) | ||
print(f"Setup {stage} datasets") | ||
|
||
def train_dataloader(self): | ||
sampler = KinshipBatchSampler(self.train_dataset, self.batch_size) | ||
return DataLoader( | ||
self.train_dataset, | ||
num_workers=1, | ||
pin_memory=True, | ||
persistent_workers=True, | ||
sampler=sampler, | ||
collate_fn=collate_fn_fiw_family_v3, | ||
) | ||
|
||
def val_dataloader(self): | ||
return DataLoader( | ||
self.val_dataset, | ||
batch_size=self.batch_size, | ||
shuffle=False, | ||
num_workers=4, | ||
pin_memory=True, | ||
persistent_workers=True, | ||
) | ||
|
||
def test_dataloader(self): | ||
return DataLoader( | ||
self.test_dataset, | ||
batch_size=self.batch_size, | ||
shuffle=False, | ||
num_workers=4, | ||
pin_memory=True, | ||
persistent_workers=True, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
batch_size = 20 | ||
dm = SCLFFDataModule( | ||
dataset="fiw", | ||
batch_size=batch_size, | ||
root_dir="../datasets/facornet", | ||
) | ||
dm.setup("fit") | ||
data_loader = dm.train_dataloader() | ||
|
||
# Iterate over DataLoader | ||
for i, batch in enumerate(data_loader): | ||
print(f"Batch {i + 1}", batch[0].shape, batch[1].shape) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.