|
| 1 | +% %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% |
| 2 | +% Copyright (c) 2020 University of Southern California |
| 3 | +% See full notice in LICENSE.md |
| 4 | +% Omid G. Sani and Maryam M. Shanechi |
| 5 | +% Shanechi Lab, University of Southern California |
| 6 | +% %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% |
| 7 | + |
| 8 | +%% Add PSID to the path (or call init.m) |
| 9 | +addpath(genpath('../source')); |
| 10 | + |
| 11 | +%% Load data |
| 12 | +data = load('./sample_data.mat'); |
| 13 | +% This data is generated from a system (shown in Supplementary Fig. 2) with |
| 14 | +% (a) 2 behaviorally relevant latent states, |
| 15 | +% (b) 2 behaviorally irrelevant latent states, and |
| 16 | +% (c) 2 states that drive behavior but are not represented in neural activity |
| 17 | + |
| 18 | +% Separate data into training and test data: |
| 19 | +trainInds = (1:round(0.5*size(data.y, 1)))'; |
| 20 | +testInds = ((1+trainInds(end)):size(data.y, 1))'; |
| 21 | +yTrain = data.y(trainInds, :); |
| 22 | +yTest = data.y(testInds, :); |
| 23 | +zTrain = data.z(trainInds, :); |
| 24 | +zTest = data.z(testInds, :); |
| 25 | +%% (Example 1) PSID can be used to dissociate and extract only the |
| 26 | +% behaviorally relevant latent states (with nx = n1 = 2) |
| 27 | +idSys1 = PSID(yTrain', zTrain', 2, 2, 10); |
| 28 | + |
| 29 | +% Predict behavior using the learned model |
| 30 | +[zTestPred1, xTestPred1] = PSIDPredict(idSys1, yTest); |
| 31 | + |
| 32 | +% Compute CC of decoding |
| 33 | +nz = size(zTest, 2); |
| 34 | +CC = arrayfun( @(i)( corr(zTestPred1(:, i), zTest(:, i)) ), 1:nz ); |
| 35 | + |
| 36 | +% Predict behavior using the true model for comparison |
| 37 | +[zTestPredIdeal, xTestIdeal] = PSIDPredict(data.trueSys, yTest); |
| 38 | +CCIdeal = arrayfun( @(i)( corr(zTestPredIdeal(:, i), zTest(:, i)) ), 1:nz ); % Compute CC of ideal decoding |
| 39 | + |
| 40 | +fprintf('PSID decoding CC = %.3g, ideal decoding CC using true model = %.3g\n', mean(CC), mean(CCIdeal)); |
| 41 | +%% (Example 2) Optionally, PSID can additionally also learn the |
| 42 | +% behaviorally irrelevant latent states (with nx = 4, n1 = 2) |
| 43 | +idSys2 = PSID(yTrain', zTrain', 4, 2, 10); |
| 44 | + |
| 45 | +%% (Example 3) PSID can be used if data is available in discontinious segments (e.g. different trials) |
| 46 | +% In this case, y and z data segments must be provided as elements of a cell array |
| 47 | +% Here, for example assume that trials start at every 1000 samples. |
| 48 | +% And each each trial has a random length of 500 to 900 samples |
| 49 | +trialStartInds = (1:1000:(size(data.y, 1)-1000))'; |
| 50 | +trialDurRange = [900 990]; |
| 51 | +trialDur = trialDurRange(1)-1 + randi(diff(trialDurRange)+1, size(trialStartInds)); |
| 52 | +trialInds = arrayfun( @(ti)( (trialStartInds(ti)-1+(1:trialDur(ti)))' ), (1:numel(trialStartInds))', 'UniformOutput', false ); |
| 53 | +yTrials = arrayfun( @(tInds)( data.y(tInds{1}, :)' ), trialInds, 'UniformOutput', false ); |
| 54 | +zTrials = arrayfun( @(tInds)( data.z(tInds{1}, :)' ), trialInds, 'UniformOutput', false ); |
| 55 | + |
| 56 | +% Separate data into training and test data: |
| 57 | +trainInds = (1:round(0.5*numel(yTrials)))'; |
| 58 | +testInds = ((1+trainInds(end)):numel(yTrials))'; |
| 59 | +yTrain = yTrials(trainInds, :); |
| 60 | +yTest = yTrials(testInds, :); |
| 61 | +zTrain = zTrials(trainInds, :); |
| 62 | +zTest = zTrials(testInds, :); |
| 63 | + |
| 64 | +idSys3 = PSID(yTrain, zTrain, 2, 2, 10); |
| 65 | + |
| 66 | +yTestT = arrayfun( @(yt)( yt{1}.' ), yTest, 'UniformOutput', false); |
| 67 | +% yTestCat = cell2mat( yTestT ); % Data can also be concatenated for |
| 68 | + % decoding if taking last state in a previous trial as the |
| 69 | + % initial state in the next trial makes sense |
| 70 | +[zTestPred1, xTestPred1Cell] = PSIDPredict(idSys3, yTestT); |
| 71 | + |
| 72 | +zTestPred1Cat = cell2mat( zTestPred1 ); |
| 73 | +% zTestPred1Cat = zTestPred1; |
| 74 | + |
| 75 | +zTestT = arrayfun( @(zt)( zt{1}.' ), zTest, 'UniformOutput', false); |
| 76 | +zTestCat = cell2mat( zTestT ); |
| 77 | +CCTrialBased = arrayfun( @(i)( corr(zTestPred1Cat(:, i), zTestCat(:, i)) ), 1:nz ); |
| 78 | + |
| 79 | +fprintf('Trial-based PSID decoding CC = %.3g, ideal decoding CC using true model = %.3g\n', mean(CCTrialBased), mean(CCIdeal)); |
| 80 | + |
| 81 | +%% |
| 82 | +% Plot the true and identified eigenvalues |
| 83 | + |
| 84 | +% (Example 1) Eigenvalues when only learning behaviorally relevant states |
| 85 | +idEigs1 = eig(idSys1.A); |
| 86 | + |
| 87 | +% (Example 2) Additional eigenvalues when also learning behaviorally irrelevant states |
| 88 | +% The identified model is already in form of Eq. 4, with behaviorally irrelevant states |
| 89 | +% coming as the last 2 dimensions of the states in the identified model |
| 90 | +idEigs2 = eig(idSys2.A(3:4, 3:4)); |
| 91 | + |
| 92 | +relevantDims = data.trueSys.zDims; % Dimensions that drive both behavior and neural activity |
| 93 | +irrelevantDims = find(~ismember(1:size(data.trueSys.a,1), data.trueSys.zDims)); % Dimensions that only drive the neural activity |
| 94 | +trueEigsRelevant = eig(data.trueSys.a(relevantDims, relevantDims)); |
| 95 | +trueEigsIrrelevant = eig(data.trueSys.a(irrelevantDims, irrelevantDims)); |
| 96 | +nonEncodedEigs = eig(data.epsSys.a); % Eigenvalues for states that only drive behavior |
| 97 | + |
| 98 | +figure; zplane([], []); ax = gca; hold(ax, 'on'); |
| 99 | +h1 = scatter(ax, real(nonEncodedEigs), imag(nonEncodedEigs), 'o', 'MarkerEdgeColor', 'b', 'DisplayName', 'Not encoded in neural signals'); |
| 100 | +h2 = scatter(ax, real(trueEigsIrrelevant), imag(trueEigsIrrelevant), 'o', 'MarkerEdgeColor', 'r', 'DisplayName', 'Behaviorally irrelevant'); |
| 101 | +h3 = scatter(ax, real(trueEigsRelevant), imag(trueEigsRelevant), 'o', 'MarkerEdgeColor', 'g', 'DisplayName', 'Behaviorally relevant'); |
| 102 | +h4 = scatter(ax, real(idEigs1), imag(idEigs1), 'x', 'MarkerEdgeColor', [0 0.5 0], 'DisplayName', 'PSID Identified (stage 1)'); |
| 103 | +h5 = scatter(ax, real(idEigs2), imag(idEigs2), 'x', 'MarkerEdgeColor', [0.5 0 0], 'DisplayName', '(optional) PSID Identified (stage 2)'); |
| 104 | +legend(ax, [h1, h2, h3, h4, h5], 'Location', 'EO'); |
0 commit comments