Skip to content

Commit

Permalink
update src/loss_fns.py
Browse files Browse the repository at this point in the history
  • Loading branch information
KeAWang committed Mar 1, 2022
1 parent 76c9fba commit 3362a7a
Showing 1 changed file with 16 additions and 33 deletions.
49 changes: 16 additions & 33 deletions src/loss_fns.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,33 +6,6 @@
from typing import Optional, List


def focal_loss(input_values, gamma):
"""Based on https://github.com/kaidic/LDAM-DRW/blob/master/losses.py"""
p = torch.exp(-input_values)
loss = (1 - p) ** gamma * input_values
return loss


class FocalLoss(nn.Module):
"""Based on https://github.com/kaidic/LDAM-DRW/blob/master/losses.py"""

def __init__(self, weight: Optional[torch.Tensor], gamma: float, reduction: str):
super(FocalLoss, self).__init__()
assert gamma >= 0
self.gamma = gamma
assert reduction == "none"
if weight is not None:
self.register_buffer("weight", weight)
else:
self.weight = None

def forward(self, input, target):
return focal_loss(
F.cross_entropy(input, target, reduction="none", weight=self.weight),
self.gamma,
)


class LDAMLoss(nn.Module):
"""Reimplementation of https://github.com/kaidic/LDAM-DRW/blob/master/losses.py
Expand Down Expand Up @@ -77,7 +50,9 @@ def forward(self, logits, target):
mask = torch.nn.functional.one_hot(target, num_classes=logits.shape[-1])
new_logits = logits - self.margins.reshape(1, -1) * mask
return F.cross_entropy(
self.inv_temperature * new_logits, target, reduction=self.reduction,
self.inv_temperature * new_logits,
target,
reduction=self.reduction,
)


Expand Down Expand Up @@ -113,7 +88,10 @@ def forward(self, logits, target):
assert logits.shape[-1] == len(self.adjustments)
adjusted_logits = logits + self.temperature * self.adjustments.reshape(1, -1)
return F.cross_entropy(
adjusted_logits, target, weight=self.weight, reduction=self.reduction,
adjusted_logits,
target,
weight=self.weight,
reduction=self.reduction,
)


Expand Down Expand Up @@ -149,9 +127,11 @@ def margin_fn(self, margin_vals: torch.Tensor):
return scores
if self.type == "logit":
indicator = margin_vals <= 1
inv_part = torch.pow(margin_vals.abs(),-1*self.alpha)
inv_part = torch.pow(margin_vals.abs(), -1 * self.alpha)
logit_inner = -1 * margin_vals
logit_part = torch.nn.functional.softplus(logit_inner)/(math.log(1+math.exp(-1)))
logit_part = torch.nn.functional.softplus(logit_inner) / (
math.log(1 + math.exp(-1))
)
scores = logit_part * indicator + inv_part * (~indicator)
return scores
if self.type == "linear":
Expand Down Expand Up @@ -179,7 +159,7 @@ def __init__(
num_per_class: List[int],
reduction: str = "none",
):
""" Note this only works for class imbalance and not group imbalance for now"""
"""Note this only works for class imbalance and not group imbalance for now"""
super().__init__()
num_per_class = torch.tensor(num_per_class, dtype=torch.long)
pi = num_per_class / num_per_class.sum(0) # prior probabilities
Expand Down Expand Up @@ -207,7 +187,10 @@ class VSGroupLoss(nn.Module):
_takes_groups = True

def __init__(
self, gamma: float, num_per_group: List[int], reduction: str = "none",
self,
gamma: float,
num_per_group: List[int],
reduction: str = "none",
):
super().__init__()
num_per_group = torch.tensor(num_per_group, dtype=torch.long)
Expand Down

0 comments on commit 3362a7a

Please sign in to comment.