-
Notifications
You must be signed in to change notification settings - Fork 0
/
edl_losses.py
129 lines (103 loc) · 3.58 KB
/
edl_losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
"""
From:
https://github.com/dougbrion/pytorch-classification-uncertainty/blob/master/losses.py
"""
import torch
import torch.nn.functional as F
def get_device():
# TODO check if is available
return "cuda"
def relu_evidence(y):
return F.relu(y)
def exp_evidence(y):
return torch.exp(torch.clamp(y, -10, 10))
def softplus_evidence(y):
return F.softplus(y)
def kl_divergence(alpha, num_classes, device=None):
if not device:
device = get_device()
ones = torch.ones([1, num_classes], dtype=torch.float32, device=device)
sum_alpha = torch.sum(alpha, dim=1, keepdim=True)
first_term = (
torch.lgamma(sum_alpha)
- torch.lgamma(alpha).sum(dim=1, keepdim=True)
+ torch.lgamma(ones).sum(dim=1, keepdim=True)
- torch.lgamma(ones.sum(dim=1, keepdim=True))
)
second_term = (
(alpha - ones)
.mul(torch.digamma(alpha) - torch.digamma(sum_alpha))
.sum(dim=1, keepdim=True)
)
kl = first_term + second_term
return kl
def loglikelihood_loss(y, alpha, device=None):
if not device:
device = get_device()
y = y.to(device)
alpha = alpha.to(device)
S = torch.sum(alpha, dim=1, keepdim=True)
loglikelihood_err = torch.sum((y - (alpha / S)) ** 2, dim=1, keepdim=True)
loglikelihood_var = torch.sum(
alpha * (S - alpha) / (S * S * (S + 1)), dim=1, keepdim=True
)
loglikelihood = loglikelihood_err + loglikelihood_var
return loglikelihood
def mse_loss(y, alpha, epoch_num, num_classes, annealing_step, device=None):
if not device:
device = get_device()
y = y.to(device)
alpha = alpha.to(device)
loglikelihood = loglikelihood_loss(y, alpha, device=device)
annealing_coef = torch.min(
torch.tensor(1.0, dtype=torch.float32),
torch.tensor(epoch_num / annealing_step, dtype=torch.float32),
)
kl_alpha = (alpha - 1) * (1 - y) + 1
kl_div = annealing_coef * kl_divergence(kl_alpha, num_classes, device=device)
return loglikelihood + kl_div
def edl_loss(func, y, alpha, epoch_num, num_classes, annealing_step, device=None):
y = y.to(device)
alpha = alpha.to(device)
S = torch.sum(alpha, dim=1, keepdim=True)
A = torch.sum(y * (func(S) - func(alpha)), dim=1, keepdim=True)
annealing_coef = torch.min(
torch.tensor(1.0, dtype=torch.float32),
torch.tensor(epoch_num / annealing_step, dtype=torch.float32),
)
kl_alpha = (alpha - 1) * (1 - y) + 1
kl_div = annealing_coef * kl_divergence(kl_alpha, num_classes, device=device)
return A + kl_div
def edl_mse_loss(output, target, epoch_num, num_classes, annealing_step, device=None):
if not device:
device = get_device()
evidence = relu_evidence(output)
alpha = evidence + 1
loss = torch.mean(
mse_loss(target, alpha, epoch_num, num_classes, annealing_step, device=device)
)
return loss
def edl_log_loss(output, target, epoch_num, num_classes, annealing_step, device=None):
if not device:
device = get_device()
evidence = relu_evidence(output)
alpha = evidence + 1
loss = torch.mean(
edl_loss(
torch.log, target, alpha, epoch_num, num_classes, annealing_step, device
)
)
return loss
def edl_digamma_loss(
output, target, epoch_num, num_classes, annealing_step, device=None
):
if not device:
device = get_device()
evidence = relu_evidence(output)
alpha = evidence + 1
loss = torch.mean(
edl_loss(
torch.digamma, target, alpha, epoch_num, num_classes, annealing_step, device
)
)
return loss