-
Notifications
You must be signed in to change notification settings - Fork 0
/
reproducing_descrambler_paper_sigmoid.m
81 lines (75 loc) · 2.28 KB
/
reproducing_descrambler_paper_sigmoid.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
%% Here we'll fiddle with the descrambled data for layer 1
% cd('/home/ssule25/Documents/spinach_2_6_5625/My_experiments/2_Layer_DeerNet/')
cd(strcat(pwd(), '/2_Layer_DeerNet/sigmoid'))
% net = load('/home/ssule25/Documents/spinach_2_6_5625/My_experiments/2_Layer_DeerNet/sigmoid/2_layer_DEERNET_sigmoid.mat');
net = load('2_layer_DEERNET_sigmoid.mat');
net = net.net;
%% Make a new net that is just the first layer of the network
layers_1 = net.Layers(1:2,1);
layers_1(3,1) = net.Layers(end,1);
net_1 = assembleNetwork(layers_1);
layers_2 = net.Layers(1:3,1);
layers_2(4,1) = net.Layers(end,1);
net_2 = assembleNetwork(layers_2);
layers_3 = net.Layers(1:4, 1);
layers_3(5,1) = net.Layers(end,1);
net_3 = assembleNetwork(layers_3);
%% Set up data
% P = load('/home/ssule25/Documents/spinach_2_6_5625/My_experiments/2_Layer_DeerNet/sigmoid/P_1.mat');
% P_2 = load('/home/ssule25/Documents/spinach_2_6_5625/My_experiments/2_Layer_DeerNet/sigmoid/P_2.mat');
% descram_W = load('/home/ssule25/Documents/spinach_2_6_5625/My_experiments/2_Layer_DeerNet/sigmoid/Wd_1.mat');
P = load('P_1.mat');
P_2 = load('P_2.mat');
P_1 = P.P;
P_1 = gather(P_1);
P_2 = P_2.P_2;
P_2 = gather(P_2);
descram_W = load('Wd_1.mat');
D_80 = load('recentered_80.mat');
D_256 = load('recentered_256.mat');
descram_W = descram_W.descrambled_weight_mat;
D_80 = D_80.D_80;
D_256 = D_256.D_256;
W = net.Layers(2,1).Weights;
W_2 = net.Layers(4,1).Weights;
descram_W_2 = W_2*P_2';
[U_2_raw, Sigma_raw, V_2_raw] = svd(W_2);
[U_2, Sigma, V_2] = svd(descram_W_2);
%% Visualize!!!
figure();
tiledlayout(2,2);
nexttile
imagesc(W');
title("Raw Weight matrix");
axis xy
nexttile
imagesc(descram_W');
title("Descrambled Weight Matrix");
axis xy
nexttile
imagesc(abs(D_80'*W*D_256)');
title("$F^+ W F^-$", 'Interpreter', 'latex');
axis xy
nexttile
imagesc(abs(D_80'*descram_W*D_256)');
title("$F^+ PW F^-$", 'Interpreter', 'latex');
axis xy
%saveas(gcf, 'Layer1.png');
%% Visualize second layer!
figure();
tiledlayout(2,2);
nexttile
imagesc(W_2')
title("Raw weight matrix");
axis xy
nexttile
imagesc(P_2*W_2');
title("Descrambled Weight matrix");
axis xy
nexttile
imagesc(V_2_raw);
title("Right singular vectors of raw weight");
nexttile
imagesc(V_2);
title("Right singular vectors of descrambled weight");
saveas(gcf, 'Layer2.png');