-
Notifications
You must be signed in to change notification settings - Fork 741
/
losses.py
executable file
·87 lines (71 loc) · 2.76 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
'''
Portions of this code copyright 2017, Clement Pinard
'''
# freda (todo) : adversarial loss
import torch
import torch.nn as nn
import math
def EPE(input_flow, target_flow):
return torch.norm(target_flow-input_flow,p=2,dim=1).mean()
class L1(nn.Module):
def __init__(self):
super(L1, self).__init__()
def forward(self, output, target):
lossvalue = torch.abs(output - target).mean()
return lossvalue
class L2(nn.Module):
def __init__(self):
super(L2, self).__init__()
def forward(self, output, target):
lossvalue = torch.norm(output-target,p=2,dim=1).mean()
return lossvalue
class L1Loss(nn.Module):
def __init__(self, args):
super(L1Loss, self).__init__()
self.args = args
self.loss = L1()
self.loss_labels = ['L1', 'EPE']
def forward(self, output, target):
lossvalue = self.loss(output, target)
epevalue = EPE(output, target)
return [lossvalue, epevalue]
class L2Loss(nn.Module):
def __init__(self, args):
super(L2Loss, self).__init__()
self.args = args
self.loss = L2()
self.loss_labels = ['L2', 'EPE']
def forward(self, output, target):
lossvalue = self.loss(output, target)
epevalue = EPE(output, target)
return [lossvalue, epevalue]
class MultiScale(nn.Module):
def __init__(self, args, startScale = 4, numScales = 5, l_weight= 0.32, norm= 'L1'):
super(MultiScale,self).__init__()
self.startScale = startScale
self.numScales = numScales
self.loss_weights = torch.FloatTensor([(l_weight / 2 ** scale) for scale in range(self.numScales)])
self.args = args
self.l_type = norm
self.div_flow = 0.05
assert(len(self.loss_weights) == self.numScales)
if self.l_type == 'L1':
self.loss = L1()
else:
self.loss = L2()
self.multiScales = [nn.AvgPool2d(self.startScale * (2**scale), self.startScale * (2**scale)) for scale in range(self.numScales)]
self.loss_labels = ['MultiScale-'+self.l_type, 'EPE'],
def forward(self, output, target):
lossvalue = 0
epevalue = 0
if type(output) is tuple:
target = self.div_flow * target
for i, output_ in enumerate(output):
target_ = self.multiScales[i](target)
epevalue += self.loss_weights[i]*EPE(output_, target_)
lossvalue += self.loss_weights[i]*self.loss(output_, target_)
return [lossvalue, epevalue]
else:
epevalue += EPE(output, target)
lossvalue += self.loss(output, target)
return [lossvalue, epevalue]