forked from ghurault/mbml-eczema
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Autoregression.stan
140 lines (120 loc) · 4.06 KB
/
Autoregression.stan
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
data {
int<lower = 0> N; // Total number of observations (missing and non-missing)
int<lower = 0> N_obs; // Number of non-missing observations
int<lower = 0> N_pt; // Number of patients
int<lower = 0> t_max[N_pt]; // Vector of time-series length (number of days) for each patient
int<lower = 1, upper = N> idx_obs[N_obs]; // Index of non-missing observations
real<lower = 0, upper = 10> S_obs[N_obs]; // Observed severity score
real<lower = 0, upper = 1> Treat[N]; // Daily treatment usage
int<lower = 0> horizon; // Time horizon (in days) for prediction
}
transformed data {
int N_pred = N + N_pt * horizon; // Number of observations for posterior predictive check (fit + prediction)
int start[N_pt]; // Index of first observation for patient each patient
int end[N_pt]; // Index of last observation for patient each patient
int N_mis = N - N_obs; // Number of missing observations
int idx_mis[N_mis]; // Index of missing observations
if (N != sum(t_max)) {
reject("N should be equal to sum(t_max)")
}
// Start and end of each time-series
for (k in 1:N_pt) {
if (k == 1) {
start[k] = 1;
} else {
start[k] = end[k - 1] + 1;
}
end[k] = start[k] - 1 + t_max[k];
}
// Index of missing observations
{
int id = 1;
int obs[N] = rep_array(0, N);
obs[idx_obs] = rep_array(1, N_obs);
for (i in 1:N) {
if (obs[i] == 0) {
idx_mis[id] = i;
id += 1;
}
}
}
}
parameters {
real<lower = 0, upper = 10> S_mis[N_mis]; // Missing S
real<lower = -0.5, upper = 0.5> err[N_obs]; // Rounding error
real<lower = 0> sigma_S; // Standard deviation of the Gaussian
real b_S; // Intercept
real mu_wS; // Population autocorrelation logit mean
real<lower = 0> sigma_wS; // Population autocorrelation logit standard deviation
real eta_wS[N_pt]; // Non-centered parametrisation for autocorrelation
real mu_T; // Population mean for responsiveness to treatment
real<lower = 0> sigma_T; // Population standard deviation for responsiveness to treatment
real eta_T[N_pt]; // Non-centered parametrisation for responsiveness to treatment
}
transformed parameters {
real S[N];
real wS[N_pt]; // Patient autocorrelation
real wT[N_pt]; // Patient responsiveness to treatment
// Define S: mix observe and missing values, rounding process
for (i in 1:N_obs) {
if (S_obs[i] == 0) {
S[idx_obs[i]] = S_obs[i] + (0.25 + 0.5 * err[i]); //cf. bounds at 0
} else if (S_obs[i] == 10) {
S[idx_obs[i]] = S_obs[i] - (0.25 + 0.5 * err[i]); // cf. bounds at 10
} else {
S[idx_obs[i]] = S_obs[i] + err[i];
}
}
S[idx_mis] = S_mis;
for (k in 1:N_pt) {
wS[k] = inv_logit(mu_wS + sigma_wS * eta_wS[k]);
wT[k] = mu_T + sigma_T * eta_T[k];
}
}
model {
eta_wS ~ std_normal();
eta_T ~ std_normal();
b_S ~ normal(1, 1);
sigma_S ~ normal(0, 1.5);
mu_wS ~ normal(0, 1);
sigma_wS ~ normal(0, 1.5);
mu_T ~ normal(0, 1);
sigma_T ~ normal(0, 0.5);
for (k in 1:N_pt) {
// Loop over patients
(b_S + wT[k]) ~ normal(0, 2); // prior on "constant term" (can't be too big)
for (t in (start[k] + 1):end[k]) {
// Loop over time
S[t] ~ normal(wS[k] * S[t - 1] + wT[k] * Treat[t - 1] + b_S, sigma_S) T[0, 10];
}
}
}
generated quantities {
vector[N_pred] S_pred;
{
int i = 1; // Indexing S[t]
int i_pred = 1; // Indexing S_pred[t]
real S_prev; // S[t - 1]
real T_prev; // Treat[t - 1]
for (k in 1:N_pt) {
S_pred[i_pred] = S[i]; // Initialisation
for (t in 2:(t_max[k] + horizon)) {
i_pred += 1;
if (t <= t_max[k]) {
i += 1;
S_prev = S[i - 1]; // Fit
T_prev = Treat[i - 1];
} else if (t == (t_max[k] + 1)) {
S_prev = S[i]; // First prediction
T_prev = Treat[i];
} else {
S_prev = S_pred[i_pred - 1]; // Remaining predictions
T_prev = 0; // Assume no treatment
}
S_pred[i_pred] = normal_rng(wS[k] * S_prev + wT[k] * T_prev + b_S, sigma_S);
}
i_pred += 1;
i += 1;
}
}
}