Skip to content

Commit

Permalink
add facornet dataset, model and training task (#71)
Browse files Browse the repository at this point in the history
  • Loading branch information
vitalwarley committed Mar 19, 2024
1 parent c15f580 commit 01caa6b
Show file tree
Hide file tree
Showing 7 changed files with 885 additions and 53 deletions.
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
repos:
- repo: https://github.com/mwouts/jupytext
rev: v1.15.2
rev: v1.16.1
hooks:
- id: jupytext
args: [--sync, --pipe, black]
- repo: https://github.com/psf/black
rev: 23.10.1
rev: 24.3.0
hooks:
- id: black
- repo: https://github.com/pycqa/flake8
rev: 6.1.0
rev: 7.0.0
hooks:
- id: flake8
additional_dependencies: [Flake8-pyproject]
- repo: https://github.com/pycqa/isort
rev: 5.12.0
rev: 5.13.2
hooks:
- id: isort
name: isort (python)
53 changes: 4 additions & 49 deletions ours/datasets/zhang.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,6 @@
from pathlib import Path
from fiw import FIW

import cv2
import torch
from torch.utils.data import Dataset
from utils import Sample


class FIW(Dataset):
def __init__(self, root_dir, sample_path, transform=None):
self.root_dir = root_dir
self.sample_path = sample_path
self.transform = transform
self.bias = 0
self.sample_list = self.load_sample()
print(f"Loaded {len(self.sample_list)} samples from {sample_path}")

def load_sample(self):
sample_list = []
lines = Path(self.sample_path).read_text().strip().split("\n")
for line in lines:
if len(line) < 1:
continue
tmp = line.split(" ")
sample = Sample(tmp[0], tmp[1], tmp[2], tmp[-2], tmp[-1])
sample_list.append(sample)
return sample_list

def __len__(self):
return len(self.sample_list)

def read_image(self, path):
# TODO: add to utils.py
img = cv2.imread(f"{self.root_dir}/{path}")
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (112, 112))
return img

def set_bias(self, bias):
self.bias = bias

def __getitem__(self, item):
# id, f1, f2, kin_relation, is_kin
sample = self.sample_list[item + self.bias]
img1, img2 = self.read_image(sample.f1), self.read_image(sample.f2)
if self.transform is not None:
img1, img2 = self.transform(img1), self.transform(img2)
is_kin = torch.tensor(int(sample.is_kin))
kin_id = Sample.NAME2LABEL[sample.kin_relation] if is_kin else 0
labels = (kin_id, is_kin)
return img1, img2, labels
class FIWSCL(FIW):
def __init__(self, **kwargs):
super().__init__(**kwargs)
23 changes: 23 additions & 0 deletions ours/guild.yml
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,26 @@
target-type: link
- operation: mtcf
select: exp

- model: facornet
operations:
train:
description: Reproduction of Kinship Representation Learning with Face Componential Relation (2023)
main: tasks.facornet train
sourcecode:
- utils.py
- losses.py
- models/facornet.py
- datasets/fiw.py
- datasets/facornet.py
- datasets/utils.py
- tasks/facornet.py
flags-import: all
flags-dest: args
output-scalars: '(\key):\s+(\value)'
requires:
- file: weights
target-type: link
- file: ../datasets/
target-type: link
rename: data
14 changes: 14 additions & 0 deletions ours/losses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import torch


def facornet_contrastive_loss(x1, x2, beta=0.08):
m = 0.0
x1x2 = torch.cat([x1, x2], dim=0)
x2x1 = torch.cat([x2, x1], dim=0)
beta = (beta**2).sum([1, 2]) / 500
beta = torch.cat([beta, beta]).reshape(-1)
cosine_mat = torch.cosine_similarity(torch.unsqueeze(x1x2, dim=1), torch.unsqueeze(x1x2, dim=0), dim=2) / (beta + m)
mask = 1.0 - torch.eye(2 * x1.size(0)).to(x1.device)
numerators = torch.exp(torch.cosine_similarity(x1x2, x2x1, dim=1) / (beta + m))
denominators = torch.sum(torch.exp(cosine_mat) * mask, dim=1)
return -torch.mean(torch.log(numerators / denominators), dim=0)
Loading

0 comments on commit 01caa6b

Please sign in to comment.