-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathloss_modules.py
136 lines (105 loc) · 4.35 KB
/
loss_modules.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import torch
from torch.nn import functional as F
import sys
def cll_v1(args, cam_cnn, transformer_embed, label_bg):
'''
Semantic Aware Projection
'''
# *********Hyperparams********* #
top_num = args.top_bot_k[0]
bottom_num = args.top_bot_k[1]
tau = args.tau
B, C, N, _ = cam_cnn.shape
N2 = N * N # for 196
scores = F.softmax(cam_cnn * label_bg, dim=1) # [B, 21, 14, 14]; Softmax on class level
pseudo_score, pseudo_label = torch.max(scores, dim=1) # [B, 14, 14]; Select best class-score on CNN CAMs by pixel
cam_cnn = cam_cnn.reshape(B, C, -1) # [B, C, 196]
pseudo_label = pseudo_label.reshape(B, -1) # [B, 196]
cam = [cam_cnn[i, pseudo_label[i]] for i in range(B)] # [B, 196, 196]
cam = torch.stack(cam, dim=0)
top_values, top_indices = torch.topk(
cam, k=top_num, dim=-1, largest=True) # [B, 196, 20]
bottom_values, bottom_indices = torch.topk(
cam, k=bottom_num, dim=-1, largest=False) # [B, 196, 20]
transformer_embed = transformer_embed.transpose(1, 2) # [B, 196, 128]
pos_init = []
neg_init = []
for i in range(B):
pos_init.append(transformer_embed[i, top_indices[i]])
neg_init.append(transformer_embed[i, bottom_indices[i]])
pos = torch.stack(pos_init, dim=0) # [B, 196, 20, 128]
neg = torch.stack(neg_init, dim=0) # [B, 196, 20, 128]
# Computing Loss
loss = torch.zeros((1)).cuda()
'''
basically fomula of loss = X/(X+Y)
'''
for i in range(N2):
main_vector_tf = transformer_embed[:, i].unsqueeze(-1)
# X where of numerator
pos_inner = pos[:, i] @ main_vector_tf # [B, 20, 1]
X = torch.exp(pos_inner.squeeze(-1) / tau)
# Y where of denominator
neg_inner = neg[:, i] @ main_vector_tf # [B, 20, 1]
Y = torch.sum((torch.exp(neg_inner.squeeze(-1)) / tau),
dim=-1, keepdim=True)
# X/(X+Y)
loss += torch.sum(-torch.log(X / (X + Y)))
return loss / (N2 * (top_num * B))
def cll_v2(args, attn_weights, cnn_embed):
'''
Class Aware Projection
'''
# *********Hyperparams*********
top_num = args.top_bot_k[2]
bottom_num = args.top_bot_k[3]
tau = args.tau
# *****************************
attn_weights = attn_weights[:, 1:, 1:] # P2P Attention Score Excepted Background Token
B, N, N = attn_weights.shape
top_values, top_indices = torch.topk(
attn_weights, k=top_num, dim=-1, largest=True) # [B, 196, 20]
bottom_values, bottom_indices = torch.topk(
attn_weights, k=bottom_num, dim=-1, largest=False) # [B, 196, 20]
cnn_embed = cnn_embed.transpose(1, 2) # [B, 196, 128]
pos_init = []
neg_init = []
for i in range(B):
pos_init.append(cnn_embed[i, top_indices[i]])
neg_init.append(cnn_embed[i, bottom_indices[i]])
pos = torch.stack(pos_init, dim=0) # [B, 196, k, 128]
neg = torch.stack(neg_init, dim=0) # [B, 196, k, 128]
# Computing Loss
loss = torch.zeros(1).cuda()
'''
basically fomula of loss is X/(X+Y)
'''
for i in range(N):
main_vector_tf = cnn_embed[:, i].unsqueeze(-1) # main vector for all batch
# X where of numerator
pos_inner = pos[:, i] @ main_vector_tf # [B, 20, 1], matmul
X = torch.exp(pos_inner.squeeze(-1) / tau)
# Y where of denominator
neg_inner = neg[:, i] @ main_vector_tf # [B, 20, 1], matmul
Y = torch.sum(torch.exp(neg_inner.squeeze(-1) / tau),
dim=-1, keepdim=True)
# X/(X+Y)
loss += torch.sum(-torch.log(X / (X + Y)))
# loss /= N * B * top_num
loss /= N * B * top_num
return loss
def loss_CAM(cam_cnn_224, cam_tf_224, label):
cam_cnn_224_classes = cam_cnn_224[:, 1:]
cam_tf_224_classes = cam_tf_224[:, 1:]
loss_interCAM = 0
for i in range(len(cam_cnn_224_classes)):
valid_cat = torch.nonzero(label[i])[:, 0]
cam_cnn_224_class = cam_cnn_224_classes[i, valid_cat]
cam_tf_224_class = cam_tf_224_classes[i, valid_cat]
loss_bg = torch.mean(torch.abs(cam_cnn_224[i][0] - cam_tf_224[i][0]))
loss_inter = torch.mean(torch.abs(cam_cnn_224_class - cam_tf_224_class))
# loss = (loss_bg + loss_inter)/2
loss = loss_inter
loss_interCAM += loss
loss_interCAM1 = loss_interCAM / len(cam_cnn_224)
return loss_interCAM1