-
Notifications
You must be signed in to change notification settings - Fork 141
/
finetune.py
78 lines (66 loc) · 2.59 KB
/
finetune.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import torch
from torch import Tensor, nn
from .few_shot_classifier import FewShotClassifier
class Finetune(FewShotClassifier):
"""
Wei-Yu Chen, Yen-Cheng Liu, Zsolt Kira, Yu-Chiang Frank Wang, Jia-Bin Huang
A Closer Look at Few-shot Classification (ICLR 2019)
https://arxiv.org/abs/1904.04232
Fine-tune prototypes based on classification error on support images.
Classify queries based on their cosine distances to updated prototypes.
As is, it is incompatible with episodic training because we freeze the backbone to perform
fine-tuning.
This is an inductive method.
"""
def __init__(
self,
*args,
fine_tuning_steps: int = 200,
fine_tuning_lr: float = 1e-4,
temperature: float = 1.0,
**kwargs,
):
"""
Args:
fine_tuning_steps: number of fine-tuning steps
fine_tuning_lr: learning rate for fine-tuning
temperature: temperature applied to the logits before computing
softmax or cross-entropy. Higher temperature means softer predictions.
"""
super().__init__(*args, **kwargs)
# Since we fine-tune the prototypes we need to make them leaf variables
# i.e. we need to freeze the backbone.
self.backbone.requires_grad_(False)
self.fine_tuning_steps = fine_tuning_steps
self.fine_tuning_lr = fine_tuning_lr
self.temperature = temperature
def forward(
self,
query_images: Tensor,
) -> Tensor:
"""
Overrides forward method of FewShotClassifier.
Fine-tune prototypes based on support classification error.
Then classify w.r.t. to cosine distance to prototypes.
"""
query_features = self.compute_features(query_images)
with torch.enable_grad():
self.prototypes.requires_grad_()
optimizer = torch.optim.Adam([self.prototypes], lr=self.fine_tuning_lr)
for _ in range(self.fine_tuning_steps):
support_logits = self.cosine_distance_to_prototypes(
self.support_features
)
loss = nn.functional.cross_entropy(
self.temperature * support_logits, self.support_labels
)
optimizer.zero_grad()
loss.backward()
optimizer.step()
return self.softmax_if_specified(
self.cosine_distance_to_prototypes(query_features),
temperature=self.temperature,
).detach()
@staticmethod
def is_transductive() -> bool:
return False