Skip to content

Commit

Permalink
Enable Left-Censoring for Random Slope Model (#291)
Browse files Browse the repository at this point in the history
  • Loading branch information
gowerc authored Mar 25, 2024
1 parent e4fe089 commit eae0339
Show file tree
Hide file tree
Showing 7 changed files with 148 additions and 58 deletions.
1 change: 0 additions & 1 deletion R/LongitudinalModel.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ LongitudinalModel <- function(
name = "<Unnamed>",
...
) {

base_stan <- read_stan("base/longitudinal.stan")

stan_full <- decorated_render(
Expand Down
17 changes: 17 additions & 0 deletions inst/stan/base/longitudinal.stan
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,24 @@ parameters {
}

transformed parameters {
//
// Source - base/longitudinal.stan
//
vector[Nta_total] Ypred_log_lik = rep_vector(0, Nta_total);

{{ stan.transformed_parameters }}

//
// Source - base/longitudinal.stan
//
log_lik += csr_matrix_times_vector(
Nind,
Nta_total,
w_mat_inds_all_y,
v_mat_inds_all_y,
u_mat_inds_all_y,
Ypred_log_lik
);
}

model {
Expand Down
32 changes: 9 additions & 23 deletions inst/stan/lm-gsf/model.stan
Original file line number Diff line number Diff line change
Expand Up @@ -66,31 +66,17 @@ transformed parameters{
lm_gsf_psi_phi[ind_index]
);

log_lik += csr_matrix_times_vector(
Nind,
Nta_obs_y,
w_mat_inds_obs_y,
v_mat_inds_obs_y,
u_mat_inds_obs_y,
vect_normal_log_dens(
Yobs[obs_y_index],
Ypred[obs_y_index],
Ypred[obs_y_index] * lm_gsf_sigma
)
);

Ypred_log_lik[obs_y_index] = vect_normal_log_dens(
Yobs[obs_y_index],
Ypred[obs_y_index],
Ypred[obs_y_index] * lm_gsf_sigma
);
if (Nta_cens_y > 0 ) {
log_lik += csr_matrix_times_vector(
Nind,
Nta_cens_y,
w_mat_inds_cens_y,
v_mat_inds_cens_y,
u_mat_inds_cens_y,
vect_normal_log_cum(
Ythreshold,
Ypred[cens_y_index],
Ypred[cens_y_index] * lm_gsf_sigma
)
Ypred_log_lik[cens_y_index] = vect_normal_log_cum(
Ythreshold,
Ypred[cens_y_index],
Ypred[cens_y_index] * lm_gsf_sigma
);
}
}
Expand Down
22 changes: 11 additions & 11 deletions inst/stan/lm-random-slope/model.stan
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,18 @@ transformed parameters {

vector[Nta_total] Ypred = lm_rs_ind_intercept[ind_index] + lm_rs_rslope_ind .* Tobs;

log_lik += csr_matrix_times_vector(
Nind,
Nta_total,
w_mat_inds_all_y,
v_mat_inds_all_y,
u_mat_inds_all_y,
vect_normal_log_dens(
Yobs,
Ypred,
rep_vector(lm_rs_sigma, Nta_total)
)
Ypred_log_lik[obs_y_index] = vect_normal_log_dens(
Yobs[obs_y_index],
Ypred[obs_y_index],
rep_vector(lm_rs_sigma, Nta_obs_y)
);
if (Nta_cens_y > 0 ) {
Ypred_log_lik[cens_y_index] = vect_normal_log_cum(
Ythreshold,
Ypred[cens_y_index],
rep_vector(lm_rs_sigma, Nta_cens_y)
);
}
}


Expand Down
31 changes: 8 additions & 23 deletions inst/stan/lm-stein-fojo/model.stan
Original file line number Diff line number Diff line change
Expand Up @@ -60,31 +60,16 @@ transformed parameters{
lm_sf_psi_kg[ind_index]
);

log_lik += csr_matrix_times_vector(
Nind,
Nta_obs_y,
w_mat_inds_obs_y,
v_mat_inds_obs_y,
u_mat_inds_obs_y,
vect_normal_log_dens(
Yobs[obs_y_index],
Ypred[obs_y_index],
Ypred[obs_y_index] * lm_sf_sigma
)
Ypred_log_lik[obs_y_index] = vect_normal_log_dens(
Yobs[obs_y_index],
Ypred[obs_y_index],
Ypred[obs_y_index] * lm_sf_sigma
);

if (Nta_cens_y > 0 ) {
log_lik += csr_matrix_times_vector(
Nind,
Nta_cens_y,
w_mat_inds_cens_y,
v_mat_inds_cens_y,
u_mat_inds_cens_y,
vect_normal_log_cum(
Ythreshold,
Ypred[cens_y_index],
Ypred[cens_y_index] * lm_sf_sigma
)
Ypred_log_lik[cens_y_index] = vect_normal_log_cum(
Ythreshold,
Ypred[cens_y_index],
Ypred[cens_y_index] * lm_sf_sigma
);
}
}
Expand Down
1 change: 1 addition & 0 deletions tests/testthat/_snaps/JointModelSamples.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
Variables:
Ypred[2800]
Ypred_log_lik[2800]
beta_os_cov[3]
lm_rs_ind_intercept[400]
lm_rs_ind_rnd_slope[400]
Expand Down
102 changes: 102 additions & 0 deletions tests/testthat/test-LongitudinalRandomSlope.R
Original file line number Diff line number Diff line change
Expand Up @@ -193,3 +193,105 @@ test_that("Random Slope Model can recover known parameter values", {
)
expect_gt(lmer_cor, 0.99)
})



test_that("Random Slope Model left-censoring works as expected", {
## Generate data with known parameters
set.seed(739)

jlist <- SimJointData(
design = list(
SimGroup(250, "Arm-A", "Study-X"),
SimGroup(150, "Arm-B", "Study-X")
),
survival = SimSurvivalExponential(1 / 100),
longitudinal = SimLongitudinalRandomSlope(
times = c(-200, -150, -100, -50, 0, 1, 100, 125, 200, 300, 350, 400, 500, 600),
intercept = 30,
sigma = 3,
slope_mu = c(1, 3),
slope_sigma = 0.2
),
.silent = TRUE
)


# Trash the data for negative values
# As the data is censored this shouldn't impact the sampled predictions
dat_lm <- jlist@longitudinal
negative_index <- dat_lm$sld < 0
dat_lm$sld[negative_index] <- runif(sum(negative_index), -999, -100)

jm <- JointModel(
longitudinal = LongitudinalRandomSlope(
intercept = prior_normal(30, 2),
slope_sigma = prior_lognormal(log(0.2), sigma = 0.5),
sigma = prior_lognormal(log(3), sigma = 0.5)
)
)

jdat <- DataJoint(
subject = DataSubject(
data = jlist@survival,
subject = "pt",
arm = "arm",
study = "study"
),
longitudinal = DataLongitudinal(
data = dat_lm,
formula = sld ~ time,
threshold = 0
)
)

mp <- run_quietly({
sampleStanModel(
jm,
data = jdat,
iter_sampling = 200,
iter_warmup = 400,
chains = 1,
refresh = 0,
parallel_chains = 1
)
})


vars <- c(
"lm_rs_intercept", # 30
"lm_rs_slope_mu", # 1 , 3
"lm_rs_slope_sigma", # 0.2
"lm_rs_sigma" # 3
)

pars <- mp@results$summary(vars)


## Check that we can recover main effects parameters
z_score <- (c(30, 1, 3, 0.2, 3) - pars$mean) / pars$sd
expect_true(all(abs(z_score) < qnorm(0.95)))


## Check that we can recover random effects parameters
pars <- suppressWarnings({
mp@results$summary("lm_rs_ind_rnd_slope")$mean
})

## Extract real random effects per patient
## We store them as (random effect + mean)
## thus need to subtract mean for comparison
## to nle4
group_mean <- c(1, 3)

## Check for consistency of random effects with lmer
mod <- lme4::lmer(
sld ~ time:arm + (time - 1 | pt),
jlist@longitudinal
)
lmer_cor <- cor(
lme4::ranef(mod)$pt$time,
pars - group_mean[as.numeric(jlist@survival$arm)]
)
expect_gt(lmer_cor, 0.99)
})

0 comments on commit eae0339

Please sign in to comment.