-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdbn_classify_v2.m
62 lines (51 loc) · 2.01 KB
/
dbn_classify_v2.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
function [ErrorRate,tpos,tneg,fpos,fneg,conf,acc,err,recall,precision,avg_acc,avg_err,macro_f1,macro_recall,...
macro_precision,micro_f1,micro_recall,micro_precision] = dbn_classify_v2 (Xtrain, Ytrain, Xtest, Ytest, fl)
addpath('C:\Users\Vasilis\Documents\DBN_v2\');
nodes = [fl 3];
bbdbn = randDBN( nodes, 'BBDBN' );
nrbm = numel(bbdbn.rbm);
opts.MaxIter = 300; % 1000 before
opts.BatchSize = 100;
opts.Verbose = true;
opts.StepRatio = 0.1;
opts.object = 'CrossEntropy';
opts.LayerNum = nrbm-1; % it was Layer before
bbdbn = pretrainDBN(bbdbn, Xtrain, opts);
bbdbn= SetLinearMapping(bbdbn, Xtrain, Ytrain);
opts.Layer = 0;
opts.MaxIter = 50; % not existed before
%opts.StepRatio = 0.6; % not existed before
bbdbn = trainDBN(bbdbn, Xtrain, Ytrain, opts);
%rmse= CalcRmse(bbdbn, Xtrain, Ytrain);
[ErrorRate1,~]= CalcErrorRate(bbdbn, Xtrain, Ytrain);
fprintf( 'For training data:\n' );
%fprintf( 'rmse: %g\n', rmse );
fprintf( 'ErrorRate: %g\n', ErrorRate1 );
%rmse= CalcRmse(bbdbn, Xtest, Ytest);
[ErrorRate,out]= CalcErrorRate(bbdbn, Xtest, Ytest);
fprintf( 'For test data:\n' );
%fprintf( 'rmse: %g\n', rmse );
fprintf( 'ErrorRate: %g\n', ErrorRate );
% proper change to the out-table in order to fit the feelings.m format
dbn_out = zeros(size(out,1),1);
dbn_te = zeros(size(Ytest,1),1);
for i=1:size(out,1),
if out(i,1) == 1,
dbn_out(i) = -1;
elseif out(i,2) == 1,
dbn_out(i) = 0;
else
dbn_out(i) = 1;
end
if Ytest(i,1) == 1,
dbn_te(i) = -1;
elseif Ytest(i,2) == 1,
dbn_te(i) = 0;
else
dbn_te(i) = 1;
end
end
[tpos,tneg,fpos,fneg,conf] = feelings(dbn_te, dbn_out); % newly added
[acc,err,recall,precision,avg_acc,avg_err,macro_f1,macro_recall,macro_precision,micro_f1,micro_recall,micro_precision]...
= metrics_calculation(tpos,tneg,fpos,fneg);
return