-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathclassifier.py
153 lines (127 loc) · 5.82 KB
/
classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
"""
@author: Baixu Chen
@contact: [email protected]
"""
from typing import Optional, Tuple
import torch
import torch.nn as nn
class Classifier(nn.Module):
"""A generic Classifier class for domain adaptation.
Args:
backbone (torch.nn.Module): Any backbone to extract 2-d features from data
num_classes (int): Number of classes
bottleneck (torch.nn.Module, optional): Any bottleneck layer. Use no bottleneck by default
bottleneck_dim (int, optional): Feature dimension of the bottleneck layer. Default: -1
head (torch.nn.Module, optional): Any classifier head. Use :class:`torch.nn.Linear` by default
finetune (bool): Whether finetune the classifier or train from scratch. Default: True
.. note::
Different classifiers are used in different domain adaptation algorithms to achieve better accuracy
respectively, and we provide a suggested `Classifier` for different algorithms.
Remember they are not the core of algorithms. You can implement your own `Classifier` and combine it with
the domain adaptation algorithm in this algorithm library.
.. note::
The learning rate of this classifier is set 10 times to that of the feature extractor for better accuracy
by default. If you have other optimization strategies, please over-ride :meth:`~Classifier.get_parameters`.
Inputs:
- x (tensor): input data fed to `backbone`
Outputs:
- predictions: classifier's predictions
- features: features after `bottleneck` layer and before `head` layer
Shape:
- Inputs: (minibatch, *) where * means, any number of additional dimensions
- predictions: (minibatch, `num_classes`)
- features: (minibatch, `features_dim`)
"""
def __init__(self, backbone: nn.Module, num_classes: int, bottleneck: Optional[nn.Module] = None,
bottleneck_dim: Optional[int] = -1, head: Optional[nn.Module] = None, finetune=True, pool_layer=None):
super(Classifier, self).__init__()
self.backbone = backbone
self.num_classes = num_classes
if pool_layer is None:
self.pool_layer = nn.Sequential(
nn.AdaptiveAvgPool2d(output_size=(1, 1)),
nn.Flatten()
)
else:
self.pool_layer = pool_layer
if bottleneck is None:
self.bottleneck = nn.Identity()
self._features_dim = backbone.out_features
else:
self.bottleneck = bottleneck
assert bottleneck_dim > 0
self._features_dim = bottleneck_dim
if head is None:
self.head = nn.Linear(self._features_dim, num_classes)
else:
self.head = head
self.finetune = finetune
@property
def features_dim(self) -> int:
"""The dimension of features before the final `head` layer"""
return self._features_dim
def forward(self, x: torch.Tensor):
""""""
f = self.backbone(x)
if len(f.shape) > 2:
f = self.pool_layer(f)
f = self.bottleneck(f)
predictions = self.head(f)
if self.training:
return predictions, f
else:
return predictions
def get_parameters(self, base_lr=1.0):
"""A parameter list which decides optimization hyper-parameters,
such as the relative learning rate of each layer
"""
params = [
{"params": self.backbone.parameters(), "lr": 0.1 * base_lr if self.finetune else 1.0 * base_lr},
{"params": self.bottleneck.parameters(), "lr": 1.0 * base_lr},
{"params": self.head.parameters(), "lr": 1.0 * base_lr},
]
return params
def get_head_parameters(self, base_lr=1.0):
"""A parameter list which decides optimization hyper-parameters,
such as the relative learning rate of each layer
"""
params = [
# {"params": self.backbone.parameters(), "lr": 0.1 * base_lr if self.finetune else 1.0 * base_lr},
# {"params": self.bottleneck.parameters(), "lr": 1.0 * base_lr},
{"params": self.head.parameters(), "lr": 1.0 * base_lr},
]
return params
class ImageClassifier(Classifier):
"""ImageClassifier specific for reproducing results of `DomainBed <https://github.com/facebookresearch/DomainBed>`_.
You are free to freeze all `BatchNorm2d` layers and insert one additional `Dropout` layer, this can achieve better
results for some datasets like PACS but may be worse for others.
Args:
backbone (torch.nn.Module): Any backbone to extract features from data
num_classes (int): Number of classes
freeze_bn (bool, optional): whether to freeze all `BatchNorm2d` layers. Default: False
dropout_p (float, optional): dropout ratio for additional `Dropout` layer, this layer is only used when `freeze_bn` is True. Default: 0.1
"""
def __init__(self, backbone: nn.Module, num_classes: int, freeze_bn: Optional[bool] = False,
dropout_p: Optional[float] = 0.1, **kwargs):
super(ImageClassifier, self).__init__(backbone, num_classes, **kwargs)
self.freeze_bn = freeze_bn
if freeze_bn:
self.feature_dropout = nn.Dropout(p=dropout_p)
def forward(self, x: torch.Tensor):
f = self.backbone(x)
if len(f.shape) > 2:
f = self.pool_layer(f)
f = self.bottleneck(f)
if self.freeze_bn:
f = self.feature_dropout(f)
predictions = self.head(f)
if self.training:
return predictions, f
else:
return predictions
def train(self, mode=True):
super(ImageClassifier, self).train(mode)
if self.freeze_bn:
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()