-
Notifications
You must be signed in to change notification settings - Fork 0
/
loss.py
204 lines (172 loc) · 7.5 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
import torch
import torch.nn as nn
from torch.nn.functional import binary_cross_entropy_with_logits
from discriminator import Discriminator
from generator import Generator
from config import cfg
class AdversarialLoss:
"""
To make the generated images indistinguishable from real images
"""
def __init__(self) -> None:
super(AdversarialLoss, self).__init__()
def __call__(
self, fake_logits: torch.Tensor, real_logits: torch.Tensor = None
) -> torch.Tensor:
"""
:param torch.Tensor fake_logits: logits prediction of discriminator that it is real image after generator with label condition
:param torch.Tensor real_logits: logits prediction of discriminator that it is real image
:rtype: torch.Tensor
:returns: Adversarial loss of real and fake images
"""
if real_logits is None:
return -torch.mean(fake_logits)
else:
return -torch.mean(real_logits) + torch.mean(fake_logits)
class DomainClassificationLoss:
"""
For a given input image x and a target domain label c, our goal is to translate x into
an output image y, which is properly classified to the target domain c
Ex,c'[− log Dcls(c'|image)] or Ex,c[− log Dcls(c|G(x, c))].
"""
def __init__(self) -> None:
super(DomainClassificationLoss, self).__init__()
def __call__(self, logit, target: torch.Tensor) -> torch.Tensor:
"""
:param torch.Tensor logit: logits from cls head of discriminator
:param torch.Tensor target: true labels
:rtype: torch.Tensor
:returns: Domain classification loss of image
"""
return binary_cross_entropy_with_logits(
logit, target, reduction="mean"
) / logit.size(0)
class ReconstructionLoss:
"""
By minimizing the adversarial and classification losses, G is trained to generate images that
are realistic and classified to its correct target domain. However, minimizing the losses does not guarantee
that translated images preserve the content of its input images while changing only the domain-related part of the inputs.
ReconstructionLoss = E(x,c,c') [||x − G(G(x, c), c')||1]
"""
def __init__(self):
super(ReconstructionLoss, self).__init__()
def __call__(self, real: torch.Tensor, reconstructed: torch.Tensor) -> torch.Tensor:
"""
:param torch.Tensor real: real input image
:param torch.Tensor reconstructed: reconstructed image G(G(x, c), c'), where x - real input image, c - target domain, c' - original domain(from dataset)
:rtype: torch.Tensor
:returns: Reconstruction Loss of x and G(G(x, c), c')
"""
return torch.mean(torch.abs(real - reconstructed))
class GeneratorLoss:
"""
Overall Generator loss: AdversarialLoss + λcls*DomainClassificationLoss + λrec*ReconstructionLoss
"""
def __init__(
self,
discriminator: Discriminator,
lambda_cls: float = 1.0,
lambda_rec: float = 10.0,
) -> None:
"""
:param Discriminator discriminator: object of discriminator network
:param float lambda_cls: control the relative importance of domain classification loss
:param float lambda_rec: control the relative importance of reconstruction loss
:rtype: None
"""
super(GeneratorLoss, self).__init__()
self.discriminator = discriminator
self.adversarial_loss = AdversarialLoss()
self.domain_classification_loss = DomainClassificationLoss()
self.reconstruction_Loss = ReconstructionLoss()
self.lambda_cls = lambda_cls
self.lambda_rec = lambda_rec
def __call__(
self,
real: torch.Tensor,
fake: torch.Tensor,
reconstructed: torch.Tensor,
labels_target: torch.Tensor,
output_fake_src: torch.Tensor,
output_fake_cls: torch.Tensor,
) -> torch.Tensor:
"""
:param torch.Tensor real: batch of real images
:param torch.Tensor fake: batch of generated by generator images
:param torch.Tensor reconstructed: batch of reconstracted images G(G(x,c),c')
:param torch.Tensor labels_target: target labels generated to train network
:param torch.Tensor output_fake_src: output src head of discriminator on fake image
:param torch.Tensor output_fake_cls: output cls head of discriminator on fake image
:rtype: torch.Tensor
:returns: Overal Generator Loss ready for backpropagate
"""
loss_adv = self.adversarial_loss(fake_logits=output_fake_src, real_logits=None)
loss_dm_cls = self.domain_classification_loss(
logit=torch.squeeze(output_fake_cls), target=labels_target
)
loss_rec = self.reconstruction_Loss(real=real, reconstructed=reconstructed)
return loss_adv + self.lambda_cls * loss_dm_cls + self.lambda_rec * loss_rec
class DiscriminatorLoss:
"""
Overall Discriminator loss: AdversarialLoss + λcls*DomainClassificationLoss
"""
def __init__(
self,
discriminator: Discriminator,
lambda_cls: float = 1.0,
lambda_gp: float = 10.0,
) -> None:
"""
:param float lambda_cls: control the relative importance of domain classification loss
:rtype: None
"""
super(DiscriminatorLoss, self).__init__()
self.adversarial_loss = AdversarialLoss()
self.discriminator = discriminator
self.domain_classification_loss = DomainClassificationLoss()
self.lambda_cls = lambda_cls
self.lambda_gp = lambda_gp
def __call__(
self,
real_image: torch.Tensor,
fake_image: torch.Tensor,
output_src_real: torch.Tensor,
output_src_fake: torch.Tensor,
output_cls_real: torch.Tensor,
labels_dataset: torch.Tensor,
) -> torch.Tensor:
"""
:param torch.Tensor real_image: real image
:param torch.Tensor fake_image: image generated by generator
:param torch.Tensor output_src_real: output src head of discriminator on real image
:param torch.Tensor output_src_fake: output src head of discriminator on fake image
:param torch.Tensor output_cls_real: output cls head of discriminator on real image
:param torch.Tensor labels_dataset: original labels from dataset
:rtype: torch.Tensor
:returns: Overall discriminator loss ready for backpropagation
"""
loss_adv = self.adversarial_loss(
fake_logits=output_src_fake, real_logits=output_src_real
)
loss_dm_cls = self.domain_classification_loss(
logit=torch.squeeze(output_cls_real), target=labels_dataset
)
alpha = torch.rand(real_image.size(0), 1, 1, 1).to(cfg.device)
x_hat = alpha * real_image + (1 - alpha) * fake_image
out_src, _ = self.discriminator(x_hat)
d_loss_gp = self.gradient_penalty(out_src, x_hat)
return loss_adv + self.lambda_cls * loss_dm_cls + self.lambda_gp * d_loss_gp
def gradient_penalty(self, y, x):
"""Compute gradient penalty: (L2_norm(dy/dx) - 1)**2."""
weight = torch.ones(y.size()).to(cfg.device)
dydx = torch.autograd.grad(
outputs=y,
inputs=x,
grad_outputs=weight,
retain_graph=True,
create_graph=True,
only_inputs=True,
)[0]
dydx = dydx.view(dydx.size(0), -1)
dydx_l2norm = torch.sqrt(torch.sum(dydx ** 2, dim=1))
return torch.mean((dydx_l2norm - 1) ** 2)