Skip to content

Commit

Permalink
Re-factored the stan code for longitudinal generated quantities (#289)
Browse files Browse the repository at this point in the history
  • Loading branch information
gowerc authored Mar 25, 2024
1 parent 2c10109 commit 1db0fc4
Show file tree
Hide file tree
Showing 14 changed files with 138 additions and 46 deletions.
2 changes: 1 addition & 1 deletion R/JointModel.R
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ JointModel <- function(
)
)

base_model <- paste0(read_stan("base/base.stan"), collapse = "\n")
base_model <- read_stan("base/base.stan")

stan_full <- decorated_render(
.x = base_model,
Expand Down
2 changes: 1 addition & 1 deletion R/Link.R
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ as.StanModule.Link <- function(object, ...) {

base_stan <- StanModule(
decorated_render(
.x = paste(read_stan("base/link.stan"), collapse = "\n"),
.x = read_stan("base/link.stan"),
items = as.list(keys)
)
)
Expand Down
2 changes: 1 addition & 1 deletion R/LongitudinalGSF.R
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ LongitudinalGSF <- function(
) {

gsf_model <- StanModule(decorated_render(
.x = paste0(read_stan("lm-gsf/model.stan"), collapse = "\n"),
.x = read_stan("lm-gsf/model.stan"),
centred = centred
))

Expand Down
9 changes: 6 additions & 3 deletions R/LongitudinalModel.R
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,16 @@ LongitudinalModel <- function(
...
) {

base_long <- StanModule(
x = "base/longitudinal.stan"
base_stan <- read_stan("base/longitudinal.stan")

stan_full <- decorated_render(
.x = base_stan,
stan = add_missing_stan_blocks(as.list(stan))
)

.LongitudinalModel(
StanModel(
stan = merge(base_long, stan),
stan = StanModule(stan_full),
parameters = parameters,
name = name,
...
Expand Down
2 changes: 1 addition & 1 deletion R/LongitudinalSteinFojo.R
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ LongitudinalSteinFojo <- function(
) {

sf_model <- StanModule(decorated_render(
.x = paste0(read_stan("lm-stein-fojo/model.stan"), collapse = "\n"),
.x = read_stan("lm-stein-fojo/model.stan"),
centred = centred
))

Expand Down
4 changes: 3 additions & 1 deletion R/StanModule.R
Original file line number Diff line number Diff line change
Expand Up @@ -268,9 +268,11 @@ read_stan <- function(string) {
files <- c(local_file, local_inst_file, system_file)
for (fi in files) {
if (is_file(fi)) {
return(readLines(fi))
string <- readLines(fi)
break
}
}
string <- paste0(string, collapse = "\n")
return(string)
}

Expand Down
2 changes: 1 addition & 1 deletion R/SurvivalModel.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ SurvivalModel <- function(
name = "<Unnamed>",
...
) {
base_stan <- paste0(read_stan("base/survival.stan"), collapse = "\n")
base_stan <- read_stan("base/survival.stan")
stan_full <- decorated_render(
.x = base_stan,
stan = add_missing_stan_blocks(as.list(stan))
Expand Down
39 changes: 39 additions & 0 deletions inst/stan/base/longitudinal.stan
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@

functions {
{{ stan.functions }}
}


data{
//
Expand Down Expand Up @@ -38,5 +42,40 @@ data{
vector[n_mat_inds_all_y[1]] w_mat_inds_all_y;
array[n_mat_inds_all_y[2]] int v_mat_inds_all_y;
array[n_mat_inds_all_y[3]] int u_mat_inds_all_y;

{{ stan.data }}
}

transformed data {
{{ stan.transformed_data }}
}

parameters {
{{ stan.parameters }}
}

transformed parameters {
{{ stan.transformed_parameters }}
}

model {
{{ stan.model }}
}

generated quantities {
{{ stan.generated_quantities }}

//
// Source - base/longitudinal.stan
//
matrix[n_pt_select_index, n_lm_time_grid] y_fit_at_time_grid;
if (n_lm_time_grid > 0) {
for (i in 1:n_pt_select_index) {
int current_pt_index = pt_select_index[i];
y_fit_at_time_grid[i, ] = lm_predict_individual_patient(
lm_time_grid,
long_gq_parameters[current_pt_index, ]
);
}
}
}
11 changes: 11 additions & 0 deletions inst/stan/lm-gsf/functions.stan
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,16 @@ functions {
);
return result;
}

row_vector lm_predict_individual_patient(vector time, row_vector long_gq_parameters) {
int nrow = rows(time);
return sld(
time,
rep_vector(long_gq_parameters[1], nrow),
rep_vector(long_gq_parameters[2], nrow),
rep_vector(long_gq_parameters[3], nrow),
rep_vector(long_gq_parameters[4], nrow)
)';
}
}

18 changes: 5 additions & 13 deletions inst/stan/lm-gsf/model.stan
Original file line number Diff line number Diff line change
Expand Up @@ -113,17 +113,9 @@ generated quantities {
//
// Source - lm-gsf/model.stan
//
matrix[n_pt_select_index, n_lm_time_grid] y_fit_at_time_grid;
if (n_lm_time_grid > 0) {
for (i in 1:n_pt_select_index) {
int current_pt_index = pt_select_index[i];
y_fit_at_time_grid[i, ] = to_row_vector(sld(
lm_time_grid,
rep_vector(lm_gsf_psi_bsld[current_pt_index], n_lm_time_grid),
rep_vector(lm_gsf_psi_ks[current_pt_index], n_lm_time_grid),
rep_vector(lm_gsf_psi_kg[current_pt_index], n_lm_time_grid),
rep_vector(lm_gsf_psi_phi[current_pt_index], n_lm_time_grid)
));
}
}
matrix[Nind, 4] long_gq_parameters;
long_gq_parameters[, 1] = lm_gsf_psi_bsld;
long_gq_parameters[, 2] = lm_gsf_psi_ks;
long_gq_parameters[, 3] = lm_gsf_psi_kg;
long_gq_parameters[, 4] = lm_gsf_psi_phi;
}
26 changes: 15 additions & 11 deletions inst/stan/lm-random-slope/model.stan
Original file line number Diff line number Diff line change
@@ -1,5 +1,16 @@


functions {
//
// Source - lm-random-slope/model.stan
//
row_vector lm_predict_individual_patient(vector time, row_vector long_gq_parameters) {
int nrow = rows(time);
return (
rep_vector(long_gq_parameters[1], nrow) +
rep_vector(long_gq_parameters[2], nrow) .* time
)';
}
}


parameters {
Expand Down Expand Up @@ -55,14 +66,7 @@ generated quantities {
//
// Source - lm-random-slope/model.stan
//
matrix[n_pt_select_index, n_lm_time_grid] y_fit_at_time_grid;
if (n_lm_time_grid > 0) {
for (i in 1:n_pt_select_index) {
int current_pt_index = pt_select_index[i];
y_fit_at_time_grid[i, ] =
lm_rs_ind_intercept[current_pt_index] +
lm_rs_ind_rnd_slope[current_pt_index] .*
to_row_vector(lm_time_grid);
}
}
matrix[Nind, 2] long_gq_parameters;
long_gq_parameters[, 1] = lm_rs_ind_intercept;
long_gq_parameters[, 2] = lm_rs_ind_rnd_slope;
}
9 changes: 9 additions & 0 deletions inst/stan/lm-stein-fojo/functions.stan
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,14 @@ functions {
);
return result;
}
row_vector lm_predict_individual_patient(vector time, row_vector long_gq_parameters) {
int nrow = rows(time);
return sld(
time,
rep_vector(long_gq_parameters[1], nrow),
rep_vector(long_gq_parameters[2], nrow),
rep_vector(long_gq_parameters[3], nrow)
)';
}
}

17 changes: 4 additions & 13 deletions inst/stan/lm-stein-fojo/model.stan
Original file line number Diff line number Diff line change
Expand Up @@ -101,21 +101,12 @@ model {
{%- endif -%}
}


generated quantities {
//
// Source - lm-stein-fojo/model.stan
//
matrix[n_pt_select_index, n_lm_time_grid] y_fit_at_time_grid;
if (n_lm_time_grid > 0) {
for (i in 1:n_pt_select_index) {
int current_pt_index = pt_select_index[i];
y_fit_at_time_grid[i, ] = to_row_vector(sld(
lm_time_grid,
rep_vector(lm_sf_psi_bsld[current_pt_index], n_lm_time_grid),
rep_vector(lm_sf_psi_ks[current_pt_index], n_lm_time_grid),
rep_vector(lm_sf_psi_kg[current_pt_index], n_lm_time_grid)
));
}
}
matrix[Nind, 3] long_gq_parameters;
long_gq_parameters[, 1] = lm_gsf_psi_bsld;
long_gq_parameters[, 2] = lm_gsf_psi_ks;
long_gq_parameters[, 3] = lm_gsf_psi_kg;
}
41 changes: 41 additions & 0 deletions tests/testthat/test-LongitudinalQuantiles.R
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,44 @@ test_that("LongitudinalQuantities can recover known results", {

expect_true(all(dat_all$correl > 0.99))
})


test_that("LongitudinalQuantities correctly subsets patients and rebuilds correct value for each sample", {
set.seed(101)
ensure_test_data_1()
times <- c(-100, 0, 1, 100, 200)

longsamps <- LongitudinalQuantities(
test_data_1$jsamples,
groups = c("pt_010", "pt_011", "pt_099"),
time_grid = times
)

map_me <- function(time, mat) {
dplyr::tibble(time = as.data.frame(mat[, 1] + mat[, 2] * time)[, 1])
}

vars_10 <- c("lm_rs_ind_intercept[10]", "lm_rs_ind_rnd_slope[10]")
mat1 <- test_data_1$jsamples@results$draws(vars_10, format = "draws_matrix")
dat1 <- dplyr::bind_rows(lapply(times, map_me, mat = mat1))

vars_11 <- c("lm_rs_ind_intercept[11]", "lm_rs_ind_rnd_slope[11]")
mat2 <- test_data_1$jsamples@results$draws(vars_11, format = "draws_matrix")
dat2 <- dplyr::bind_rows(lapply(times, map_me, mat = mat2))

vars_99 <- c("lm_rs_ind_intercept[99]", "lm_rs_ind_rnd_slope[99]")
mat3 <- test_data_1$jsamples@results$draws(vars_99, format = "draws_matrix")
dat3 <- dplyr::bind_rows(lapply(times, map_me, mat = mat3))

vec_actual <- as.data.frame(longsamps)[["values"]]
vec_expected <- c(dat1$time, dat2$time, dat3$time)

# cmdstanr rounds the generated samples to 6 s.f.
# Stan then uses these rounded samples when calculating the generated quantiles
# The generated quantities themselves are then rounded to 6 s.f. when being stored on disk
# This makes direct comparison of values (even with rounding or tolerance) impossible
# Instead we just test for an extremely high correlation
# For reference even changing a single number in one of the vectors from say 34 to 35
# is enough to cause this test to fail
expect_gt(cor(vec_actual, vec_expected), 0.9999999999)
})

0 comments on commit 1db0fc4

Please sign in to comment.