Skip to content

Commit

Permalink
Fix SF Quantities (#370)
Browse files Browse the repository at this point in the history
  • Loading branch information
gowerc authored Jun 26, 2024
1 parent 46d9b2e commit a28d91b
Show file tree
Hide file tree
Showing 16 changed files with 168 additions and 25 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ S3method(as.QuantityGenerator,GridObserved)
S3method(as.QuantityGenerator,GridPopulation)
S3method(as.QuantityGenerator,GridPrediction)
S3method(as.StanModule,JointModel)
S3method(as.StanModule,JointModelSamples)
S3method(as.StanModule,Link)
S3method(as.StanModule,LinkComponent)
S3method(as.StanModule,Parameter)
Expand Down
9 changes: 6 additions & 3 deletions R/JointModel.R
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,12 @@ as.character.JointModel <- function(x, ...) {

#' @rdname write_stan
#' @export
write_stan.JointModel <- function(object, file_path) {
fi <- file(file_path, open = "w")
writeLines(as.character(object), fi)
write_stan.JointModel <- function(object, destination, ...) {
if (is_connection(destination)) {
return(writeLines(as.character(object), con = destination))
}
fi <- file(destination, open = "w")
writeLines(as.character(object), con = fi)
close(fi)
}

Expand Down
34 changes: 25 additions & 9 deletions R/JointModelSamples.R
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,30 @@ generateQuantities.JointModelSamples <- function(object, generator, type, ...) {
append(as_stan_list(object@model@parameters)) |>
append(as_stan_list(generator, data = object@data, model = object@model))

stanobj <- as.StanModule(object, generator = generator, type = type)
model <- compileStanModel(stanobj)

devnull <- utils::capture.output(
results <- model$generate_quantities(
data = data,
fitted_params = object@results
)
)
return(results)
}


#' `JointModelSamples` -> `StanModule`
#'
#' Converts a `JointModelSamples` object into a `StanModule` object ensuring
#' that the resulting `StanModule` object is able to generate post sampling
#' quantities.
#'
#' @inheritParams generateQuantities
#' @export
as.StanModule.JointModelSamples <- function(object, generator, type, ...) {
assert_that(
is(generator, "QuantityGenerator"),
length(type) == 1,
type %in% c("survival", "longitudinal")
)
Expand All @@ -60,17 +83,10 @@ generateQuantities.JointModelSamples <- function(object, generator, type, ...) {
quant_stanobj
)
)
stanobj
}

model <- compileStanModel(stanobj)

devnull <- utils::capture.output(
results <- model$generate_quantities(
data = data,
fitted_params = object@results
)
)
return(results)
}

#' `JointModelSamples` -> Printable `Character`
#'
Expand Down
2 changes: 1 addition & 1 deletion R/SimLongitudinalClaretBruno.R
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ sampleSubjects.SimLongitudinalClaretBruno <- function(object, subjects_df) {

#' Claret-Bruno Functionals
#'
#' @param time (`numeric`)\cr time grid.
#' @param t (`numeric`)\cr time grid.
#' @param b (`number`)\cr baseline sld.
#' @param g (`number`)\cr growth rate.
#' @param c (`number`)\cr resistance rate.
Expand Down
5 changes: 3 additions & 2 deletions R/generics.R
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,11 @@ NULL
#' Write the Stan code for a Stan module.
#'
#' @param object the module.
#' @param file_path (`string`)\cr output file.
#' @param destination (`character` or `connection`)\cr Where to write stan code to.
#' @param ... Additional arguments
#'
#' @export
write_stan <- function(object, file_path) {
write_stan <- function(object, destination, ...) {
UseMethod("write_stan")
}

Expand Down
5 changes: 5 additions & 0 deletions R/utilities.R
Original file line number Diff line number Diff line change
Expand Up @@ -329,3 +329,8 @@ decompose_subjects <- function(subjects, all_subjects) {
is_cmdstanr_available <- function() {
requireNamespace("cmdstanr", quietly = TRUE)
}


is_connection <- function(obj) {
inherits(obj, "connection")
}
1 change: 1 addition & 0 deletions _pkgdown.yml
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ reference:
- as.StanModule.JointModel
- as.StanModule.Link
- as.StanModule.LinkComponent
- as.StanModule.JointModelSamples
- as.character.Parameter
- as.character.Prior
- show-object
Expand Down
6 changes: 3 additions & 3 deletions inst/stan/lm-stein-fojo/quantities.stan
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ generated quantities {
// Source - lm-stein-fojo/quantities.stan
//
matrix[n_subjects, 3] long_gq_parameters;
long_gq_parameters[, 1] = lm_gsf_psi_bsld;
long_gq_parameters[, 2] = lm_gsf_psi_ks;
long_gq_parameters[, 3] = lm_gsf_psi_kg;
long_gq_parameters[, 1] = lm_sf_psi_bsld;
long_gq_parameters[, 2] = lm_sf_psi_ks;
long_gq_parameters[, 3] = lm_sf_psi_kg;

matrix[gq_n_quant, 3] long_gq_pop_parameters;
long_gq_pop_parameters[, 1] = exp(lm_sf_mu_bsld[gq_long_pop_study_index]);
Expand Down
24 changes: 24 additions & 0 deletions man/as.StanModule.JointModelSamples.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions man/clbr_sld.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 5 additions & 3 deletions man/write_stan.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

22 changes: 22 additions & 0 deletions tests/testthat/test-LongitudinalClaretBruno.R
Original file line number Diff line number Diff line change
Expand Up @@ -273,3 +273,25 @@ test_that("Can recover known distributional parameters from a SF joint model", {
expect_true(all(dat$q99 >= true_values))
expect_true(all(dat$ess_bulk > 100))
})


test_that("Quantity models pass the parser", {
mock_samples <- .JointModelSamples(
model = JointModel(longitudinal = LongitudinalClaretBruno()),
data = structure(1, class = "DataJoint"),
results = structure(1, class = "CmdStanMCMC")
)
stanmod <- as.StanModule(
mock_samples,
generator = QuantityGeneratorPopulation(1, "A", "B"),
type = "longitudinal"
)
expect_stan_syntax(stanmod)

stanmod <- as.StanModule(
mock_samples,
generator = QuantityGeneratorSubject(1, "A"),
type = "longitudinal"
)
expect_stan_syntax(stanmod)
})
23 changes: 23 additions & 0 deletions tests/testthat/test-LongitudinalGSF.R
Original file line number Diff line number Diff line change
Expand Up @@ -197,3 +197,26 @@ test_that("Can recover known distributional parameters from a full GSF joint mod
expect_true(all(dat$q99 >= true_values))
expect_true(all(dat$ess_bulk > 100))
})



test_that("Quantity models pass the parser", {
mock_samples <- .JointModelSamples(
model = JointModel(longitudinal = LongitudinalGSF(centred = FALSE)),
data = structure(1, class = "DataJoint"),
results = structure(1, class = "CmdStanMCMC")
)
stanmod <- as.StanModule(
mock_samples,
generator = QuantityGeneratorPopulation(1, "A", "B"),
type = "longitudinal"
)
expect_stan_syntax(stanmod)

stanmod <- as.StanModule(
mock_samples,
generator = QuantityGeneratorSubject(1, "A"),
type = "longitudinal"
)
expect_stan_syntax(stanmod)
})
22 changes: 22 additions & 0 deletions tests/testthat/test-LongitudinalRandomSlope.R
Original file line number Diff line number Diff line change
Expand Up @@ -295,3 +295,25 @@ test_that("Random Slope Model left-censoring works as expected", {
)
expect_gt(lmer_cor, 0.99)
})


test_that("Quantity models pass the parser", {
mock_samples <- .JointModelSamples(
model = JointModel(longitudinal = LongitudinalRandomSlope()),
data = structure(1, class = "DataJoint"),
results = structure(1, class = "CmdStanMCMC")
)
stanmod <- as.StanModule(
mock_samples,
generator = QuantityGeneratorPopulation(1, "A", "B"),
type = "longitudinal"
)
expect_stan_syntax(stanmod)

stanmod <- as.StanModule(
mock_samples,
generator = QuantityGeneratorSubject(1, "A"),
type = "longitudinal"
)
expect_stan_syntax(stanmod)
})
25 changes: 24 additions & 1 deletion tests/testthat/test-LongitudinalSteinFojo.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@


test_that("LongitudinalSteinFojo works as expected with default arguments", {
result <- expect_silent(LongitudinalSteinFojo())
expect_s4_class(result, "LongitudinalSteinFojo")
Expand Down Expand Up @@ -90,6 +89,7 @@ test_that("Non-Centralised parameterisation compiles without issues", {




test_that("Can recover known distributional parameters from a SF joint model", {

skip_if_not(is_full_test())
Expand Down Expand Up @@ -388,3 +388,26 @@ test_that("Can recover known distributional parameters from a SF joint model wit
expect_true(all(dat$q99 >= true_values))
expect_true(all(dat$ess_bulk > 100))
})



test_that("Quantity models pass the parser", {
mock_samples <- .JointModelSamples(
model = JointModel(longitudinal = LongitudinalSteinFojo(centred = TRUE)),
data = structure(1, class = "DataJoint"),
results = structure(1, class = "CmdStanMCMC")
)
stanmod <- as.StanModule(
mock_samples,
generator = QuantityGeneratorPopulation(1, "A", "B"),
type = "longitudinal"
)
expect_stan_syntax(stanmod)

stanmod <- as.StanModule(
mock_samples,
generator = QuantityGeneratorSubject(1, "A"),
type = "longitudinal"
)
expect_stan_syntax(stanmod)
})
2 changes: 1 addition & 1 deletion vignettes/model_fitting.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ It is always possible to read out the Stan code that is contained in the

```{r debug_stan}
tmp <- tempfile()
write_stan(simple_model, file_path = tmp)
write_stan(simple_model, destination = tmp)
first_part <- head(readLines(tmp), 10)
cat(paste(first_part, collapse = "\n"))
```
Expand Down

0 comments on commit a28d91b

Please sign in to comment.