Skip to content

Commit

Permalink
adjust facornet training and datasets (#71)
Browse files Browse the repository at this point in the history
  • Loading branch information
vitalwarley committed Mar 28, 2024
1 parent d91c6d4 commit bd76555
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 30 deletions.
8 changes: 4 additions & 4 deletions ours/datasets/facornet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@

class FIWFaCoRNet(FIW):

TRAIN_PAIRS = "facornet/train_sort_A2_m.txt"
VAL_PAIRS_MODEL_SEL = "facornet/val_choose_A.txt"
VAL_PAIRS_THRES_SEL = "facornet/val_A.txt"
TEST_PAIRS = "facornet/test_A.txt"
TRAIN_PAIRS = "txt/train_sort_A2_m.txt"
VAL_PAIRS_MODEL_SEL = "txt/val_choose_A.txt"
VAL_PAIRS_THRES_SEL = "txt/val_A.txt"
TEST_PAIRS = "txt/test_A.txt"

# AdaFace uses BGR -- should I revert conversion read_image here?

Expand Down
6 changes: 3 additions & 3 deletions ours/datasets/fiw.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def load_sample(self):
for line in lines:
if len(line) < 1:
continue
# tmp = line.split(" ")
line = line.split(" ")
# sample = Sample(tmp[0], tmp[1], tmp[2], tmp[-2], tmp[-1])
# facornet
# id, f1, f2, kin, is_kin, sim -> train
Expand All @@ -37,7 +37,7 @@ def __len__(self):

def read_image(self, path):
# TODO: add to utils.py
img = cv2.imread(f"{self.root_dir}/{path}")
img = cv2.imread(f"{self.root_dir}/images/{path}")
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (112, 112))
return img
Expand All @@ -54,7 +54,7 @@ def _process_images(self, sample):
def _process_labels(self, sample):
is_kin = torch.tensor(sample.is_kin)
kin_id = self.sample_cls.NAME2LABEL[sample.kin_relation] if is_kin else 0
fid1, fid2 = int(sample.f1fid), int(sample.f2fid)
fid1, fid2 = int(sample.f1fid[1:]), int(sample.f2fid[1:])
labels = (kin_id, is_kin, fid1, fid2)
return labels

Expand Down
2 changes: 1 addition & 1 deletion ours/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class Sample:
"gmgs": 11,
}

def __init__(self, id: str, f1: str, f2: str, kin_relation: str, is_kin: str, **kwargs):
def __init__(self, id: str, f1: str, f2: str, kin_relation: str, is_kin: str, *args, **kwargs):
self.id = id
self.f1 = f1
self.f1fid = f1.split("/")[2]
Expand Down
9 changes: 5 additions & 4 deletions ours/models/facornet.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def __init__(self):
self.CCA = ChannelInteraction(1024)
self.avg_pool = nn.AdaptiveAvgPool2d(1)

self.task_kin = HeadKin(512, 12, 8)
# self.task_kin = HeadKin(512, 12, 8)

def forward(self, imgs, aug=False):
img1, img2 = imgs
Expand Down Expand Up @@ -121,10 +121,11 @@ def forward(self, imgs, aug=False):
f1s = torch.flatten(f1s, 1)
f2s = torch.flatten(f2s, 1)

fc = torch.cat([f1s, f2s], dim=1)
kin = self.task_kin(fc)
# fc = torch.cat([f1s, f2s], dim=1)
# kin = self.task_kin(fc)

return kin, f1s, f2s, att_map0
# return kin, f1s, f2s, att_map0
return f1s, f2s, att_map0


class SpatialCrossAttention(nn.Module):
Expand Down
44 changes: 26 additions & 18 deletions ours/tasks/facornet.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,32 +31,36 @@ def predict(model, val_loader, device: int | str = 0) -> tuple[torch.Tensor, tor
y_true = torch.zeros(dataset_size, dtype=torch.uint8, device=device)
y_true_kin_relations = torch.zeros(dataset_size, dtype=torch.uint8, device=device)
pred_kin_relations = torch.zeros(dataset_size, dtype=torch.uint8, device=device)
loss_values = torch.zeros(dataset_size, device=device)

current_index = 0
for img1, img2, labels in tqdm(val_loader, total=len(val_loader), bar_format=TQDM_BAR_FORMAT):
batch_size_current = img1.size(0) # Handle last batch potentially being smaller
img1, img2 = img1.to(device), img2.to(device)
(kin_relation, is_kin) = labels
(kin_relation, is_kin, f1fid, f2fid) = labels
kin_relation, is_kin = kin_relation.to(device), is_kin.to(device)

kin, f1, f2, _ = model([img1, img2])
# kin, f1, f2, att = model([img1, img2])
f1, f2, att = model([img1, img2])
sim = torch.cosine_similarity(f1, f2)
loss = facornet_contrastive_loss(f1, f2, beta=att)

# Fill preallocated tensors
similarities[current_index : current_index + batch_size_current] = sim
loss_values[current_index : current_index + batch_size_current] = loss
y_true[current_index : current_index + batch_size_current] = is_kin
y_true_kin_relations[current_index : current_index + batch_size_current] = kin_relation
pred_kin_relations[current_index : current_index + batch_size_current] = kin.argmax(dim=1)
# pred_kin_relations[current_index : current_index + batch_size_current] = kin.argmax(dim=1)

current_index += batch_size_current

return similarities, y_true, pred_kin_relations, y_true_kin_relations
return loss, similarities, y_true, pred_kin_relations, y_true_kin_relations


def validate(model, dataloader, device=0, threshold=None):
model.eval()
# Compute similarities
similarities, y_true, pred_kin_relations, y_true_kin_relations = predict(model, dataloader)
loss, similarities, y_true, pred_kin_relations, y_true_kin_relations = predict(model, dataloader)
# Compute metrics
auc = tm.functional.auroc(similarities, y_true, task="binary")
fpr, tpr, thresholds = tm.functional.roc(similarities, y_true, task="binary")
Expand All @@ -77,8 +81,9 @@ def validate(model, dataloader, device=0, threshold=None):
acc_kin_relations[kin_relation] = tm.functional.accuracy(
similarities[mask], y_true[mask], task="binary", threshold=threshold
)
kin_acc = tm.functional.accuracy(pred_kin_relations, y_true_kin_relations, task="multiclass", num_classes=12)
return auc, threshold, acc, acc_kin_relations, kin_acc
# kin_acc = tm.functional.accuracy(pred_kin_relations, y_true_kin_relations, task="multiclass", num_classes=12)
# return loss, auc, threshold, acc, acc_kin_relations, kin_acc
return loss, auc, threshold, acc, acc_kin_relations


def train(args):
Expand Down Expand Up @@ -117,8 +122,8 @@ def train(args):
total_steps = len(train_loader)
print(f"Total steps: {total_steps}")
global_step = 0
best_model_auc, _, val_acc, acc_kv, acc_clf_kr = validate(model, val_model_sel_loader)
out = f"epoch: 0 | auc: {best_model_auc:.6f} | acc_kv: {val_acc:.6f} | acc_clf_kr: {acc_clf_kr:.6f}"
val_loss, best_model_auc, _, val_acc, acc_kv = validate(model, val_model_sel_loader)
out = f"epoch: 0 | val_loss: {val_loss:.6f} | auc: {best_model_auc:.6f} | acc_kv: {val_acc:.6f}"
out = acc_kr_to_str(out, acc_kv)
print(out)

Expand All @@ -132,20 +137,22 @@ def train(args):
global_step = step + epoch * args.steps_per_epoch

image1, image2, labels = data
(kin_relation, is_kin) = labels
(kin_relation, is_kin, f1fid, f2fid) = labels

image1 = image1.to(args.device)
image2 = image2.to(args.device)
kin_relation = kin_relation.to(args.device)
is_kin = is_kin.to(args.device)

kin, x1, x2, att = model([image1, image2])
# kin, x1, x2, att = model([image1, image2])
x1, x2, att = model([image1, image2])
contrastive_loss = facornet_contrastive_loss(x1, x2, beta=att)
kin_loss = ce_loss(kin, kin_relation)
# kin_loss = ce_loss(kin, kin_relation)

contrastive_loss_epoch += contrastive_loss.item()
kin_loss_epoch += kin_loss.item()
loss = contrastive_loss + kin_loss
# kin_loss_epoch += kin_loss.item()
# loss = contrastive_loss + kin_loss
loss = contrastive_loss

optimizer.zero_grad()
loss.backward()
Expand All @@ -158,17 +165,18 @@ def train(args):
train_dataset.set_bias(use_sample)

# Save model checkpoints
auc, _, val_acc, acc_kv, acc_clf_kr = validate(model, val_model_sel_loader)
# auc, _, val_acc, acc_kv, acc_clf_kr = validate(model, val_model_sel_loader)
loss, auc, _, val_acc, acc_kv = validate(model, val_model_sel_loader)

if auc > best_model_auc:
best_model_auc = auc
torch.save(model.state_dict(), args.output_dir / "best.pth")

out = (
f"epoch: {epoch + 1:>2} | step: {global_step} "
+ f"| loss: {contrastive_loss_epoch / args.steps_per_epoch:.3f} "
+ f"| kin_loss: {kin_loss_epoch / args.steps_per_epoch:.3f} "
+ f"| auc: {auc:.6f} | acc_kv: {val_acc:.6f} | acc_clf_kr: {acc_clf_kr:.6f}"
+ f"| train_loss: {contrastive_loss_epoch / args.steps_per_epoch:.3f} | val_loss: {loss:.3f} "
# + f"| kin_loss: {kin_loss_epoch / args.steps_per_epoch:.3f} "
+ f"| auc: {auc:.6f} | acc_kv: {val_acc:.6f}"
)
out = acc_kr_to_str(out, acc_kv)
print(out)
Expand Down

0 comments on commit bd76555

Please sign in to comment.