forked from wmvanvliet/viswordrec-baseline
-
Notifications
You must be signed in to change notification settings - Fork 0
/
get_model_layer_activity.py
92 lines (79 loc) · 3.14 KB
/
get_model_layer_activity.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
"""
Run the stimuli used in the MEG experiment through the model and record the
activity in each layer.
"""
import torch
from torchvision import transforms
import numpy as np
import pickle
import pandas as pd
import mkl
mkl.set_num_threads(4)
from PIL import Image
from tqdm import tqdm
import network
import dataloader
data_path = './data'
classes = dataloader.WebDataset(f'{data_path}/datasets/epasana-10kwords').classes
classes.append(pd.Series(['noise'], index=[10000]))
# In order to get word2vec vectors, the KORAANI class was replaced with
# KORAANIN. Switch this back, otherwise, this word will be erroneously flagged
# as being misclassified.
classes[classes == 'KORAANIN'] = 'KORAANI'
# Load the TIFF images presented in the MEG experiment and apply the
# ImageNet preprocessing transformation to them.
stimuli = pd.read_csv('stimuli.csv')
preproc = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
images = []
for fname in tqdm(stimuli['tif_file'], desc='Reading images'):
with Image.open(f'{data_path}/stimulus_images/{fname}') as orig:
image = Image.new('RGB', (224, 224), '#696969')
image.paste(orig, (12, 62))
image = preproc(image).unsqueeze(0)
images.append(image)
images = torch.cat(images, 0)
# Load the model and feed through the images. Make sure to feed the images
# through as a single batch, because I don't trust the BatchNormalization2d
# layers to behave predictably otherwise.
checkpoint = torch.load(f'{data_path}/models/vgg11_first_imagenet_then_epasana-10kwords_noise.pth.tar', map_location='cpu')
model = network.VGG11.from_checkpoint(checkpoint, freeze=True)
layer_outputs = model.get_layer_activations(
images,
feature_layers=[2, 6, 13, 20, 27],
classifier_layers=[1, 4, 6]
)
layer_activity = []
for output in layer_outputs:
if output.shape[-1] == 10001:
print('Removing nontext class')
output = np.hstack((output[:, :10000], output[:, 10001:]))
print('New output shape:', output.shape)
layer_activity.append(output)
layer_names = [
'conv1_relu',
'conv2_relu',
'conv3_relu',
'conv4_relu',
'conv5_relu',
'fc1_relu',
'fc2_relu',
'word_relu',
]
mean_activity = np.array([np.square(a.reshape(len(stimuli), -1)).mean(axis=1)
for a in layer_activity])
with open(f'{data_path}/model_layer_activity.pkl', 'wb') as f:
pickle.dump(dict(layer_names=layer_names, mean_activity=mean_activity), f)
# Translate output of the model to text predictions
predictions = stimuli.copy()
predictions['predicted_text'] = classes[layer_activity[-1].argmax(axis=1)].values
predictions['predicted_class'] = layer_activity[-1].argmax(axis=1)
predictions.to_csv('predictions.csv')
# How many of the word stimuli did we classify correctly?
word_predictions = predictions.query('type=="word"')
n_correct = (word_predictions['text'] == word_predictions['predicted_text']).sum()
accuracy = n_correct / len(word_predictions)
print(f'Word prediction accuracy: {n_correct}/{len(word_predictions)} = {accuracy * 100:.1f}%') # 113/118, not bad at all!