-
Notifications
You must be signed in to change notification settings - Fork 0
/
weirdpredict.m
63 lines (55 loc) · 2.48 KB
/
weirdpredict.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
% function predictions = weirdpredict(y, X, model, verbose)
%
% This functions computes predictions for a dataset X given a model
% computed by weirdtrain.
%
% INPUT:
% y: label vector (n_samples x 1)
% X: testing data (n_samples x n_features)
% model: struct as returned by weirdtrain
% distance_type: 'manhattan' or 'euclidean' (default: euclidean)
% verbose: if true, prints accuracy (default: true)
%
% OUTPUT:
% predictions: predicted label for each sample in X
%
% WeiRD stands for "Weighted Robust Distance" and is a fast and simple
% classification algorithm that assigns class labels based on the distance
% to class prototypes. The distance is the euclidian distance between a
% current sample and a prototype in a space, in which each feature
% dimension is scaled by the two-sample t-value of the respective feature
% in the training data. Class prototypes correspond to the averages of
% each feature in the training data. The current implementation works for
% two-class problems only.
%
% by Matthias Guggenmos 06/07/16
% _________________________________________________________________________
% Matthias Guggenmos, Katharina Schmack and Philipp Sterzer, "WeiRD - a fast
% and performant multivariate pattern classifier", 2016, International Workshop
% on Pattern Recognition in Neuroimaging (PRNI), Trento, Italy, 2016, pp. 1-4.
% doi: 10.1109/PRNI.2016.7552349
function predictions = weirdpredict(y, X, model, distance_type, verbose)
if nargin <= 3 || isempty(distance_type)
distance_type = 'euclidean';
end
if nargin <= 4 || isempty(verbose)
verbose = false;
end
if strcmp(distance_type, 'manhattan')
% compute votes of each feature based on distance to centroid
votes = abs(X - repmat(model.x1, size(X, 1), 1)) - ...
abs(X - repmat(model.x2, size(X, 1), 1));
% compute decision values as a weighted sum of votes and feature
% importances
fi_matrix = repmat(model.feature_importances_, size(votes, 1), 1);
dec = dot(votes, fi_matrix, 2) / size(votes, 2);
elseif strcmp(distance_type, 'euclidean')
dec = sum((model.feature_importances_ .* (X - model.x1)) .^ 2, 2) -...
sum((model.feature_importances_ .* (X - model.x2)) .^ 2, 2);
end
% compute predictions based on the sign of the decision value
predictions = arrayfun(@(x)model.classes(x+1), dec > 0);
if verbose
fprintf('Accuracy: %.1f%%\n', 100*mean(predictions == y))
end
end