From a28d91bd4a0f719268750dcd91cc041973dc0bca Mon Sep 17 00:00:00 2001 From: Craig Gower-Page Date: Wed, 26 Jun 2024 12:29:29 +0100 Subject: [PATCH] Fix SF Quantities (#370) --- NAMESPACE | 1 + R/JointModel.R | 9 +++-- R/JointModelSamples.R | 34 ++++++++++++++----- R/SimLongitudinalClaretBruno.R | 2 +- R/generics.R | 5 +-- R/utilities.R | 5 +++ _pkgdown.yml | 1 + inst/stan/lm-stein-fojo/quantities.stan | 6 ++-- man/as.StanModule.JointModelSamples.Rd | 24 +++++++++++++ man/clbr_sld.Rd | 4 +-- man/write_stan.Rd | 8 +++-- tests/testthat/test-LongitudinalClaretBruno.R | 22 ++++++++++++ tests/testthat/test-LongitudinalGSF.R | 23 +++++++++++++ tests/testthat/test-LongitudinalRandomSlope.R | 22 ++++++++++++ tests/testthat/test-LongitudinalSteinFojo.R | 25 +++++++++++++- vignettes/model_fitting.Rmd | 2 +- 16 files changed, 168 insertions(+), 25 deletions(-) create mode 100644 man/as.StanModule.JointModelSamples.Rd diff --git a/NAMESPACE b/NAMESPACE index d9247e2b0..62ec805dd 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -18,6 +18,7 @@ S3method(as.QuantityGenerator,GridObserved) S3method(as.QuantityGenerator,GridPopulation) S3method(as.QuantityGenerator,GridPrediction) S3method(as.StanModule,JointModel) +S3method(as.StanModule,JointModelSamples) S3method(as.StanModule,Link) S3method(as.StanModule,LinkComponent) S3method(as.StanModule,Parameter) diff --git a/R/JointModel.R b/R/JointModel.R index 183f130b3..19aa11155 100755 --- a/R/JointModel.R +++ b/R/JointModel.R @@ -142,9 +142,12 @@ as.character.JointModel <- function(x, ...) { #' @rdname write_stan #' @export -write_stan.JointModel <- function(object, file_path) { - fi <- file(file_path, open = "w") - writeLines(as.character(object), fi) +write_stan.JointModel <- function(object, destination, ...) { + if (is_connection(destination)) { + return(writeLines(as.character(object), con = destination)) + } + fi <- file(destination, open = "w") + writeLines(as.character(object), con = fi) close(fi) } diff --git a/R/JointModelSamples.R b/R/JointModelSamples.R index 76a1f5b49..c0df83118 100644 --- a/R/JointModelSamples.R +++ b/R/JointModelSamples.R @@ -38,7 +38,30 @@ generateQuantities.JointModelSamples <- function(object, generator, type, ...) { append(as_stan_list(object@model@parameters)) |> append(as_stan_list(generator, data = object@data, model = object@model)) + stanobj <- as.StanModule(object, generator = generator, type = type) + model <- compileStanModel(stanobj) + + devnull <- utils::capture.output( + results <- model$generate_quantities( + data = data, + fitted_params = object@results + ) + ) + return(results) +} + + +#' `JointModelSamples` -> `StanModule` +#' +#' Converts a `JointModelSamples` object into a `StanModule` object ensuring +#' that the resulting `StanModule` object is able to generate post sampling +#' quantities. +#' +#' @inheritParams generateQuantities +#' @export +as.StanModule.JointModelSamples <- function(object, generator, type, ...) { assert_that( + is(generator, "QuantityGenerator"), length(type) == 1, type %in% c("survival", "longitudinal") ) @@ -60,17 +83,10 @@ generateQuantities.JointModelSamples <- function(object, generator, type, ...) { quant_stanobj ) ) + stanobj +} - model <- compileStanModel(stanobj) - devnull <- utils::capture.output( - results <- model$generate_quantities( - data = data, - fitted_params = object@results - ) - ) - return(results) -} #' `JointModelSamples` -> Printable `Character` #' diff --git a/R/SimLongitudinalClaretBruno.R b/R/SimLongitudinalClaretBruno.R index 5ca97a1da..7c949c349 100644 --- a/R/SimLongitudinalClaretBruno.R +++ b/R/SimLongitudinalClaretBruno.R @@ -175,7 +175,7 @@ sampleSubjects.SimLongitudinalClaretBruno <- function(object, subjects_df) { #' Claret-Bruno Functionals #' -#' @param time (`numeric`)\cr time grid. +#' @param t (`numeric`)\cr time grid. #' @param b (`number`)\cr baseline sld. #' @param g (`number`)\cr growth rate. #' @param c (`number`)\cr resistance rate. diff --git a/R/generics.R b/R/generics.R index ec0e45ee7..cb5082edf 100755 --- a/R/generics.R +++ b/R/generics.R @@ -44,10 +44,11 @@ NULL #' Write the Stan code for a Stan module. #' #' @param object the module. -#' @param file_path (`string`)\cr output file. +#' @param destination (`character` or `connection`)\cr Where to write stan code to. +#' @param ... Additional arguments #' #' @export -write_stan <- function(object, file_path) { +write_stan <- function(object, destination, ...) { UseMethod("write_stan") } diff --git a/R/utilities.R b/R/utilities.R index 039582566..87b9de8ad 100644 --- a/R/utilities.R +++ b/R/utilities.R @@ -329,3 +329,8 @@ decompose_subjects <- function(subjects, all_subjects) { is_cmdstanr_available <- function() { requireNamespace("cmdstanr", quietly = TRUE) } + + +is_connection <- function(obj) { + inherits(obj, "connection") +} diff --git a/_pkgdown.yml b/_pkgdown.yml index 73e7d457c..f22443b4c 100644 --- a/_pkgdown.yml +++ b/_pkgdown.yml @@ -143,6 +143,7 @@ reference: - as.StanModule.JointModel - as.StanModule.Link - as.StanModule.LinkComponent + - as.StanModule.JointModelSamples - as.character.Parameter - as.character.Prior - show-object diff --git a/inst/stan/lm-stein-fojo/quantities.stan b/inst/stan/lm-stein-fojo/quantities.stan index 92c07fdb1..7352540c6 100644 --- a/inst/stan/lm-stein-fojo/quantities.stan +++ b/inst/stan/lm-stein-fojo/quantities.stan @@ -15,9 +15,9 @@ generated quantities { // Source - lm-stein-fojo/quantities.stan // matrix[n_subjects, 3] long_gq_parameters; - long_gq_parameters[, 1] = lm_gsf_psi_bsld; - long_gq_parameters[, 2] = lm_gsf_psi_ks; - long_gq_parameters[, 3] = lm_gsf_psi_kg; + long_gq_parameters[, 1] = lm_sf_psi_bsld; + long_gq_parameters[, 2] = lm_sf_psi_ks; + long_gq_parameters[, 3] = lm_sf_psi_kg; matrix[gq_n_quant, 3] long_gq_pop_parameters; long_gq_pop_parameters[, 1] = exp(lm_sf_mu_bsld[gq_long_pop_study_index]); diff --git a/man/as.StanModule.JointModelSamples.Rd b/man/as.StanModule.JointModelSamples.Rd new file mode 100644 index 000000000..fe70d01bb --- /dev/null +++ b/man/as.StanModule.JointModelSamples.Rd @@ -0,0 +1,24 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/JointModelSamples.R +\name{as.StanModule.JointModelSamples} +\alias{as.StanModule.JointModelSamples} +\title{\code{JointModelSamples} -> \code{StanModule}} +\usage{ +\method{as.StanModule}{JointModelSamples}(object, generator, type, ...) +} +\arguments{ +\item{object}{object to obtain generated quantities from} + +\item{generator}{(\code{QuantityGenerator})\cr object that specifies which subjects and time points +to calculate the quantities at} + +\item{type}{(\code{character})\cr type of quantities to be generated, must be either "survival" or +"longitudinal".} + +\item{...}{additional options.} +} +\description{ +Converts a \code{JointModelSamples} object into a \code{StanModule} object ensuring +that the resulting \code{StanModule} object is able to generate post sampling +quantities. +} diff --git a/man/clbr_sld.Rd b/man/clbr_sld.Rd index e194cf304..fea332cd2 100644 --- a/man/clbr_sld.Rd +++ b/man/clbr_sld.Rd @@ -13,6 +13,8 @@ clbr_ttg(t, b, g, c, p) clbr_dsld(t, b, g, c, p) } \arguments{ +\item{t}{(\code{numeric})\cr time grid.} + \item{b}{(\code{number})\cr baseline sld.} \item{g}{(\code{number})\cr growth rate.} @@ -20,8 +22,6 @@ clbr_dsld(t, b, g, c, p) \item{c}{(\code{number})\cr resistance rate.} \item{p}{(\code{number})\cr growth inhibition.} - -\item{time}{(\code{numeric})\cr time grid.} } \value{ The function results. diff --git a/man/write_stan.Rd b/man/write_stan.Rd index 5b64623b0..86d2933d3 100644 --- a/man/write_stan.Rd +++ b/man/write_stan.Rd @@ -5,14 +5,16 @@ \alias{write_stan.JointModel} \title{\code{write_stan}} \usage{ -write_stan(object, file_path) +write_stan(object, destination, ...) -\method{write_stan}{JointModel}(object, file_path) +\method{write_stan}{JointModel}(object, destination, ...) } \arguments{ \item{object}{the module.} -\item{file_path}{(\code{string})\cr output file.} +\item{destination}{(\code{character} or \code{connection})\cr Where to write stan code to.} + +\item{...}{Additional arguments} } \description{ Write the Stan code for a Stan module. diff --git a/tests/testthat/test-LongitudinalClaretBruno.R b/tests/testthat/test-LongitudinalClaretBruno.R index 4e059c8b2..de2b35b8e 100644 --- a/tests/testthat/test-LongitudinalClaretBruno.R +++ b/tests/testthat/test-LongitudinalClaretBruno.R @@ -273,3 +273,25 @@ test_that("Can recover known distributional parameters from a SF joint model", { expect_true(all(dat$q99 >= true_values)) expect_true(all(dat$ess_bulk > 100)) }) + + +test_that("Quantity models pass the parser", { + mock_samples <- .JointModelSamples( + model = JointModel(longitudinal = LongitudinalClaretBruno()), + data = structure(1, class = "DataJoint"), + results = structure(1, class = "CmdStanMCMC") + ) + stanmod <- as.StanModule( + mock_samples, + generator = QuantityGeneratorPopulation(1, "A", "B"), + type = "longitudinal" + ) + expect_stan_syntax(stanmod) + + stanmod <- as.StanModule( + mock_samples, + generator = QuantityGeneratorSubject(1, "A"), + type = "longitudinal" + ) + expect_stan_syntax(stanmod) +}) diff --git a/tests/testthat/test-LongitudinalGSF.R b/tests/testthat/test-LongitudinalGSF.R index 22f211a85..0e2f05537 100644 --- a/tests/testthat/test-LongitudinalGSF.R +++ b/tests/testthat/test-LongitudinalGSF.R @@ -197,3 +197,26 @@ test_that("Can recover known distributional parameters from a full GSF joint mod expect_true(all(dat$q99 >= true_values)) expect_true(all(dat$ess_bulk > 100)) }) + + + +test_that("Quantity models pass the parser", { + mock_samples <- .JointModelSamples( + model = JointModel(longitudinal = LongitudinalGSF(centred = FALSE)), + data = structure(1, class = "DataJoint"), + results = structure(1, class = "CmdStanMCMC") + ) + stanmod <- as.StanModule( + mock_samples, + generator = QuantityGeneratorPopulation(1, "A", "B"), + type = "longitudinal" + ) + expect_stan_syntax(stanmod) + + stanmod <- as.StanModule( + mock_samples, + generator = QuantityGeneratorSubject(1, "A"), + type = "longitudinal" + ) + expect_stan_syntax(stanmod) +}) diff --git a/tests/testthat/test-LongitudinalRandomSlope.R b/tests/testthat/test-LongitudinalRandomSlope.R index ff60e437d..4afbb6ffc 100644 --- a/tests/testthat/test-LongitudinalRandomSlope.R +++ b/tests/testthat/test-LongitudinalRandomSlope.R @@ -295,3 +295,25 @@ test_that("Random Slope Model left-censoring works as expected", { ) expect_gt(lmer_cor, 0.99) }) + + +test_that("Quantity models pass the parser", { + mock_samples <- .JointModelSamples( + model = JointModel(longitudinal = LongitudinalRandomSlope()), + data = structure(1, class = "DataJoint"), + results = structure(1, class = "CmdStanMCMC") + ) + stanmod <- as.StanModule( + mock_samples, + generator = QuantityGeneratorPopulation(1, "A", "B"), + type = "longitudinal" + ) + expect_stan_syntax(stanmod) + + stanmod <- as.StanModule( + mock_samples, + generator = QuantityGeneratorSubject(1, "A"), + type = "longitudinal" + ) + expect_stan_syntax(stanmod) +}) diff --git a/tests/testthat/test-LongitudinalSteinFojo.R b/tests/testthat/test-LongitudinalSteinFojo.R index 9af172365..03d42b09c 100644 --- a/tests/testthat/test-LongitudinalSteinFojo.R +++ b/tests/testthat/test-LongitudinalSteinFojo.R @@ -1,5 +1,4 @@ - test_that("LongitudinalSteinFojo works as expected with default arguments", { result <- expect_silent(LongitudinalSteinFojo()) expect_s4_class(result, "LongitudinalSteinFojo") @@ -90,6 +89,7 @@ test_that("Non-Centralised parameterisation compiles without issues", { + test_that("Can recover known distributional parameters from a SF joint model", { skip_if_not(is_full_test()) @@ -388,3 +388,26 @@ test_that("Can recover known distributional parameters from a SF joint model wit expect_true(all(dat$q99 >= true_values)) expect_true(all(dat$ess_bulk > 100)) }) + + + +test_that("Quantity models pass the parser", { + mock_samples <- .JointModelSamples( + model = JointModel(longitudinal = LongitudinalSteinFojo(centred = TRUE)), + data = structure(1, class = "DataJoint"), + results = structure(1, class = "CmdStanMCMC") + ) + stanmod <- as.StanModule( + mock_samples, + generator = QuantityGeneratorPopulation(1, "A", "B"), + type = "longitudinal" + ) + expect_stan_syntax(stanmod) + + stanmod <- as.StanModule( + mock_samples, + generator = QuantityGeneratorSubject(1, "A"), + type = "longitudinal" + ) + expect_stan_syntax(stanmod) +}) diff --git a/vignettes/model_fitting.Rmd b/vignettes/model_fitting.Rmd index 9c12088eb..c1d8f5dfb 100644 --- a/vignettes/model_fitting.Rmd +++ b/vignettes/model_fitting.Rmd @@ -237,7 +237,7 @@ It is always possible to read out the Stan code that is contained in the ```{r debug_stan} tmp <- tempfile() -write_stan(simple_model, file_path = tmp) +write_stan(simple_model, destination = tmp) first_part <- head(readLines(tmp), 10) cat(paste(first_part, collapse = "\n")) ```