Skip to content

Commit

Permalink
figure 6c
Browse files Browse the repository at this point in the history
  • Loading branch information
anne-urai committed Feb 15, 2019
1 parent 64ce9d6 commit 16c9e1b
Show file tree
Hide file tree
Showing 6 changed files with 210 additions and 231 deletions.
8 changes: 4 additions & 4 deletions barplots_modelcomparison_regression.m
Original file line number Diff line number Diff line change
Expand Up @@ -93,24 +93,24 @@ function getPlotModelIC(mdls, s, d, colors)

% move together in clusters of 3
if contains(mdls{i}, '_z_'),
xpos = i+0.2;
xpos = i+0.18;
thiscolor = colors(1, :);
elseif contains(mdls{i}, '_dc_'),
xpos = i;
thiscolor = colors(2, :);

elseif contains(mdls{i}, '_dcz_'),
xpos = i-0.2;
xpos = i-0.18;
thiscolor = colors(3, :);

end

% best fit with outline
if i == bestMdl,
b = bar(xpos, mdldic(i), 'facecolor', thiscolor, 'barwidth', 0.8, 'BaseValue', 0, ...
b = bar(xpos, mdldic(i), 'facecolor', thiscolor-0.1, 'barwidth', 0.8, 'BaseValue', 0, ...
'edgecolor', 'k');
else
b = bar(xpos, mdldic(i), 'facecolor', thiscolor, 'barwidth', 0.8, 'BaseValue', 0, ...
b = bar(xpos, mdldic(i), 'facecolor', thiscolor+0.1, 'barwidth', 0.8, 'BaseValue', 0, ...
'edgecolor', 'none');
end
end
Expand Down
182 changes: 182 additions & 0 deletions correlations_regression_lags.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
function allresults = correlations_regression_lags

global mypath datasets datasetnames
addpath(genpath('~/code/Tools'));
warning off; close all;
cols = cbrewer('qual', 'Paired', 10);

numlags = 6;
vars = {'z_correct', 'z_error', 'v_correct', 'v_error', 'repeat_correct', 'repeat_error'};
cnt = 1;

for d = 1:length(datasets),

dat = readtable(sprintf('%s/summary/%s/allindividualresults.csv', mypath, datasets{d}));
dat = dat(dat.session == 0, :);

for m = 1:length(vars),
alldata.(vars{m}) = nan(numlags, size(dat, 1));
end

% ALL MODELS THAT WERE RAN
mdls = {'regress_nohist', ...
'regress_z_lag1', ...
'regress_dc_lag1', ...
'regress_dcz_lag1', ...
'regress_z_lag2', ...
'regress_dc_lag2', ...
'regress_dcz_lag2', ...
'regress_z_lag3', ...
'regress_dc_lag3', ...
'regress_dcz_lag3', ...
'regress_z_lag4', ...
'regress_dc_lag4', ...
'regress_dcz_lag4', ...
'regress_z_lag5', ...
'regress_dc_lag5', ...
'regress_dcz_lag5', ...
'regress_z_lag6', ...
'regress_dc_lag6', ...
'regress_dcz_lag6'};

% ============================= %
% 1. DETERMINE THE BEST MODEL
% ============================= %

mdldic = nan(1, length(mdls));
for m = 1:length(mdls),
try
modelcomp = readtable(sprintf('%s/%s/%s/model_comparison.csv', ...
mypath, datasets{d}, mdls{m}), 'readrownames', true);
mdldic(m) = modelcomp.aic;
catch
fprintf('%s/%s/%s/model_comparison.csv NOT FOUND\n', ...
mypath, datasets{d}, mdls{m})
end
end

% everything relative to the full model
mdldic = bsxfun(@minus, mdldic, mdldic(1));
mdldic = mdldic(2:end);
mdls = mdls(2:end);
[~, bestMdl] = min(mdldic);

% now take the hybrid model for this best-fitting lag
bestmodelname = sprintf('regressdczlag%s', mdls{bestMdl}(end));
disp(bestmodelname);

% ========================================================== %
% 2. FOR THE BEST-FITTING MODEL, GET HISTORY WEIGHTS
% ========================================================== %

% ignore lag 1 - just take the average of lag 2:bestmodel
for l = 2:str2double(bestmodelname(end)),
lname = num2str(l);

% get regression weights
for v = 1:length(vars),
switch vars{v}
case 'z_correct'
alldata.(vars{v})(l,:) = ...
(dat.(['z_prev' lname 'resp__' bestmodelname]) + ...
dat.(['z_prev' lname 'stim__' bestmodelname]));
case 'z_error'
alldata.z_error(l,:) = ...
(dat.(['z_prev' lname 'resp__' bestmodelname]) - ...
dat.(['z_prev' lname 'stim__' bestmodelname]));
case 'v_correct'
alldata.v_correct(l,:) = ...
(dat.(['v_prev' lname 'resp__' bestmodelname]) + ...
dat.(['v_prev' lname 'stim__' bestmodelname]));
case 'v_error'
alldata.v_error(l,:) = ...
(dat.(['v_prev' lname 'resp__' bestmodelname]) - ...
dat.(['v_prev' lname 'stim__' bestmodelname]));

case 'repeat_error'
alldata.(vars{v})(l,:) = dat.(['repetition_error' num2str(l)])...
- arrayfun(@trivial_probabilities, dat.repetition_error1, repmat(l, size(dat, 1), 1));

case 'repeat_correct'
alldata.(vars{v})(l,:) = dat.(['repetition_correct' num2str(l)]) ...
- arrayfun(@trivial_probabilities, dat.repetition_correct1, repmat(l, size(dat, 1), 1));
end

end

end

% assign to structure - correct choices
allresults(1).z_prevresp = nanmean(alldata.z_correct);
allresults(1).v_prevresp = nanmean(alldata.v_correct);
allresults(1).criterionshift = nanmean(alldata.repeat_correct);
alltitles{1} = datasetnames{d};
allresults(1).marker = 'o';
allresults(1).meancolor = [0 0 0];
allresults(1).scattercolor = [0.5 0.5 0.5];

% also after error choices
allresults(2).z_prevresp = nanmean(alldata.z_error);
allresults(2).v_prevresp = nanmean(alldata.v_error);
allresults(2).criterionshift = nanmean(alldata.repeat_error);
alltitles{2} = datasetnames{d};
allresults(2).marker = 's';
allresults(2).meancolor = cols(6, :);
allresults(2).scattercolor = cols(5, :);

% ========================================================== %
% COMPUTE CORRELATIONS
% ========================================================== %

for a = 1:length(allresults),

% SAVE CORRELATIONS FOR OVERVIEW PLOT
% COMPUTE THE SPEARMANS CORRELATION AND ITS CONFIDENCE INTERVAL!
[alldat(cnt).corrz, alldat(cnt).corrz_ci, alldat(cnt).pz, alldat(cnt).bfz] = ...
spearmans(allresults(a).z_prevresp(:), allresults(a).criterionshift(:));

[alldat(cnt).corrv, alldat(cnt).corrv_ci, alldat(cnt).pv, alldat(cnt).bfv] = ...
spearmans(allresults(a).v_prevresp(:), allresults(a).criterionshift(:));

alldat(cnt).datasets = datasets{d};
alldat(cnt).datasetnames = alltitles{a};

% also add the difference in correlation, steigers test
[r,p,rlo,rup] = spearmans(allresults(a).v_prevresp(:), allresults(a).z_prevresp(:));

[rhodiff, rhodiffci, pval] = rddiffci(alldat(cnt).corrz, alldat(cnt).corrv, ...
r, numel(allresults(a).v_prevresp), 0.05);

alldat(cnt).corrdiff = rhodiff;
alldat(cnt).corrdiff_ci = rhodiffci;
alldat(cnt).pdiff = pval;

% plotting layout for forestPlot
alldat(cnt).marker = allresults(a).marker;
alldat(cnt).scattercolor = allresults(a).scattercolor;
alldat(cnt).meancolor = allresults(a).meancolor;

cnt = cnt + 1;
end
end

% ========================================================== %
% COMPUTE CORRELATIONS
% ========================================================== %

forestPlot(alldat(1:2:end));
print(gcf, '-dpdf', sprintf('~/Data/serialHDDM/forestplot_regressionHDDM_prevcorrect.pdf'));
forestPlot(alldat(2:2:end));
print(gcf, '-dpdf', sprintf('~/Data/serialHDDM/forestplot_regressionHDDM_preverror.pdf'));

end

function [ vec_repeat ] = trivial_probabilities(p,lag)
vec_repeat(1)=p;
for i=2:lag;
vec_repeat(i)=p*vec_repeat(i-1)+(1-p)*(1-vec_repeat(i-1));
end

vec_repeat = vec_repeat(end);
end

79 changes: 11 additions & 68 deletions define_behavioral_metrics.m
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

nrSess = length(unique(alldata.session)) + 1;
results = array2table(nan(length(unique(alldata.subj_idx))*nrSess, length(varnames)), 'variablenames', varnames);
% results.drug = repmat({'NaN'}, length(unique(alldata.subj_idx))*nrSess, 1);

% preallocate dprime for different coherence levels
if sum(strcmp(alldata.Properties.VariableNames, 'coherence')) > 0,
Expand Down Expand Up @@ -47,21 +46,6 @@
alldata.correct = (tmpstim == tmpresp);
end

% % only MEG-PL data has starthand
% if isfield(alldata, 'startHand'),
% alldata.startHand(alldata.startHand > 20) = nan;
% else
% alldata.startHand = nan(size(alldata.subj_idx));
% end

% for criterion shift
alldata.nextstim = circshift(alldata.stimulus, -1);
alldata.nextresp = circshift(alldata.response, -1);
try
alldata.nextstim((diff(alldata.trial) ~= 1)) = NaN;
alldata.nextresp((diff(alldata.trial) ~= 1)) = NaN;
end

% for mulder et al. analysis
alldata.prevstim = circshift(alldata.stimulus, 1);
alldata.prevresp = circshift(alldata.response, 1);
Expand Down Expand Up @@ -122,10 +106,12 @@
% data.stimrepeat = [~(abs(diff(data.stimulus)) > 0); NaN];

% 01.10.2017, use the same metric as in MEG, A1c_writeCSV.m
for l = 1:16,
data.(['repeat' num2str(l)]) = double(data.response == circshift(data.response, l));
for l = 1:6,
data.(['prev' num2str(l) 'resp']) = circshift(data.response, l);
data.(['prev' num2str(l) 'stim']) = circshift(data.stimulus, l);
data.(['prev' num2str(l) 'resp'])(data.(['prev' num2str(l) 'resp']) == 0) = -1;

data.(['repeat' num2str(l)]) = double(data.response == circshift(data.response, l));
wrongTrls = ((data.trial - circshift(data.trial, l)) ~= l);
data.(['repeat' num2str(l)])(wrongTrls) = NaN;
end
Expand Down Expand Up @@ -158,64 +144,21 @@
% for figure 6c
% ======================================= %

% ALSO REMOVE THE EFFECT OF MORE RECENT LAGS, TAKE THE RESIDUALS
repetitions_mat = data{:, {'repeat1', 'repeat2', 'repeat3', 'repeat4', ...
'repeat5', 'repeat6', 'repeat7', 'repeat8', 'repeat9', 'repeat10', ...
'repeat11', 'repeat12', 'repeat13', 'repeat14', 'repeat15', 'repeat16'}};
repetitions_mat(repetitions_mat == 0) = -1;
repetitions_mat(isnan(repetitions_mat)) = 0;
qr_mat = qr(repetitions_mat);

for l = 1:size(repetitions_mat, 2),

usetrls = find((repetitions_mat(:, l) ~= 0));
% use QR decomposition, only on valid trials, to remove the effect of more recent choice sequences
cleaned = qr_mat(usetrls, l);

% put back
data.(['repeat_corrected' num2str(l)]) = nan(size(data.(['repeat' num2str(l)])));
data.(['repeat_corrected' num2str(l)])(usetrls) = cleaned(:, end);

end

for l = 1:16,
for l = 1:6,
results.(['repetition' num2str(l)])(icnt) = nanmean(data.(['repeat' num2str(l)]));
end
for l = 1:16,
results.(['repetition_corrected' num2str(l)])(icnt) = nanmean(data.(['repeat_corrected' num2str(l)]));
end

% ======================================= %
% logistic regression weights
% ======================================= %

X = [data.stimulus data.prev1resp data.prev2resp data.prev3resp data.prev4resp ...
data.prev5resp data.prev6resp data.prev7resp data.prev8resp data.prev9resp ...
data.prev10resp data.prev11resp data.prev12resp data.prev13resp data.prev14resp data.prev15resp];
X(X == 0) = -1;
b = glmfit(X, data.response, 'binomial', 'constant', 'on');
b = b(3:end); % remove overall bias and stimulus weight
b2 = glmfit(qr(X), data.response, 'binomial', 'constant', 'on');
b2 = b2(3:end); % remove overall bias and stimulus
for l = 1:length(b)
results.(['logistic' num2str(l)])(icnt) = b(l);
results.(['logistic_orth' num2str(l)])(icnt) = b2(l);
end
results.(['repetition_correct' num2str(l)])(icnt) = ...
nanmean(data.(['repeat' num2str(l)])((data.(['prev' num2str(l) 'stim']) > 0) == (data.(['prev' num2str(l) 'resp']) > 0)));
results.(['repetition_error' num2str(l)])(icnt) = ...
nanmean(data.(['repeat' num2str(l)])((data.(['prev' num2str(l) 'stim']) > 0) ~= (data.(['prev' num2str(l) 'resp']) > 0)));
end

% for the first lag, no correction (perhaps not necessary?)
%results.repetition_corrected1(icnt) = results.repetition1(icnt);
results.repetition(icnt) = nanmean(data.repeat1);

% also compute this after error and correct trials
results.repetition_prevcorrect(icnt) = nanmean(data.repeat((data.prevstim > 0) == (data.prevresp > 0)));
results.repetition_preverror(icnt) = nanmean(data.repeat((data.prevstim > 0) ~= (data.prevresp > 0)));

% % criterion based on repetition and stimulus sequences
% [~, c] = dprime(data.stimrepeat, data.repeat);
% results.repetitioncrit(icnt) = -c;

% criterion based on next trial bias, then collapsed
results.criterionshift(icnt) = criterionshift(data.response, data.nextstim, data.nextresp);

end
end

Expand Down
16 changes: 8 additions & 8 deletions forestPlot.m
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,13 @@ function forestPlot(alldat)
disp('z bayes factor');
bf = prod([alldat(ds).bfz])
if bf < 100,
t = title(sprintf('BF_{10} < 1/100'));
t = title(sprintf('BF_{10} < 1/100'), 'fontweight', 'normal', 'fontangle', 'italic');
elseif bf > 100,
t = title(sprintf('BF_{10} > 100'));
t = title(sprintf('BF_{10} > 100'), 'fontweight', 'normal', 'fontangle', 'italic');
elseif bf < 1,
t = title(sprintf('BF_{10} = 1/%.2f', 1/bf));
t = title(sprintf('BF_{10} = 1/%.2f', 1/bf), 'fontweight', 'normal', 'fontangle', 'italic');
elseif bf > 1,
t = title(sprintf('BF_{10} = %.2f', bf));
t = title(sprintf('BF_{10} = %.2f', bf), 'fontweight', 'normal', 'fontangle', 'italic');
end
t.Position(2) = t.Position(2) - 1.2;

Expand Down Expand Up @@ -133,13 +133,13 @@ function forestPlot(alldat)

bf = prod([alldat(ds).bfv]);
if bf < 100,
t = title(sprintf('BF_{10} < 1/100'));
t = title(sprintf('BF_{10} < 1/100'), 'fontweight', 'normal', 'fontangle', 'italic');
elseif bf > 100,
t = title(sprintf('BF_{10} > 100'));
t = title(sprintf('BF_{10} > 100'), 'fontweight', 'normal', 'fontangle', 'italic');
elseif bf < 1,
t = title(sprintf('BF_{10} = 1/%.2f', 1/bf));
t = title(sprintf('BF_{10} = 1/%.2f', 1/bf), 'fontweight', 'normal', 'fontangle', 'italic');
elseif bf > 1,
t = title(sprintf('BF_{10} = %.2f', bf));
t = title(sprintf('BF_{10} = %.2f', bf), 'fontweight', 'normal', 'fontangle', 'italic');
end
t.Position(2) = t.Position(2) - 1.2;
% move closer together
Expand Down
Loading

0 comments on commit 16c9e1b

Please sign in to comment.