From ba916ced812c418d2578dccb86e1950969db9752 Mon Sep 17 00:00:00 2001 From: Warley Vital Barbosa Date: Tue, 19 Mar 2024 18:04:18 -0300 Subject: [PATCH] add facornet test task (#71) --- ours/tasks/facornet.py | 107 +++++++++++++++++++++++++++++++---------- 1 file changed, 81 insertions(+), 26 deletions(-) diff --git a/ours/tasks/facornet.py b/ours/tasks/facornet.py index 19e8901..a1f5d30 100644 --- a/ours/tasks/facornet.py +++ b/ours/tasks/facornet.py @@ -3,6 +3,7 @@ import torch import torchmetrics as tm +from datasets.utils import Sample from losses import facornet_contrastive_loss from models.facornet import FaCoR from torch.utils.data import DataLoader @@ -13,20 +14,29 @@ from datasets.facornet import FIWFaCoRNet as FIW +def acc_kr_to_str(out, acc_kr): + # Add acc_kr to out + id2name = {v: k for k, v in Sample.NAME2LABEL.items()} + for kin_id, acc in acc_kr.items(): + kr = id2name[kin_id] + out += f" | acc_{kr}: {acc:.6f}" + return out + + @torch.no_grad() -def predict(model, val_loader, device: int | str = 0) -> tuple[torch.Tensor, torch.Tensor]: +def predict(model, val_loader, device: int | str = 0) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: dataset_size = len(val_loader.dataset) # Preallocate tensors based on the total dataset size similarities = torch.zeros(dataset_size, device=device) y_true = torch.zeros(dataset_size, dtype=torch.uint8, device=device) - # y_true_kin_relations = torch.zeros(dataset_size, dtype=torch.long, device=device) + y_true_kin_relations = torch.zeros(dataset_size, dtype=torch.uint8, device=device) current_index = 0 for img1, img2, labels in tqdm(val_loader, total=len(val_loader), bar_format=TQDM_BAR_FORMAT): batch_size_current = img1.size(0) # Handle last batch potentially being smaller img1, img2 = img1.to(device), img2.to(device) (kin_relation, is_kin) = labels - # kin_relation, is_kin = kin_relation.to(device), is_kin.to(device) + kin_relation, is_kin = kin_relation.to(device), is_kin.to(device) f1, f2, _ = model([img1, img2]) sim = torch.cosine_similarity(f1, f2) @@ -34,31 +44,38 @@ def predict(model, val_loader, device: int | str = 0) -> tuple[torch.Tensor, tor # Fill preallocated tensors similarities[current_index : current_index + batch_size_current] = sim y_true[current_index : current_index + batch_size_current] = is_kin - # y_true_kin_relations[current_index:current_index + batch_size_current] = kin_relation + y_true_kin_relations[current_index : current_index + batch_size_current] = kin_relation current_index += batch_size_current - return similarities, y_true + return similarities, y_true, y_true_kin_relations -def validate(model, dataloader, device=0): +def validate(model, dataloader, device=0, threshold=None): model.eval() # Compute similarities - similarities, y_true = predict(model, dataloader) + similarities, y_true, y_true_kin_relations = predict(model, dataloader) # Compute metrics auc = tm.functional.auroc(similarities, y_true, task="binary") fpr, tpr, thresholds = tm.functional.roc(similarities, y_true, task="binary") - # Get the best threshold - maxindex = (tpr - fpr).argmax() - threshold = thresholds[maxindex] - if threshold.isnan().item(): - threshold = 0.01 - else: - threshold = threshold.item() + if threshold is None: + # Get the best threshold + maxindex = (tpr - fpr).argmax() + threshold = thresholds[maxindex] + if threshold.isnan().item(): + threshold = 0.01 + else: + threshold = threshold.item() # Compute acc - acc_metric = tm.Accuracy(task="binary", threshold=threshold).to(device) - acc = acc_metric(similarities, y_true) - return auc, threshold, acc + acc = tm.functional.accuracy(similarities, y_true, task="binary", threshold=threshold) + # Compute accuracy with respect to kinship relations + acc_kin_relations = {} + for kin_relation in Sample.NAME2LABEL.values(): + mask = y_true_kin_relations == kin_relation + acc_kin_relations[kin_relation] = tm.functional.accuracy( + similarities[mask], y_true[mask], task="binary", threshold=threshold + ) + return auc, threshold, acc, acc_kin_relations def train(args): @@ -97,8 +114,12 @@ def train(args): total_steps = len(train_loader) print(f"Total steps: {total_steps}") global_step = 0 - best_model_auc, _, val_acc = validate(model, val_model_sel_loader) - print(f"epoch: 0 | auc: {best_model_auc:.6f} | acc: {val_acc:.6f}") + best_model_auc, _, val_acc, acc_kr = validate(model, val_model_sel_loader) + out = f"epoch: 0 | auc: {best_model_auc:.6f} | acc: {val_acc:.6f}" + # Add acc_kr to out + for kin_relation, acc in acc_kr.items(): + out += f" | acc_{kin_relation}: {acc:.6f}" + print(out) for epoch in range(args.num_epoch): model.train() @@ -130,16 +151,18 @@ def train(args): train_dataset.set_bias(use_sample) # Save model checkpoints - auc, _, val_acc = validate(model, val_model_sel_loader) + auc, _, val_acc, acc_kr = validate(model, val_model_sel_loader) if auc > best_model_auc: best_model_auc = auc torch.save(model.state_dict(), args.output_dir / "best.pth") - print( + out = ( f"epoch: {epoch + 1:>2} | step: {global_step} " + f"| loss: {loss_epoch / args.steps_per_epoch:.3f} | auc: {auc:.6f} | acc: {val_acc:.6f}" ) + out = acc_kr_to_str(out, acc_kr) + print(out) def val(args): @@ -150,17 +173,38 @@ def val(args): ] ) - val_model_sel_dataset = FIW(root_dir=args.root_dir, sample_path=Path(FIW.VAL_PAIRS_THRES_SEL), transform=transform) - val_model_sel_loader = DataLoader( - val_model_sel_dataset, batch_size=args.batch_size, num_workers=0, pin_memory=True, shuffle=False + dataset = FIW(root_dir=args.root_dir, sample_path=Path(FIW.VAL_PAIRS_THRES_SEL), transform=transform) + dataloader = DataLoader(dataset, batch_size=args.batch_size, num_workers=0, pin_memory=True, shuffle=False) + + model = FaCoR() + model.load_state_dict(torch.load(args.weights)) + model.to(args.device) + + auc, threshold, val_acc, acc_kr = validate(model, dataloader) + out = f"auc: {auc:.6f} | acc: {val_acc:.6f} | threshold: {threshold}" + out = acc_kr_to_str(out, acc_kr) + print(out) + + +def test(args): + + transform = transforms.Compose( + [ + transforms.ToTensor(), + ] ) + dataset = FIW(root_dir=args.root_dir, sample_path=Path(FIW.TEST_PAIRS), transform=transform) + dataloader = DataLoader(dataset, batch_size=args.batch_size, num_workers=0, pin_memory=True, shuffle=False) + model = FaCoR() model.load_state_dict(torch.load(args.weights)) model.to(args.device) - auc, threshold, val_acc = validate(model, val_model_sel_loader) - print(f"auc: {auc:.6f} | acc: {val_acc:.6f} | threshold: {threshold}") + auc, threshold, val_acc, acc_kr = validate(model, dataloader, threshold=args.threshold) + out = f"auc: {auc:.6f} | acc: {val_acc:.6f} | threshold: {threshold}" + out = acc_kr_to_str(out, acc_kr) + print(out) def create_parser_train(subparsers): @@ -187,12 +231,23 @@ def create_parser_val(subparsers): parser.set_defaults(func=val) +def create_parser_test(subparsers): + parser = subparsers.add_parser("test", help="Test the model") + parser.add_argument("--weights", type=str, required=True) + parser.add_argument("--root-dir", type=str, required=True) + parser.add_argument("--batch-size", type=int, default=100, help="Batch size") + parser.add_argument("--threshold", type=float, required=True) + parser.add_argument("--device", type=str, default="0", help="Device to use for training") + parser.set_defaults(func=test) + + if __name__ == "__main__": parser = ArgumentParser(description="Configuration for the FaCoRNet strategy") subparsers = parser.add_subparsers() create_parser_train(subparsers) create_parser_val(subparsers) + create_parser_test(subparsers) args = parser.parse_args() # Necessary for dataloaders?