From 6ee69af9bac7c965fee6ee2c9d57488b08e9f797 Mon Sep 17 00:00:00 2001 From: Santos Safrao Date: Fri, 8 Dec 2023 18:27:23 +0900 Subject: [PATCH] added Kernel Mutual Subspace Method --- examples/kmsm.asv | 19 +++++++++++ examples/kmsm.m | 71 +++++++++------------------------------- src/classes/@CMSM/CMSM.m | 6 ++-- src/classes/@KMSM/KMSM.m | 58 ++++++++++++++++++++++++++++++++ src/classes/@MSM/MSM.m | 24 +++++++------- 5 files changed, 107 insertions(+), 71 deletions(-) create mode 100644 examples/kmsm.asv create mode 100644 src/classes/@KMSM/KMSM.m diff --git a/examples/kmsm.asv b/examples/kmsm.asv new file mode 100644 index 0000000..07f53a8 --- /dev/null +++ b/examples/kmsm.asv @@ -0,0 +1,19 @@ +%% Load data +testDataUsage = TestDataUsage.Subsets; +n = 2; +[trainData, testData, testLabels] = prepareData(testDataUsage, n); + +%% Train model +numDimReferenceSubspace = 10; +numDimInputSubspace = 4; +sigma = 1; + +model = KMSM(trainData,... + numDimReferenceSubspace,... + numDimInputSubspace,... + sigma,... + testLabels); + +%% Evaluate model +modelEvaluation = model.evaluate(testData); +modelEvaluation.printResults(); \ No newline at end of file diff --git a/examples/kmsm.m b/examples/kmsm.m index eaffc58..07f53a8 100644 --- a/examples/kmsm.m +++ b/examples/kmsm.m @@ -1,60 +1,19 @@ -%% Load Data -clear; -clc; -load('TsukubaHandDigitsDataset24x24.mat') -% Check if the variable 'trainData' and 'testData' do not exist in the workspace -% if ~(exist('trainData', 'var') == 1 && exist('testData', 'var') == 1) -% % If they don't exist, load the data from the .mat file -% load('TsukubaHandDigitsDataset24x24.mat'); -% end -% you can use the following code to convert the test data format from 3d to 4d -testData = subsetTestData(testData, 2); -% specific_class = 5; -%% to accomodate for MATLAB indexing -% specific_class = specific_class + 1; -training_data = cvlNormalize(trainData); -testing_data = cvlNormalize(testData); -% testing_data = testData(:, :, specific_class); -size_of_test_data = size(testing_data); +%% Load data +testDataUsage = TestDataUsage.Subsets; +n = 2; +[trainData, testData, testLabels] = prepareData(testDataUsage, n); -% get number of elements of size_of_test_data -array_size = numel(size_of_test_data); - -if array_size == 4 - % do nothing - num_sets = size_of_test_data(3); - num_classes = size_of_test_data(4); -elseif array_size == 3 - num_sets = 1; - num_classes = size_of_test_data(3); -else - num_classes = 1; - num_sets = 1; -end - -%% Train Model -num_dim_reference_subspaces = 10; -num_dim_input_subpaces = 5; +%% Train model +numDimReferenceSubspace = 10; +numDimInputSubspace = 4; sigma = 1; -reference_subspaces = cvlKernelBasisVector(training_data, num_dim_reference_subspaces, sigma); -input_subspaces = cvlKernelBasisVector(testing_data, num_dim_input_subpaces, sigma); -% save('reference_subspaces.mat', 'reference_subspaces'); -% reference_subspaces = reference_subspaces(:, :, 1); -tic; -%% Recognition Phase -similarities = cvlKernelCanonicalAngles(training_data,reference_subspaces,... - testing_data, input_subspaces, sigma); -similarities = similarities(:, :, end, end); -% End timing and display the elapsed time -elapsedTime = toc; -fprintf('The code block executed in %.5f seconds.\n', elapsedTime); -model_evaluation = ModelEvaluation(similarities, generateLabels(num_classes, num_sets)); - -displayModelResults('Kernel Mutual Subspace Methods', model_evaluation); +model = KMSM(trainData,... + numDimReferenceSubspace,... + numDimInputSubspace,... + sigma,... + testLabels); -%% Print preditions -% disp(model_evaluation.predicted_labels); -% disp(model_evaluation.true_labels); -% disp(similarities) -% plotSimilarities(similarities) \ No newline at end of file +%% Evaluate model +modelEvaluation = model.evaluate(testData); +modelEvaluation.printResults(); \ No newline at end of file diff --git a/src/classes/@CMSM/CMSM.m b/src/classes/@CMSM/CMSM.m index e35023f..5f53d04 100644 --- a/src/classes/@CMSM/CMSM.m +++ b/src/classes/@CMSM/CMSM.m @@ -1,7 +1,6 @@ classdef CMSM properties name = 'Constrained Mutual Subspace Method'; - trainData; generalizedDifferenceSubspace; % GDS referenceSubspaces; numDimInputSubspace; @@ -13,11 +12,11 @@ function obj = CMSM(trainData, numDimReferenceSubspace, numDimInputSubspace, indexOfEigsToKeep, labels) numDim = size(trainData, 1); numClasses = size(trainData, 3); - obj.trainData = trainData; + trainData = cvlNormalize(trainData); obj.numDimReferenceSubspace = numDimReferenceSubspace; obj.numDimInputSubspace = numDimInputSubspace; obj.trueTestLabels = labels; - subspaces = cvlBasisVector(obj.trainData,... + subspaces = cvlBasisVector(trainData,... obj.numDimReferenceSubspace); P = zeros(numDim, numDim); for I=1:numClasses @@ -61,6 +60,7 @@ function scores = getSimilarityScores(obj, testData) testDatasize = size(testData); testDatasizeNumElements = numel(testDatasize); + testData = cvlNormalize(testData); if testDatasizeNumElements == 4 numSets = testDatasize(3); diff --git a/src/classes/@KMSM/KMSM.m b/src/classes/@KMSM/KMSM.m new file mode 100644 index 0000000..da46de8 --- /dev/null +++ b/src/classes/@KMSM/KMSM.m @@ -0,0 +1,58 @@ +classdef KMSM + properties + name = 'Kernel Mutual Subspace Method'; + trainData; + referenceSubspaces; + numDimInputSubspace; + numDimReferenceSubspace; + sigma; + trueTestLabels; + end + + methods + function obj = KMSM(trainData, numDimReferenceSubspace, numDimInputSubspace, sigma, labels) + obj.trainData = cvlNormalize(trainData); + obj.sigma = sigma; + obj.numDimReferenceSubspace = numDimReferenceSubspace; + obj.numDimInputSubspace = numDimInputSubspace; + obj.trueTestLabels = labels; + obj.referenceSubspaces = cvlKernelBasisVector(obj.trainData,... + obj.numDimReferenceSubspace,... + obj.sigma); + end + + + % Returns the predicted labels for the test data + function prediction = predict(obj, testData) + similarityScores = obj.getSimilarityScores(testData); + eval = ModelEvaluation(similarityScores, obj.trueTestLabels, obj.name); + prediction = eval.predicted_labels; + end + + % Returns the similarity scores for the test data (same as probabilities) + function probabilities = predictProb(obj, testData) + probabilities = obj.getSimilarityScores(testData); + end + + % Returns the evaluation object for the test data + function eval = evaluate(obj, testData) + similarityScores = obj.getSimilarityScores(testData); + eval = ModelEvaluation(similarityScores, obj.trueTestLabels, obj.name); + end + end + + methods (Access = private) + function scores = getSimilarityScores(obj, testData) + testData = cvlNormalize(testData); + inputSubspace = cvlKernelBasisVector(testData,... + obj.numDimInputSubspace,... + obj.sigma); + similarities = cvlKernelCanonicalAngles(obj.trainData,... + obj.referenceSubspaces,... + testData,... + inputSubspace,... + obj.sigma); + scores = similarities(:, :, end, end); + end + end +end diff --git a/src/classes/@MSM/MSM.m b/src/classes/@MSM/MSM.m index d10bdac..5ec14ed 100644 --- a/src/classes/@MSM/MSM.m +++ b/src/classes/@MSM/MSM.m @@ -1,50 +1,50 @@ classdef MSM properties name = 'Mutual Subspace Method'; - trainData; referenceSubspaces; numDimInputSubspace; numDimReferenceSubspace; trueTestLabels; end - + methods function obj = MSM(trainData, numDimReferenceSubspace, numDimInputSubspace, labels) - obj.trainData = trainData; + trainData = cvlNormalize(trainData); obj.numDimReferenceSubspace = numDimReferenceSubspace; obj.numDimInputSubspace = numDimInputSubspace; obj.trueTestLabels = labels; - subspaces = cvlBasisVector(obj.trainData,... - obj.numDimReferenceSubspace); + subspaces = cvlBasisVector(trainData,... + obj.numDimReferenceSubspace); obj.referenceSubspaces = subspaces; end - - + + % Returns the predicted labels for the test data function prediction = predict(obj, testData) similarityScores = obj.getSimilarityScores(testData); eval = ModelEvaluation(similarityScores, obj.trueTestLabels, obj.name); prediction = eval.predicted_labels; end - + % Returns the similarity scores for the test data (same as probabilities) function probabilities = predictProb(obj, testData) probabilities = obj.getSimilarityScores(testData); end - + % Returns the evaluation object for the test data function eval = evaluate(obj, testData) similarityScores = obj.getSimilarityScores(testData); eval = ModelEvaluation(similarityScores, obj.trueTestLabels, obj.name); end end - + methods (Access = private) function scores = getSimilarityScores(obj, testData) + testData = cvlNormalize(testData); inputSubspace = cvlBasisVector(testData,... - obj.numDimInputSubspace); + obj.numDimInputSubspace); similarities = cvlCanonicalAngles(obj.referenceSubspaces,... - inputSubspace); + inputSubspace); scores = similarities(:, :, end, end); end end