-
Notifications
You must be signed in to change notification settings - Fork 0
/
gradnorm_postprocessor.py
49 lines (39 loc) · 1.53 KB
/
gradnorm_postprocessor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from typing import Any
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from .base_postprocessor import BasePostprocessor
from .info import get_num_classes
class GradNormPostprocessor(BasePostprocessor):
def __init__(self, config):
super().__init__(config)
self.args = self.config.postprocessor.postprocessor_args
self.num_classes = get_num_classes(self.config.dataset.name)
def gradnorm(self, x, w, b):
fc = torch.nn.Linear(*w.shape[::-1])
fc.weight.data[...] = torch.from_numpy(w)
fc.bias.data[...] = torch.from_numpy(b)
fc.cuda()
targets = torch.ones((1, self.num_classes)).cuda()
confs = []
for i in x:
fc.zero_grad()
loss = torch.mean(
torch.sum(-targets * F.log_softmax(fc(i[None]), dim=-1),
dim=-1))
loss.backward()
layer_grad_norm = torch.sum(torch.abs(
fc.weight.grad.data)).cpu().numpy()
confs.append(layer_grad_norm)
return np.array(confs)
def setup(self, net: nn.Module, id_loader_dict, ood_loader_dict, id_loader_split=None):
pass
@torch.no_grad()
def postprocess(self, net: nn.Module, data: Any):
w, b = net.get_fc()
logits, features = net.forward(data, return_feature=True)
with torch.enable_grad():
scores = self.gradnorm(features, w, b)
_, preds = torch.max(logits, dim=1)
return preds, torch.from_numpy(scores)