Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Different initial values for each mcmc chain #254

Merged
merged 22 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ Collate:
'DataLongitudinal.R'
'DataSurvival.R'
'DataJoint.R'
'constants.R'
'StanModule.R'
'Prior.R'
'Parameter.R'
Expand All @@ -88,6 +89,7 @@ Collate:
'defaults.R'
'external-exports.R'
'jmpost-package.R'
'settings.R'
'simulations.R'
'simulations_gsf.R'
'simulations_os.R'
Expand Down
12 changes: 11 additions & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,11 @@ S3method(extractVariableNames,DataSurvival)
S3method(generateQuantities,JointModelSamples)
S3method(getParameters,default)
S3method(initialValues,JointModel)
S3method(initialValues,Link)
S3method(initialValues,Parameter)
S3method(initialValues,ParameterList)
S3method(initialValues,Prior)
S3method(initialValues,StanModel)
S3method(names,Parameter)
S3method(names,ParameterList)
S3method(sampleStanModel,JointModel)
Expand Down Expand Up @@ -100,6 +102,7 @@ export(generateQuantities)
export(gsf_dsld)
export(gsf_sld)
export(gsf_ttg)
export(initialValues)
export(link_gsf_abstract)
export(link_gsf_dsld)
export(link_gsf_identity)
Expand All @@ -112,7 +115,6 @@ export(prior_invgamma)
export(prior_logistic)
export(prior_loglogistic)
export(prior_lognormal)
export(prior_none)
export(prior_normal)
export(prior_std_normal)
export(prior_student_t)
Expand Down Expand Up @@ -162,5 +164,13 @@ importFrom(ggplot2,autoplot)
importFrom(ggplot2.utils,geom_km)
importFrom(glue,as_glue)
importFrom(stats,acf)
importFrom(stats,rbeta)
importFrom(stats,rcauchy)
importFrom(stats,rgamma)
importFrom(stats,rlnorm)
importFrom(stats,rlogis)
importFrom(stats,rnorm)
importFrom(stats,rt)
importFrom(stats,runif)
importFrom(survival,Surv)
importFrom(tibble,add_case)
74 changes: 64 additions & 10 deletions R/JointModel.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#' @include LongitudinalModel.R
#' @include SurvivalModel.R
#' @include Link.R
#' @include constants.R
NULL


Expand Down Expand Up @@ -136,6 +137,8 @@ compileStanModel.JointModel <- function(object) {
#' @export
sampleStanModel.JointModel <- function(object, data, ...) {

assert_class(data, "DataJoint")

if (!is.null(object@survival)) {
assert_that(
!is.null(data@survival),
Expand All @@ -150,21 +153,37 @@ sampleStanModel.JointModel <- function(object, data, ...) {
}

args <- list(...)

args[["data"]] <- append(
as_stan_list(data),
as_stan_list(object@parameters)
)

if (!"init" %in% names(args)) {
values_initial <- initialValues(object)
values_sizes <- size(object@parameters)
values_sizes_complete <- replace_with_lookup(values_sizes, args[["data"]])
values_initial_expanded <- expand_initial_values(values_initial, values_sizes_complete)
args[["init"]] <- function() values_initial_expanded
args[["chains"]] <- if ("chains" %in% names(args)) {
args[["chains"]]
} else {
# Magic constant from R/constants.R
CMDSTAN_DEFAULT_CHAINS
}

initial_values <- if ("init" %in% names(args)) {
args[["init"]]
} else {
initialValues(object, n_chains = args[["chains"]])
}

args[["init"]] <- ensure_initial_values(
initial_values,
args[["data"]],
object@parameters
)

model <- compileStanModel(object)
results <- do.call(model$sample, args)

results <- do.call(
model$sample,
args
)

.JointModelSamples(
model = object,
Expand All @@ -174,12 +193,47 @@ sampleStanModel.JointModel <- function(object, data, ...) {
}


# initialValues-JointModel ----
#' Ensure that initial values are correctly specified
#'
#' @param initial_values (`list`)\cr A list of lists containing the initial values
#' must be 1 list per desired chain. All elements should have identical names
#' @param data (`list`)\cr specifies the size to expand each of our initial values to be.
#' That is elements of size 1 in `initial_values` will be expanded to be the same
#' size as the corresponding element in `data` by broadcasting the value.
#' @param parameters ([`ParameterList`])\cr the parameters object
#'
#' @details
#' This function is mostly a thin wrapper around `expand_initial_values` to
#' enable easier unit testing.
#'
#' @keywords internal
ensure_initial_values <- function(initial_values, data, parameters) {
if (is.function(initial_values)) {
return(initial_values)
}

assert_class(data, "list")
assert_class(parameters, "ParameterList")
assert_class(initial_values, "list")

values_sizes <- size(parameters)
values_sizes_complete <- replace_with_lookup(
values_sizes,
data
)
lapply(
initial_values,
expand_initial_values,
sizes = values_sizes_complete
)
}



#' @rdname initialValues
#' @export
initialValues.JointModel <- function(object) {
initialValues(object@parameters)
initialValues.JointModel <- function(object, n_chains, ...) {
initialValues(object@parameters, n_chains)
}


Expand Down
5 changes: 3 additions & 2 deletions R/Link.R
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,9 @@ setMethod(
# initialValues-Link ----

#' @rdname initialValues
initialValues.Link <- function(object) {
initialValues(object@parameters)
#' @export
initialValues.Link <- function(object, n_chains, ...) {
initialValues(object@parameters, n_chains)
}


Expand Down
6 changes: 3 additions & 3 deletions R/LinkGSF.R
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ link_gsf_abstract <- function(
#'
#' @export
link_gsf_ttg <- function(
gamma = prior_normal(0, 5, init = 0)
gamma = prior_normal(0, 5)
) {
.link_gsf_ttg(
name = "TTG",
Expand Down Expand Up @@ -182,7 +182,7 @@ link_gsf_ttg <- function(
#'
#' @export
link_gsf_dsld <- function(
beta = prior_normal(0, 5, init = 0)
beta = prior_normal(0, 5)
) {
.link_gsf_dsld(
name = "dSLD",
Expand Down Expand Up @@ -215,7 +215,7 @@ link_gsf_dsld <- function(
#' @param tau (`Prior`)\cr prior for the link coefficient `tau`.
#'
#' @export
link_gsf_identity <- function(tau = prior_normal(0, 5, init = 0)) {
link_gsf_identity <- function(tau = prior_normal(0, 5)) {
.link_gsf_identity(
name = "Identity",
stan = StanModule("lm-gsf/link_identity.stan"),
Expand Down
2 changes: 1 addition & 1 deletion R/LinkRandomSlope.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ NULL
#'
#' @export
LinkRandomSlope <- function(
link_lm_phi = prior_normal(0.2, 0.5, init = 0.02)
link_lm_phi = prior_normal(0.2, 0.5)
) {
.LinkRandomSlope(
Link(
Expand Down
56 changes: 29 additions & 27 deletions R/LongitudinalGSF.R
Original file line number Diff line number Diff line change
Expand Up @@ -36,37 +36,23 @@ NULL
#' @param a_phi (`Prior`)\cr for the alpha parameter for the fraction of cells that respond to treatment.
#' @param b_phi (`Prior`)\cr for the beta parameter for the fraction of cells that respond to treatment.
#'
#' @param psi_bsld (`Prior`)\cr for the baseline value random effect `psi_bsld`. Only used in the
#' centered parameterization to set the initial value.
#' @param psi_ks (`Prior`)\cr for the shrinkage rate random effect `psi_ks`. Only used in the
#' centered parameterization to set the initial value.
#' @param psi_kg (`Prior`)\cr for the growth rate random effect `psi_kg`. Only used in the
#' centered parameterization to set the initial value.
#' @param psi_phi (`Prior`)\cr for the shrinkage proportion random effect `psi_phi`. Only used in the
#' centered parameterization to set the initial value.
#'
#' @param centered (`logical`)\cr whether to use the centered parameterization.
#'
#' @export
LongitudinalGSF <- function(

mu_bsld = prior_normal(log(60), 1, init = 60),
mu_ks = prior_normal(log(0.5), 1, init = 0.5),
mu_kg = prior_normal(log(0.3), 1, init = 0.3),

omega_bsld = prior_lognormal(log(0.2), 1, init = 0.2),
omega_ks = prior_lognormal(log(0.2), 1, init = 0.2),
omega_kg = prior_lognormal(log(0.2), 1, init = 0.2),
mu_bsld = prior_normal(log(60), 1),
mu_ks = prior_normal(log(0.5), 1),
mu_kg = prior_normal(log(0.3), 1),

a_phi = prior_lognormal(log(5), 1, init = 5),
b_phi = prior_lognormal(log(5), 1, init = 5),
omega_bsld = prior_lognormal(log(0.2), 1),
omega_ks = prior_lognormal(log(0.2), 1),
omega_kg = prior_lognormal(log(0.2), 1),

sigma = prior_lognormal(log(0.1), 1, init = 0.1),
a_phi = prior_lognormal(log(5), 1),
b_phi = prior_lognormal(log(5), 1),

psi_bsld = prior_none(init = 60),
psi_ks = prior_none(init = 0.5),
psi_kg = prior_none(init = 0.5),
psi_phi = prior_none(init = 0.5),
sigma = prior_lognormal(log(0.1), 1),

centered = FALSE
) {
Expand All @@ -87,17 +73,33 @@ LongitudinalGSF <- function(

Parameter(name = "lm_gsf_a_phi", prior = a_phi, size = "n_arms"),
Parameter(name = "lm_gsf_b_phi", prior = b_phi, size = "n_arms"),
Parameter(name = "lm_gsf_psi_phi", prior = psi_phi, size = "Nind"),
Parameter(
name = "lm_gsf_psi_phi",
prior = prior_init_only(prior_beta(a_phi@init, b_phi@init)),
size = "Nind"
),

Parameter(name = "lm_gsf_sigma", prior = sigma, size = 1)
)

assert_flag(centered)
parameters_extra <- if (centered) {
list(
Parameter(name = "lm_gsf_psi_bsld", prior = psi_bsld, size = "Nind"),
Parameter(name = "lm_gsf_psi_ks", prior = psi_ks, size = "Nind"),
Parameter(name = "lm_gsf_psi_kg", prior = psi_kg, size = "Nind")
Parameter(
name = "lm_gsf_psi_bsld",
prior = prior_init_only(prior_lognormal(mu_bsld@init, omega_bsld@init)),
size = "Nind"
),
Parameter(
name = "lm_gsf_psi_ks",
prior = prior_init_only(prior_lognormal(mu_ks@init, omega_ks@init)),
size = "Nind"
),
Parameter(
name = "lm_gsf_psi_kg",
prior = prior_init_only(prior_lognormal(mu_kg@init, omega_kg@init)),
size = "Nind"
)
)
} else {
list(
Expand Down
23 changes: 9 additions & 14 deletions R/LongitudinalRandomSlope.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,28 +24,19 @@ NULL
#' @param slope_mu (`Prior`)\cr for the population slope `slope_mu`.
#' @param slope_sigma (`Prior`)\cr for the random slope standard deviation `slope_sigma`.
#' @param sigma (`Prior`)\cr for the variance of the longitudinal values `sigma`.
#' @param random_slope (`Prior`)\cr must be `prior_none()`, just used to specify initial values.
#'
#' @export
LongitudinalRandomSlope <- function(
intercept = prior_normal(30, 10, init = 30),
slope_mu = prior_normal(0, 15, init = 0.001),
slope_sigma = prior_lognormal(1, 5, init = 1),
sigma = prior_lognormal(1, 5, init = 1),
random_slope = prior_none(init = 0)
intercept = prior_normal(30, 10),
slope_mu = prior_normal(0, 15),
slope_sigma = prior_lognormal(0, 1.5),
sigma = prior_lognormal(0, 1.5)
) {

stan <- StanModule(
x = "lm-random-slope/model.stan"
)

assert_that(
inherits(random_slope, "Prior"),
random_slope@repr_data == "",
random_slope@repr_model == "",
msg = "`random_slope` must be a `prior_none()`"
)

.LongitudinalRandomSlope(
LongitudinalModel(
name = "Random Slope",
Expand All @@ -55,7 +46,11 @@ LongitudinalRandomSlope <- function(
Parameter(name = "lm_rs_slope_mu", prior = slope_mu, size = "n_arms"),
Parameter(name = "lm_rs_slope_sigma", prior = slope_sigma, size = 1),
Parameter(name = "lm_rs_sigma", prior = sigma, size = 1),
Parameter(name = "lm_rs_ind_rnd_slope", prior = random_slope, size = "Nind")
Parameter(
name = "lm_rs_ind_rnd_slope",
prior = prior_init_only(prior_normal(slope_mu@init, slope_sigma@init)),
size = "Nind"
)
)
)
)
Expand Down
3 changes: 2 additions & 1 deletion R/Parameter.R
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ as_stan_list.Parameter <- function(object, ...) {
#'
#' @param x (`Paramater`) \cr A model parameter
#' @param object (`Paramater`) \cr A model parameter
#' @param ... Not used.
#'
#' @description
#' Getter functions for the slots of a [`Parameter`] object
Expand All @@ -115,7 +116,7 @@ names.Parameter <- function(x) x@name

#' @describeIn Parameter-Getter-Methods The parameter's initial values
#' @export
initialValues.Parameter <- function(object) initialValues(object@prior)
initialValues.Parameter <- function(object, ...) initialValues(object@prior)

#' @describeIn Parameter-Getter-Methods The parameter's dimensionality
#' @export
Expand Down
19 changes: 14 additions & 5 deletions R/ParameterList.R
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ as.list.ParameterList <- function(x, ...) {
#' Getter functions for the slots of a [`ParameterList`] object
#' @inheritParams ParameterList-Shared
#' @family ParameterList
#' @param n_chains (`integer`) \cr the number of chains.
#' @name ParameterList-Getter-Methods
NULL

Expand All @@ -145,11 +146,19 @@ names.ParameterList <- function(x) {

#' @describeIn ParameterList-Getter-Methods The parameter-list's parameter initial values
#' @export
initialValues.ParameterList <- function(object) {
vals <- lapply(object@parameters, initialValues)
name <- vapply(object@parameters, names, character(1))
names(vals) <- name
return(vals)
initialValues.ParameterList <- function(object, n_chains, ...) {
# Generate initial values as a list of lists. This is to ensure it is in the required
# format as specified by cmdstanr see the `init` argument of
# `help("model-method-sample", "cmdstanr")` for more details
lapply(
seq_len(n_chains),
\(i) {
vals <- lapply(object@parameters, initialValues)
name <- vapply(object@parameters, names, character(1))
names(vals) <- name
vals
}
)
}


Expand Down
Loading
Loading