-
Notifications
You must be signed in to change notification settings - Fork 1
/
probe_utils.py
115 lines (102 loc) · 4.47 KB
/
probe_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# Copyright (c) 2022 Robert Bosch GmbH
# SPDX-License-Identifier: AGPL-3.0
import torch.nn as nn
from timm_future_imports import Mlp
class LinearProbeLayer(nn.Module):
def __init__(self, num_features, num_classes):
super().__init__()
self.lin_layer = nn.Linear(num_features, num_classes)
def forward(self, x):
return self.lin_layer(x)
class LinearProbeCollection():
def __init__(self, model, nb_classes, ntargets, args):
if 'deit' in args.model:
self.nfeat = 4160
elif args.cls_token_linprobe:
self.nfeat = args.embedding_dim_per_head
elif 'vanilla' in args.model:
self.nfeat = args.embedding_dim_per_head * 197
else:
self.nfeat = args.embedding_dim_per_head * 196
self.depth = model.depth
self.num_heads = model.num_heads
self.ntargets = ntargets
if isinstance(nb_classes, int):
nb_classes = [nb_classes]
if args.return_qkv:
self.num_reps = 4
else:
self.num_reps = 1
if args.return_intermed_x:
self.num_reps += 1
self.mlp_probes = args.mlp_probes
probes = []
for d in range(self.depth):
d_probes = []
for h in range(self.num_heads):
h_probes = []
for target in range(self.ntargets):
r_probes = []
for rep in range(self.num_reps):
if not args.mlp_probes:
if rep == self.num_reps-1 and args.return_intermed_x:
r_probes.append(LinearProbeLayer(self.nfeat * args.num_heads, nb_classes[target]))
else:
r_probes.append(LinearProbeLayer(self.nfeat, nb_classes[target]))
else:
if rep == self.num_reps-1 and args.return_intermed_x:
r_probes.append(Mlp(self.nfeat * args.num_heads, 10, nb_classes[target]))
else:
r_probes.append(Mlp(self.nfeat, 10, nb_classes[target]))
h_probes.append(r_probes)
d_probes.append(h_probes)
probes.append(d_probes)
self.probe_list = probes
def to_gpu(self):
for d in range(self.depth):
for h in range(self.num_heads):
for t in range(self.ntargets):
for r in range(self.num_reps):
if not self.mlp_probes:
self.probe_list[d][h][t][r].lin_layer.cuda()
else:
self.probe_list[d][h][t][r].cuda()
def set_trainable(self):
for d in range(self.depth):
for h in range(self.num_heads):
for t in range(self.ntargets):
for r in range(self.num_reps):
if not self.mlp_probes:
self.probe_list[d][h][t][r].lin_layer.requires_grad = True
else:
for n, v in self.probe_list[0][0][0][0].named_parameters():
v.requires_grad = True
def train(self):
for d in range(self.depth):
for h in range(self.num_heads):
for t in range(self.ntargets):
for r in range(self.num_reps):
if not self.mlp_probes:
self.probe_list[d][h][t][r].lin_layer.train(True)
else:
self.probe_list[d][h][t][r].train(True)
def eval(self):
for d in range(self.depth):
for h in range(self.num_heads):
for t in range(self.ntargets):
for r in range(self.num_reps):
self.probe_list[d][h][t][r].eval()
def get_optimizers(self, args, create_optimizer_fn, **kwargs):
optims = []
for d in range(self.depth):
d_optims = []
for h in range(self.num_heads):
h_optims = []
for t in range(self.ntargets):
r_optims = []
for r in range(self.num_reps):
r_optims.append(create_optimizer_fn(args, self.probe_list[d][h][t][r], kwargs))
h_optims.append(r_optims)
d_optims.append(h_optims)
optims.append(d_optims)
return optims