Skip to content

Commit

Permalink
add facornet val task (#71)
Browse files Browse the repository at this point in the history
- Refactor predict to fix memory leak caused by appending tensors to a
  python list.
  • Loading branch information
vitalwarley committed Mar 19, 2024
1 parent 01caa6b commit 22bb0cf
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 39 deletions.
44 changes: 44 additions & 0 deletions ours/guild.yml
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,47 @@
- file: ../datasets/
target-type: link
rename: data
val:
description: Reproduction of Kinship Representation Learning with Face Componential Relation (2023)
main: tasks.facornet val
sourcecode:
- utils.py
- losses.py
- models/facornet.py
- datasets/fiw.py
- datasets/facornet.py
- datasets/utils.py
- tasks/facornet.py
flags-import: all
flags-dest: args
output-scalars: '(\key):\s+(\value)'
requires:
- file: weights
target-type: link
- file: ../datasets/
target-type: link
rename: data
- operation: facornet:train
select: exp
test:
description: Reproduction of Kinship Representation Learning with Face Componential Relation (2023)
main: tasks.facornet test
sourcecode:
- utils.py
- losses.py
- models/facornet.py
- datasets/fiw.py
- datasets/facornet.py
- datasets/utils.py
- tasks/facornet.py
flags-import: all
flags-dest: args
output-scalars: '(\key):\s+(\value)'
requires:
- file: weights
target-type: link
- file: ../datasets/
target-type: link
rename: data
- operation: facornet:train
select: exp
109 changes: 70 additions & 39 deletions ours/tasks/facornet.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,30 +13,30 @@
from datasets.facornet import FIWFaCoRNet as FIW


# Validation loop
def predict(model, val_loader, device=0):
similarities = []
# logits_list = []
y_true = []
y_true_kin_relations = []

with torch.no_grad():
for i, (img1, img2, labels) in tqdm(enumerate(val_loader), total=len(val_loader), bar_format=TQDM_BAR_FORMAT):
# Transfer to GPU if available
img1, img2 = img1.to(device), img2.to(device)
(kin_relation, is_kin) = labels
labels = (kin_relation.to(device), is_kin.to(device))

f1, f2, _ = model([img1, img2])
sim = torch.cosine_similarity(f1, f2).detach()
similarities.append(sim)
y_true.append(is_kin)
y_true_kin_relations.append(kin_relation)

# Concat
similarities = torch.concatenate(similarities)
y_true = torch.concatenate(y_true)
y_true_kin_relations = torch.concatenate(y_true_kin_relations)
@torch.no_grad()
def predict(model, val_loader, device: int | str = 0) -> tuple[torch.Tensor, torch.Tensor]:
dataset_size = len(val_loader.dataset)
# Preallocate tensors based on the total dataset size
similarities = torch.zeros(dataset_size, device=device)
y_true = torch.zeros(dataset_size, dtype=torch.uint8, device=device)
# y_true_kin_relations = torch.zeros(dataset_size, dtype=torch.long, device=device)

current_index = 0
for img1, img2, labels in tqdm(val_loader, total=len(val_loader), bar_format=TQDM_BAR_FORMAT):
batch_size_current = img1.size(0) # Handle last batch potentially being smaller
img1, img2 = img1.to(device), img2.to(device)
(kin_relation, is_kin) = labels
# kin_relation, is_kin = kin_relation.to(device), is_kin.to(device)

f1, f2, _ = model([img1, img2])
sim = torch.cosine_similarity(f1, f2)

# Fill preallocated tensors
similarities[current_index : current_index + batch_size_current] = sim
y_true[current_index : current_index + batch_size_current] = is_kin
# y_true_kin_relations[current_index:current_index + batch_size_current] = kin_relation

current_index += batch_size_current

return similarities, y_true

Expand All @@ -45,9 +45,6 @@ def validate(model, dataloader, device=0):
model.eval()
# Compute similarities
similarities, y_true = predict(model, dataloader)
# Move all to device
similarities = similarities.to(device)
y_true = y_true.to(device)
# Compute metrics
auc = tm.functional.auroc(similarities, y_true, task="binary")
fpr, tpr, thresholds = tm.functional.roc(similarities, y_true, task="binary")
Expand All @@ -65,6 +62,17 @@ def validate(model, dataloader, device=0):


def train(args):

set_seed(args.seed)

args.output_dir = Path(args.output_dir)
# Create output directory
args.output_dir.mkdir(parents=True, exist_ok=True)

# Write args to args.yaml
with open(args.output_dir / "args.yaml", "w") as f:
f.write(str(args))

# Define transformations for training and validation sets
# Did they mentioned augmentations?
transform = transforms.Compose(
Expand All @@ -85,7 +93,6 @@ def train(args):
model.to(args.device)

optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
# scheduler = MultiStepLR(optimizer, milestones=args.lr_steps, gamma=args.lr_factor)

total_steps = len(train_loader)
print(f"Total steps: {total_steps}")
Expand Down Expand Up @@ -135,6 +142,27 @@ def train(args):
)


def val(args):

transform = transforms.Compose(
[
transforms.ToTensor(),
]
)

val_model_sel_dataset = FIW(root_dir=args.root_dir, sample_path=Path(FIW.VAL_PAIRS_THRES_SEL), transform=transform)
val_model_sel_loader = DataLoader(
val_model_sel_dataset, batch_size=args.batch_size, num_workers=0, pin_memory=True, shuffle=False
)

model = FaCoR()
model.load_state_dict(torch.load(args.weights))
model.to(args.device)

auc, threshold, val_acc = validate(model, val_model_sel_loader)
print(f"auc: {auc:.6f} | acc: {val_acc:.6f} | threshold: {threshold}")


def create_parser_train(subparsers):
parser = subparsers.add_parser("train", help="Train the model")
parser.add_argument("--root-dir", type=str, required=True)
Expand All @@ -146,29 +174,32 @@ def create_parser_train(subparsers):
parser.add_argument("--momentum", type=float, default=0.9, help="Momentum")
parser.add_argument("--weight-decay", type=float, default=0, help="Weight decay")
parser.add_argument("--device", type=str, default="0", help="Device to use for training")
parser.add_argument("--seed", type=int, default=42, help="Seed for reproducibility")
parser.set_defaults(func=train)


def create_parser_val(subparsers):
parser = subparsers.add_parser("val", help="Select best threshold for the model")
parser.add_argument("--weights", type=str, required=True)
parser.add_argument("--root-dir", type=str, required=True)
parser.add_argument("--batch-size", type=int, default=100, help="Batch size")
parser.add_argument("--device", type=str, default="0", help="Device to use for training")
parser.set_defaults(func=val)


if __name__ == "__main__":
# Necessary for dataloaders?
torch.multiprocessing.set_start_method("spawn")
set_seed(42)

parser = ArgumentParser(description="Configuration for the FaCoRNet strategy")
subparsers = parser.add_subparsers()
create_parser_train(subparsers)
create_parser_val(subparsers)
args = parser.parse_args()

args.output_dir = Path(args.output_dir)
# Create output directory
args.output_dir.mkdir(parents=True, exist_ok=True)
# Necessary for dataloaders?
torch.multiprocessing.set_start_method("spawn")

print(args)

# Write args to args.yaml
with open(args.output_dir / "args.yaml", "w") as f:
f.write(str(args))

if torch.cuda.is_available():
args.device = torch.device(f"cuda:{args.device}")
current_device = torch.cuda.current_device()
Expand All @@ -178,4 +209,4 @@ def create_parser_train(subparsers):
else:
print("CUDA is not available.")

train(args)
args.func(args)
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ target-version = ['py311']
max-line-length = 120
exclude = [".git", "__pycache__", "dist"]
max-complexity = 10
ignore = "E203, W503"

[tool.isort]
atomic = true
Expand Down

0 comments on commit 22bb0cf

Please sign in to comment.