Skip to content

Commit

Permalink
Added user option to adjust #nodes in Gaussian Quadrature (#272)
Browse files Browse the repository at this point in the history
  • Loading branch information
gowerc authored Mar 1, 2024
1 parent b0579b0 commit 1c60a69
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 4 deletions.
5 changes: 4 additions & 1 deletion R/DataSurvival.R
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,10 @@ as_stan_list.DataSurvival <- function(object, ...) {
rownames(design_mat) <- NULL

# Parameters for efficient integration of hazard function -> survival function
gh_parameters <- statmod::gauss.quad(n = 15, kind = "legendre")
gh_parameters <- statmod::gauss.quad(
n = getOption("jmpost.gauss_quad_n"),
kind = getOption("jmpost.gauss_quad_kind")
)

model_data <- list(
Nind_dead = sum(df[[vars$event]]),
Expand Down
18 changes: 17 additions & 1 deletion R/settings.R
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,21 @@
#' Directory to store compiled stan models in. If not set, a temporary directory is used for
#' the given R session. Can also be set via the environment variable `JMPOST_CACHE_DIR`.
#'
#'
#'
#' ## `jmpost.gauss_quad_n`
#'
#' Default = 15
#'
#' In most cases the survival function of the joint model does not have a closed form
#' and as such it is calculated by integrating the hazard function. `jmpost` estimates this
#' via Gaussian Quadrature, in particular it uses [`statmod::gauss.quad`] with
#' `kind = "legendre"` to create the nodes and weights.
#'
#' This option specifies the `n` argument in the call to [`statmod::gauss.quad`]. In general
#' higher values of `n` lead to better accuracy of the approximation but at the cost of
#' increased computational time.
#'
#' @examples
#' \dontrun{
#' options(jmpost.prior_shrinkage = 0.5)
Expand All @@ -46,7 +61,8 @@ set_options <- function() {
current_opts <- names(options())
jmpost_opts <- list(
jmpost.cache_dir = cache_dir,
jmpost.prior_shrinkage = 0.5
jmpost.prior_shrinkage = 0.5,
jmpost.gauss_quad_n = 15
)
for (opt in names(jmpost_opts)) {
if (!opt %in% current_opts) {
Expand Down
2 changes: 1 addition & 1 deletion design/interfaces.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ matrix link_contribution(matrix time, matrix pars_lm)
```

- must return same dimensions as time
- where time is 1 row per subject (Nind) and 1 column per required timepoint (nodes from the guasian quadrature) (i.e. if multiple timepoints need to be evaluated for each subject it will still be 1 row but with multiple columns, if only 1 timepoint needs to be evaluated then there will only be 1 column)
- where time is 1 row per subject (Nind) and 1 column per required timepoint (nodes from the Gaussian quadrature) (i.e. if multiple timepoints need to be evaluated for each subject it will still be 1 row but with multiple columns, if only 1 timepoint needs to be evaluated then there will only be 1 column)
- pars_lm is a matrix with 1 row per subject and as many columns as you need defined in transformed_parameters()


Expand Down
2 changes: 1 addition & 1 deletion inst/WORDLIST
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ frac
funder
gi
gl
guasian
Gaussian
hardcoded
ie
ig
Expand Down
14 changes: 14 additions & 0 deletions man/jmpost-settings.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

43 changes: 43 additions & 0 deletions tests/testthat/test-options.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@




test_that("Can alter Gaussian Quadrature arguments", {
x <- data.frame(
vpt = c("b", "a", "c", "d", "e"),
vtime = c(10, 20, 30, 25, 15),
vevent = c(1, 1, 0, 1, 0),
vcov1 = c("A", "A", "B", "B", "A"),
vcov2 = rnorm(5)
)

## Test defaults 15 + "legendre"
df <- DataSurvival(
data = x,
formula = Surv(vtime, vevent) ~ vcov1 * vcov2
)
li <- as.list(df)
expect_equal(li$n_nodes, 15)
expect_equal(
li[c("nodes", "weights")],
statmod::gauss.quad(15, "legendre")
)


## Test modified values
options("jmpost.gauss_quad_n" = 20)
df <- DataSurvival(
data = x,
formula = Surv(vtime, vevent) ~ vcov1 * vcov2
)
li <- as.list(df)
expect_equal(li$n_nodes, 20)
expect_equal(
li[c("nodes", "weights")],
statmod::gauss.quad(20, "legendre")
)


## Reset back to default to not impact additional tests
options("jmpost.gauss_quad_n" = 15)
})

0 comments on commit 1c60a69

Please sign in to comment.