forked from PumpkinPop/StdVisualModel
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy paths2_fit_all_cluster.m
98 lines (78 loc) · 3.6 KB
/
s2_fit_all_cluster.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
%% Set up the dataset and the models we are going to test
% For running on HPC, execute
% sbatch hpc_solve_models.sh
%% Parse the hyperparameters
if ~exist('doCross', 'var'), doCross = false; end
if ~exist('target', 'var'), target = 'target'; end % 'target' or 'All';
if ~exist('optimizer', 'var'), optimizer = 'classic'; end % 'classic' or 'reparam';
if ~exist('start_idx', 'var'), start_idx = 1; end % cache
if ~exist('choose_model', 'var'), choose_model = 'all'; end
switch doCross
case false
cross_valid = 'one'; % 'one': not cross validate; 'cross_valid': cross validate
data_folder = 'noCross'; % save in which folder. value space: 'noCross', 'Cross'
case true
cross_valid = 'cross_valid'; % choose what kind of cross , value space: 'one', 'cross_valid'. 'one' is no cross validation.
data_folder = 'Cross'; % save in which folder. value space: 'noCross', .....
end
fittime = 40; % how many initialization. value space: Integer
verbose = 'off'; % show the fit details?
%% generate save address and choose data
% save address
save_address = fullfile(stdnormRootPath, 'Data', data_folder, target, optimizer);
if ~exist(save_address, 'dir'), mkdir(save_address); end
% create jobs
T = chooseData(choose_model, optimizer, fittime);
%% Start fit
% assign job
hpc_job_number = str2double(getenv('SLURM_ARRAY_TASK_ID'));
if isnan(hpc_job_number), hpc_job_number = 2; end
dataset = T.dataset(hpc_job_number);
roi = T.roiNum(hpc_job_number);
model_idx = T.modelNum(hpc_job_number);
model = T.modelLoader{hpc_job_number};
% set the save info, this helps in
% continuing job in the broken pipe
save_temp = fullfile(save_address, 'temp');
if ~exist(save_temp, 'dir'), mkdir(save_temp); end
save_info.dir = save_temp;
save_info.roi = roi;
save_info.model_idx = model_idx;
save_info.dataset = dataset;
save_info.start_idx = start_idx;
% display information to keep track of fitting
display = [ 'dataset: ' num2str(dataset), ' roi: ',num2str(roi), ' model: ', num2str(model_idx) ];
disp(display)
% load training label
BOLD_target = dataloader(stdnormRootPath, 'BOLD_target', target, dataset, roi);
% load the input E
switch model.model_type
case 'orientation'; which_obj = 'E_ori'; % CE, NOA
case 'space' ; which_obj = 'E_xy'; % SOC, OTS
end
% load contrast energy
E = dataloader(stdnormRootPath, which_obj, target, dataset, roi);
x = {E};
disp(model.legend)
% get pre-calc normalized energy
switch model.legend
case 'OTS'
Z = dataloader(stdnormRootPath, 'Z1', target, dataset, roi);
x{end + 1} = Z;
case 'DN'
Z = dataloader(stdnormRootPath, 'Z2', target, dataset, roi);
x{end + 1} = Z;
end
% fit the data without cross validation: knock-1-out
[BOLD_pred, params, Rsquare, model] = ...
model.fit(model, x, BOLD_target, verbose, cross_valid, save_info);
if strcmp(cross_valid, 'one')
loss_log = model.loss_log;
end
% save data
save(fullfile(save_address , sprintf('parameters_data-%d_roi-%d_model-%d.mat',dataset, roi, model_idx)), 'params');
save(fullfile(save_address , sprintf('prediction_data-%d_roi-%d_model-%d.mat',dataset, roi, model_idx)), 'BOLD_pred');
save(fullfile(save_address , sprintf('Rsquare_data-%d_roi-%d_model-%d.mat',dataset, roi, model_idx)) , 'Rsquare');
if strcmp(cross_valid, 'one')
save(fullfile(save_address , sprintf('loss_log_data-%d_roi-%d_model-%d.mat',dataset, roi, model_idx)) , 'loss_log');
end