Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Individual Loo for submodels #408

Merged
merged 8 commits into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion R/JointModel.R
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,9 @@ as.StanModule.JointModel <- function(object, ...) {
longitudinal = add_missing_stan_blocks(as.list(object@longitudinal)),
survival = add_missing_stan_blocks(as.list(object@survival)),
link = add_missing_stan_blocks(as.list(object@link)),
priors = add_missing_stan_blocks(as.list(object@parameters))
priors = add_missing_stan_blocks(as.list(object@parameters)),
has_os_submodel = !is.null(object@survival),
has_long_submodel = !is.null(object@longitudinal)
)
# Unresolved Jinja code within the longitudinal / Survival / Link
# models won't be resolved by the above call to `decorated_render`.
Expand Down
2 changes: 1 addition & 1 deletion README.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ sim_data <- SimJointData(
joint_data <- DataJoint(
subject = DataSubject(
data = sim_data@survival,
subject = "pt",
subject = "subject",
arm = "arm",
study = "study"
),
Expand Down
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@

<!-- markdownlint-disable-file -->

<!-- README.md needs to be generated from README.Rmd. Please edit that file -->

# jmpost <a href="https://genentech.github.io/jmpost/"><img src="man/figures/logo.png" align="right" height="139" /></a>
Expand Down Expand Up @@ -93,7 +94,7 @@ sim_data <- SimJointData(
gamma = 0.97
)
)
#> INFO: 1 subject did not die before max(times)
#> INFO: 1 subjects did not die before max(times)

joint_data <- DataJoint(
subject = DataSubject(
Expand Down
23 changes: 8 additions & 15 deletions inst/stan/base/base.stan
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@ data{
{{ survival.data }}
{{ link.data }}
{{ longitudinal.data }}

{{ priors.data }}

}


Expand All @@ -45,31 +43,26 @@ parameters{


transformed parameters{
//
// Source - base/base.stan
//

// Log-likelihood values for using the loo package.
vector[n_subjects] log_lik = rep_vector(0.0, n_subjects);

{{ longitudinal.transformed_parameters }}
{{ link.transformed_parameters }}
{{ survival.transformed_parameters }}

//
// Source - base/base.stan
//
{% if has_os_submodel and not has_long_submodel -%}
vector[n_subjects] log_lik = os_subj_log_lik;
{% else if has_long_submodel and not has_os_submodel -%}
vector[n_tumour_all] log_lik = long_obvs_log_lik;
{%- endif -%}
}


model{
{{ longitudinal.model }}
{{ link.model }}
{{ survival.model }}

{{ priors.model }}

//
// Source - base/base.stan
//
target += sum(log_lik);
}

generated quantities{
Expand Down
29 changes: 11 additions & 18 deletions inst/stan/base/longitudinal.stan
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

functions {
{{ stan.functions }}
{{ stan.functions }}
}


Expand Down Expand Up @@ -43,42 +43,35 @@ data{
array[n_mat_inds_all_y[2]] int v_mat_inds_all_y;
array[n_mat_inds_all_y[3]] int u_mat_inds_all_y;

{{ stan.data }}
{{ stan.data }}
}

transformed data {
{{ stan.transformed_data }}
{{ stan.transformed_data }}
}

parameters {
{{ stan.parameters }}
{{ stan.parameters }}
}

transformed parameters {
//
// Source - base/longitudinal.stan
//
vector[n_tumour_all] Ypred_log_lik = rep_vector(0, n_tumour_all);
// Vector to store per observation log-likelihood values
vector[n_tumour_all] long_obvs_log_lik = rep_vector(0, n_tumour_all);

{{ stan.transformed_parameters }}
}

model {
{{ stan.model }}
//
// Source - base/longitudinal.stan
//
log_lik += csr_matrix_times_vector(
n_subjects,
n_tumour_all,
w_mat_inds_all_y,
v_mat_inds_all_y,
u_mat_inds_all_y,
Ypred_log_lik
);
}

model {
{{ stan.model }}
target += sum(long_obvs_log_lik);
}

generated quantities {
{{ stan.generated_quantities }}
{{ stan.generated_quantities }}
}
18 changes: 11 additions & 7 deletions inst/stan/base/survival.stan
Original file line number Diff line number Diff line change
Expand Up @@ -123,11 +123,8 @@ parameters {
//
// Source - base/survival.stan
//

// Covariate coefficients.

vector[p_os_cov_design] beta_os_cov;

{{ stan.parameters }}
}

Expand All @@ -140,6 +137,9 @@ transformed parameters {
// Source - base/survival.stan
//

// Vector to store per subject log-likelihood values
vector[n_subjects] os_subj_log_lik = rep_vector(0, n_subjects);

// Calculate covariate contributions to log hazard function
vector[n_subjects] os_cov_contribution = get_os_cov_contribution(
os_cov_design,
Expand All @@ -166,10 +166,10 @@ transformed parameters {
);

// We always add the log-survival to the log-likelihood.
log_lik += log_surv_fit_at_obs_times;
os_subj_log_lik += log_surv_fit_at_obs_times;

// In case of death we add the log-hazard on top.
log_lik[subject_event_index] += to_vector(
os_subj_log_lik[subject_event_index] += to_vector(
log_hazard(
to_matrix(event_times[subject_event_index]),
pars_os,
Expand All @@ -182,9 +182,13 @@ transformed parameters {


model{
{{ stan.model }}
{{ stan.model }}
//
// Source - base/survival.stan
//
target += sum(os_subj_log_lik);
}

generated quantities {
{{ stan.generated_quantities }}
{{ stan.generated_quantities }}
}
4 changes: 2 additions & 2 deletions inst/stan/lm-claret-bruno/model.stan
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,13 @@ transformed parameters{
);


Ypred_log_lik[subject_tumour_index_obs] = vect_normal_log_dens(
long_obvs_log_lik[subject_tumour_index_obs] = vect_normal_log_dens(
tumour_value[subject_tumour_index_obs],
Ypred[subject_tumour_index_obs],
Ypred[subject_tumour_index_obs] * lm_clbr_sigma
);
if (n_tumour_cens > 0 ) {
Ypred_log_lik[subject_tumour_index_cens] = vect_normal_log_cum(
long_obvs_log_lik[subject_tumour_index_cens] = vect_normal_log_cum(
tumour_value_lloq,
Ypred[subject_tumour_index_cens],
Ypred[subject_tumour_index_cens] * lm_clbr_sigma
Expand Down
4 changes: 2 additions & 2 deletions inst/stan/lm-gsf/model.stan
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,13 @@ transformed parameters{
);


Ypred_log_lik[subject_tumour_index_obs] = vect_normal_log_dens(
long_obvs_log_lik[subject_tumour_index_obs] = vect_normal_log_dens(
tumour_value[subject_tumour_index_obs],
Ypred[subject_tumour_index_obs],
Ypred[subject_tumour_index_obs] * lm_gsf_sigma
);
if (n_tumour_cens > 0 ) {
Ypred_log_lik[subject_tumour_index_cens] = vect_normal_log_cum(
long_obvs_log_lik[subject_tumour_index_cens] = vect_normal_log_cum(
tumour_value_lloq,
Ypred[subject_tumour_index_cens],
Ypred[subject_tumour_index_cens] * lm_gsf_sigma
Expand Down
4 changes: 2 additions & 2 deletions inst/stan/lm-random-slope/model.stan
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@ transformed parameters {

vector[n_tumour_all] Ypred = lm_rs_ind_intercept[subject_tumour_index] + lm_rs_rslope_ind .* tumour_time;

Ypred_log_lik[subject_tumour_index_obs] = vect_normal_log_dens(
long_obvs_log_lik[subject_tumour_index_obs] = vect_normal_log_dens(
tumour_value[subject_tumour_index_obs],
Ypred[subject_tumour_index_obs],
rep_vector(lm_rs_sigma, n_tumour_obs)
);
if (n_tumour_cens > 0 ) {
Ypred_log_lik[subject_tumour_index_cens] = vect_normal_log_cum(
long_obvs_log_lik[subject_tumour_index_cens] = vect_normal_log_cum(
tumour_value_lloq,
Ypred[subject_tumour_index_cens],
rep_vector(lm_rs_sigma, n_tumour_cens)
Expand Down
4 changes: 2 additions & 2 deletions inst/stan/lm-stein-fojo/model.stan
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,13 @@ transformed parameters{
lm_sf_psi_kg[subject_tumour_index]
);

Ypred_log_lik[subject_tumour_index_obs] = vect_normal_log_dens(
long_obvs_log_lik[subject_tumour_index_obs] = vect_normal_log_dens(
tumour_value[subject_tumour_index_obs],
Ypred[subject_tumour_index_obs],
Ypred[subject_tumour_index_obs] * lm_sf_sigma
);
if (n_tumour_cens > 0 ) {
Ypred_log_lik[subject_tumour_index_cens] = vect_normal_log_cum(
long_obvs_log_lik[subject_tumour_index_cens] = vect_normal_log_cum(
tumour_value_lloq,
Ypred[subject_tumour_index_cens],
Ypred[subject_tumour_index_cens] * lm_sf_sigma
Expand Down
4 changes: 2 additions & 2 deletions tests/testthat/_snaps/JointModelSamples.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

Variables:
Ypred[2800]
Ypred_log_lik[2800]
beta_os_cov[3]
lm_rs_ind_intercept[400]
lm_rs_ind_rnd_slope[400]
Expand All @@ -20,10 +19,11 @@
lm_rs_sigma
lm_rs_slope_mu[2]
lm_rs_slope_sigma
log_lik[400]
log_surv_fit_at_obs_times[400]
long_obvs_log_lik[2800]
lp__
os_cov_contribution[400]
os_subj_log_lik[400]
pars_os
sm_exp_lambda

Expand Down
32 changes: 32 additions & 0 deletions tests/testthat/test-JointModel.R
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,35 @@ test_that("JointModel print method works as expected", {
print(x)
})
})



test_that("Log_Lik variables are produced correctly", {
x <- JointModel(
longitudinal = LongitudinalRandomSlope(),
survival = SurvivalWeibullPH()
)
stan_code <- as.character(x)
expect_true(grepl("target \\+= sum\\(long_obvs_log_lik\\)", stan_code))
expect_true(grepl("target \\+= sum\\(os_subj_log_lik\\)", stan_code))
expect_false(grepl("log_lik = long_obvs_log_lik", stan_code))
expect_false(grepl("log_lik = os_subj_log_lik", stan_code))

x <- JointModel(
longitudinal = LongitudinalRandomSlope()
)
stan_code <- as.character(x)
expect_true(grepl("target \\+= sum\\(long_obvs_log_lik\\)", stan_code))
expect_false(grepl("target \\+= sum\\(os_subj_log_lik\\)", stan_code))
expect_true(grepl("log_lik = long_obvs_log_lik", stan_code))
expect_false(grepl("log_lik = os_subj_log_lik", stan_code))

x <- JointModel(
survival = SurvivalWeibullPH()
)
stan_code <- as.character(x)
expect_false(grepl("target \\+=sum\\(long_obvs_log_lik\\)", stan_code))
expect_true(grepl("target \\+= sum\\(os_subj_log_lik\\)", stan_code))
expect_false(grepl("log_lik = long_obvs_log_lik", stan_code))
expect_true(grepl("log_lik = os_subj_log_lik", stan_code))
})
14 changes: 8 additions & 6 deletions tests/testthat/test-brierScore.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ test_that("brierScore(SurvivalQuantities) returns same results as survreg", {

dat_os <- test_data_1$dat_os
mp <- test_data_1$jsamples

set.seed(1231)
### Get our internal bayesian estimate
t_grid <- c(1, 25, 60, 425, 750)
t_grid <- c(1, 30, 45, 60, 425, 750)
sq <- SurvivalQuantities(
mp,
grid = GridFixed(times = t_grid),
Expand Down Expand Up @@ -52,12 +52,14 @@ test_that("brierScore(SurvivalQuantities) returns same results as survreg", {
pred_mat = pred_mat
)

expect_equal(
round(bs_survquant, 3),
round(bs_survreg, 3)
# Expect all values are approximately equal with a 2% relative tolerance
expect_true(
all(
((abs(bs_survquant - bs_survreg) / bs_survreg) * 100) < 2
)
)
})

22

test_that("brier score weight matrix is correctly calculated", {
# nolint start
Expand Down
2 changes: 1 addition & 1 deletion vignettes/custom-model.stan
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ transformed parameters{

// Calculate per observation log-likelihood for {loo} integration
// These values are automatically added to the target for you
Ypred_log_lik = vect_normal_log_dens(
long_obvs_log_lik = vect_normal_log_dens(
tumour_value,
Ypred,
rep_vector(sigma, n_tumour_all) // broadcast sigma to the length of Ypred
Expand Down
4 changes: 2 additions & 2 deletions vignettes/extending-jmpost.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,12 @@ interfaces if you want to enable them for your model.
### 1) `loo` integration

If you want to use the `loo` package to calculate the leave-one-out cross-validation
then you need to populate the `Ypred_log_lik` vector. This vector should contain the log-likelihood
then you need to populate the `long_obvs_log_lik` vector. This vector should contain the log-likelihood
contribution for each individual tumour observation. This vector
is automatically 0-initialised, thus all your code needs to do is populate it.
```stan
transformed parameters {
Ypred_log_lik = vect_normal_log_dens(
long_obvs_log_lik = vect_normal_log_dens(
tumour_value,
expected_tumour_value,
rep_vector(lm_rs_sigma, n_tumour_obs)
Expand Down
Loading