-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathvisualize_caffe.py
48 lines (39 loc) · 1.52 KB
/
visualize_caffe.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import numpy as np
import matplotlib.pyplot as plt
def visualize_weights(net, layer_name, padding=4, filename=''):
# The parameters are a list of [weights, biases]
data = np.copy(net.params[layer_name][0].data)
# N is the total number of convolutions
N = data.shape[0]*data.shape[1]
# Ensure the resulting image is square
filters_per_row = int(np.ceil(np.sqrt(N)))
# Assume the filters are square
filter_size = data.shape[2]
# Size of the result image including padding
result_size = filters_per_row*(filter_size + padding) - padding
# Initialize result image to all zeros
result = np.zeros((result_size, result_size))
# Tile the filters into the result image
filter_x = 0
filter_y = 0
for n in range(data.shape[0]):
for c in range(data.shape[1]):
if filter_x == filters_per_row:
filter_y += 1
filter_x = 0
for i in range(filter_size):
for j in range(filter_size):
result[filter_y*(filter_size + padding) + i, filter_x*(filter_size + padding) + j] = data[n, c, i, j]
filter_x += 1
# Normalize image to 0-1
min = result.min()
max = result.max()
result = (result - min) / (max - min)
# Plot figure
plt.figure(figsize=(10, 10))
plt.axis('off')
plt.imshow(result, cmap='gray', interpolation='nearest')
# Save plot if filename is set
if filename != '':
plt.savefig(filename, bbox_inches='tight', pad_inches=0)
plt.show()