Skip to content

Commit

Permalink
add facornet test task (#71)
Browse files Browse the repository at this point in the history
  • Loading branch information
vitalwarley committed Mar 19, 2024
1 parent 22bb0cf commit ba916ce
Showing 1 changed file with 81 additions and 26 deletions.
107 changes: 81 additions & 26 deletions ours/tasks/facornet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import torch
import torchmetrics as tm
from datasets.utils import Sample
from losses import facornet_contrastive_loss
from models.facornet import FaCoR
from torch.utils.data import DataLoader
Expand All @@ -13,52 +14,68 @@
from datasets.facornet import FIWFaCoRNet as FIW


def acc_kr_to_str(out, acc_kr):
# Add acc_kr to out
id2name = {v: k for k, v in Sample.NAME2LABEL.items()}
for kin_id, acc in acc_kr.items():
kr = id2name[kin_id]
out += f" | acc_{kr}: {acc:.6f}"
return out


@torch.no_grad()
def predict(model, val_loader, device: int | str = 0) -> tuple[torch.Tensor, torch.Tensor]:
def predict(model, val_loader, device: int | str = 0) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
dataset_size = len(val_loader.dataset)
# Preallocate tensors based on the total dataset size
similarities = torch.zeros(dataset_size, device=device)
y_true = torch.zeros(dataset_size, dtype=torch.uint8, device=device)
# y_true_kin_relations = torch.zeros(dataset_size, dtype=torch.long, device=device)
y_true_kin_relations = torch.zeros(dataset_size, dtype=torch.uint8, device=device)

current_index = 0
for img1, img2, labels in tqdm(val_loader, total=len(val_loader), bar_format=TQDM_BAR_FORMAT):
batch_size_current = img1.size(0) # Handle last batch potentially being smaller
img1, img2 = img1.to(device), img2.to(device)
(kin_relation, is_kin) = labels
# kin_relation, is_kin = kin_relation.to(device), is_kin.to(device)
kin_relation, is_kin = kin_relation.to(device), is_kin.to(device)

f1, f2, _ = model([img1, img2])
sim = torch.cosine_similarity(f1, f2)

# Fill preallocated tensors
similarities[current_index : current_index + batch_size_current] = sim
y_true[current_index : current_index + batch_size_current] = is_kin
# y_true_kin_relations[current_index:current_index + batch_size_current] = kin_relation
y_true_kin_relations[current_index : current_index + batch_size_current] = kin_relation

current_index += batch_size_current

return similarities, y_true
return similarities, y_true, y_true_kin_relations


def validate(model, dataloader, device=0):
def validate(model, dataloader, device=0, threshold=None):
model.eval()
# Compute similarities
similarities, y_true = predict(model, dataloader)
similarities, y_true, y_true_kin_relations = predict(model, dataloader)
# Compute metrics
auc = tm.functional.auroc(similarities, y_true, task="binary")
fpr, tpr, thresholds = tm.functional.roc(similarities, y_true, task="binary")
# Get the best threshold
maxindex = (tpr - fpr).argmax()
threshold = thresholds[maxindex]
if threshold.isnan().item():
threshold = 0.01
else:
threshold = threshold.item()
if threshold is None:
# Get the best threshold
maxindex = (tpr - fpr).argmax()
threshold = thresholds[maxindex]
if threshold.isnan().item():
threshold = 0.01
else:
threshold = threshold.item()
# Compute acc
acc_metric = tm.Accuracy(task="binary", threshold=threshold).to(device)
acc = acc_metric(similarities, y_true)
return auc, threshold, acc
acc = tm.functional.accuracy(similarities, y_true, task="binary", threshold=threshold)
# Compute accuracy with respect to kinship relations
acc_kin_relations = {}
for kin_relation in Sample.NAME2LABEL.values():
mask = y_true_kin_relations == kin_relation
acc_kin_relations[kin_relation] = tm.functional.accuracy(
similarities[mask], y_true[mask], task="binary", threshold=threshold
)
return auc, threshold, acc, acc_kin_relations


def train(args):
Expand Down Expand Up @@ -97,8 +114,12 @@ def train(args):
total_steps = len(train_loader)
print(f"Total steps: {total_steps}")
global_step = 0
best_model_auc, _, val_acc = validate(model, val_model_sel_loader)
print(f"epoch: 0 | auc: {best_model_auc:.6f} | acc: {val_acc:.6f}")
best_model_auc, _, val_acc, acc_kr = validate(model, val_model_sel_loader)
out = f"epoch: 0 | auc: {best_model_auc:.6f} | acc: {val_acc:.6f}"
# Add acc_kr to out
for kin_relation, acc in acc_kr.items():
out += f" | acc_{kin_relation}: {acc:.6f}"
print(out)

for epoch in range(args.num_epoch):
model.train()
Expand Down Expand Up @@ -130,16 +151,18 @@ def train(args):
train_dataset.set_bias(use_sample)

# Save model checkpoints
auc, _, val_acc = validate(model, val_model_sel_loader)
auc, _, val_acc, acc_kr = validate(model, val_model_sel_loader)

if auc > best_model_auc:
best_model_auc = auc
torch.save(model.state_dict(), args.output_dir / "best.pth")

print(
out = (
f"epoch: {epoch + 1:>2} | step: {global_step} "
+ f"| loss: {loss_epoch / args.steps_per_epoch:.3f} | auc: {auc:.6f} | acc: {val_acc:.6f}"
)
out = acc_kr_to_str(out, acc_kr)
print(out)


def val(args):
Expand All @@ -150,17 +173,38 @@ def val(args):
]
)

val_model_sel_dataset = FIW(root_dir=args.root_dir, sample_path=Path(FIW.VAL_PAIRS_THRES_SEL), transform=transform)
val_model_sel_loader = DataLoader(
val_model_sel_dataset, batch_size=args.batch_size, num_workers=0, pin_memory=True, shuffle=False
dataset = FIW(root_dir=args.root_dir, sample_path=Path(FIW.VAL_PAIRS_THRES_SEL), transform=transform)
dataloader = DataLoader(dataset, batch_size=args.batch_size, num_workers=0, pin_memory=True, shuffle=False)

model = FaCoR()
model.load_state_dict(torch.load(args.weights))
model.to(args.device)

auc, threshold, val_acc, acc_kr = validate(model, dataloader)
out = f"auc: {auc:.6f} | acc: {val_acc:.6f} | threshold: {threshold}"
out = acc_kr_to_str(out, acc_kr)
print(out)


def test(args):

transform = transforms.Compose(
[
transforms.ToTensor(),
]
)

dataset = FIW(root_dir=args.root_dir, sample_path=Path(FIW.TEST_PAIRS), transform=transform)
dataloader = DataLoader(dataset, batch_size=args.batch_size, num_workers=0, pin_memory=True, shuffle=False)

model = FaCoR()
model.load_state_dict(torch.load(args.weights))
model.to(args.device)

auc, threshold, val_acc = validate(model, val_model_sel_loader)
print(f"auc: {auc:.6f} | acc: {val_acc:.6f} | threshold: {threshold}")
auc, threshold, val_acc, acc_kr = validate(model, dataloader, threshold=args.threshold)
out = f"auc: {auc:.6f} | acc: {val_acc:.6f} | threshold: {threshold}"
out = acc_kr_to_str(out, acc_kr)
print(out)


def create_parser_train(subparsers):
Expand All @@ -187,12 +231,23 @@ def create_parser_val(subparsers):
parser.set_defaults(func=val)


def create_parser_test(subparsers):
parser = subparsers.add_parser("test", help="Test the model")
parser.add_argument("--weights", type=str, required=True)
parser.add_argument("--root-dir", type=str, required=True)
parser.add_argument("--batch-size", type=int, default=100, help="Batch size")
parser.add_argument("--threshold", type=float, required=True)
parser.add_argument("--device", type=str, default="0", help="Device to use for training")
parser.set_defaults(func=test)


if __name__ == "__main__":

parser = ArgumentParser(description="Configuration for the FaCoRNet strategy")
subparsers = parser.add_subparsers()
create_parser_train(subparsers)
create_parser_val(subparsers)
create_parser_test(subparsers)
args = parser.parse_args()

# Necessary for dataloaders?
Expand Down

0 comments on commit ba916ce

Please sign in to comment.