From dcce0f81e8305e9e5aa5b57a6fd321fec5d80d00 Mon Sep 17 00:00:00 2001 From: Craig Gower-Page Date: Thu, 22 Feb 2024 16:44:30 +0000 Subject: [PATCH] Enhanced Comparison Examples - Added example of log-logistic distribution - Added examples of how to calculate AIC / BIC / LOO - Added examples of how to calculate survival quantities + plots - Added brms examples --- NAMESPACE | 2 + R/StanModel.R | 1 + R/defaults.R | 1 + design/examples/loglogistic.R | 157 +++++++++++++++++ design/examples/loglogistic.stan | 51 ++++++ design/examples/weibull.R | 287 +++++++++++++++++++++++++------ design/examples/weibull.stan | 22 +-- design/examples/weibull_std.R | 141 --------------- vignettes/extending-jmpost.Rmd | 2 + 9 files changed, 458 insertions(+), 206 deletions(-) create mode 100644 design/examples/loglogistic.R create mode 100644 design/examples/loglogistic.stan delete mode 100644 design/examples/weibull_std.R diff --git a/NAMESPACE b/NAMESPACE index 0e82c450..3768a4a7 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -59,6 +59,8 @@ S3method(extractVariableNames,DataSurvival) S3method(generateQuantities,JointModelSamples) S3method(getParameters,Link) S3method(getParameters,LinkComponent) +S3method(getParameters,StanModel) +S3method(getParameters,default) S3method(initialValues,JointModel) S3method(initialValues,Link) S3method(initialValues,LinkComponent) diff --git a/R/StanModel.R b/R/StanModel.R index f05e07af..c1cb7f6e 100644 --- a/R/StanModel.R +++ b/R/StanModel.R @@ -68,6 +68,7 @@ as.list.StanModel <- function(x, ...) { # getParameters-StanModel ---- #' @rdname getParameters +#' @export getParameters.StanModel <- function(object) object@parameters diff --git a/R/defaults.R b/R/defaults.R index 955d179c..dda61bdf 100755 --- a/R/defaults.R +++ b/R/defaults.R @@ -10,6 +10,7 @@ NULL #' @rdname getParameters +#' @export getParameters.default <- function(object) { if (missing(object) || is.null(object)) { return(NULL) diff --git a/design/examples/loglogistic.R b/design/examples/loglogistic.R new file mode 100644 index 00000000..c8d0fbda --- /dev/null +++ b/design/examples/loglogistic.R @@ -0,0 +1,157 @@ + + + +library(dplyr) +library(flexsurv) +library(survival) +library(cmdstanr) +library(posterior) +library(bayesplot) +library(here) + + +dat <- flexsurv::bc |> + as_tibble() |> + mutate(arm = "A", study = "S", pt = sprintf("pt-%05d", 1:n())) + + +# +# +# Have specified no covariates in this models so that they are comparable +# JMpost uses a PH model where as the others use AFT thus adding covariates +# would result in different models. Without covariates they are just fitting +# the base distribution which should be identical. +# +# + + + + +################################ +# +# Flexsurv parametric Regression +# + + +mod_flex <- flexsurvreg( + Surv(recyrs, censrec) ~ 1, + data = dat, + dist = "llogis" +) +mod_flex + +logLik(mod_flex) +AIC(mod_flex) +BIC(mod_flex) + + +################################ +# +# Bayesian Weibull Regression +# + +mod <- cmdstan_model( + stan_file = here("design/examples/loglogistic.stan"), + exe_file = here("design/examples/models/loglogistic") +) + +design_mat <- model.matrix(~ 1, data = dat) + + +stan_data <- list( + n = nrow(dat), + design = design_mat, + p = ncol(design_mat), + times = dat$recyrs, + event_fl = dat$censrec +) +fit <- mod$sample( + data = stan_data, + chains = 2, + parallel_chains = 2, + refresh = 200, + iter_warmup = 1000, + iter_sampling = 1500 +) +vars <- c( + "beta_design", + "alpha_0", + "beta_0" +) +fit$summary(vars) + + +# Log Likelihood +log_lik <- fit$draws("log_lik", format = "draws_matrix") |> + apply(1, sum) |> + mean() +log_lik + +# AIC +k <- 2 +-2 * log_lik + k * (stan_data$p + 1) # +1 for the scale parameter + +# BIC +((stan_data$p + 1) * log(stan_data$n)) + (-2 * log_lik) + + + +################################ +# +# JMpost +# + +devtools::load_all() +# library(jmpost) + +jm <- JointModel( + survival = SurvivalLogLogistic() +) + +jdat <- DataJoint( + subject = DataSubject( + data = dat, + subject = "pt", + arm = "arm", + study = "study" + ), + survival = DataSurvival( + data = dat, + formula = Surv(recyrs, censrec) ~ 1 + ) +) + +mp <- sampleStanModel( + jm, + data = jdat, + iter_warmup = 1000, + iter_sampling = 1500, + chains = 2, + parallel_chains = 2 +) + +vars <- c( + "sm_logl_lambda", + "sm_logl_p" +) + +x <- mp@results$summary(vars) + +c( + "scale" = 1 / x$mean[1], + "shape" = x$mean[2] +) + + +# Log Likelihood +log_lik <- mp@results$draws("log_lik", format = "draws_matrix") |> + apply(1, sum) |> + mean() +log_lik + +# AIC +k <- 2 +-2 * log_lik + k * 2 + +# BIC +(2 * log(nrow(dat))) + (-2 * log_lik) diff --git a/design/examples/loglogistic.stan b/design/examples/loglogistic.stan new file mode 100644 index 00000000..80a61af5 --- /dev/null +++ b/design/examples/loglogistic.stan @@ -0,0 +1,51 @@ + + + +data { + int n; // Number of subjects + int p; // Number of covariates (including intercept) + vector[n] times; // Event|Censor times + array[n] int event_fl; // 1=event 0=censor + matrix[n, p] design; // Design matrix +} + +transformed data { + // Assuming that the first term is an intercept column which + // will conflict with the alpha_0 term so remove it + matrix[n, p-1] design_reduced; + if (p > 1 ) { + design_reduced = design[, 2:p]; + }} + +parameters { + vector[p-1] beta_design; + real alpha_0; + real beta_0; +} + +transformed parameters { + vector[n] alpha; + if (p == 1) { + alpha = rep_vector(alpha_0, n); + } else { + alpha = alpha_0 .* exp(design_reduced * beta_design); + } + + // Likelihood + vector[n] log_lik; + for (i in 1:n) { + if (event_fl[i] == 1) { + log_lik[i] = loglogistic_lpdf(times[i] | alpha[i], beta_0); + } else { + log_lik[i] = log(1 - loglogistic_cdf(times[i] | alpha[i], beta_0)); + } + } +} + +model { + // Priors + beta_design ~ normal(0, 3); + alpha_0 ~ lognormal(log(2), 1); + beta_0 ~ lognormal(log(2), 1); + target += sum(log_lik); +} diff --git a/design/examples/weibull.R b/design/examples/weibull.R index 372621d0..ac0aaa72 100644 --- a/design/examples/weibull.R +++ b/design/examples/weibull.R @@ -7,42 +7,60 @@ library(cmdstanr) library(posterior) library(bayesplot) library(here) +library(brms) +library(tidyr) +# library(jmpost) +devtools::load_all() -n <- 1000 +dat <- flexsurv::bc |> + as_tibble() |> + mutate(arm = "A", study = "S", pt = sprintf("pt-%05d", 1:n())) -log_hr_trt <- c( - "placebo" = 0, - "active" = -0.3 -) -log_hr_sex <- c( - "M" = 0, - "F" = 0.2 -) -log_hr_age <- 0.1 -lambda_bl <- 1 / 200 -gamma_bl <- 0.95 +# +# +# Commented out code below generates a simulated dataset with known +# Parameter values. +# +# + + +# n <- 1000 + +# log_hr_trt <- c( +# "placebo" = 0, +# "active" = -0.3 +# ) +# log_hr_sex <- c( +# "M" = 0, +# "F" = 0.2 +# ) +# log_hr_age <- 0.1 + +# lambda_bl <- 1 / 200 +# gamma_bl <- 0.95 + +# dat <- tibble( +# pt = sprintf("pt-%05d", 1:n), +# trt = sample(names(log_hr_trt), size = n, replace = TRUE, prob = c(0.5, 0.5)), +# age = rnorm(n), +# sex = sample(names(log_hr_sex), size = n, replace = TRUE, prob = c(0.4, 0.6)), +# HR = exp( +# log(lambda_bl) + +# log_hr_age * age + +# log_hr_sex[sex] + +# log_hr_trt[trt] +# ), +# time = rweibullPH(n, scale = HR, shape = gamma_bl), +# centime = rexp(n, 1 / 400) +# ) |> +# mutate(event = ifelse(time <= centime, 1, 0)) |> +# mutate(time = ifelse(time <= centime, time, centime)) |> +# mutate(sex = factor(sex, levels = names(log_hr_sex))) |> +# mutate(trt = factor(trt, levels = names(log_hr_trt))) |> +# mutate(study = "Study-1") -dat <- tibble( - pt = sprintf("pt-%05d", 1:n), - trt = sample(names(log_hr_trt), size = n, replace = TRUE, prob = c(0.5, 0.5)), - age = rnorm(n), - sex = sample(names(log_hr_sex), size = n, replace = TRUE, prob = c(0.4, 0.6)), - HR = exp( - log(lambda_bl) + - log_hr_age * age + - log_hr_sex[sex] + - log_hr_trt[trt] - ), - time = rweibullPH(n, scale = HR, shape = gamma_bl), - centime = rexp(n, 1 / 400) -) |> - mutate(event = ifelse(time <= centime, 1, 0)) |> - mutate(time = ifelse(time <= centime, time, centime)) |> - mutate(sex = factor(sex, levels = names(log_hr_sex))) |> - mutate(trt = factor(trt, levels = names(log_hr_trt))) |> - mutate(study = "Study-1") @@ -51,8 +69,8 @@ dat <- tibble( # Cox Regression # -coxph( - Surv(time, event) ~ age + sex + trt, +mod_cox <- coxph( + Surv(recyrs, censrec) ~ group, data = dat ) @@ -63,12 +81,14 @@ coxph( # -flexsurvreg( - Surv(time, event) ~ sex + trt, +mod_flex <- flexsurvreg( + Surv(recyrs, censrec) ~ group, data = dat, dist = "weibullPH" ) - +AIC(mod_flex) +BIC(mod_flex) +logLik(mod_flex) ################################ # @@ -76,19 +96,58 @@ flexsurvreg( # -mod <- survreg( - Surv(time, event) ~ age + sex + trt, +mod_surv <- survreg( + Surv(recyrs, censrec) ~ group, data = dat, dist = "weibull" ) -gamma <- 1 / mod$scale -lambda <- exp(-mod$coefficients[1] * gamma) -param_log_coefs <- -mod$coefficients[-1] * gamma +gamma <- 1 / mod_surv$scale +lambda <- exp(-mod_surv$coefficients[1] * gamma) +param_log_coefs <- -mod_surv$coefficients[-1] * gamma + c(gamma, lambda, param_log_coefs) ## Need to use delta method to get standard errors +AIC(mod_surv) +BIC(mod_surv) +logLik(mod_surv) + + + +################################ +# +# brms Weibull Regression +# + +mod_brms <- brm( + recyrs | cens(censored) ~ group, + dat |> mutate(censored = if_else(censrec == 1, "none", "right")), + family = weibull(), + cores = 4, + warmup = 1000, + iter = 2000, + seed = 7819 +) + +brms_pars <- as_draws_matrix(mod_brms)[, c("shape", "b_Intercept", "b_groupMedium", "b_groupPoor")] |> + apply(2, mean) + +brms_gamma <- brms_pars[["shape"]] +brms_lambda <- exp(-brms_pars[["b_Intercept"]] * brms_gamma) +brms_param_log_coefs <- -brms_pars[c("b_groupMedium", "b_groupPoor")] * brms_gamma + +c( + "gamma"= brms_gamma, + "lambda" = brms_lambda, + brms_param_log_coefs +) + +# Leave one out CV +loo(mod_brms) + + ################################ # @@ -100,15 +159,15 @@ mod <- cmdstan_model( exe_file = here("design/examples/models/weibull") ) -design_mat <- model.matrix(~ sex + trt, data = dat) +design_mat <- model.matrix(~ group, data = dat) stan_data <- list( n = nrow(dat), design = design_mat, p = ncol(design_mat), - times = dat$time, - event_fl = dat$event + times = dat$recyrs, + event_fl = dat$censrec ) fit <- mod$sample( data = stan_data, @@ -118,8 +177,86 @@ fit <- mod$sample( iter_warmup = 1000, iter_sampling = 1500 ) +vars_stan <- c( + "gamma_0", + "lambda_0", + "beta" +) +fit$summary(vars_stan) + + +# Log Likelihood +log_lik <- fit$draws("log_lik", format = "draws_matrix") |> + apply(1, sum) |> + mean() +log_lik + +# AIC +k <- 2 +-2 * log_lik + k * (stan_data$p + 1) # +1 for the scale parameter + +# BIC +((stan_data$p + 1) * log(stan_data$n)) + (-2 * log_lik) + + +# Leave one out CV +fit$loo() + + +#### Extract Desired Quantities + +# Lambda here represents the fitted lambda for each individual subject separately +lambda <- fit$draws("lambda", format = "draws_df") |> + as_tibble() |> + gather(KEY, lambda, -.draw, -.iteration, -.chain) |> + mutate(pt_index = str_extract(KEY, "\\d+")) + +# gamma_0 is common to all subjects +gamma <- fit$draws("gamma_0", format = "draws_df") |> + as_tibble() + +# Combine gamma_0 and lambda so we have 1 row per subject per sample +# Also re-attach the PT variable to have a clear subject label +pt_map <- levels(factor(dat$pt)) +samples <- left_join(lambda, gamma, by = c(".chain", ".draw", ".iteration")) |> + mutate(pt = pt_map[as.numeric(pt_index)]) + + +# Reduce the dataset down to just 2 subjects that we will create predictions for +samples_reduced <- samples |> + filter(pt %in% c("pt-00681", "pt-00002")) + +# Time points to evaluate their predictions at +target_times <- seq(min(dat$recyrs), max(dat$recyrs), length.out = 20) + +# Duplicate dataset once per desired timepoint +samples_all_times <- bind_rows( + lapply(target_times, \(t) samples_reduced |> mutate(time = t)) +) + +# Calculate the survival distribution for each subject at each desired timepoint +# To get different quantities change the `pweibullPH` to the desired distribution +# function e.g. hweibullPH / HweibullPH +survival_times <- samples_all_times |> + mutate(surv = flexsurv::pweibullPH(time, gamma_0, lambda, lower.tail = FALSE)) |> + group_by(pt, time) |> + summarise( + lci = quantile(surv, 0.025), + med = quantile(surv, 0.5), + uci = quantile(surv, 0.975), + .groups = "drop" + ) + +ggplot( + data = survival_times, + aes(x = time, y = med, ymin = lci, ymax = uci, group = pt, col = pt, fill = pt) +) + + geom_line() + + geom_ribbon(alpha = 0.4, col = NA) + + theme_bw() + + -fit$summary() ################################ @@ -127,13 +264,11 @@ fit$summary() # JMpost # -devtools::load_all() -# library(jmpost) jm <- JointModel( survival = SurvivalWeibullPH( - lambda = prior_lognormal(log(1 / 200), 0.5), - gamma = prior_lognormal(log(0.95), 0.5) + lambda = prior_lognormal(log(1/200), 1.3), + gamma = prior_lognormal(log(1), 1.3) ) ) @@ -141,12 +276,12 @@ jdat <- DataJoint( subject = DataSubject( data = dat, subject = "pt", - arm = "trt", + arm = "arm", study = "study" ), survival = DataSurvival( data = dat, - formula = Surv(time, event) ~ age + sex + trt + formula = Surv(recyrs, censrec) ~ group ) ) @@ -155,14 +290,56 @@ mp <- sampleStanModel( data = jdat, iter_warmup = 1000, iter_sampling = 1500, - chains = 2 + chains = 2, + parallel_chains = 2 ) vars <- c( - "sm_weibull_ph_lambda", # 0.005 - "sm_weibull_ph_gamma", # 0.95 - "beta_os_cov" # 0.1, 0.3, -0.2 + "sm_weibull_ph_lambda", + "sm_weibull_ph_gamma", + "beta_os_cov" ) mp@results$summary(vars) +# Log Likelihood +log_lik <- mp@results$draws("log_lik", format = "draws_matrix") |> + apply(1, sum) |> + mean() +log_lik + +# AIC +k <- 2 +-2 * log_lik + k * 4 + +# BIC +(4 * log(nrow(dat))) + (-2 * log_lik) + +# Leave one out CV +mp@results$loo() + + +#### Extract Desired Quantities + +prediction_times <- seq(min(dat$recyrs), max(dat$recyrs), length.out = 20) +selected_patients <- c("pt-00681", "pt-00002") + +# Survival plots +sq_surv <- SurvivalQuantities( + mp, + time_grid = prediction_times, + groups = selected_patients, + type = "surv" +) +autoplot(sq_surv, add_km = FALSE, add_wrap = FALSE) +summary(sq_surv) +# as.data.frame(sq_surv) # Raw sample data + +# Hazard +sq_haz <- SurvivalQuantities( + mp, + time_grid = prediction_times, + groups = selected_patients, + type = "haz" +) +autoplot(sq_haz, add_km = FALSE, add_wrap = FALSE) diff --git a/design/examples/weibull.stan b/design/examples/weibull.stan index d3b29994..366a6e38 100644 --- a/design/examples/weibull.stan +++ b/design/examples/weibull.stan @@ -32,21 +32,23 @@ parameters { real gamma_0; } -model { - +transformed parameters { + vector[n] log_lik; vector[n] lambda = lambda_0 .* exp(design_reduced * beta); - - // Priors - beta ~ normal(0, 3); - gamma_0 ~ lognormal(log(1), 1.5); - lambda_0 ~ lognormal(log(0.05), 1.5); - // Likelihood for (i in 1:n) { if (event_fl[i] == 1) { - target += weibull_ph_lpdf(times[i] | lambda[i], gamma_0); + log_lik[i] = weibull_ph_lpdf(times[i] | lambda[i], gamma_0); } else { - target += weibull_ph_lccdf(times[i] | lambda[i], gamma_0); + log_lik[i] = weibull_ph_lccdf(times[i] | lambda[i], gamma_0); } } } + +model { + // Priors + beta ~ normal(0, 3); + gamma_0 ~ lognormal(log(1), 1.5); + lambda_0 ~ lognormal(log(0.05), 1.5); + target += sum(log_lik); +} diff --git a/design/examples/weibull_std.R b/design/examples/weibull_std.R deleted file mode 100644 index 84712758..00000000 --- a/design/examples/weibull_std.R +++ /dev/null @@ -1,141 +0,0 @@ - - -library(dplyr) -library(flexsurv) -library(survival) -library(cmdstanr) -library(posterior) -library(bayesplot) -library(here) - - -dat <- flexsurv::bc |> - as_tibble() |> - mutate(arm = "A", study = "S", pt = sprintf("pt-%05d", 1:n())) - - - -################################ -# -# Cox Regression -# - -coxph( - Surv(recyrs, censrec) ~ group, - data = dat -) - - -################################ -# -# Flexsurv parametric Regression -# - - -flexsurvreg( - Surv(recyrs, censrec) ~ group, - data = dat, - dist = "weibullPH" -) - - -################################ -# -# Survreg parametric Regression -# - - -mod <- survreg( - Surv(recyrs, censrec) ~ group, - data = dat, - dist = "weibull" -) - -gamma <- 1 / mod$scale -lambda <- exp(-mod$coefficients[1] * gamma) -param_log_coefs <- -mod$coefficients[-1] * gamma - -c(gamma, lambda, param_log_coefs) -## Need to use delta method to get standard errors - - - -################################ -# -# Bayesian Weibull Regression -# - -mod <- cmdstan_model( - stan_file = here("design/examples/weibull.stan"), - exe_file = here("design/examples/models/weibull") -) - -design_mat <- model.matrix(~ group, data = dat) - - -stan_data <- list( - n = nrow(dat), - design = design_mat, - p = ncol(design_mat), - times = dat$recyrs, - event_fl = dat$censrec -) -fit <- mod$sample( - data = stan_data, - chains = 2, - parallel_chains = 2, - refresh = 200, - iter_warmup = 1000, - iter_sampling = 1500 -) - -fit$summary() - - -bayesplot::mcmc_pairs(fit$draws()) - -################################ -# -# JMpost -# - -devtools::load_all() -# library(jmpost) - -jm <- JointModel( - survival = SurvivalWeibullPH( - lambda = prior_lognormal(log(1/200), 1.3), - gamma = prior_lognormal(log(1), 1.3) - ) -) - -jdat <- DataJoint( - subject = DataSubject( - data = dat, - subject = "pt", - arm = "arm", - study = "study" - ), - survival = DataSurvival( - data = dat, - formula = Surv(recyrs, censrec) ~ group - ) -) - -mp <- sampleStanModel( - jm, - data = jdat, - iter_warmup = 1000, - iter_sampling = 1500, - chains = 2, - parallel_chains = 2 -) - -vars <- c( - "sm_weibull_ph_lambda", - "sm_weibull_ph_gamma", - "beta_os_cov" -) - -mp@results$summary(vars) - diff --git a/vignettes/extending-jmpost.Rmd b/vignettes/extending-jmpost.Rmd index aff43029..c4153684 100644 --- a/vignettes/extending-jmpost.Rmd +++ b/vignettes/extending-jmpost.Rmd @@ -31,6 +31,7 @@ contain complete information for this package yet. ## Custom Link Functions + Users can define custom link functions in several ways based upon the level of customisation required. In order to explain this process it is first important to understand how the link functions are implemented under the hood. @@ -147,6 +148,7 @@ Link( Note that there are a few families of link functions that are common across all models for example identity, dSLD, TTG. Users can access these by simply using the inbuilt `link_identity()`, `link_dsld()` and `link_ttg()` functions respectively. + These functions are responsible for loading the correct `LinkComponent` to implement that link for a particular model. That is if the user wants to specify both the