diff --git a/NAMESPACE b/NAMESPACE index a6270a432..0f2bebeeb 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -86,9 +86,15 @@ S3method(coalesceGridTime,GridPrediction) S3method(coalesceGridTime,default) S3method(compileStanModel,JointModel) S3method(dim,Quantities) +S3method(enableGQ,JointModel) +S3method(enableGQ,LongitudinalGSF) +S3method(enableGQ,LongitudinalRandomSlope) +S3method(enableGQ,LongitudinalSteinFojo) +S3method(enableGQ,default) S3method(enableLink,LongitudinalGSF) S3method(enableLink,LongitudinalRandomSlope) S3method(enableLink,LongitudinalSteinFojo) +S3method(enableLink,default) S3method(extractVariableNames,DataSubject) S3method(extractVariableNames,DataSurvival) S3method(generateQuantities,JointModelSamples) @@ -202,6 +208,7 @@ export(as_stan_list) export(autoplot) export(brierScore) export(compileStanModel) +export(enableGQ) export(enableLink) export(generateQuantities) export(getParameters) diff --git a/R/JointModel.R b/R/JointModel.R index a25e4a1ad..183f130b3 100755 --- a/R/JointModel.R +++ b/R/JointModel.R @@ -80,16 +80,24 @@ JointModel <- function( } +#' @export +enableGQ.JointModel <- function(object, ...) { + merge( + enableGQ(object@survival), + enableGQ(object@longitudinal) + ) +} + + #' `JointModel` -> `StanModule` #' #' Converts a [`JointModel`] object to a [`StanModule`] object #' #' @inheritParams JointModel-Shared -#' @param include_gq (`logical`)\cr whether to include the generated quantities block. #' @family JointModel #' @family as.StanModule #' @export -as.StanModule.JointModel <- function(object, include_gq = FALSE, ...) { +as.StanModule.JointModel <- function(object, ...) { base_model <- read_stan("base/base.stan") @@ -113,9 +121,6 @@ as.StanModule.JointModel <- function(object, include_gq = FALSE, ...) { StanModule(stan_full) ) - if (!include_gq) { - x@generated_quantities <- "" - } return(x) } diff --git a/R/JointModelSamples.R b/R/JointModelSamples.R index 9b6ed1dc3..76a1f5b49 100644 --- a/R/JointModelSamples.R +++ b/R/JointModelSamples.R @@ -52,9 +52,13 @@ generateQuantities.JointModelSamples <- function(object, generator, type, ...) { ) |> StanModule() - stanobj <- merge( - as.StanModule(object@model, include_gq = TRUE), - quant_stanobj + stanobj <- Reduce( + merge, + list( + as.StanModule(object@model), + enableGQ(object@model), + quant_stanobj + ) ) model <- compileStanModel(stanobj) diff --git a/R/LongitudinalGSF.R b/R/LongitudinalGSF.R index f6a3aaab1..26611816e 100755 --- a/R/LongitudinalGSF.R +++ b/R/LongitudinalGSF.R @@ -126,6 +126,13 @@ LongitudinalGSF <- function( } + +#' @export +enableGQ.LongitudinalGSF <- function(object, ...) { + StanModule("lm-gsf/quantities.stan") +} + + #' @export enableLink.LongitudinalGSF <- function(object, ...) { object@stan <- merge( diff --git a/R/LongitudinalRandomSlope.R b/R/LongitudinalRandomSlope.R index e837ff3b4..57b70a45e 100755 --- a/R/LongitudinalRandomSlope.R +++ b/R/LongitudinalRandomSlope.R @@ -60,6 +60,11 @@ LongitudinalRandomSlope <- function( } +#' @export +enableGQ.LongitudinalRandomSlope <- function(object, ...) { + StanModule("lm-random-slope/quantities.stan") +} + #' @export enableLink.LongitudinalRandomSlope <- function(object, ...) { object@stan <- merge( diff --git a/R/LongitudinalSteinFojo.R b/R/LongitudinalSteinFojo.R index 6b6bbec5b..da5881407 100755 --- a/R/LongitudinalSteinFojo.R +++ b/R/LongitudinalSteinFojo.R @@ -111,6 +111,11 @@ LongitudinalSteinFojo <- function( +#' @export +enableGQ.LongitudinalSteinFojo <- function(object, ...) { + StanModule("lm-stein-fojo/quantities.stan") +} + #' @export enableLink.LongitudinalSteinFojo <- function(object, ...) { object@stan <- merge( diff --git a/R/brier_score.R b/R/brier_score.R index 3e1e13b22..c3b319e34 100644 --- a/R/brier_score.R +++ b/R/brier_score.R @@ -105,7 +105,7 @@ reverse_km_cen_first <- function(t, times, events) { survival::Surv(times, cen_events) ~ 1, data = dat ) - preds <- summary(mod, times = t, extend = TRUE)$surv + preds <- summary(mod, times = t[order(t)], extend = TRUE)$surv assert_that( length(preds) == length(t) diff --git a/R/generics.R b/R/generics.R index 367cd8a91..ec0e45ee7 100755 --- a/R/generics.R +++ b/R/generics.R @@ -410,11 +410,32 @@ resolvePromise.default <- function(object, ...) { enableLink <- function(object, ...) { UseMethod("enableLink") } +#' @export enableLink.default <- function(object, ...) { object } +#' Enable Generated Quantities Generic +#' +#' @param object ([`StanModel`])\cr to enable generated quantities for. +#' @param ... Not used. +#' +#' Optional hook method that is called on a [`StanModel`] if attempting to use +#' either [`LongitudinalQuantities`] or [`SurvivalQuantities`] +#' +#' @return [`StanModule`] object +#' +#' @export +enableGQ <- function(object, ...) { + UseMethod("enableGQ") +} +#' @export +enableGQ.default <- function(object, ...) { + StanModule() +} + + #' Get Prediction Names #' diff --git a/_pkgdown.yml b/_pkgdown.yml index 7b63965ea..868e3a636 100644 --- a/_pkgdown.yml +++ b/_pkgdown.yml @@ -168,6 +168,7 @@ reference: - sampleObservations - sampleSubjects - enableLink + - enableGQ - as_formula - getPredictionNames diff --git a/inst/stan/lm-gsf/functions.stan b/inst/stan/lm-gsf/functions.stan index 0b71cbec4..925abc45f 100644 --- a/inst/stan/lm-gsf/functions.stan +++ b/inst/stan/lm-gsf/functions.stan @@ -21,15 +21,5 @@ functions { ); return result; } - - vector lm_predict_value(vector time, matrix long_gq_parameters) { - return sld( - time, - long_gq_parameters[,1], - long_gq_parameters[,2], - long_gq_parameters[,3], - long_gq_parameters[,4] - ); - } } diff --git a/inst/stan/lm-gsf/model.stan b/inst/stan/lm-gsf/model.stan index 43f78bca1..560023d6b 100755 --- a/inst/stan/lm-gsf/model.stan +++ b/inst/stan/lm-gsf/model.stan @@ -94,21 +94,3 @@ model { lm_gsf_psi_phi ~ beta(lm_gsf_a_phi[subject_arm_index], lm_gsf_b_phi[subject_arm_index]); } - -generated quantities { - // - // Source - lm-gsf/model.stan - // - matrix[n_subjects, 4] long_gq_parameters; - long_gq_parameters[, 1] = lm_gsf_psi_bsld; - long_gq_parameters[, 2] = lm_gsf_psi_ks; - long_gq_parameters[, 3] = lm_gsf_psi_kg; - long_gq_parameters[, 4] = lm_gsf_psi_phi; - - - matrix[gq_n_quant, 4] long_gq_pop_parameters; - long_gq_pop_parameters[, 1] = exp(lm_gsf_mu_bsld[gq_long_pop_study_index]); - long_gq_pop_parameters[, 2] = exp(lm_gsf_mu_ks[gq_long_pop_arm_index]); - long_gq_pop_parameters[, 3] = exp(lm_gsf_mu_kg[gq_long_pop_arm_index]); - long_gq_pop_parameters[, 4] = lm_gsf_a_phi[gq_long_pop_arm_index] ./ (lm_gsf_a_phi[gq_long_pop_arm_index] + lm_gsf_b_phi[gq_long_pop_arm_index]); -} diff --git a/inst/stan/lm-gsf/quantities.stan b/inst/stan/lm-gsf/quantities.stan new file mode 100644 index 000000000..285d89511 --- /dev/null +++ b/inst/stan/lm-gsf/quantities.stan @@ -0,0 +1,29 @@ +functions { + vector lm_predict_value(vector time, matrix long_gq_parameters) { + return sld( + time, + long_gq_parameters[,1], + long_gq_parameters[,2], + long_gq_parameters[,3], + long_gq_parameters[,4] + ); + } +} + +generated quantities { + // + // Source - lm-gsf/quantities.stan + // + matrix[n_subjects, 4] long_gq_parameters; + long_gq_parameters[, 1] = lm_gsf_psi_bsld; + long_gq_parameters[, 2] = lm_gsf_psi_ks; + long_gq_parameters[, 3] = lm_gsf_psi_kg; + long_gq_parameters[, 4] = lm_gsf_psi_phi; + + + matrix[gq_n_quant, 4] long_gq_pop_parameters; + long_gq_pop_parameters[, 1] = exp(lm_gsf_mu_bsld[gq_long_pop_study_index]); + long_gq_pop_parameters[, 2] = exp(lm_gsf_mu_ks[gq_long_pop_arm_index]); + long_gq_pop_parameters[, 3] = exp(lm_gsf_mu_kg[gq_long_pop_arm_index]); + long_gq_pop_parameters[, 4] = lm_gsf_a_phi[gq_long_pop_arm_index] ./ (lm_gsf_a_phi[gq_long_pop_arm_index] + lm_gsf_b_phi[gq_long_pop_arm_index]); +} diff --git a/inst/stan/lm-random-slope/model.stan b/inst/stan/lm-random-slope/model.stan index ddbfe3595..65647c223 100755 --- a/inst/stan/lm-random-slope/model.stan +++ b/inst/stan/lm-random-slope/model.stan @@ -1,15 +1,4 @@ -functions { - // - // Source - lm-random-slope/model.stan - // - vector lm_predict_value(vector time, matrix long_gq_parameters) { - int nrow = rows(time); - return ( - long_gq_parameters[, 1] + long_gq_parameters[, 2] .* time - ); - } -} parameters { @@ -61,16 +50,3 @@ model { } -generated quantities { - // - // Source - lm-random-slope/model.stan - // - matrix[n_subjects, 2] long_gq_parameters; - long_gq_parameters[, 1] = lm_rs_ind_intercept; - long_gq_parameters[, 2] = lm_rs_ind_rnd_slope; - - - matrix[gq_n_quant, 2] long_gq_pop_parameters; - long_gq_pop_parameters[, 1] = to_vector(lm_rs_intercept[gq_long_pop_study_index]); - long_gq_pop_parameters[, 2] = to_vector(lm_rs_slope_mu[gq_long_pop_arm_index]); -} diff --git a/inst/stan/lm-random-slope/quantities.stan b/inst/stan/lm-random-slope/quantities.stan new file mode 100644 index 000000000..7ba6150b3 --- /dev/null +++ b/inst/stan/lm-random-slope/quantities.stan @@ -0,0 +1,26 @@ + +functions { + // + // Source - lm-random-slope/quantities.stan + // + vector lm_predict_value(vector time, matrix long_gq_parameters) { + int nrow = rows(time); + return ( + long_gq_parameters[, 1] + long_gq_parameters[, 2] .* time + ); + } +} + +generated quantities { + // + // Source - lm-random-slope/quantities.stan + // + matrix[n_subjects, 2] long_gq_parameters; + long_gq_parameters[, 1] = lm_rs_ind_intercept; + long_gq_parameters[, 2] = lm_rs_ind_rnd_slope; + + + matrix[gq_n_quant, 2] long_gq_pop_parameters; + long_gq_pop_parameters[, 1] = to_vector(lm_rs_intercept[gq_long_pop_study_index]); + long_gq_pop_parameters[, 2] = to_vector(lm_rs_slope_mu[gq_long_pop_arm_index]); +} diff --git a/inst/stan/lm-stein-fojo/functions.stan b/inst/stan/lm-stein-fojo/functions.stan index 1c69244e1..f02d2271b 100644 --- a/inst/stan/lm-stein-fojo/functions.stan +++ b/inst/stan/lm-stein-fojo/functions.stan @@ -22,13 +22,5 @@ functions { ); return result; } - vector lm_predict_value(vector time, matrix long_gq_parameters) { - return sld( - time, - long_gq_parameters[,1], - long_gq_parameters[,2], - long_gq_parameters[,3] - ); - } } diff --git a/inst/stan/lm-stein-fojo/model.stan b/inst/stan/lm-stein-fojo/model.stan index 4ce7d4f26..b368818f8 100755 --- a/inst/stan/lm-stein-fojo/model.stan +++ b/inst/stan/lm-stein-fojo/model.stan @@ -86,17 +86,3 @@ model { {%- endif -%} } -generated quantities { - // - // Source - lm-stein-fojo/model.stan - // - matrix[n_subjects, 3] long_gq_parameters; - long_gq_parameters[, 1] = lm_gsf_psi_bsld; - long_gq_parameters[, 2] = lm_gsf_psi_ks; - long_gq_parameters[, 3] = lm_gsf_psi_kg; - - matrix[gq_n_quant, 3] long_gq_pop_parameters; - long_gq_pop_parameters[, 1] = exp(lm_sf_mu_bsld[gq_long_pop_study_index]); - long_gq_pop_parameters[, 2] = exp(lm_sf_mu_ks[gq_long_pop_arm_index]); - long_gq_pop_parameters[, 3] = exp(lm_sf_mu_kg[gq_long_pop_arm_index]); -} diff --git a/inst/stan/lm-stein-fojo/quantities.stan b/inst/stan/lm-stein-fojo/quantities.stan new file mode 100644 index 000000000..92c07fdb1 --- /dev/null +++ b/inst/stan/lm-stein-fojo/quantities.stan @@ -0,0 +1,26 @@ + +functions { + vector lm_predict_value(vector time, matrix long_gq_parameters) { + return sld( + time, + long_gq_parameters[,1], + long_gq_parameters[,2], + long_gq_parameters[,3] + ); + } +} + +generated quantities { + // + // Source - lm-stein-fojo/quantities.stan + // + matrix[n_subjects, 3] long_gq_parameters; + long_gq_parameters[, 1] = lm_gsf_psi_bsld; + long_gq_parameters[, 2] = lm_gsf_psi_ks; + long_gq_parameters[, 3] = lm_gsf_psi_kg; + + matrix[gq_n_quant, 3] long_gq_pop_parameters; + long_gq_pop_parameters[, 1] = exp(lm_sf_mu_bsld[gq_long_pop_study_index]); + long_gq_pop_parameters[, 2] = exp(lm_sf_mu_ks[gq_long_pop_arm_index]); + long_gq_pop_parameters[, 3] = exp(lm_sf_mu_kg[gq_long_pop_arm_index]); +} diff --git a/man/as.StanModule.JointModel.Rd b/man/as.StanModule.JointModel.Rd index 168b278f7..a82198d19 100644 --- a/man/as.StanModule.JointModel.Rd +++ b/man/as.StanModule.JointModel.Rd @@ -4,13 +4,11 @@ \alias{as.StanModule.JointModel} \title{\code{JointModel} -> \code{StanModule}} \usage{ -\method{as.StanModule}{JointModel}(object, include_gq = FALSE, ...) +\method{as.StanModule}{JointModel}(object, ...) } \arguments{ \item{object}{(\code{\link{JointModel}}) \cr Joint model specification.} -\item{include_gq}{(\code{logical})\cr whether to include the generated quantities block.} - \item{...}{Not Used.} } \description{ diff --git a/man/enableGQ.Rd b/man/enableGQ.Rd new file mode 100644 index 000000000..0c808bb03 --- /dev/null +++ b/man/enableGQ.Rd @@ -0,0 +1,22 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/generics.R +\name{enableGQ} +\alias{enableGQ} +\title{Enable Generated Quantities Generic} +\usage{ +enableGQ(object, ...) +} +\arguments{ +\item{object}{(\code{\link{StanModel}})\cr to enable generated quantities for.} + +\item{...}{Not used. + +Optional hook method that is called on a \code{\link{StanModel}} if attempting to use +either \code{\link{LongitudinalQuantities}} or \code{\link{SurvivalQuantities}}} +} +\value{ +\code{\link{StanModule}} object +} +\description{ +Enable Generated Quantities Generic +} diff --git a/vignettes/custom-model-gq.stan b/vignettes/custom-model-gq.stan new file mode 100644 index 000000000..2fe67f241 --- /dev/null +++ b/vignettes/custom-model-gq.stan @@ -0,0 +1,29 @@ + +functions { + // Define required function for enabling generated quantities + vector lm_predict_value(vector time, matrix long_gq_parameters) { + return sld( + time, + long_gq_parameters[,1], // baseline + long_gq_parameters[,2], // shrinkage + long_gq_parameters[,3] // growth + ); + } +} + + +generated quantities { + // Enable individual subject predictions / quantities e.g. + // `GridFixed()` / `GridObservation()` / `GridGrouped()` / `GridEven + matrix[n_subjects, 3] long_gq_parameters; + long_gq_parameters[, 1] = baseline_idv; + long_gq_parameters[, 2] = shrinkage_idv; + long_gq_parameters[, 3] = growth_idv; + + // Enable Population level predictions / quantities by taking the median of the + // hierarchical distribution e.g. `GridPopulation()` + matrix[gq_n_quant, 3] long_gq_pop_parameters; + long_gq_pop_parameters[, 1] = rep_vector(mu_baseline, gq_n_quant); + long_gq_pop_parameters[, 2] = rep_vector(mu_shrinkage, gq_n_quant); + long_gq_pop_parameters[, 3] = rep_vector(mu_growth, gq_n_quant); +} diff --git a/vignettes/custom-model.Rmd b/vignettes/custom-model.Rmd index 75ab46440..77c3c4dc1 100644 --- a/vignettes/custom-model.Rmd +++ b/vignettes/custom-model.Rmd @@ -126,7 +126,7 @@ sampleSubjects.SimWang <- function(object, subjects_df) { subjects_df } -# Method to generate observations for each indiviudal subject +# Method to generate observations for each individual subject sampleObservations.SimWang <- function(object, times_df) { nobs <- nrow(times_df) calc_mu <- function(time, b, s, g) b * exp(-s * time) + g * time @@ -327,12 +327,26 @@ as.CmdStanMCMC(model_samples)$summary(vars) ## Generating Quantities of Interest -If you look at the Stan code for the above model you will notice that several of the -requirement elements for the generated quantities have been defined thus enabling us to -generate quantities both at a population and individual subject level. +In order to enable the generation of both population and individual level quantities of interest +we need to implement the required +generated quantity objects and functions as outlined in the "Extending jmpost" vignette. +This can be done as follows: -This can be done at the subject level via: +```{R} +enableGQ.WangModel <- function(object, ...) { + StanModule("custom-model-gq.stan") +} +``` + +Where the Stan code for the `custom-model-gq.stan` file is as follows: +```{R, results='asis', echo=FALSE} +x <- readLines("./custom-model-gq.stan") +cat(c("\```stan", x, "\```"), sep = "\n") +``` + +With the above in place we are now able to generate quantities as needed; +this can be done at the subject level via: ```{R} selected_subjects <- head(dat_os$subject, 4) diff --git a/vignettes/custom-model.stan b/vignettes/custom-model.stan index c21621bde..2935da7f5 100644 --- a/vignettes/custom-model.stan +++ b/vignettes/custom-model.stan @@ -7,16 +7,6 @@ functions { growth .* tumour_time; return tumour_value; } - - // Define required function for enabling generated quantities - vector lm_predict_value(vector time, matrix long_gq_parameters) { - return sld( - time, - long_gq_parameters[,1], // baseline - long_gq_parameters[,2], // shrinkage - long_gq_parameters[,3] // growth - ); - } } parameters{ @@ -64,19 +54,3 @@ model { growth_idv ~ lognormal(log(mu_growth), sigma_growth); } - -generated quantities { - // Enable individual subject predictions / quantities e.g. - // `GridFixed()` / `GridObservation()` / `GridGrouped()` / `GridEven - matrix[n_subjects, 3] long_gq_parameters; - long_gq_parameters[, 1] = baseline_idv; - long_gq_parameters[, 2] = shrinkage_idv; - long_gq_parameters[, 3] = growth_idv; - - // Enable Population level predictions / quantities by taking the median of the - // hierarchical distribution e.g. `GridPopulation()` - matrix[gq_n_quant, 3] long_gq_pop_parameters; - long_gq_pop_parameters[, 1] = rep_vector(mu_baseline, gq_n_quant); - long_gq_pop_parameters[, 2] = rep_vector(mu_shrinkage, gq_n_quant); - long_gq_pop_parameters[, 3] = rep_vector(mu_growth, gq_n_quant); -} diff --git a/vignettes/extending-jmpost.Rmd b/vignettes/extending-jmpost.Rmd index ce836002c..92a3aa79d 100644 --- a/vignettes/extending-jmpost.Rmd +++ b/vignettes/extending-jmpost.Rmd @@ -180,6 +180,9 @@ Note that the `long_gq_parameters` matrix should be structured as your `lm_predict_value()` function would expect it to be for the `long_gq_parameters` argument. +Please see "Custom Generated Quantities" section below for implementation details for inserting +custom generated quantity code. + ### 3) Population Generated Quantity Integration A common use case is to calculate the quantities based on the "population" level parameters which is @@ -217,6 +220,9 @@ Note that the `long_gq_pop_parameters` matrix should be structured as your `lm_predict_value()` function would expect it to be for the `long_gq_parameters` argument. +Please see "Custom Generated Quantities" section below for implementation details for inserting +custom generated quantity code. + ## Prior Specification When writing your own custom longitudinal or survival model it is important to understand @@ -241,6 +247,54 @@ as these are handled by the `parameters` slot of the model object as mentioned a The main reason for using this approach is that `jmpost` implements the priors in such a way that users can change them without having to re-compile the Stan model. + +## Custom Generated Quantities + +In order to avoid unnecessary processing, code that is solely used for generation of post sampling +quantities is +excluded from the Stan program when initially sampling from the joint model. Instead this code +is only included when generating quantities via the `LongitidunalQuantities` or `SurvivalQuantities` +constructors. + +To facilitate this, when using `LongitidunalQuantities` or `SurvivalQuantities`, +a dedicated model method `enableGQ()` is called on the user provided longitudinal and +survival models. +This model specific +method is responsible for returning a `StanModule` object that contains all the relevant +code required to generate the quantities for that given model. The following is a rough +implementation of this method for the Random-Slope model which implements both feature (2) and (3) +outlined in the above "Custom Longitudinal Model" section: + +```R +enableGQ.LongitudinalRandomSlope <- function() { + StanModule(" + functions { + vector lm_predict_value(vector time, matrix long_gq_parameters) { + int nrow = rows(time); + return ( + long_gq_parameters[, 1] + long_gq_parameters[, 2] .* time + ); + } + } + + generated quantities { + matrix[n_subjects, 2] long_gq_parameters; + long_gq_parameters[, 1] = lm_rs_ind_intercept; + long_gq_parameters[, 2] = lm_rs_ind_rnd_slope; + + matrix[gq_n_quant, 2] long_gq_pop_parameters; + long_gq_pop_parameters[, 1] = to_vector(lm_rs_intercept[gq_long_pop_study_index]); + long_gq_pop_parameters[, 2] = to_vector(lm_rs_slope_mu[gq_long_pop_arm_index]); + } + ") +} +``` + +Note that whilst it is possible to provide an `enableGQ()` method for the survival model it is not +required. This is because the underlying framework for creating survival quantities +is distribution agnostic and does not require any model specific code. + + ## Custom Link Functions