-
Notifications
You must be signed in to change notification settings - Fork 86
/
Copy pathrnn_stack2params.m
57 lines (56 loc) · 2.25 KB
/
rnn_stack2params.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
function [ params ] = rnn_stack2params( stack, eI, W_t, sum_tied )
%RNN_STACK2PARAMS converts stack structure of RNN weights to single vector
% Takes a stack strcutre with stack{l}.W and stack{l}.b for each layer
% Also takes single matrix of temporal weights W_t
% The flag sum_tied will sum tied encoder and decoder weights.
% This is useful for gradient aggregation
% Verifies the stack structure conforms to their descriptions in eI
% Namely checks eI.layerSizes, eI.inputDim, and eI.temporalLayer
%% assume no weight tieing if parameter unset
if ~isfield(eI, 'tieWeights')
eI.tieWeights = 0;
end;
if ~exist('sum_tied','var')
sum_tied = false;
end;
%% default short circuits to false
if ~isfield(eI, 'shortCircuit')
eI.shortCircuit = 0;
end;
% check short circuit consistency
assert( ~xor(eI.shortCircuit, isfield(stack{end},'W_ss')));
%% check first layer dimensions
assert( size(stack{1}.W,1) == eI.layerSizes(1));
assert( size(stack{1}.W,2) == eI.inputDim);
assert( size(stack{1}.b,1) == eI.layerSizes(1));
%% stack first layer
params = [ stack{1}.W(:); stack{1}.b(:)];
%% check and stack all layers. no special treatment of output layer
for l = 2 : numel(eI.layerSizes)
assert( size(stack{l}.W,1) == eI.layerSizes(l));
assert( size(stack{l}.W,2) == eI.layerSizes(l-1));
assert( size(stack{l}.b,1) == eI.layerSizes(l));
if ~eI.tieWeights || (l <= numel(eI.layerSizes)/2 ...
|| l == numel(eI.layerSizes))
% untied layer, save the weights
if eI.tieWeights && sum_tied && l < numel(eI.layerSizes)
% sum decoder weights if its a tied encoder layer
lDec = numel(eI.layerSizes) - l + 1 ;
params = [ params; reshape(stack{l}.W + stack{lDec}.W',[],1)];
else
params = [ params; stack{l}.W(:)];
end;
end;
% always aggregate bias
params = [ params; stack{l}.b(:)];
end
%% append temporal weight matrix
if ~isempty(W_t) || eI.temporalLayer
assert(size(W_t,1) == eI.layerSizes(eI.temporalLayer));
assert(size(W_t,2) == eI.layerSizes(eI.temporalLayer));
params = [ params; W_t(:)];
end
%% append short circuit matrix
if eI.shortCircuit
params = [params; stack{end}.W_ss(:)];
end;