-
Notifications
You must be signed in to change notification settings - Fork 0
/
typhon_model.py
146 lines (115 loc) · 4.94 KB
/
typhon_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
from architecture_loader import ArchitectureLoader
import torch.nn as nn
import torch
import copy
import utils
class TyphonModel(nn.Module):
# Check when loading if it is the same version of this file
version = 1
def __init__(self,
dropout_fe,
dropouts_dm,
architecture,
dsets_names,
n_classes,
training_task
):
assert isinstance(architecture, str), "Provide an architecture"
assert dsets_names is not None, "Provide names for the datasets"
super(TyphonModel, self).__init__()
self.dropout_fe = dropout_fe
self.dropouts_dm = dropouts_dm
self.architecture = architecture
self.dsets_names = dsets_names
self.n_classes = n_classes
self.training_task = training_task
self.fe = ArchitectureLoader.get_fe(self.architecture, self.dropout_fe)
self.dms = self.init_dms()
# Recursively init the weights
self.fe.apply(self.init_weights)
self.dms.apply(self.init_weights)
self.set_dropout(self.dropout_fe, self.dropouts_dm)
def init_weights(self, module):
# Skip if module has no weights
if hasattr(module, 'weight'):
# Weights cannot be fewer than 2D for Xavier/Kaiming initializations
if len(module.weight.shape) > 1:
nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu')
def reload_dm(self, dset_name):
self.dms[dset_name].apply(self.init_weights)
def init_dms(self):
dms = nn.ModuleDict({})
for dset_name in self.dsets_names:
if self.training_task == 'classification':
new_dm = ArchitectureLoader.get_classification_dm(self.architecture, self.dropouts_dm[dset_name], self.n_classes[dset_name])
else:
new_dm = ArchitectureLoader.get_dm(self.architecture, self.dropouts_dm[dset_name])
dms[dset_name] = new_dm
return dms
def forward(self, x, dset_name):
x = self.fe(x)
return self.dms[dset_name](x)
def forward_fe(self, x):
return self.fe(x)
def forward_dm(self, x, dset_name):
return self.dms[dset_name](x)
# Sets dropout on model -- IMPORTANT when loading from target state!
def set_dropout(self, dropout_fe=None, dropouts_dm=None):
assert (dropout_fe or dropouts_dm), "Need new dropout for DMs and/or FE"
if dropout_fe:
self.dropout_fe = dropout_fe
for mod in self.fe:
if type(mod) is nn.Dropout:
mod.p = self.dropout_fe
if dropouts_dm:
self.dropouts_dm = dropouts_dm
for dset_name, dm in self.dms.items():
for mod in dm:
if type(mod) is nn.Dropout:
mod.p = self.dropouts_dm[dset_name]
# To freeze and unfreeze the feature extractor during hydra
def freeze_fe(self):
for name, param in self.named_parameters():
if 'fe' in name:
param.requires_grad = False
def unfreeze_fe(self):
for name, param in self.named_parameters():
if 'fe' in name:
param.requires_grad = True
def print_stats(self):
utils.print_time("Model statistics")
fe_params = sum(p.numel() for p in self.fe.parameters() if p.requires_grad)
print(f"> The model has {fe_params} trainable parameters in the feature extractor")
for dset_name, dm in self.dms.items():
dm_params = sum(p.numel() for p in dm.parameters() if p.requires_grad)
print(f"> The model has {dm_params} trainable parameters in the {dset_name} head")
print()
# Return separate models, with one FE and one DM each (used for specialization)
def split_typhon(self):
models = {}
for dset_name in self.dsets_names:
# Use deepcopy to have a new object with new reference
model = copy.deepcopy(self)
model.dms = nn.ModuleDict({dset_name: self.dms[dset_name]})
model.dsets_names = [dset_name]
models[dset_name] = model
return models
def to_state_dict(self):
variables = {k:v for k, v in vars(self).items() if not k.startswith('_')}
# Throws an error when loading with double splat operator (and is not needed)
del variables['training']
return {
'fe': self.fe.state_dict(),
'dms': self.dms.state_dict(),
'variables': variables
}
# Generate new model from state_dict
@staticmethod
def from_state_dict(trg_model_state):
# Check the version of the model
ret = TyphonModel(**trg_model_state['variables'])
assert TyphonModel.version == ret.version, "Version not corresponding"
ret.fe.load_state_dict(trg_model_state['fe'])
ret.dms.load_state_dict(trg_model_state['dms'])
ret.set_dropout(ret.dropout_fe, ret.dropouts_dm)
return ret