-
Notifications
You must be signed in to change notification settings - Fork 48
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
2,197 additions
and
196 deletions.
There are no files selected for viewing
Binary file not shown.
Large diffs are not rendered by default.
Oops, something went wrong.
Binary file not shown.
This file was deleted.
Oops, something went wrong.
204 changes: 98 additions & 106 deletions
204
..._project/scripts/delay_discounting_main.R → ...ming_project/scripts/learning_task_main.R
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,106 +1,98 @@ | ||
# ============================================================================= | ||
#### Info #### | ||
# ============================================================================= | ||
# Hierachical models for delay discounting models | ||
# | ||
|
||
# ============================================================================= | ||
#### Construct Data #### | ||
# ============================================================================= | ||
# clear workspace | ||
rm(list=ls(all=TRUE)) | ||
library(rstan) | ||
library(loo) | ||
library(ggplot2) | ||
|
||
#### read raw ----------------------------------------------------------------- | ||
rawdata = # complete this line for reading raw data | ||
|
||
#### Preprocess the data ------------------------------------------------------ | ||
subjList = unique(rawdata[,"subjID"]) | ||
nSubjects = length(subjList) | ||
|
||
Tsubj = as.vector( rep( 0, nSubjects ) ) # number of valid trials per subj | ||
|
||
for ( s in 1:nSubjects ) { | ||
curSubj = subjList[ s ] | ||
Tsubj[s] = sum( rawdata$subjID == curSubj ) | ||
} | ||
|
||
maxTrials = max(Tsubj) | ||
delay_later = array(0, c(nSubjects, maxTrials) ) | ||
amount_later = array(0, c(nSubjects, maxTrials) ) | ||
delay_sooner = array(0, c(nSubjects, maxTrials) ) | ||
amount_sooner = array(0, c(nSubjects, maxTrials) ) | ||
choice = array(0, c(nSubjects, maxTrials) ) | ||
|
||
for (s in 1:nSubjects) { | ||
curSubj = subjList[s] | ||
useTrials = Tsubj[s] | ||
tmp = subset(rawdata, rawdata$subjID == curSubj) | ||
delay_later[s, 1:useTrials] = tmp$delay_later | ||
amount_later[s, 1:useTrials] = tmp$amount_later | ||
delay_sooner[s, 1:useTrials] = tmp$delay_sooner | ||
amount_sooner[s, 1:useTrials] = tmp$amount_sooner | ||
choice[s, 1:useTrials] = tmp$choice | ||
} | ||
|
||
dataList = list( | ||
nSubjects = nSubjects, | ||
nTrials = maxTrials, | ||
Tsubj = Tsubj, | ||
choice = choice, | ||
amount_later = amount_later, | ||
delay_later = delay_later, | ||
amount_sooner = amount_sooner, | ||
delay_sooner = delay_sooner | ||
) | ||
|
||
# ============================================================================= | ||
#### Running Stan #### | ||
# ============================================================================= | ||
rstan_options(auto_write = TRUE) | ||
options(mc.cores = 2) | ||
|
||
nIter = 2000 | ||
nChains = 4 | ||
nWarmup = floor(nIter/2) | ||
nThin = 1 | ||
|
||
#### run the hyperbolic model ---------------------------------------- | ||
modelFile1 = 'scripts/hyperbolic.stan' | ||
|
||
cat("Estimating", modelFile1, "model... \n") | ||
startTime = Sys.time(); print(startTime) | ||
cat("Calling", nChains, "simulations in Stan... \n") | ||
|
||
fit_hyperbolic = stan() # complete this line for calling Stan | ||
|
||
cat("Finishing", modelFile1, "model simulation ... \n") | ||
endTime = Sys.time(); print(endTime) | ||
cat("It took",as.character.Date(endTime - startTime), "\n") | ||
|
||
#### run the simple heuristic model --------------------------------------- | ||
modelFile2 = 'scripts/heuristic.stan' | ||
|
||
cat("Estimating", modelFile1, "model... \n") | ||
startTime = Sys.time(); print(startTime) | ||
cat("Calling", nChains, "simulations in Stan... \n") | ||
|
||
fit_heuristic = stan() # complete this line for calling Stan | ||
|
||
cat("Finishing", modelFile2, "model simulation ... \n") | ||
endTime = Sys.time(); print(endTime) | ||
cat("It took",as.character.Date(endTime - startTime), "\n") | ||
|
||
# ============================================================================= | ||
#### Model selection #### | ||
# ============================================================================= | ||
LL_hyperbolic = # complete this line for extreact log-likelihood | ||
LL_heuristic = # complete this line for extreact log-likelihood | ||
|
||
waic_hyperbolic = waic(LL1) | ||
waic_heuristic = waic(LL2) | ||
|
||
|
||
#### End of file | ||
# ============================================================================= | ||
#### Info #### | ||
# ============================================================================= | ||
# Hierachical models for two-armed bandit learning task | ||
# (C) Lei Zhang <[email protected]> | ||
|
||
# ============================================================================= | ||
#### Construct Data #### | ||
# ============================================================================= | ||
# clear workspace | ||
rm(list=ls(all=TRUE)) | ||
library(rstan) | ||
library(loo) | ||
library(ggplot2) | ||
|
||
#### read raw ----------------------------------------------------------------- | ||
rawdata = # complete this line for reading raw data | ||
|
||
# write a line here to remove missing trials | ||
|
||
#### Preprocess the data ------------------------------------------------------ | ||
subjList = unique(rawdata[,"subjID"]) | ||
nSubjects = length(subjList) | ||
|
||
Tsubj = as.vector( rep( 0, nSubjects ) ) # number of valid trials per subj | ||
|
||
for ( s in 1:nSubjects ) { | ||
curSubj = subjList[ s ] | ||
Tsubj[s] = sum( rawdata$subjID == curSubj ) | ||
} | ||
|
||
maxTrials = max(Tsubj) | ||
choice = array(0, c(nSubjects, maxTrials) ) | ||
reward = array(0, c(nSubjects, maxTrials) ) | ||
|
||
for (s in 1:nSubjects) { | ||
curSubj = subjList[s] | ||
useTrials = Tsubj[s] | ||
tmp = subset(rawdata, rawdata$subjID == curSubj) | ||
choice[s, 1:useTrials] = tmp$choice | ||
reward[s, 1:useTrials] = tmp$reward | ||
} | ||
|
||
dataList = list( | ||
nSubjects = nSubjects, | ||
nTrials = maxTrials, | ||
Tsubj = Tsubj, | ||
choice = choice, | ||
reward = reward | ||
) | ||
|
||
# ============================================================================= | ||
#### Running Stan #### | ||
# ============================================================================= | ||
rstan_options(auto_write = TRUE) | ||
options(mc.cores = 2) # <-- adjust if you want to run 4 cores in parallel | ||
|
||
nIter = 2000 | ||
nChains = 4 | ||
nWarmup = floor(nIter/2) | ||
nThin = 1 | ||
|
||
#### run the Rescorla-Wagner model ---------------------------------------- | ||
modelFile1 = 'scripts/rw.stan' | ||
|
||
cat("Estimating", modelFile1, "model... \n") | ||
startTime = Sys.time(); print(startTime) | ||
cat("Calling", nChains, "simulations in Stan... \n") | ||
|
||
fit_hyperbolic = stan() # complete this line for calling Stan | ||
|
||
cat("Finishing", modelFile1, "model simulation ... \n") | ||
endTime = Sys.time(); print(endTime) | ||
cat("It took",as.character.Date(endTime - startTime), "\n") | ||
|
||
#### run the reward-punishment model --------------------------------------- | ||
modelFile2 = 'scripts/rp.stan' | ||
|
||
cat("Estimating", modelFile1, "model... \n") | ||
startTime = Sys.time(); print(startTime) | ||
cat("Calling", nChains, "simulations in Stan... \n") | ||
|
||
fit_heuristic = stan() # complete this line for calling Stan | ||
|
||
cat("Finishing", modelFile2, "model simulation ... \n") | ||
endTime = Sys.time(); print(endTime) | ||
cat("It took",as.character.Date(endTime - startTime), "\n") | ||
|
||
# ============================================================================= | ||
#### Model selection #### | ||
# ============================================================================= | ||
LL_rw = # complete this line for extreact log-likelihood | ||
LL_rp = # complete this line for extreact log-likelihood | ||
|
||
waic_rw = waic(LL_rw) | ||
waic_rp = waic(LL_rp) | ||
|
||
#### End of file |
31 changes: 17 additions & 14 deletions
31
Programing_project/scripts/heuristic.stan → Programing_project/scripts/rp.stan
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,17 @@ | ||
data { | ||
} | ||
|
||
parameters { | ||
} | ||
|
||
transformed parameters { | ||
} | ||
|
||
model { | ||
} | ||
|
||
generated quantities { | ||
} | ||
data { | ||
} | ||
|
||
transformed data { | ||
} | ||
|
||
parameters { | ||
} | ||
|
||
transformed parameters { | ||
} | ||
|
||
model { | ||
} | ||
|
||
generated quantities { | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
data { | ||
int<lower=1> nSubjects; | ||
int<lower=1> nTrials; | ||
int<lower=1,upper=2> choice; | ||
real<lower=-1, upper=1> reward[nSubjects, nTrials]; | ||
} | ||
|
||
transformed data { | ||
vector[2] initV; // initial values for V | ||
initV = rep_vector(0.0, 2); | ||
} | ||
|
||
parameters { | ||
// group-level parameters | ||
real lr_mu_rae; | ||
real tau_mu_raw; | ||
real lr_sd_raw; | ||
real tau_sd_raw; | ||
|
||
// subject-level raw parameters | ||
vector[nSubjects] lr_raw; | ||
vector[nSubjects] tau_raw; | ||
} | ||
|
||
transformed parameters { | ||
vector<lower=0,upper=1>[nSubjects] lr; | ||
vector<lower=0,upper=3>[nSubjects] tau; | ||
|
||
lr = Phi_approx( lr_mu_raw + lr_sd_raw * lr_raw[s] ) | ||
tau = Phi_approx( tau_mu_raw + tau_sd_raw * tau_raw ) * 5; | ||
} | ||
|
||
|
||
model { | ||
// group-level prior | ||
lr_mu_raw ~ normal(0,1); | ||
tau_mu_raw ~ normal(0,1); | ||
|
||
// individual-level prior | ||
lr_raw ~ normal(0,1); | ||
tau_raw ~ normal(0,1); | ||
|
||
for (s in 1:nSubjects) { | ||
vector[2] v; | ||
real pe; | ||
v = initV; | ||
|
||
for (t in 1:nTrials) { | ||
Choice[s,t] ~ categorical( tau[s] * v ); | ||
|
||
pe = Reward - v[choice[s,t]]; | ||
v[choice[s,t]] = v[choice[s,t]] + lr[s] * pe; | ||
} | ||
} | ||
} | ||
|
||
generated quantities { | ||
real<lower=0,upper=1> lr_mu; | ||
real<lower=0,upper=5> tau_mu; | ||
|
||
real log_lik[nSubjects]; | ||
|
||
lr_mu = Phi_approx(lr_mu_raw); | ||
tau_mu = Phi_approx(tau_mu_raw) * 5; | ||
|
||
{ // local section, this saves time and space | ||
for (s in 1:nSubjects) { | ||
vector[2] v; | ||
real pe; | ||
|
||
v = initV; | ||
|
||
for (t in 1:nTrials) { | ||
log_lik[s] = log_lik[s] + categorical_logit_lpdf(choice[s,t] | tau[s] * v); | ||
|
||
pe = reward[s,t] - v[choice[s,t]]; | ||
v[choice[s,t]] = v[choice[s,t]] + lr[s] * pe; | ||
} | ||
} | ||
} | ||
} |
Binary file renamed
BIN
+38.8 KB
...aming_project/short_summary_template.docx → Programing_project/short_summary.docx
Binary file not shown.