forked from probml/pmtk3
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathkernelClassifierComparison.html
1027 lines (997 loc) · 72.6 KB
/
kernelClassifierComparison.html
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
<h1>Empirical comparison of some kernelized classifiers</h1>
This was written by Kevin Murphy in about 2010 and relies on PMTK3, which is deprecated.
This page was retrieved from the
<a href="https://web.archive.org/web/20160506233505/http://pmtk3.googlecode.com/svn/trunk/docs/tutorial/html/tutKernelClassif.html">wayback machine cached copy</a>, and is missing all the figures.
Also, the code no longer runs, and many hyperlinks may be broken.
But hopefully there is still some value in the content.
<h2>Kernel functions</h2>
One common form of basis function expansion is to define a new feature vector <img src="./Supervised learning using non-parametric discriminative models in pmtk3_files/tutKernelClassif_eq31990.png" alt="$\phi(x)$"> by comparing the input <img src="./Supervised learning using non-parametric discriminative models in pmtk3_files/tutKernelClassif_eq43551.png" alt="$x$"> to a set of prototypes or examplars <img src="./Supervised learning using non-parametric discriminative models in pmtk3_files/tutKernelClassif_eq62509.png" alt="$\mu_k$"> as follows:</p><p><img src="./Supervised learning using non-parametric discriminative models in pmtk3_files/tutKernelClassif_eq09998.png" alt="$$\phi(x) = (K(x,\mu_1), ..., K(x,\mu_D))$$"></p><p>Here <img src="./Supervised learning using non-parametric discriminative models in pmtk3_files/tutKernelClassif_eq94202.png" alt="$K(x,\mu)$"> is a 'kernel function', which in this context just means a function of two arguments. A common example is the Gaussian or RBF kernel</p><p><img src="./Supervised learning using non-parametric discriminative models in pmtk3_files/tutKernelClassif_eq89797.png" alt="$$K(x,\mu) = \exp(-\frac{||x-\mu||^2}{2\sigma^2})$$"></p><p>where <img src="./Supervised learning using non-parametric discriminative models in pmtk3_files/tutKernelClassif_eq24873.png" alt="$\sigma$"> is the 'bandwidth'. This can be created using <a href="https://web.archive.org/web/20160506233505/http://pmtk3.googlecode.com/svn/trunk/toolbox/Algorithms/kernels/kernelRbfSigma.m">kernelRbfSigma.m</a> . Alternatively, we can write</p><p><img src="./Supervised learning using non-parametric discriminative models in pmtk3_files/tutKernelClassif_eq61895.png" alt="$$K(x,\mu) = \exp(-\gamma ||x-\mu||^2)$$"></p><p>The quantity <img src="./Supervised learning using non-parametric discriminative models in pmtk3_files/tutKernelClassif_eq89526.png" alt="$\gamma=1/\sigma^2$"> is known as the scale or precision. This can be created using <a href="https://web.archive.org/web/20160506233505/http://pmtk3.googlecode.com/svn/trunk/toolbox/Algorithms/kernels/kernelRbfGamma.m">kernelRbfGamma.m</a> . Most software packages use this latter parameterization.</p><p>Another common example is the polynomial kernel <a href="https://web.archive.org/web/20160506233505/http://pmtk3.googlecode.com/svn/trunk/toolbox/Algorithms/kernels/kernelPolyPmtk.m">kernelPolyPmtk.m</a> :</p><p><img src="./Supervised learning using non-parametric discriminative models in pmtk3_files/tutKernelClassif_eq61193.png" alt="$$K(x,\mu) = (1+x^T \mu)^d$$"></p><p>where d is the degree.</p><p>Another common example is the linear kernel <a href="https://web.archive.org/web/20160506233505/http://pmtk3.googlecode.com/svn/trunk/toolbox/Algorithms/kernels/kernelLinearPmtk.m">kernelLinearPmtk.m</a> :</p><p><img src="./Supervised learning using non-parametric discriminative models in pmtk3_files/tutKernelClassif_eq18816.png" alt="$$K(x,\mu) = x^T \mu$$"></p><p>(The reason for the 'pmtk' suffix is to distinguish these functions from other implementations of the same concept.)</p><p>Often we take the prototypes <img src="./Supervised learning using non-parametric discriminative models in pmtk3_files/tutKernelClassif_eq62509.png" alt="$\mu_k$"> to be the training vectors (rows of <img src="./Supervised learning using non-parametric discriminative models in pmtk3_files/tutKernelClassif_eq03598.png" alt="$X$">), but we don't have to. Some methods require that the kernel be a Mercer (positive definite) kernel. All of the above kernels are Mercer kernels, but this is not always the case.</p><p>The advantages of using kernels include the following</p><div><ul><li>We can apply standard parametric models (e.g., linear and logistic regression) to non-vectorial inputs (e.g., strings, molecular structures, etc.), by defining <img src="./Supervised learning using non-parametric discriminative models in pmtk3_files/tutKernelClassif_eq94202.png" alt="$K(x,\mu)$"> to be some kind of function for comparing structured inputs.</li><li>We can increase the flexibility of the model by working in an enlarged feature space.</li></ul></div><p>Below we show an example where we fit the XOR data using kernelized logistic regression, with various kernels and prototypes (from <a href="https://web.archive.org/web/20160506233505/http://pmtk3.googlecode.com/svn/trunk/demos/bookDemos/Introduction/logregXorDemo.m">logregXorDemo.m</a> )
.</p><pre class="codeinput">clear <span class="string">all</span>; close <span class="string">all</span>
[X, y] = createXORdata();
rbfSigma = 1;
polydeg = 2;
protoTypes = [1 1; 1 5; 5 1; 5 5];
protoTypesStnd = standardizeCols(protoTypes);
kernels = {@(X1, X2)kernelRbfSigma(X1, protoTypesStnd, rbfSigma)
@(X1, X2)kernelRbfSigma(X1, X2, rbfSigma)
@(X1, X2)kernelPolyPmtk(X1, X2, polydeg)};
titles = {<span class="string">'rbf'</span>, <span class="string">'rbf prototypes'</span>, <span class="string">'poly'</span>};
<span class="keyword">for</span> i=1:numel(kernels)
preproc = preprocessorCreate(<span class="string">'kernelFn'</span>, kernels{i}, <span class="string">'standardizeX'</span>, true, <span class="string">'addOnes'</span>, true);
model = logregFit(X, y, <span class="string">'preproc'</span>, preproc);
yhat = logregPredict(model, X);
errorRate = mean(yhat ~= y);
fprintf(<span class="string">'Error rate using %s features: %2.f%%\n'</span>, titles{i}, 100*errorRate);
predictFcn = @(Xtest)logregPredict(model, Xtest);
plotDecisionBoundary(X, y, predictFcn);
<span class="keyword">if</span> i==1
hold <span class="string">on</span>;
plot(protoTypes(:, 1), protoTypes(:, 2), <span class="string">'*k'</span>, <span class="string">'linewidth'</span>, 2, <span class="string">'markersize'</span>, 10)
<span class="keyword">end</span>
title(titles{i});
<span class="keyword">end</span>
</pre><pre class="codeoutput">Error rate using rbf features: 0%
Error rate using rbf prototypes features: 0%
Error rate using poly features: 0%
</pre><img vspace="5" hspace="5" src="./Supervised learning using non-parametric discriminative models in pmtk3_files/tutKernelClassif_01.png" alt=""> <img vspace="5" hspace="5" src="./Supervised learning using non-parametric discriminative models in pmtk3_files/tutKernelClassif_02.png" alt=""> <img vspace="5" hspace="5" src="./Supervised learning using non-parametric discriminative models in pmtk3_files/tutKernelClassif_03.png" alt=""> <p>In the first example, we use an RBF kernel with centers at 4 manually chosen points, shown with black stars. In the second and third examples, we use an RBF and polynomial kernel, centered at all the training data. This is an example of a non-parametric model, since the number of parameters grows with the size of the training set (which makes training slow on large datasets). We can use sparsity promoting priors to select a subset of the training data, as we illustrate below.</p>
<h2>Using grid search plus cross validation to choose the kernel parameters<a name="14"></a></h2><p>We can create a grid of models, with different kernels and different regularizers, as shown in the example below ( from <a href="https://web.archive.org/web/20160506233505/http://pmtk3.googlecode.com/svn/trunk/demos/otherDemos/supervisedModels/logregKernelDemo.m">logregKernelDemo.m</a> ). If CV does not pick a point on the edge of the grid, we can be faily confident we have searched over a reasonable range. For this reason, it is helpful to plot the cost surface.</p><pre class="codeinput">loadData(<span class="string">'fglass'</span>); <span class="comment">% 6 classes, X is 214*9</span>
X = [Xtrain; Xtest];
y = canonizeLabels([ytrain; ytest]); <span class="comment">% class 4 is missing, so relabel 1:6</span>
setSeed(0);
split = 0.7;
[X, y] = shuffleRows(X, y);
X = rescaleData(standardizeCols(X));
N = size(X, 1);
nTrain = floor(split*N);
nTest = N - nTrain;
Xtrain = X(1:nTrain, :);
Xtest = X(nTrain+1:end, :);
ytrain = y(1:nTrain);
ytest = y(nTrain+1:end);
<span class="comment">% 2D CV</span>
lambdaRange = logspace(-6, 1, 5);
gammaRange = logspace(-4, 4, 5);
paramRange = crossProduct(lambdaRange, gammaRange);
regtypes = {<span class="string">'L2'</span>}; <span class="comment">%L1 is a bit better but a bit slower</span>
<span class="keyword">for</span> r=1:length(regtypes)
regtype = regtypes{r};
fitFn = @(X, y, param)<span class="keyword">...</span>
logregFit(X, y, <span class="string">'lambda'</span>, param(1), <span class="string">'regType'</span>, regtype, <span class="string">'preproc'</span>, <span class="keyword">...</span>
preprocessorCreate(<span class="string">'kernelFn'</span>, @(X1, X2)kernelRbfGamma(X1, X2, param(2))));
predictFn = @logregPredict;
lossFn = @(ytest, yhat)mean(yhat ~= ytest);
nfolds = 5;
useSErule = true;
plotCv = true;
tic;
[LRmodel, bestParam, LRmu, LRse] = <span class="keyword">...</span>
fitCv(paramRange, fitFn, predictFn, lossFn, Xtrain, ytrain, nfolds, <span class="keyword">...</span>
<span class="string">'useSErule'</span>, useSErule, <span class="string">'doPlot'</span>, plotCv, <span class="string">'params1'</span>, lambdaRange, <span class="string">'params2'</span>, gammaRange);
time(r) = toc
yhat = logregPredict(LRmodel, Xtest);
nerrors(r) = sum(yhat ~= ytest);
<span class="keyword">end</span>
errRate = nerrors/nTest
</pre><pre class="codeoutput">time =
31.2589
errRate =
0.4154
</pre><img vspace="5" hspace="5" src="./Supervised learning using non-parametric discriminative models in pmtk3_files/tutKernelClassif_04.png" alt=""> <p>In the example above, we just use a 5x5 grid for speed, but in practice one might use a 10x10 grid for a coarse search (possibly on a subset of the data), followed by a more refined search in a promising part of hyper-parameter space. This could all be handed off to a generic discrete optimization algorithm, but this is not yet supported. (One big advantage of Gaussian processes, which we will discuss later, is that we can use continous optimization algorithms to tune the kernel parameters.)</p><h2>Sparse multinomial logistic regression (SMLR)<a name="17"></a></h2><p>We can select a subset of the training examples by using an L1 regularizer. This is called Sparse multinomial logistic regression (SMLR). If we use an L2 regularizer instead of L1, we call the method 'ridged multinomial logistic regression' or RMLR. (This terminology is from the paper <a href="https://web.archive.org/web/20160506233505/http://www.lx.it.pt/~mtf/Krishnapuram_Carin_Figueiredo_Hartemink_2005.pdf">"Learning sparse Bayesian classifiers: multi-class formulation, fast algorithms, and generalization bounds"</a>, Krishnapuram et al, PAMI 2005.)</p><p>One way to implement <a href="https://web.archive.org/web/20160506233505/http://pmtk3.googlecode.com/svn/trunk/toolbox/SupervisedModels/smlr/smlrFit.m">smlrFit.m</a> is to kernelize the data, and then pick the best lambda on the regularization path using <a href="https://web.archive.org/web/20160506233505/http://pmtk3.googlecode.com/svn/trunk/toolbox/SupervisedModels/logreg/logregFitPathCv.m">logregFitPathCv.m</a> (which uses <a href="https://web.archive.org/web/20160506233505/http://www-stat.stanford.edu/~tibs/glmnet-matlab/">glmnet</a>). Another way is call <a href="https://web.archive.org/web/20160506233505/http://pmtk3.googlecode.com/svn/trunk/matlabTools/stats/fitCv.m">fitCv.m</a> , which lets us use a different kernel basis for each fold. This is much slower but gives much better results. See <a href="https://web.archive.org/web/20160506233505/http://pmtk3.googlecode.com/svn/trunk/demos/otherDemos/supervisedModels/smlrPathDemo.m">smlrPathDemo.m</a> for a comparison of these two approaches.</p><p>To fit an SMLR model with an RBF kernel, and to cross validate over lambdaRange, use</p><p><tt>model = smlrFit(X,y, 'kernelFn', @(X1, X2)kernelRbfGamma(X1, X2, gamma), ... 'regType', 'L1', 'lambdaRange', lambdaRange)</tt></p><p>regType defaults to 'L1', and lambdaRange defaults to logspace(-5, 2, 10), so both these parameters can be omitted. The kernelFn is mandatory, however. After fitting, use <a href="https://web.archive.org/web/20160506233505/http://pmtk3.googlecode.com/svn/trunk/toolbox/SupervisedModels/smlr/smlrPredict.m">smlrPredict.m</a> to predict.</p><h2>Relevance vector machines (RVM)<a name="18"></a></h2><p>An alternative approach to achieving sparsity is to use automatic relevance determination (ARD). The combination of kernel basis function expansion and ARD is known as the relevance vector machine (RVM). This can be used for classification or regression.</p><p>One way to fit an RVM (implemented in <a href="https://web.archive.org/web/20160506233505/http://pmtk3.googlecode.com/svn/trunk/toolbox/SupervisedModels/rvm/sub/rvmSimpleFit.m">rvmSimpleFit.m</a> ) is to use kernel basis expansion followed by the ARD fitting feature in <a href="https://web.archive.org/web/20160506233505/http://pmtk3.googlecode.com/svn/trunk/toolbox/SupervisedModels/linreg/linregFitBayes.m">linregFitBayes.m</a> ; however, this is rather slow. Instead, <a href="https://web.archive.org/web/20160506233505/http://pmtk3.googlecode.com/svn/trunk/toolbox/SupervisedModels/rvm/rvmFit.m">rvmFit.m</a> provides a wrapper to Mike Tipping's <a href="https://web.archive.org/web/20160506233505/http://www.vectoranomaly.com/downloads/downloads.htm">SparseBayes 2.0</a> Matlab library, which implements a greedy algorithm that adds basis functions one at a time.</p><p>To fit an RVM with an RBF kernel, use</p><p><tt>model = rvmFit(X,y, 'kernelFn', @(X1, X2)kernelRbfGamma(X1, X2, gamma))</tt></p><p>There is no need to specify lambdaRange, since the method uses ARD to estimate the hyper-parameters. After fitting, use <a href="https://web.archive.org/web/20160506233505/http://pmtk3.googlecode.com/svn/trunk/toolbox/SupervisedModels/rvm/rvmPredict.m">rvmPredict.m</a> to predict.</p><p>Currently Tipping's package does not support multi-class classification. Therefore we convert the base binary classifier into a multi-class one using <a href="https://web.archive.org/web/20160506233505/http://pmtk3.googlecode.com/svn/trunk/toolbox/SupervisedModels/oneVsRestClassif/oneVsRestClassifFit.m">oneVsRestClassifFit.m</a> . This is done internally by <a href="https://web.archive.org/web/20160506233505/http://pmtk3.googlecode.com/svn/trunk/toolbox/SupervisedModels/rvm/rvmFit.m">rvmFit.m</a> .</p><h2>Support vector machines (SVM)<a name="19"></a></h2><p>SVMs are a very popular form of non-probabilistic kernelized discriminative classifier. They achieve sparsity not by using a sparsity-promoting prior, but instead by using a hinge loss function when training.</p><p><a href="https://web.archive.org/web/20160506233505/http://pmtk3.googlecode.com/svn/trunk/toolbox/SupervisedModels/svm/svmFit.m">svmFit.m</a> (which handles multi-class classification and regression) is a wrapper to several different implementations of SVMs:</p><div><ul><li>svmQP: our own Matlab code (based on code originally written by Steve Gunn), which uses the quadprog.m function in the optimization toolbox.</li><li><a href="https://web.archive.org/web/20160506233505/http://svmlight.joachims.org/">svmlight</a>, which is a C library</li><li><a href="https://web.archive.org/web/20160506233505/http://www.csie.ntu.edu.tw/~cjlin/libsvm">libsvm</a>, which is a C library</li><li><a href="https://web.archive.org/web/20160506233505/http://www.csie.ntu.edu.tw/~cjlin/liblinear/">liblinear</a>, which is a C library</li></ul></div><p>The appropriate library is determined automatically based on the type of kernel, as follows: If you use a linear kernel, it calls liblinear; if you use an RBF kernel, it calls libsvm; if you use an arbitrary kernel (eg. a string kernel), it calls our QP code. (Thus it never calls svmlight by default, since libsvm seems to be much faster.)</p><p>The function <a href="https://web.archive.org/web/20160506233505/http://pmtk3.googlecode.com/svn/trunk/demos/otherDemos/supervisedModels/svmFitTest.m">svmFitTest.m</a> checks that all these implementations give the same results, up to numerical error. (This should be the case since the objective is convex; however, some packages only solve the problem to a very low precision.)</p><p><a href="https://web.archive.org/web/20160506233505/http://pmtk3.googlecode.com/svn/trunk/toolbox/SupervisedModels/svm/svmFit.m">svmFit.m</a> calls <a href="https://web.archive.org/web/20160506233505/http://pmtk3.googlecode.com/svn/trunk/matlabTools/stats/fitCv.m">fitCv.m</a> internally to choose the appropriate regularization constant <img src="./Supervised learning using non-parametric discriminative models in pmtk3_files/tutKernelClassif_eq19278.png" alt="$C = 1/\lambda$">. It can also choose the best kernel parameter. Here is an example of the calling syntax.</p><p><tt>model = svmFit(Xtrain, ytrain, 'C', logspace(-5, 1, 10),... 'kernel', 'rbf', 'kernelParam', logspace(-2,2,5));</tt></p><p>After fitting, use <a href="https://web.archive.org/web/20160506233505/http://pmtk3.googlecode.com/svn/trunk/toolbox/SupervisedModels/svm/svmPredict.m">svmPredict.m</a> to predict.</p><h2>Comparison of SVM, RVM, SMLR<a name="20"></a></h2><p>Let us compare various kernelized classifiers. Below we show the characteristics of some data sets to which we will apply the various classifiers. Colon and AML/ALL are gene microarray datasets, which is why the number of features is so large. Soy and forensic glass are standard datasets from the <a href="https://web.archive.org/web/20160506233505/http://archive.ics.uci.edu/ml/">UCI repository</a>. (All data is locally stored in <a href="https://web.archive.org/web/20160506233505/http://code.google.com/p/pmtkdata/">pmtkdata</a>.)</p><p>
<table border="3" cellpadding="5" width="100%">
<tbody><tr align="left">
<th bgcolor="#00CCFF"><font color="000000"></font></th>
<th bgcolor="#00CCFF"><font color="000000">nClasses</font></th>
<th bgcolor="#00CCFF"><font color="000000">nFeatures</font></th>
<th bgcolor="#00CCFF"><font color="000000">nTrain</font></th>
<th bgcolor="#00CCFF"><font color="000000">nTest</font></th>
</tr>
<tr>
<td bgcolor="#00CCFF"><font color="000000">crabs</font>
</td><td> 2
</td><td> 5
</td><td> 140
</td><td> 60
</td></tr><tr>
<td bgcolor="#00CCFF"><font color="000000">iris</font>
</td><td> 3
</td><td> 4
</td><td> 105
</td><td> 45
</td></tr><tr>
<td bgcolor="#00CCFF"><font color="000000">bankruptcy</font>
</td><td> 2
</td><td> 2
</td><td> 46
</td><td> 20
</td></tr><tr>
<td bgcolor="#00CCFF"><font color="000000">pima</font>
</td><td> 2
</td><td> 7
</td><td> 140
</td><td> 60
</td></tr><tr>
<td bgcolor="#00CCFF"><font color="000000">soy</font>
</td><td> 3
</td><td> 35
</td><td> 214
</td><td> 93
</td></tr><tr>
<td bgcolor="#00CCFF"><font color="000000">Fglass</font>
</td><td> 6
</td><td> 9
</td><td> 149
</td><td> 65
</td></tr><tr>
<td bgcolor="#00CCFF"><font color="000000">colon</font>
</td><td> 2
</td><td> 2000
</td><td> 43
</td><td> 19
</td></tr><tr>
<td bgcolor="#00CCFF"><font color="000000">AML/ALL</font>
</td><td> 2
</td><td> 7129
</td><td> 50
</td><td> 22
</td></tr><tr>
</tr></tbody></table>
</p><p>In <a href="https://web.archive.org/web/20160506233505/http://pmtk3.googlecode.com/svn/trunk/demos/otherDemos/supervisedModels/classificationShootout.m">classificationShootout.m</a> we compare SVM, RVM, SMLR and RMLR on the lowdim datasets using RBF kernels. For each split, we use 70% of the data for training and 30% for testing. Cross validation on the training set is then used internally, if necessary, to tune the regularization parameter. The results are shown below. (This table is modelled after Table 2 of <a href="https://web.archive.org/web/20160506233505/http://www.lx.it.pt/~mtf/Krishnapuram_Carin_Figueiredo_Hartemink_2005.pdf">Learning sparse Bayesian classifiers: multi-class formulation, fast algorithms, and generalization bounds</a>, Krishnapuram et al, PAMI 2005.) We show the total number of misclassifications, and in brackets, the total number of retained kernel basis functions (- means not computed). The bottom row shows the total number of test cases, and the total number of possible basis functions, which is <img src="./Supervised learning using non-parametric discriminative models in pmtk3_files/tutKernelClassif_eq23831.png" alt="$N \times C$">.</p><p>
<font size="4"></font><table bgcolor="grey" align="left" cellpadding="9" valign="top" <caption=""><tbody><tr><td></td><th bgcolor="white" align="center" valign="top"><font size="3">Crabs</font></th><th bgcolor="white" align="center" valign="top"><font size="3">Iris</font></th><th bgcolor="white" align="center" valign="top"><font size="3">Bankruptcy</font></th><th bgcolor="white" align="center" valign="top"><font size="3">Pima</font></th><th bgcolor="white" align="center" valign="top"><font size="3">Soy</font></th><th bgcolor="white" align="center" valign="top"><font size="3">Fglass</font></th><th bgcolor="white" align="center" valign="top"><font size="3">train(minutes)</font></th><th bgcolor="white" align="center" valign="top"><font size="3">test(seconds)</font></th></tr>
<tr>
<th bgcolor="white" align="left" valign="center"><font size="3">SVM</font></th><td bgcolor="white" align="left" valign="top"><font size="3">4 (40)</font></td><td bgcolor="white" align="left" valign="top"><font size="3">4 (32)</font></td><td bgcolor="white" align="left" valign="top"><font size="3">2 (12)</font></td><td bgcolor="white" align="left" valign="top"><font size="3">15 (81)</font></td><td bgcolor="white" align="left" valign="top"><font size="3">7 (96)</font></td><td bgcolor="white" align="left" valign="top"><font size="3">25 (99)</font></td><td bgcolor="white" align="left" valign="top"><font size="3"> 7.3</font></td><td bgcolor="white" align="left" valign="top"><font size="3"> 0.024</font></td>
</tr><tr>
<th bgcolor="white" align="left" valign="center"><font size="3">RVM</font></th><td bgcolor="white" align="left" valign="top"><font size="3">6 (8)</font></td><td bgcolor="white" align="left" valign="top"><font size="3">5 (12)</font></td><td bgcolor="white" align="left" valign="top"><font size="3">2 (2)</font></td><td bgcolor="white" align="left" valign="top"><font size="3">13 (3)</font></td><td bgcolor="white" align="left" valign="top"><font size="3">9 (31)</font></td><td bgcolor="white" align="left" valign="top"><font size="3">23 (67)</font></td><td bgcolor="white" align="left" valign="top"><font size="3"> 38</font></td><td bgcolor="white" align="left" valign="top"><font size="3"> 0.013</font></td>
</tr><tr>
<th bgcolor="white" align="left" valign="center"><font size="3">SMLR</font></th><td bgcolor="white" align="left" valign="top"><font size="3">2 (140)</font></td><td bgcolor="white" align="left" valign="top"><font size="3">5 (210)</font></td><td bgcolor="white" align="left" valign="top"><font size="3">1 (46)</font></td><td bgcolor="white" align="left" valign="top"><font size="3">14 (140)</font></td><td bgcolor="white" align="left" valign="top"><font size="3">9 (400)</font></td><td bgcolor="white" align="left" valign="top"><font size="3">22 (730)</font></td><td bgcolor="white" align="left" valign="top"><font size="3">2.5e+002</font></td><td bgcolor="white" align="left" valign="top"><font size="3"> 0.01</font></td>
</tr><tr>
<th bgcolor="white" align="left" valign="center"><font size="3">RMLR</font></th><td bgcolor="white" align="left" valign="top"><font size="3">3 (280)</font></td><td bgcolor="white" align="left" valign="top"><font size="3">6 (315)</font></td><td bgcolor="white" align="left" valign="top"><font size="3">1 (92)</font></td><td bgcolor="white" align="left" valign="top"><font size="3">16 (280)</font></td><td bgcolor="white" align="left" valign="top"><font size="3">7 (642)</font></td><td bgcolor="white" align="left" valign="top"><font size="3">23 (894)</font></td><td bgcolor="white" align="left" valign="top"><font size="3"> 48</font></td><td bgcolor="white" align="left" valign="top"><font size="3">0.0097</font></td>
</tr><tr>
<th bgcolor="white" align="left" valign="center"><font size="3">Out
of</font></th><td bgcolor="white" align="left" valign="top"><font size="3">60 (280)</font></td><td bgcolor="white" align="left" valign="top"><font size="3">45 (315)</font></td><td bgcolor="white" align="left" valign="top"><font size="3">20 (92)</font></td><td bgcolor="white" align="left" valign="top"><font size="3">60 (280)</font></td><td bgcolor="white" align="left" valign="top"><font size="3">93 (642)</font></td><td bgcolor="white" align="left" valign="top"><font size="3">65 (894)</font></td><td bgcolor="white" align="left" valign="top"><font size="3"> </font></td><td bgcolor="white" align="left" valign="top"><font size="3"> </font></td>
</tr></tbody></table><br>
</p><p>The training time above is total time in minutes, including cross validation. But beware, we are comparing apples with oranges here, since the packages are in different langauges:</p><div><ul><li>svm is a wrapper to C code (libsvm)</li><li>rvm is optimized Matlab (SparseBayes)</li><li>SMLR and RMLR is unoptimized Matlab (very slow).</li></ul></div><p>The total time to make the above table is about 8 hours! Since it is very slow to cross validate over the kernel bandwidth <img src="./Supervised learning using non-parametric discriminative models in pmtk3_files/tutKernelClassif_eq41477.png" alt="$\gamma$"> and the regularization penalty <img src="./Supervised learning using non-parametric discriminative models in pmtk3_files/tutKernelClassif_eq23351.png" alt="$\lambda$">, we made a faster version of this demo, called <a href="https://web.archive.org/web/20160506233505/http://pmtk3.googlecode.com/svn/trunk/demos/otherDemos/supervisedModels/classificationShootoutCvLambdaOnly.m">classificationShootoutCvLambdaOnly.m</a> Here we first picked <img src="./Supervised learning using non-parametric discriminative models in pmtk3_files/tutKernelClassif_eq41477.png" alt="$\gamma$"> using CV for an SVM; we then used this same kernel parameter for all methods. (For the high dimensional datasets, we used a linear kernel.) The results are shown below. We see that performance is worse than using CV to pick the RBF param for each method separately.</p><p>
<table valign="top" align="left" bgcolor="grey" cellpadding="9">
<caption align="bottom"><font size="4"></font></caption>
<tbody><tr><td></td><th align="center" bgcolor="white" valign="top">
<font size="3">Crabs</font></th><th align="center" bgcolor="white" valign="top">
<font size="3">Iris</font></th><th align="center" bgcolor="white" valign="top">
<font size="3">Bankruptcy</font></th><th align="center" bgcolor="white" valign="top"><font size="3">Pima</font></th><th align="center" bgcolor="white" valign="top"><font size="3">Soy</font></th><th align="center" bgcolor="white" valign="top"><font size="3">Fglass</font></th><th align="center" bgcolor="white" valign="top"><font size="3">colon (linear)</font></th><th align="center" bgcolor="white" valign="top"><font size="3">amlAll (linear)</font></th><th align="center" bgcolor="white" valign="top"><font size="3">train(seconds)</font></th><th align="center" bgcolor="white" valign="top"><font size="3">test(seconds)</font></th></tr>
<tr>
<th align="left" bgcolor="white" valign="center"><font size="3">SVM</font></th><td align="left" bgcolor="white" valign="top"><font size="3">6 (106)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">5 (25)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">2 (17)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">16 (87)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">5 (143)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">24 (120)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">5 (0)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">8 (0)</font></td><td align="left" bgcolor="white" valign="top"><font size="3"> 32</font></td><td align="left" bgcolor="white" valign="top"><font size="3">0.078</font></td>
</tr><tr>
<th align="left" bgcolor="white" valign="center"><font size="3">RVM</font></th><td align="left" bgcolor="white" valign="top"><font size="3">5 (6)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">5 (8)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">2 (2)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">22 (1)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">7 (32)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">36 (28)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">3 (3)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">4 (1)</font></td><td align="left" bgcolor="white" valign="top"><font size="3"> 4.2</font></td><td align="left" bgcolor="white" valign="top"><font size="3">0.047</font></td>
</tr><tr>
<th align="left" bgcolor="white" valign="center"><font size="3">SMLR</font></th><td align="left" bgcolor="white" valign="top"><font size="3">3 (140)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">5 (209)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">2 (39)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">13 (140)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">7 (376)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">25 (743)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">3 (28)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">5 (49)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">4.8e+002</font></td><td align="left" bgcolor="white" valign="top"><font size="3">0.029</font></td>
</tr><tr>
<th align="left" bgcolor="white" valign="center"><font size="3">RMLR</font></th><td align="left" bgcolor="white" valign="top"><font size="3">3 (280)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">5 (315)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">2 (92)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">15 (280)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">7 (642)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">22 (894)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">8 (86)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">4 (100)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">1.1e+002</font></td><td align="left" bgcolor="white" valign="top"><font size="3">0.028</font></td>
</tr><tr>
<th align="left" bgcolor="white" valign="center"><font size="3">Out
of</font></th><td align="left" bgcolor="white" valign="top"><font size="3">60 (280)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">45 (315)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">20 (92)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">60 (280)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">93 (642)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">65 (894)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">19 (86)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">22 (100)</font></td><td align="left" bgcolor="white" valign="top"><font size="3"> </font></td><td align="left" bgcolor="white" valign="top"><font size="3"> </font></td>
</tr></tbody></table><br>
</p><p>In the spirit of reproducible research, we created a simpler demo, called <a href="https://web.archive.org/web/20160506233505/http://pmtk3.googlecode.com/svn/trunk/demos/otherDemos/supervisedModels/linearKernelDemo.m">linearKernelDemo.m</a> , which only uses linear kernels (so we don't have to cross validate over gamma in the RBF kernel) and only runs on a few datasets. This is much faster, allowing us to perform multiple trials. Below we show the median misclassification rates on the different data sets, averaged over 3 random splits. We also added logregL1path and logregL2path to the mix; these are written in Fortran (glmnet). The results are shown below.</p><p>
<table border="3" cellpadding="5" width="100%">
<tbody><tr><th colspan="7" align="center"> test error rate (median over 3 trials) </th></tr>
<tr align="left">
<th bgcolor="#00CCFF"><font color="000000"></font></th>
<th bgcolor="#00CCFF"><font color="000000">SVM</font></th>
<th bgcolor="#00CCFF"><font color="000000">RVM</font></th>
<th bgcolor="#00CCFF"><font color="000000">SMLR</font></th>
<th bgcolor="#00CCFF"><font color="000000">RMLR</font></th>
<th bgcolor="#00CCFF"><font color="000000">logregL2</font></th>
<th bgcolor="#00CCFF"><font color="000000">logregL1</font></th>
</tr>
<tr>
<td bgcolor="#00CCFF"><font color="000000">soy</font>
</td><td> 0.108
</td><td> 0.108
</td><td> 0.118
</td><td> 0.129
</td><td> 0.710
</td><td> 0.108
</td></tr><tr>
<td bgcolor="#00CCFF"><font color="000000">fglass</font>
</td><td> 0.477
</td><td> 0.554
</td><td> 0.400
</td><td> 0.431
</td><td> 0.708
</td><td> 0.492
</td></tr><tr>
<td bgcolor="#00CCFF"><font color="000000">colon</font>
</td><td> 0.211
</td><td> 0.211
</td><td> 0.158
</td><td> 0.211
</td><td> 0.316
</td><td> 0.211
</td></tr><tr>
<td bgcolor="#00CCFF"><font color="000000">amlAll</font>
</td><td> 0.455
</td><td> 0.227
</td><td> 0.136
</td><td> 0.182
</td><td> 0.364
</td><td> 0.182
</td></tr></tbody></table>
</p><p>Before reading too much into these results, let's look at the boxplots, which show that the differences are probably not signficant (we don't plot L2 lest it distort the scale)</p><p>
<img src="./Supervised learning using non-parametric discriminative models in pmtk3_files/linearKernelBoxplotErr.png">
</p><p>Below are the training times in seconds (median over 3 trials)</p><p>
<table border="3" cellpadding="5" width="100%">
<tbody><tr><th colspan="7" align="center"> training time in seconds (median over 3 trials) </th></tr>
<tr align="left">
<th bgcolor="#00CCFF"><font color="000000"></font></th>
<th bgcolor="#00CCFF"><font color="000000">SVM</font></th>
<th bgcolor="#00CCFF"><font color="000000">RVM</font></th>
<th bgcolor="#00CCFF"><font color="000000">SMLR</font></th>
<th bgcolor="#00CCFF"><font color="000000">RMLR</font></th>
<th bgcolor="#00CCFF"><font color="000000">logregL2</font></th>
<th bgcolor="#00CCFF"><font color="000000">logregL1</font></th>
</tr>
<tr>
<td bgcolor="#00CCFF"><font color="000000">soy</font>
</td><td> 0.566
</td><td> 0.549
</td><td> 43.770
</td><td> 24.193
</td><td> 0.024
</td><td> 0.720
</td></tr><tr>
<td bgcolor="#00CCFF"><font color="000000">fglass</font>
</td><td> 0.586
</td><td> 0.146
</td><td> 67.552
</td><td> 30.204
</td><td> 0.043
</td><td> 0.684
</td></tr><tr>
<td bgcolor="#00CCFF"><font color="000000">colon</font>
</td><td> 1.251
</td><td> 0.028
</td><td> 2.434
</td><td> 2.618
</td><td> 0.021
</td><td> 0.418
</td></tr><tr>
<td bgcolor="#00CCFF"><font color="000000">amlAll</font>
</td><td> 3.486
</td><td> 0.017
</td><td> 2.337
</td><td> 2.569
</td><td> 0.097
</td><td> 1.674
</td></tr></tbody></table>
</p><p>And here are the boxplots</p><p>
<img src="./Supervised learning using non-parametric discriminative models in pmtk3_files/linearKernelBoxplotTime.png">
</p><p>We see that the RVM is consistently the fastest. which is somewhat surprising since the SVM code is in C. However, the SVM needs to use cross validation, whereas RVM uses empirical Bayes.</p><p>Reproducing the above results using <a href="https://web.archive.org/web/20160506233505/http://pmtk3.googlecode.com/svn/trunk/demos/otherDemos/supervisedModels/linearKernelDemo.m">linearKernelDemo.m</a> takes about 10 minutes (on my laptop). However, we can run a simplified version of the demo, which only uses 1 random fold, and only uses the last two datasets (with smaller sample size). This just takes 20 seconds, so makes a suitable demo for publishing.</p><pre class="codeinput">clear <span class="string">all</span>
tic
split = 0.7;
d = 1;
loadData(<span class="string">'colon'</span>) <span class="comment">% 2 class, X is 62*2000</span>
dataSets(d).X = X;
dataSets(d).y = y;
dataSets(d).name = <span class="string">'colon'</span>;
d=d+1;
loadData(<span class="string">'amlAll'</span>); <span class="comment">% 2 class, X is 72*7129</span>
X = [Xtrain; Xtest];
y = [ytrain; ytest];
dataSets(d).X = X;
dataSets(d).y = y;
dataSets(d).name = <span class="string">'amlAll'</span>;
d=d+1;
dataNames = {dataSets.name};
nDataSets = numel(dataSets);
methods = {<span class="string">'SVM'</span>, <span class="string">'RVM'</span>, <span class="string">'SMLR'</span>, <span class="string">'RMLR'</span>, <span class="string">'logregL2path'</span>, <span class="string">'logregL1path'</span>};
nMethods = numel(methods);
<span class="keyword">for</span> d=1:nDataSets
X = dataSets(d).X;
y = dataSets(d).y;
setSeed(0); s=1;
[X, y] = shuffleRows(X, y);
X = rescaleData(standardizeCols(X));
N = size(X, 1);
nTrain = floor(split*N);
nTest = N - nTrain;
Xtrain = X(1:nTrain, :);
Xtest = X(nTrain+1:end, :);
ytrain = y(1:nTrain);
ytest = y(nTrain+1:end);
<span class="keyword">for</span> m=1:nMethods
method = methods{m};
<span class="keyword">switch</span> lower(method)
<span class="keyword">case</span> <span class="string">'svm'</span>
Crange = logspace(-6, 1, 20); <span class="comment">% if too small, libsvm crashes!</span>
model = svmFit(Xtrain, ytrain, <span class="string">'C'</span>, Crange, <span class="string">'kernel'</span>, <span class="string">'linear'</span>);
predFn = @(m,X) svmPredict(m,X);
<span class="keyword">case</span> <span class="string">'rvm'</span>
model = rvmFit(Xtrain, ytrain, <span class="string">'kernelFn'</span>, @kernelLinear);
predFn = @(m,X) rvmPredict(m,X);
<span class="keyword">case</span> <span class="string">'smlr'</span>
model = smlrFit(Xtrain, ytrain, <span class="string">'kernelFn'</span>, @kernelLinear);
predFn = @(m,X) smlrPredict(m,X);
<span class="keyword">case</span> <span class="string">'smlrpath'</span>
model = smlrFit(Xtrain, ytrain, <span class="string">'kernelFn'</span>, @kernelLinear, <span class="string">'usePath'</span>, 1);
predFn = @(m,X) smlrPredict(m,X);
<span class="keyword">case</span> <span class="string">'rmlr'</span>
model = smlrFit(Xtrain, ytrain, <span class="string">'kernelFn'</span>, @kernelLinear, <span class="keyword">...</span>
<span class="string">'regtype'</span>, <span class="string">'L2'</span>);
predFn = @(m,X) smlrPredict(m,X);
<span class="keyword">case</span> <span class="string">'rmlrpath'</span>
model = smlrFit(Xtrain, ytrain, <span class="string">'kernelFn'</span>, @kernelLinear, <span class="keyword">...</span>
<span class="string">'regtype'</span>, <span class="string">'L2'</span>, <span class="string">'usePath'</span>, 1);
predFn = @(m,X) smlrPredict(m,X);
<span class="keyword">case</span> <span class="string">'logregl2path'</span>
model = logregFitPathCv(Xtrain, ytrain, <span class="string">'regtype'</span>, <span class="string">'L2'</span>);
predFn = @(m,X) logregPredict(m,X);
<span class="keyword">case</span> <span class="string">'logregl1path'</span>
model = logregFitPathCv(Xtrain, ytrain, <span class="string">'regtype'</span>, <span class="string">'L1'</span>);
predFn = @(m,X) logregPredict(m,X);
<span class="keyword">end</span>
saveModel{d,m,s} = model;
yHat = predFn(model, Xtest);
nerrs = sum(yHat ~= ytest);
testErrRate(d,m,s) = nerrs/nTest;
numErrors(d,m,s) = nerrs;
maxNumErrors(d) = nTest;
<span class="keyword">end</span>
<span class="keyword">end</span>
toc
fprintf(<span class="string">'test err\n'</span>);
disp(testErrRate)
</pre><pre class="codeoutput">Warning: In the directory "C:\kmurphy\GoogleCode\pmtkSupport\glmnet-matlab", glmnetMex.mexw32 now shadows glmnetMex.dll.
Please see the MATLAB 7.1 Release Notes.
Elapsed time is 18.752526 seconds.
test err
0.1579 0.2105 0.3158 0.2105 0.4737 0.2632
0.5000 0.1364 0.1364 0.1818 0.1364 0.0909
</pre><p>It is easy to add other classifiers and data sets to this comparison.</p><p>For more extensive comparison of different classifiers on different datasets, see <a href="https://web.archive.org/web/20160506233505/http://pmtk3.googlecode.com/svn/trunk/docs/tutorial/html/tutMLcomp.html">tutMLcomp.html</a> .</p><h2>Gaussian processes<a name="37"></a></h2><p>GPs are discussed in more detail <a href="https://web.archive.org/web/20160506233505/http://pmtk3.googlecode.com/svn/trunk/docs/tutorial/html/tutGP.html">here</a>.</p><p>
</p><hr>
<p></p><p>This page was auto-generated by calling <i>pmtkPublish(tutKernelClassif)</i> on 20-Nov-2010 14:43:44</p><p class="footer"><br>
Published with MATLAB® 7.9<br></p></div><!--
##### SOURCE BEGIN #####
%% Supervised learning using non-parametric discriminative models in pmtk3
%
%
%% Kernel functions
%
% One common form of basis function expansion
% is to define a new feature vector $\phi(x)$ by comparing the input
% $x$ to a set of prototypes or examplars $\mu_k$ as follows:
%%
% $$\phi(x) = (K(x,\mu_1), ..., K(x,\mu_D))$$
%%
% Here $K(x,\mu)$ is a 'kernel function',
% which in this context just means a function of two arguments.
% A common example is the Gaussian or RBF kernel
%%
% $$K(x,\mu) = \exp(-\frac{||x-\mu||^2}{2\sigma^2})$$
%%
% where $\sigma$ is the 'bandwidth'.
% This can be created using <http://pmtk3.googlecode.com/svn/trunk/toolbox/Algorithms/kernels/kernelRbfSigma.m kernelRbfSigma.m> .
% Alternatively, we can write
%%
% $$K(x,\mu) = \exp(-\gamma ||x-\mu||^2)$$
%%
% The quantity $\gamma=1/\sigma^2$ is known as
% the scale or precision. This can be created using <http://pmtk3.googlecode.com/svn/trunk/toolbox/Algorithms/kernels/kernelRbfGamma.m kernelRbfGamma.m> .
% Most software packages use this latter parameterization.
%
% Another common example is the polynomial kernel
% <http://pmtk3.googlecode.com/svn/trunk/toolbox/Algorithms/kernels/kernelPolyPmtk.m kernelPolyPmtk.m> :
%%
% $$K(x,\mu) = (1+x^T \mu)^d$$
%%
% where d is the degree.
%
% Another common example is the linear kernel
% <http://pmtk3.googlecode.com/svn/trunk/toolbox/Algorithms/kernels/kernelLinearPmtk.m kernelLinearPmtk.m> :
%%
% $$K(x,\mu) = x^T \mu$$
%%
% (The reason for the 'pmtk' suffix is to distinguish
% these functions from other implementations of the same concept.)
%
% Often we take the prototypes $\mu_k$ to be the training vectors (rows of $X$),
% but we don't have to.
% Some methods require that the kernel be a Mercer (positive definite)
% kernel. All of the above kernels are Mercer kernels,
% but this is not always the case.
%
% The advantages of using kernels include the following
%
% * We can apply standard parametric models (e.g., linear and logistic
% regression) to non-vectorial inputs (e.g., strings, molecular structures, etc.),
% by defining $K(x,\mu)$ to be some
% kind of function for comparing structured inputs.
% * We can increase the flexibility of the model by working in an
% enlarged feature space.
%
% Below we show an example where we fit the XOR data using kernelized
% logistic regression, with various kernels and prototypes
% (from <http://pmtk3.googlecode.com/svn/trunk/demos/bookDemos/Introduction/logregXorDemo.m logregXorDemo.m> ).
%%
clear all; close all
[X, y] = createXORdata();
rbfSigma = 1;
polydeg = 2;
protoTypes = [1 1; 1 5; 5 1; 5 5];
protoTypesStnd = standardizeCols(protoTypes);
kernels = {@(X1, X2)kernelRbfSigma(X1, protoTypesStnd, rbfSigma)
@(X1, X2)kernelRbfSigma(X1, X2, rbfSigma)
@(X1, X2)kernelPolyPmtk(X1, X2, polydeg)};
titles = {'rbf', 'rbf prototypes', 'poly'};
for i=1:numel(kernels)
preproc = preprocessorCreate('kernelFn', kernels{i}, 'standardizeX', true, 'addOnes', true);
model = logregFit(X, y, 'preproc', preproc);
yhat = logregPredict(model, X);
errorRate = mean(yhat ~= y);
fprintf('Error rate using %s features: %2.f%%\n', titles{i}, 100*errorRate);
predictFcn = @(Xtest)logregPredict(model, Xtest);
plotDecisionBoundary(X, y, predictFcn);
if i==1
hold on;
plot(protoTypes(:, 1), protoTypes(:, 2), '*k', 'linewidth', 2, 'markersize', 10)
end
title(titles{i});
end
%%
% In the first example, we use an RBF kernel with centers at 4
% manually chosen points, shown with black stars.
% In the second and third examples, we use an RBF and polynomial kernel,
% centered at all the training data.
% This is an example of a non-parametric model,
% since the number of parameters grows with the size of
% the training set (which makes training slow on large datasets).
% We can use sparsity promoting priors to select a subset of the training
% data, as we illustrate below.
%% Using grid search plus cross validation to choose the kernel parameters
% We can create a grid of models, with different kernels
% and different regularizers, as shown in the example
% below ( from <http://pmtk3.googlecode.com/svn/trunk/demos/otherDemos/supervisedModels/logregKernelDemo.m logregKernelDemo.m> ).
% If CV does not pick a point on the edge of the grid,
% we can be faily confident we have searched over
% a reasonable range. For this reason,
% it is helpful to plot the cost surface.
%
%%
loadData('fglass'); % 6 classes, X is 214*9
X = [Xtrain; Xtest];
y = canonizeLabels([ytrain; ytest]); % class 4 is missing, so relabel 1:6
setSeed(0);
split = 0.7;
[X, y] = shuffleRows(X, y);
X = rescaleData(standardizeCols(X));
N = size(X, 1);
nTrain = floor(split*N);
nTest = N - nTrain;
Xtrain = X(1:nTrain, :);
Xtest = X(nTrain+1:end, :);
ytrain = y(1:nTrain);
ytest = y(nTrain+1:end);
% 2D CV
lambdaRange = logspace(-6, 1, 5);
gammaRange = logspace(-4, 4, 5);
paramRange = crossProduct(lambdaRange, gammaRange);
regtypes = {'L2'}; %L1 is a bit better but a bit slower
for r=1:length(regtypes)
regtype = regtypes{r};
fitFn = @(X, y, param)...
logregFit(X, y, 'lambda', param(1), 'regType', regtype, 'preproc', ...
preprocessorCreate('kernelFn', @(X1, X2)kernelRbfGamma(X1, X2, param(2))));
predictFn = @logregPredict;
lossFn = @(ytest, yhat)mean(yhat ~= ytest);
nfolds = 5;
useSErule = true;
plotCv = true;
tic;
[LRmodel, bestParam, LRmu, LRse] = ...
fitCv(paramRange, fitFn, predictFn, lossFn, Xtrain, ytrain, nfolds, ...
'useSErule', useSErule, 'doPlot', plotCv, 'params1', lambdaRange, 'params2', gammaRange);
time(r) = toc
yhat = logregPredict(LRmodel, Xtest);
nerrors(r) = sum(yhat ~= ytest);
end
errRate = nerrors/nTest
%%
%
% In the example above, we just use a 5x5 grid for speed,
% but in practice one might use a 10x10 grid for a coarse
% search (possibly on a subset of the data), followed by a
% more refined search in a promising part of hyper-parameter space.
% This could all be handed off to a generic discrete optimization
% algorithm, but this is not yet supported.
% (One big advantage of Gaussian processes,
% which we will discuss later,
% is that we can use continous optimization algorithms
% to tune the kernel parameters.)
%% Sparse multinomial logistic regression (SMLR)
% We can select a subset of the training examples
% by using an L1 regularizer.
% This is called Sparse multinomial logistic regression (SMLR).
% If we use an L2 regularizer instead of L1,
% we call the method 'ridged multinomial logistic regression' or RMLR.
% (This terminology is from the paper
% <http://www.lx.it.pt/~mtf/Krishnapuram_Carin_Figueiredo_Hartemink_2005.pdf "Learning sparse Bayesian classifiers: multi-class formulation, fast
% algorithms, and generalization bounds">, Krishnapuram et al, PAMI 2005.)
%
% One way to implement <http://pmtk3.googlecode.com/svn/trunk/toolbox/SupervisedModels/smlr/smlrFit.m smlrFit.m> is to
% kernelize the data,
% and then pick the best lambda on the regularization path
% using <http://pmtk3.googlecode.com/svn/trunk/toolbox/SupervisedModels/logreg/logregFitPathCv.m logregFitPathCv.m> (which uses
% <http://www-stat.stanford.edu/~tibs/glmnet-matlab/ glmnet>).
% Another way is call <http://pmtk3.googlecode.com/svn/trunk/matlabTools/stats/fitCv.m fitCv.m> , which
% lets us use a different kernel basis for each fold.
% This is much slower but gives much better results.
% See <http://pmtk3.googlecode.com/svn/trunk/demos/otherDemos/supervisedModels/smlrPathDemo.m smlrPathDemo.m> for a comparison of these two approaches.
%
% To fit an SMLR model with an RBF kernel, and to
% cross validate over lambdaRange, use
%
% |model = smlrFit(X,y, 'kernelFn', @(X1, X2)kernelRbfGamma(X1, X2, gamma), ...
% 'regType', 'L1', 'lambdaRange', lambdaRange)|
%
% regType defaults to 'L1',
% and lambdaRange defaults to logspace(-5, 2, 10),
% so both these parameters can be omitted. The kernelFn is mandatory,
% however.
% After fitting, use <http://pmtk3.googlecode.com/svn/trunk/toolbox/SupervisedModels/smlr/smlrPredict.m smlrPredict.m> to predict.
%% Relevance vector machines (RVM)
% An alternative approach to achieving sparsity is to
% use automatic relevance determination (ARD).
% The combination of kernel basis function expansion
% and ARD is known as the relevance vector machine (RVM).
% This can be used for classification or regression.
%
% One way to fit an RVM (implemented in <http://pmtk3.googlecode.com/svn/trunk/toolbox/SupervisedModels/rvm/sub/rvmSimpleFit.m rvmSimpleFit.m> )
% is to use kernel basis expansion followed by the ARD
% fitting feature in
% <http://pmtk3.googlecode.com/svn/trunk/toolbox/SupervisedModels/linreg/linregFitBayes.m linregFitBayes.m> ; however,
% this is rather slow.
% Instead, <http://pmtk3.googlecode.com/svn/trunk/toolbox/SupervisedModels/rvm/rvmFit.m rvmFit.m> provides a wrapper to
% Mike Tipping's
% <http://www.vectoranomaly.com/downloads/downloads.htm SparseBayes 2.0>
% Matlab library, which implements a greedy algorithm
% that adds basis functions one at a time.
%
% To fit an RVM with an RBF kernel, use
%
% |model = rvmFit(X,y, 'kernelFn', @(X1, X2)kernelRbfGamma(X1, X2,
% gamma))|
%
% There is no need to specify lambdaRange, since the method
% uses ARD to estimate the hyper-parameters.
% After fitting, use <http://pmtk3.googlecode.com/svn/trunk/toolbox/SupervisedModels/rvm/rvmPredict.m rvmPredict.m> to predict.
%
% Currently Tipping's package does not support multi-class
% classification. Therefore we convert the base binary classifier
% into a multi-class one using <http://pmtk3.googlecode.com/svn/trunk/toolbox/SupervisedModels/oneVsRestClassif/oneVsRestClassifFit.m oneVsRestClassifFit.m> .
% This is done internally by <http://pmtk3.googlecode.com/svn/trunk/toolbox/SupervisedModels/rvm/rvmFit.m rvmFit.m> .
%
%% Support vector machines (SVM)
% SVMs are a very popular form of non-probabilistic kernelized
% discriminative classifier. They achieve sparsity not by using
% a sparsity-promoting prior, but instead by using a hinge loss
% function when training.
%
% <http://pmtk3.googlecode.com/svn/trunk/toolbox/SupervisedModels/svm/svmFit.m svmFit.m> (which handles multi-class classification and regression)
% is a wrapper to several different implementations of SVMs:
%
% * svmQP: our own Matlab code (based on code originally written by Steve Gunn),
% which uses the quadprog.m function in the optimization toolbox.
% * <http://svmlight.joachims.org/ svmlight>, which is a C library
% * <http://www.csie.ntu.edu.tw/~cjlin/libsvm libsvm>, which is a C library
% * <http://www.csie.ntu.edu.tw/~cjlin/liblinear/ liblinear>, which is a C
% library
%
% The appropriate library is determined automatically based on the type
% of kernel, as follows: If you use a linear kernel, it calls liblinear;
% if you use an RBF kernel, it calls libsvm; if you use an arbitrary
% kernel (eg. a string kernel), it calls our QP code.
% (Thus it never calls svmlight by default, since libsvm seems to be
% much faster.)
%
% The function <http://pmtk3.googlecode.com/svn/trunk/demos/otherDemos/supervisedModels/svmFitTest.m svmFitTest.m> checks that all these implementations
% give the same results, up to numerical error.
% (This should be the case since
% the objective is convex; however, some
% packages only solve the problem to a very low precision.)
%
% <http://pmtk3.googlecode.com/svn/trunk/toolbox/SupervisedModels/svm/svmFit.m svmFit.m> calls <http://pmtk3.googlecode.com/svn/trunk/matlabTools/stats/fitCv.m fitCv.m> internally to choose the appropriate
% regularization constant $C = 1/\lambda$.
% It can also choose the best kernel parameter.
% Here is an example of the calling syntax.
%
% |model = svmFit(Xtrain, ytrain, 'C', logspace(-5, 1, 10),...
% 'kernel', 'rbf', 'kernelParam', logspace(-2,2,5));|
%
% After fitting, use <http://pmtk3.googlecode.com/svn/trunk/toolbox/SupervisedModels/svm/svmPredict.m svmPredict.m> to predict.
%% Comparison of SVM, RVM, SMLR
% Let us compare various kernelized classifiers.
% Below we show the characteristics of some data sets
% to which we will apply the various classifiers.
% Colon and AML/ALL are gene microarray datasets,
% which is why the number of features is so large.
% Soy and forensic glass are standard datasets
% from the <http://archive.ics.uci.edu/ml/ UCI repository>.
% (All data is locally stored in
% <http://code.google.com/p/pmtkdata/ pmtkdata>.)
%
%%
% <html>
% <TABLE BORDER=3 CELLPADDING=5 WIDTH="100%" >
% <TR ALIGN=left>
% <TH BGCOLOR=#00CCFF><FONT COLOR=000000></FONT></TH>
% <TH BGCOLOR=#00CCFF><FONT COLOR=000000>nClasses</FONT></TH>
% <TH BGCOLOR=#00CCFF><FONT COLOR=000000>nFeatures</FONT></TH>
% <TH BGCOLOR=#00CCFF><FONT COLOR=000000>nTrain</FONT></TH>
% <TH BGCOLOR=#00CCFF><FONT COLOR=000000>nTest</FONT></TH>
% </TR>
% <tr>
% <td BGCOLOR=#00CCFF><FONT COLOR=000000>crabs</FONT>
% <td> 2
% <td> 5
% <td> 140
% <td> 60
% <tr>
% <td BGCOLOR=#00CCFF><FONT COLOR=000000>iris</FONT>
% <td> 3
% <td> 4
% <td> 105
% <td> 45
% <tr>
% <td BGCOLOR=#00CCFF><FONT COLOR=000000>bankruptcy</FONT>
% <td> 2
% <td> 2
% <td> 46
% <td> 20
% <tr>
% <td BGCOLOR=#00CCFF><FONT COLOR=000000>pima</FONT>
% <td> 2
% <td> 7
% <td> 140
% <td> 60
% <tr>
% <td BGCOLOR=#00CCFF><FONT COLOR=000000>soy</FONT>
% <td> 3
% <td> 35
% <td> 214
% <td> 93
% <tr>
% <td BGCOLOR=#00CCFF><FONT COLOR=000000>Fglass</FONT>
% <td> 6
% <td> 9
% <td> 149
% <td> 65
% <tr>
% <td BGCOLOR=#00CCFF><FONT COLOR=000000>colon</FONT>
% <td> 2
% <td> 2000
% <td> 43
% <td> 19
% <tr>
% <td BGCOLOR=#00CCFF><FONT COLOR=000000>AML/ALL</FONT>
% <td> 2
% <td> 7129
% <td> 50
% <td> 22
% <tr>
% </table>
% </html>
%%
%
% In <http://pmtk3.googlecode.com/svn/trunk/demos/otherDemos/supervisedModels/classificationShootout.m classificationShootout.m>
% we compare SVM, RVM, SMLR and RMLR
% on the lowdim datasets using RBF kernels.
% For each split, we use 70% of the data for training and 30% for testing.
% Cross validation on the training set is then used internally,
% if necessary, to tune the regularization parameter.
% The results are shown below.
% (This table is modelled after Table 2 of
% <http://www.lx.it.pt/~mtf/Krishnapuram_Carin_Figueiredo_Hartemink_2005.pdf Learning
% sparse Bayesian classifiers: multi-class formulation, fast
% algorithms, and generalization bounds>, Krishnapuram et al, PAMI 2005.)
% We show the total number of misclassifications, and in brackets, the
% total number of retained kernel basis functions (- means not computed).
% The bottom row shows the total number of test cases, and the total
% number of possible basis functions, which is $N \times C$.
%%
% <html>
% <TABLE BGCOLOR=grey ALIGN=left CELLPADDING=9 VALIGN=top <CAPTION ALIGN=bottom><font size=4></font></CAPTION><TR><TD></TD><TH BGCOLOR=white ALIGN=center VALIGN=top><font size=3>Crabs</font></TH><TH BGCOLOR=white ALIGN=center VALIGN=top><font size=3>Iris</font></TH><TH BGCOLOR=white ALIGN=center VALIGN=top><font size=3>Bankruptcy</font></TH><TH BGCOLOR=white ALIGN=center VALIGN=top><font size=3>Pima</font></TH><TH BGCOLOR=white ALIGN=center VALIGN=top><font size=3>Soy</font></TH><TH BGCOLOR=white ALIGN=center VALIGN=top><font size=3>Fglass</font></TH><TH BGCOLOR=white ALIGN=center VALIGN=top><font size=3>train(minutes)</font></TH><TH BGCOLOR=white ALIGN=center VALIGN=top><font size=3>test(seconds)</font></TH></TR>
% <TR>
% <TH BGCOLOR=white ALIGN=left VALIGN=center ><font
% size=3>SVM</font></TH><TD BGCOLOR=white ALIGN=left VALIGN=top><font size=3>4 (40)</font></TD><TD BGCOLOR=white ALIGN=left VALIGN=top><font size=3>4 (32)</font></TD><TD BGCOLOR=white ALIGN=left VALIGN=top><font size=3>2 (12)</font></TD><TD BGCOLOR=white ALIGN=left VALIGN=top><font size=3>15 (81)</font></TD><TD BGCOLOR=white ALIGN=left VALIGN=top><font size=3>7 (96)</font></TD><TD BGCOLOR=white ALIGN=left VALIGN=top><font size=3>25 (99)</font></TD><TD BGCOLOR=white ALIGN=left VALIGN=top><font size=3> 7.3</font></TD><TD BGCOLOR=white ALIGN=left VALIGN=top><font size=3> 0.024</font></TD>
% </TR><TR>
% <TH BGCOLOR=white ALIGN=left VALIGN=center ><font size=3>RVM</font></TH><TD BGCOLOR=white ALIGN=left VALIGN=top><font size=3>6 (8)</font></TD><TD BGCOLOR=white ALIGN=left VALIGN=top><font size=3>5 (12)</font></TD><TD BGCOLOR=white ALIGN=left VALIGN=top><font size=3>2 (2)</font></TD><TD BGCOLOR=white ALIGN=left VALIGN=top><font size=3>13 (3)</font></TD><TD BGCOLOR=white ALIGN=left VALIGN=top><font size=3>9 (31)</font></TD><TD BGCOLOR=white ALIGN=left VALIGN=top><font size=3>23 (67)</font></TD><TD BGCOLOR=white ALIGN=left VALIGN=top><font size=3> 38</font></TD><TD BGCOLOR=white ALIGN=left VALIGN=top><font size=3> 0.013</font></TD>
% </TR><TR>
% <TH BGCOLOR=white ALIGN=left VALIGN=center ><font
% size=3>SMLR</font></TH><TD BGCOLOR=white ALIGN=left VALIGN=top><font size=3>2 (140)</font></TD><TD BGCOLOR=white ALIGN=left VALIGN=top><font size=3>5 (210)</font></TD><TD BGCOLOR=white ALIGN=left VALIGN=top><font size=3>1 (46)</font></TD><TD BGCOLOR=white ALIGN=left VALIGN=top><font size=3>14 (140)</font></TD><TD BGCOLOR=white ALIGN=left VALIGN=top><font size=3>9 (400)</font></TD><TD BGCOLOR=white ALIGN=left VALIGN=top><font size=3>22 (730)</font></TD><TD BGCOLOR=white ALIGN=left VALIGN=top><font size=3>2.5e+002</font></TD><TD BGCOLOR=white ALIGN=left VALIGN=top><font size=3> 0.01</font></TD>
% </TR><TR>
% <TH BGCOLOR=white ALIGN=left VALIGN=center ><font size=3>RMLR</font></TH><TD BGCOLOR=white ALIGN=left VALIGN=top><font size=3>3 (280)</font></TD><TD BGCOLOR=white ALIGN=left VALIGN=top><font size=3>6 (315)</font></TD><TD BGCOLOR=white ALIGN=left VALIGN=top><font size=3>1 (92)</font></TD><TD BGCOLOR=white ALIGN=left VALIGN=top><font size=3>16 (280)</font></TD><TD BGCOLOR=white ALIGN=left VALIGN=top><font size=3>7 (642)</font></TD><TD BGCOLOR=white ALIGN=left VALIGN=top><font size=3>23 (894)</font></TD><TD BGCOLOR=white ALIGN=left VALIGN=top><font size=3> 48</font></TD><TD BGCOLOR=white ALIGN=left VALIGN=top><font size=3>0.0097</font></TD>
% </TR><TR>
% <TH BGCOLOR=white ALIGN=left VALIGN=center ><font size=3>Out
% of</font></TH><TD BGCOLOR=white ALIGN=left VALIGN=top><font size=3>60 (280)</font></TD><TD BGCOLOR=white ALIGN=left VALIGN=top><font size=3>45 (315)</font></TD><TD BGCOLOR=white ALIGN=left VALIGN=top><font size=3>20 (92)</font></TD><TD BGCOLOR=white ALIGN=left VALIGN=top><font size=3>60 (280)</font></TD><TD BGCOLOR=white ALIGN=left VALIGN=top><font size=3>93 (642)</font></TD><TD BGCOLOR=white ALIGN=left VALIGN=top><font size=3>65 (894)</font></TD><TD BGCOLOR=white ALIGN=left VALIGN=top><font size=3> </font></TD><TD BGCOLOR=white ALIGN=left VALIGN=top><font size=3> </font></TD>
% </TR></TABLE><br>
% </html>
%%
%
% The training time above is total time in minutes, including cross
% validation.
% But beware, we are comparing apples with oranges here,
% since the packages are in different langauges:
%
% * svm is a wrapper to C code (libsvm)
% * rvm is optimized Matlab (SparseBayes)
% * SMLR and RMLR is unoptimized Matlab (very slow).
%
% The total time to make the above table is about 8 hours!
% Since it is very slow to cross validate over the
% kernel bandwidth $\gamma$ and the regularization penalty $\lambda$,
% we made a faster version of this demo, called
% <http://pmtk3.googlecode.com/svn/trunk/demos/otherDemos/supervisedModels/classificationShootoutCvLambdaOnly.m classificationShootoutCvLambdaOnly.m>
% Here we first
% picked $\gamma$ using CV for an SVM; we then used this same kernel
% parameter for all methods. (For the high dimensional datasets,
% we used a linear kernel.) The results are shown below.
% We see that performance is worse than using CV to pick
% the RBF param for each method separately.
%
%%
% <html>
% <table valign="top" align="left" bgcolor="grey" cellpadding="9">
% <caption align="bottom"><font size="4"></font></caption>
% <tbody><tr><td></td><th align="center" bgcolor="white" valign="top">
% <font size="3">Crabs</font></th><th align="center" bgcolor="white" valign="top">
% <font size="3">Iris</font></th><th align="center" bgcolor="white" valign="top">
% <font size="3">Bankruptcy</font></th><th align="center" bgcolor="white" valign="top"><font size="3">Pima</font></th><th align="center" bgcolor="white" valign="top"><font size="3">Soy</font></th><th align="center" bgcolor="white" valign="top"><font size="3">Fglass</font></th><th align="center" bgcolor="white" valign="top"><font size="3">colon (linear)</font></th><th align="center" bgcolor="white" valign="top"><font size="3">amlAll (linear)</font></th><th align="center" bgcolor="white" valign="top"><font size="3">train(seconds)</font></th><th align="center" bgcolor="white" valign="top"><font size="3">test(seconds)</font></th></tr>
% <tr>
% <th align="left" bgcolor="white" valign="center"><font
% size="3">SVM</font></th><td align="left" bgcolor="white" valign="top"><font size="3">6 (106)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">5 (25)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">2 (17)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">16 (87)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">5 (143)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">24 (120)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">5 (0)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">8 (0)</font></td><td align="left" bgcolor="white" valign="top"><font size="3"> 32</font></td><td align="left" bgcolor="white" valign="top"><font size="3">0.078</font></td>
% </tr><tr>
% <th align="left" bgcolor="white" valign="center"><font size="3">RVM</font></th><td align="left" bgcolor="white" valign="top"><font size="3">5 (6)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">5 (8)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">2 (2)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">22 (1)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">7 (32)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">36 (28)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">3 (3)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">4 (1)</font></td><td align="left" bgcolor="white" valign="top"><font size="3"> 4.2</font></td><td align="left" bgcolor="white" valign="top"><font size="3">0.047</font></td>
% </tr><tr>
% <th align="left" bgcolor="white" valign="center"><font
% size="3">SMLR</font></th><td align="left" bgcolor="white" valign="top"><font size="3">3 (140)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">5 (209)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">2 (39)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">13 (140)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">7 (376)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">25 (743)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">3 (28)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">5 (49)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">4.8e+002</font></td><td align="left" bgcolor="white" valign="top"><font size="3">0.029</font></td>
% </tr><tr>
% <th align="left" bgcolor="white" valign="center"><font size="3">RMLR</font></th><td align="left" bgcolor="white" valign="top"><font size="3">3 (280)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">5 (315)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">2 (92)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">15 (280)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">7 (642)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">22 (894)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">8 (86)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">4 (100)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">1.1e+002</font></td><td align="left" bgcolor="white" valign="top"><font size="3">0.028</font></td>
% </tr><tr>
% <th align="left" bgcolor="white" valign="center"><font size="3">Out
% of</font></th><td align="left" bgcolor="white" valign="top"><font
% size="3">60 (280)</font></td><td align="left" bgcolor="white"
% valign="top"><font size="3">45 (315)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">20 (92)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">60 (280)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">93 (642)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">65 (894)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">19 (86)</font></td><td align="left" bgcolor="white" valign="top"><font size="3">22 (100)</font></td><td align="left" bgcolor="white" valign="top"><font size="3"> </font></td><td align="left" bgcolor="white" valign="top"><font size="3"> </font></td>
% </tr></tbody></table><br>
% </html>
%%
%
% In the spirit of reproducible research,
% we created a simpler demo, called
% <http://pmtk3.googlecode.com/svn/trunk/demos/otherDemos/supervisedModels/linearKernelDemo.m linearKernelDemo.m> ,
% which only uses linear kernels (so we don't have to cross validate
% over gamma in the RBF kernel) and only runs on a few datasets.
% This is much faster, allowing us to perform multiple trials.
% Below we show the median misclassification rates on the different data sets,
% averaged over 3 random splits.
% We also added logregL1path and logregL2path to the mix;
% these are written in Fortran (glmnet).
% The results are shown below.
%%
% <html>
% <TABLE BORDER=3 CELLPADDING=5 WIDTH="100%" >
% <TR><TH COLSPAN=7 ALIGN=center> test error rate (median over 3 trials) </font></TH></TR>
% <TR ALIGN=left>
% <TH BGCOLOR=#00CCFF><FONT COLOR=000000></FONT></TH>
% <TH BGCOLOR=#00CCFF><FONT COLOR=000000>SVM</FONT></TH>
% <TH BGCOLOR=#00CCFF><FONT COLOR=000000>RVM</FONT></TH>
% <TH BGCOLOR=#00CCFF><FONT COLOR=000000>SMLR</FONT></TH>
% <TH BGCOLOR=#00CCFF><FONT COLOR=000000>RMLR</FONT></TH>
% <TH BGCOLOR=#00CCFF><FONT COLOR=000000>logregL2</FONT></TH>
% <TH BGCOLOR=#00CCFF><FONT COLOR=000000>logregL1</FONT></TH>
% </TR>
% <tr>
% <td BGCOLOR=#00CCFF><FONT COLOR=000000>soy</FONT>
% <td> 0.108
% <td> 0.108
% <td> 0.118
% <td> 0.129
% <td> 0.710
% <td> 0.108
% <tr>
% <td BGCOLOR=#00CCFF><FONT COLOR=000000>fglass</FONT>
% <td> 0.477
% <td> 0.554
% <td> 0.400
% <td> 0.431
% <td> 0.708
% <td> 0.492
% <tr>
% <td BGCOLOR=#00CCFF><FONT COLOR=000000>colon</FONT>
% <td> 0.211
% <td> 0.211
% <td> 0.158
% <td> 0.211
% <td> 0.316
% <td> 0.211
% <tr>
% <td BGCOLOR=#00CCFF><FONT COLOR=000000>amlAll</FONT>
% <td> 0.455
% <td> 0.227
% <td> 0.136
% <td> 0.182
% <td> 0.364
% <td> 0.182
% </table>
% </html>
%%
% Before reading too much into these results,
% let's look at the boxplots, which show that
% the differences are probably not signficant
% (we don't plot L2 lest it distort the scale)
%%
% <html>
% <img
% src="http://pmtk3.googlecode.com/svn/trunk/docs/tutorial/extraFigures/linearKernelBoxplotErr.png">
% </html>
%%
% Below are the training times in seconds (median over 3 trials)
%
%%
% <html>
% <TABLE BORDER=3 CELLPADDING=5 WIDTH="100%" >
% <TR><TH COLSPAN=7 ALIGN=center> training time in seconds (median over 3 trials) </font></TH></TR>
% <TR ALIGN=left>
% <TH BGCOLOR=#00CCFF><FONT COLOR=000000></FONT></TH>
% <TH BGCOLOR=#00CCFF><FONT COLOR=000000>SVM</FONT></TH>
% <TH BGCOLOR=#00CCFF><FONT COLOR=000000>RVM</FONT></TH>
% <TH BGCOLOR=#00CCFF><FONT COLOR=000000>SMLR</FONT></TH>
% <TH BGCOLOR=#00CCFF><FONT COLOR=000000>RMLR</FONT></TH>
% <TH BGCOLOR=#00CCFF><FONT COLOR=000000>logregL2</FONT></TH>
% <TH BGCOLOR=#00CCFF><FONT COLOR=000000>logregL1</FONT></TH>
% </TR>
% <tr>
% <td BGCOLOR=#00CCFF><FONT COLOR=000000>soy</FONT>
% <td> 0.566
% <td> 0.549
% <td> 43.770
% <td> 24.193
% <td> 0.024
% <td> 0.720
% <tr>
% <td BGCOLOR=#00CCFF><FONT COLOR=000000>fglass</FONT>
% <td> 0.586
% <td> 0.146
% <td> 67.552
% <td> 30.204
% <td> 0.043
% <td> 0.684
% <tr>
% <td BGCOLOR=#00CCFF><FONT COLOR=000000>colon</FONT>
% <td> 1.251
% <td> 0.028
% <td> 2.434
% <td> 2.618
% <td> 0.021
% <td> 0.418
% <tr>
% <td BGCOLOR=#00CCFF><FONT COLOR=000000>amlAll</FONT>
% <td> 3.486
% <td> 0.017
% <td> 2.337
% <td> 2.569
% <td> 0.097
% <td> 1.674
% </table>
% </html>
%%
% And here are the boxplots
%%
% <html>
% <img
% src="http://pmtk3.googlecode.com/svn/trunk/docs/tutorial/extraFigures/linearKernelBoxplotTime.png">
% </html>
%%
% We see that the RVM is consistently the fastest.
% which is somewhat surprising since the SVM code is in C.
% However, the SVM needs to use cross validation, whereas RVM uses
% empirical Bayes.
%
% Reproducing the above results using
% <http://pmtk3.googlecode.com/svn/trunk/demos/otherDemos/supervisedModels/linearKernelDemo.m linearKernelDemo.m>
% takes about 10 minutes (on my laptop).
% However, we can run a simplified version of the demo,
% which only uses 1 random fold, and only uses the last
% two datasets (with smaller sample size). This just takes 20 seconds,
% so makes a suitable demo for publishing.
%%
clear all
tic
split = 0.7;
d = 1;
loadData('colon') % 2 class, X is 62*2000
dataSets(d).X = X;
dataSets(d).y = y;
dataSets(d).name = 'colon';
d=d+1;
loadData('amlAll'); % 2 class, X is 72*7129
X = [Xtrain; Xtest];
y = [ytrain; ytest];
dataSets(d).X = X;
dataSets(d).y = y;
dataSets(d).name = 'amlAll';
d=d+1;
dataNames = {dataSets.name};
nDataSets = numel(dataSets);
methods = {'SVM', 'RVM', 'SMLR', 'RMLR', 'logregL2path', 'logregL1path'};
nMethods = numel(methods);
for d=1:nDataSets
X = dataSets(d).X;
y = dataSets(d).y;
setSeed(0); s=1;
[X, y] = shuffleRows(X, y);
X = rescaleData(standardizeCols(X));
N = size(X, 1);
nTrain = floor(split*N);
nTest = N - nTrain;
Xtrain = X(1:nTrain, :);
Xtest = X(nTrain+1:end, :);
ytrain = y(1:nTrain);
ytest = y(nTrain+1:end);
for m=1:nMethods
method = methods{m};
switch lower(method)
case 'svm'
Crange = logspace(-6, 1, 20); % if too small, libsvm crashes!
model = svmFit(Xtrain, ytrain, 'C', Crange, 'kernel', 'linear');
predFn = @(m,X) svmPredict(m,X);
case 'rvm'
model = rvmFit(Xtrain, ytrain, 'kernelFn', @kernelLinear);
predFn = @(m,X) rvmPredict(m,X);
case 'smlr'
model = smlrFit(Xtrain, ytrain, 'kernelFn', @kernelLinear);
predFn = @(m,X) smlrPredict(m,X);
case 'smlrpath'
model = smlrFit(Xtrain, ytrain, 'kernelFn', @kernelLinear, 'usePath', 1);
predFn = @(m,X) smlrPredict(m,X);
case 'rmlr'
model = smlrFit(Xtrain, ytrain, 'kernelFn', @kernelLinear, ...
'regtype', 'L2');
predFn = @(m,X) smlrPredict(m,X);
case 'rmlrpath'
model = smlrFit(Xtrain, ytrain, 'kernelFn', @kernelLinear, ...
'regtype', 'L2', 'usePath', 1);
predFn = @(m,X) smlrPredict(m,X);
case 'logregl2path'
model = logregFitPathCv(Xtrain, ytrain, 'regtype', 'L2');
predFn = @(m,X) logregPredict(m,X);
case 'logregl1path'
model = logregFitPathCv(Xtrain, ytrain, 'regtype', 'L1');
predFn = @(m,X) logregPredict(m,X);
end
saveModel{d,m,s} = model;
yHat = predFn(model, Xtest);
nerrs = sum(yHat ~= ytest);
testErrRate(d,m,s) = nerrs/nTest;
numErrors(d,m,s) = nerrs;
maxNumErrors(d) = nTest;
end
end