-
Notifications
You must be signed in to change notification settings - Fork 12
/
utils.py
100 lines (91 loc) · 4.36 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
# code for kNN prediction from here:
# https://colab.research.google.com/github/facebookresearch/moco/blob/colab-notebook/colab/moco_cifar10_demo.ipynb
def knn_predict(feature, feature_bank, feature_labels, classes: int, knn_k: int, knn_t: float):
"""Helper method to run kNN predictions on features based on a feature bank
Args:
feature: Tensor of shape [N, D] consisting of N D-dimensional features
feature_bank: Tensor of a database of features used for kNN
feature_labels: Labels for the features in our feature_bank
classes: Number of classes (e.g. 10 for CIFAR-10)
knn_k: Number of k neighbors used for kNN
knn_t:
"""
# compute cos similarity between each feature vector and feature bank ---> [B, N]
sim_matrix = torch.mm(feature, feature_bank)
# [B, K]
sim_weight, sim_indices = sim_matrix.topk(k=knn_k, dim=-1)
# [B, K]
sim_labels = torch.gather(feature_labels.expand(feature.size(0), -1), dim=-1, index=sim_indices)
# we do a reweighting of the similarities
sim_weight = (sim_weight / knn_t).exp()
# counts for each class
one_hot_label = torch.zeros(feature.size(0) * knn_k, classes, device=sim_labels.device)
# [B*K, C]
one_hot_label = one_hot_label.scatter(dim=-1, index=sim_labels.view(-1, 1), value=1.0)
# weighted score ---> [B, C]
pred_scores = torch.sum(one_hot_label.view(feature.size(0), -1, classes) * sim_weight.unsqueeze(dim=-1), dim=1)
pred_labels = pred_scores.argsort(dim=-1, descending=True)
return pred_labels
class BenchmarkModule(pl.LightningModule):
"""A PyTorch Lightning Module for automated kNN callback
At the end of every training epoch we create a feature bank by inferencing
the backbone on the dataloader passed to the module.
At every validation step we predict features on the validation data.
After all predictions on validation data (validation_epoch_end) we evaluate
the predictions on a kNN classifier on the validation data using the
feature_bank features from the train data.
We can access the highest accuracy during a kNN prediction using the
max_accuracy attribute.
"""
def __init__(self, dataloader_kNN, gpus, classes, knn_k, knn_t):
super().__init__()
self.backbone = nn.Module()
self.max_accuracy = 0.0
self.dataloader_kNN = dataloader_kNN
self.gpus = gpus
self.classes = classes
self.knn_k = knn_k
self.knn_t = knn_t
def training_epoch_end(self, outputs):
# update feature bank at the end of each training epoch
self.backbone.eval()
self.feature_bank = []
self.targets_bank = []
with torch.no_grad():
for data in self.dataloader_kNN:
img, target, _ = data
if self.gpus > 0:
img = img.cuda()
target = target.cuda()
feature = self.backbone(img).squeeze()
feature = F.normalize(feature, dim=1)
self.feature_bank.append(feature)
self.targets_bank.append(target)
self.feature_bank = torch.cat(self.feature_bank, dim=0).t().contiguous()
self.targets_bank = torch.cat(self.targets_bank, dim=0).t().contiguous()
self.backbone.train()
def validation_step(self, batch, batch_idx):
# we can only do kNN predictions once we have a feature bank
if hasattr(self, 'feature_bank') and hasattr(self, 'targets_bank'):
images, targets, _ = batch
feature = self.backbone(images).squeeze()
feature = F.normalize(feature, dim=1)
pred_labels = knn_predict(feature, self.feature_bank, self.targets_bank, self.classes, self.knn_k, self.knn_t)
num = images.size(0)
top1 = (pred_labels[:, 0] == targets).float().sum().item()
return (num, top1)
def validation_epoch_end(self, outputs):
if outputs:
total_num = 0
total_top1 = 0.
for (num, top1) in outputs:
total_num += num
total_top1 += top1
acc = float(total_top1 / total_num)
if acc > self.max_accuracy:
self.max_accuracy = acc
self.log('kNN_accuracy', acc * 100.0, prog_bar=True)