diff --git a/src/loss_fns.py b/src/loss_fns.py index 689f9f9..0c3cffb 100644 --- a/src/loss_fns.py +++ b/src/loss_fns.py @@ -6,33 +6,6 @@ from typing import Optional, List -def focal_loss(input_values, gamma): - """Based on https://github.com/kaidic/LDAM-DRW/blob/master/losses.py""" - p = torch.exp(-input_values) - loss = (1 - p) ** gamma * input_values - return loss - - -class FocalLoss(nn.Module): - """Based on https://github.com/kaidic/LDAM-DRW/blob/master/losses.py""" - - def __init__(self, weight: Optional[torch.Tensor], gamma: float, reduction: str): - super(FocalLoss, self).__init__() - assert gamma >= 0 - self.gamma = gamma - assert reduction == "none" - if weight is not None: - self.register_buffer("weight", weight) - else: - self.weight = None - - def forward(self, input, target): - return focal_loss( - F.cross_entropy(input, target, reduction="none", weight=self.weight), - self.gamma, - ) - - class LDAMLoss(nn.Module): """Reimplementation of https://github.com/kaidic/LDAM-DRW/blob/master/losses.py @@ -77,7 +50,9 @@ def forward(self, logits, target): mask = torch.nn.functional.one_hot(target, num_classes=logits.shape[-1]) new_logits = logits - self.margins.reshape(1, -1) * mask return F.cross_entropy( - self.inv_temperature * new_logits, target, reduction=self.reduction, + self.inv_temperature * new_logits, + target, + reduction=self.reduction, ) @@ -113,7 +88,10 @@ def forward(self, logits, target): assert logits.shape[-1] == len(self.adjustments) adjusted_logits = logits + self.temperature * self.adjustments.reshape(1, -1) return F.cross_entropy( - adjusted_logits, target, weight=self.weight, reduction=self.reduction, + adjusted_logits, + target, + weight=self.weight, + reduction=self.reduction, ) @@ -149,9 +127,11 @@ def margin_fn(self, margin_vals: torch.Tensor): return scores if self.type == "logit": indicator = margin_vals <= 1 - inv_part = torch.pow(margin_vals.abs(),-1*self.alpha) + inv_part = torch.pow(margin_vals.abs(), -1 * self.alpha) logit_inner = -1 * margin_vals - logit_part = torch.nn.functional.softplus(logit_inner)/(math.log(1+math.exp(-1))) + logit_part = torch.nn.functional.softplus(logit_inner) / ( + math.log(1 + math.exp(-1)) + ) scores = logit_part * indicator + inv_part * (~indicator) return scores if self.type == "linear": @@ -179,7 +159,7 @@ def __init__( num_per_class: List[int], reduction: str = "none", ): - """ Note this only works for class imbalance and not group imbalance for now""" + """Note this only works for class imbalance and not group imbalance for now""" super().__init__() num_per_class = torch.tensor(num_per_class, dtype=torch.long) pi = num_per_class / num_per_class.sum(0) # prior probabilities @@ -207,7 +187,10 @@ class VSGroupLoss(nn.Module): _takes_groups = True def __init__( - self, gamma: float, num_per_group: List[int], reduction: str = "none", + self, + gamma: float, + num_per_group: List[int], + reduction: str = "none", ): super().__init__() num_per_group = torch.tensor(num_per_group, dtype=torch.long)