forked from TaoRuijie/TalkNet-ASD
-
Notifications
You must be signed in to change notification settings - Fork 1
/
loss.py
executable file
·50 lines (42 loc) · 1.27 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import torch
import torch.nn as nn
import torch.nn.functional as F
class lossAV(nn.Module):
def __init__(self):
super(lossAV, self).__init__()
self.criterion = nn.CrossEntropyLoss()
self.FC = nn.Linear(256, 2)
def forward(self, x, labels=None):
x = x.squeeze(1)
x = self.FC(x)
if labels == None:
predScore = x[:,1]
predScore = predScore.t()
predScore = predScore.view(-1).detach().cpu().numpy()
return predScore
else:
nloss = self.criterion(x, labels)
predScore = F.softmax(x, dim = -1)
predLabel = torch.round(F.softmax(x, dim = -1))[:,1]
correctNum = (predLabel == labels).sum().float()
return nloss, predScore, predLabel, correctNum
class lossA(nn.Module):
def __init__(self):
super(lossA, self).__init__()
self.criterion = nn.CrossEntropyLoss()
self.FC = nn.Linear(128, 2)
def forward(self, x, labels):
x = x.squeeze(1)
x = self.FC(x)
nloss = self.criterion(x, labels)
return nloss
class lossV(nn.Module):
def __init__(self):
super(lossV, self).__init__()
self.criterion = nn.CrossEntropyLoss()
self.FC = nn.Linear(128, 2)
def forward(self, x, labels):
x = x.squeeze(1)
x = self.FC(x)
nloss = self.criterion(x, labels)
return nloss