-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathdemgtm1.m
147 lines (133 loc) · 4.77 KB
/
demgtm1.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
%DEMGTM1 Demonstrate EM for GTM.
%
% Description
% This script demonstrates the use of the EM algorithm to fit a one-
% dimensional GTM to a two-dimensional set of data using maximum
% likelihood. The location and spread of the Gaussian kernels in the
% data space is shown during training.
%
% See also
% DEMGTM2, GTM, GTMEM, GTMPOST
%
% Copyright (c) Ian T Nabney (1996-2001)
% Demonstrates the GTM with a 2D target space and a 1D latent space.
%
% This script generates a simple data set in 2 dimensions,
% with an intrinsic dimensionality of 1, and trains a GTM
% with a 1-dimensional latent variable to model this data
% set, visually illustrating the training process
%
% Synopsis: gtm_demo
% Generate and plot a 2D data set
data_min = 0.15;
data_max = 3.05;
T = [data_min:0.05:data_max]';
T = [T (T + 1.25*sin(2*T))];
fh1 = figure;
plot(T(:,1), T(:,2), 'ro');
axis([data_min-0.05 data_max+0.05 data_min-0.05 data_max+0.05]);
clc;
disp('This demonstration shows in detail how the EM algorithm works')
disp('for training a GTM with a one dimensional latent space.')
disp(' ')
fprintf([...
'The figure shows data generated by feeding a 1D uniform distribution\n', ...
'(on the X-axis) through a non-linear function (y = x + 1.25*sin(2*x))\n', ...
'\nPress any key to continue ...\n\n']);
pause;
% Generate a unit circle figure, to be used for plotting
src = [0:(2*pi)/(20-1):2*pi]';
unitC = [sin(src) cos(src)];
% Generate and plot (along with the data) an initial GTM model
clc;
num_latent_points = 20;
num_rbf_centres = 5;
net = gtm(1, num_latent_points, 2, num_rbf_centres, 'gaussian');
options = zeros(1, 18);
options(7) = 1;
net = gtminit(net, options, T, 'regular', num_latent_points, ...
num_rbf_centres);
mix = gtmfwd(net);
% Replot the figure
hold off;
plot(mix.centres(:,1), mix.centres(:,2), 'g');
hold on;
for i=1:num_latent_points
c = 2*unitC*sqrt(mix.covars(1)) + [ones(20,1)*mix.centres(i,1) ...
ones(num_latent_points,1)*mix.centres(i,2)];
fill(c(:,1), c(:,2), [0.8 1 0.8]);
end
plot(T(:,1), T(:,2), 'ro');
plot(mix.centres(:,1), mix.centres(:,2), 'g+');
plot(mix.centres(:,1), mix.centres(:,2), 'g');
axis([data_min-0.05 data_max+0.05 data_min-0.05 data_max+0.05]);
drawnow;
title('Initial configuration');
disp(' ')
fprintf([...
'The figure shows the starting point for the GTM, before the training.\n', ...
'A discrete latent variable distribution of %d points in 1 dimension \n', ...
'is mapped to the 1st principal component of the target data by an RBF.\n', ...
'with %d basis functions. Each of the %d points defines the centre of\n', ...
'a Gaussian in a Gaussian mixture, marked by the green ''+''-signs. The\n', ...
'mixture components all have equal variance, illustrated by the filled\n', ...
'circle around each ''+''-sign, the radii corresponding to 2 standard\n', ...
'deviations. The ''+''-signs are connected with a line according to their\n', ...
'corresponding ordering in latent space.\n\n', ...
'Press any key to begin training ...\n\n'], num_latent_points, ...
num_rbf_centres, num_latent_points);
pause;
figure(fh1);
%%%% Train the GTM and plot it (along with the data) as training proceeds %%%%
options = foptions;
options(1) = -1; % Turn off all warning messages
options(14) = 1;
for j = 1:15
[net, options] = gtmem(net, T, options);
hold off;
mix = gtmfwd(net);
plot(mix.centres(:,1), mix.centres(:,2), 'g');
hold on;
for i=1:20
c = 2*unitC*sqrt(mix.covars(1)) + [ones(20,1)*mix.centres(i,1) ...
ones(20,1)*mix.centres(i,2)];
fill(c(:,1), c(:,2), [0.8 1.0 0.8]);
end
plot(T(:,1), T(:,2), 'ro');
plot(mix.centres(:,1), mix.centres(:,2), 'g+');
plot(mix.centres(:,1), mix.centres(:,2), 'g');
axis([0 3.5 0 3.5]);
title(['After ', int2str(j),' iterations of training.']);
drawnow;
if (j == 4)
fprintf([...
'The GTM initially adapts relatively quickly - already after \n', ...
'4 iterations of training, a rough fit is attained.\n\n', ...
'Press any key to continue training ...\n\n']);
pause;
figure(fh1);
elseif (j == 8)
fprintf([...
'After another 4 iterations of training: from now on further \n', ...
'training only makes small changes to the mapping, which combined with \n', ...
'decrements of the Gaussian mixture variance, optimize the fit in \n', ...
'terms of likelihood.\n\n', ...
'Press any key to continue training ...\n\n']);
pause;
figure(fh1);
else
pause(1);
end
end
clc;
fprintf([...
'After 15 iterations of training the GTM can be regarded as converged. \n', ...
'Is has been adapted to fit the target data distribution as well \n', ...
'as possible, given prior smoothness constraints on the mapping. It \n', ...
'captures the fact that the probabilty density is higher at the two \n', ...
'bends of the curve, and lower towards its end points.\n\n']);
disp(' ');
disp('Press any key to exit.');
pause;
close(fh1);
clear all;