Skip to content

Commit

Permalink
fix facornet metrics by keeping it simple (#71)
Browse files Browse the repository at this point in the history
  • Loading branch information
vitalwarley committed Apr 3, 2024
1 parent 0622901 commit e5600af
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 79 deletions.
4 changes: 2 additions & 2 deletions ours/configs/facornet.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ trainer:
class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
dirpath: ./
filename: '{epoch}-{auc/val:.3f}-{auc/train:.3f}'
monitor: auc/val
filename: '{epoch}-{loss/val:.3f}-{loss/train:.3f}-{auc:.6f}'
monitor: auc
verbose: no
save_last: yes
save_top_k: 1
Expand Down
1 change: 1 addition & 0 deletions ours/datasets/fiw.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def _process_labels(self, sample):
# fid1, fid2 = int(sample.f1fid[1:]), int(sample.f2fid[1:])
# labels = (kin_id, is_kin, fid1, fid2)
labels = (kin_id, is_kin)
# labels = is_kin
return labels

def __getitem__(self, item):
Expand Down
129 changes: 52 additions & 77 deletions ours/models/facornet.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,24 +658,6 @@ def IR_SE_200(input_size):
return model


class DynamicThresholdAccuracy(tm.Metric):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")

def update(self, preds: torch.Tensor, target: torch.Tensor, threshold: torch.Tensor):
if preds.shape != target.shape:
raise ValueError("preds and target must have the same shape")
preds_thresholded = preds >= threshold
correct = torch.sum(preds_thresholded == target)
self.correct += correct
self.total += target.numel()

def compute(self):
return self.correct.float() / self.total


class CollectPreds(tm.Metric):
def __init__(self, name: str, **kwargs):
self.name = name
Expand All @@ -696,47 +678,18 @@ def __init__(
self, model: torch.nn.Module, lr=1e-4, momentum=0.9, weight_decay=0, weights_path=None, threshold=None
):
super().__init__()
self.model = FaCoR() or model
self.save_hyperparameters(ignore=("model"))

self.model = FaCoR() or model
self.lr = lr
self.momentum = momentum
self.weight_decay = weight_decay
self.loss_fn = facornet_contrastive_loss

self.threshold = threshold

self.similarities = tm.MetricCollection(
{
"similarities/train": CollectPreds("similarities/train"),
"similarities/val": CollectPreds("similarities/val"),
}
)
self.is_kin_labels = tm.MetricCollection(
{
"is_kin_labels/train": CollectPreds("is_kin_labels/train"),
"is_kin_labels/val": CollectPreds("is_kin_labels/val"),
}
)
self.kin_labels = tm.MetricCollection(
{
"kin_labels/train": CollectPreds("kin_labels/train"),
"kin_labels/val": CollectPreds("kin_labels/val"),
}
)

# Metrics
self.train_auc = tm.AUROC(task="binary")
self.val_auc = tm.AUROC(task="binary")
self.train_acc = DynamicThresholdAccuracy()
self.val_acc = DynamicThresholdAccuracy()
self.train_acc_kin_relations = tm.MetricCollection(
{f"accuracy/train/{kin}": DynamicThresholdAccuracy() for kin in Sample.NAME2LABEL.keys()}
)
self.val_acc_kin_relations = tm.MetricCollection(
{f"accuracy/val/{kin}": DynamicThresholdAccuracy() for kin in Sample.NAME2LABEL.keys()}
)

self.save_hyperparameters(ignore=("model"))
self.similarities = CollectPreds("similarities")
self.is_kin_labels = CollectPreds("is_kin_labels")
self.kin_labels = CollectPreds("kin_labels")

def setup(self, stage):
# TODO: use checkpoint callback to load the weights
Expand All @@ -763,10 +716,11 @@ def _step(self, batch, stage="train"):
loss = self.loss_fn(f1, f2, beta=att)
sim = torch.cosine_similarity(f1, f2)

# Compute best threshold for training or validation
self.similarities[f"similarities/{stage}"](sim)
self.is_kin_labels[f"is_kin_labels/{stage}"](is_kin)
self.kin_labels[f"kin_labels/{stage}"](kin_relation)
if stage != "train":
# Compute best threshold for training or validation
self.similarities(sim)
self.is_kin_labels(is_kin)
self.kin_labels(kin_relation)

self.log(f"loss/{stage}", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)

Expand Down Expand Up @@ -795,7 +749,6 @@ def on_train_epoch_end(self):
# Update the dataset's bias or sampling strategy
self.trainer.datamodule.train_dataset.set_bias(use_sample)
print(f"Updated dataset bias to {use_sample}")
self._on_epoch_end("train")

def on_validation_epoch_end(self):
self._on_epoch_end("val")
Expand All @@ -805,14 +758,14 @@ def on_test_epoch_end(self):

def _on_epoch_end(self, stage):
# Compute predictions
similarities = self.similarities[f"similarities/{stage}"].compute()
is_kin_labels = self.is_kin_labels[f"is_kin_labels/{stage}"].compute()
kin_labels = self.kin_labels[f"kin_labels/{stage}"].compute()
similarities = self.similarities.compute()
is_kin_labels = self.is_kin_labels.compute()
kin_labels = self.kin_labels.compute()
self.__compute_metrics(similarities, is_kin_labels, kin_labels, stage=stage)
# Reset predictions
self.similarities[f"similarities/{stage}"].reset()
self.is_kin_labels[f"is_kin_labels/{stage}"].reset()
self.kin_labels[f"kin_labels/{stage}"].reset()
self.similarities.reset()
self.is_kin_labels.reset()
self.kin_labels.reset()

def __compute_metrics(self, similarities, is_kin_labels, kin_labels, stage="train"):
# Compute best threshold
Expand All @@ -824,30 +777,52 @@ def __compute_metrics(self, similarities, is_kin_labels, kin_labels, stage="trai
fpr, tpr, thresholds = tm.functional.roc(similarities, is_kin_labels, task="binary")
best_threshold = compute_best_threshold(tpr, fpr, thresholds)

# Log similarities histogram by is_kin_labels
self.logger.experiment.add_histogram(
"similarities/positive",
similarities[is_kin_labels == 1],
global_step=self.current_epoch,
)
self.logger.experiment.add_histogram(
"similarities/negative",
similarities[is_kin_labels == 0],
global_step=self.current_epoch,
)

# Compute metrics
auc_fn = self.train_auc if stage == "train" else self.val_auc
acc_fn = self.train_acc if stage == "train" else self.val_acc
auc = auc_fn(similarities, is_kin_labels)
acc = acc_fn(similarities, is_kin_labels, best_threshold)
auc = tm.functional.auroc(similarities, is_kin_labels, task="binary")
acc = tm.functional.accuracy(similarities, is_kin_labels, threshold=best_threshold, task="binary")
precision = tm.functional.precision(similarities, is_kin_labels, threshold=best_threshold, task="binary")
recall = tm.functional.recall(similarities, is_kin_labels, threshold=best_threshold, task="binary")

# Log metrics
self.log("threshold", best_threshold, on_step=False, on_epoch=True, prog_bar=True, logger=True)
self.log("accuracy", acc, on_step=False, on_epoch=True, prog_bar=True, logger=True)
self.log("auc", auc, on_step=False, on_epoch=True, prog_bar=True, logger=True)
self.log("precision", precision, on_step=False, on_epoch=True, prog_bar=True, logger=True)
self.log("recall", recall, on_step=False, on_epoch=True, prog_bar=True, logger=True)

# Accuracy for each kinship relation
acc_kin_relations = self.train_acc_kin_relations if stage == "train" else self.val_acc_kin_relations
for kin, kin_id in Sample.NAME2LABEL.items(): # TODO: pass Sample class as argument
mask = kin_labels == kin_id
if torch.any(mask):
acc_kin_relations[f"accuracy/{stage}/{kin}"](
similarities[mask], is_kin_labels[mask].int(), best_threshold
acc = tm.functional.accuracy(
similarities[mask], is_kin_labels[mask].int(), threshold=best_threshold, task="binary"
)
self.log(
f"accuracy/{stage}/{kin}",
acc_kin_relations[f"accuracy/{stage}/{kin}"],
f"accuracy/{kin}",
acc,
on_step=False,
on_epoch=True,
prog_bar=False,
logger=True,
)

# Log metrics
self.log(f"threshold/{stage}", best_threshold, on_step=False, on_epoch=True, prog_bar=True, logger=True)
self.log(f"accuracy/{stage}", auc, on_step=False, on_epoch=True, prog_bar=True, logger=True)
self.log(f"auc/{stage}", acc, on_step=False, on_epoch=True, prog_bar=True, logger=True)
# Add similarities
# Negative pairs are "non-kin" pairs, which are equal to the overall similarities/negative
positives = similarities[mask][is_kin_labels[mask] == 1]
if positives.numel() > 0:
self.logger.experiment.add_histogram(
f"similarities/{kin}",
positives,
global_step=self.current_epoch,
)

0 comments on commit e5600af

Please sign in to comment.