Skip to content

Commit

Permalink
Different initial values for each mcmc chain
Browse files Browse the repository at this point in the history
Fixes #4
  • Loading branch information
gowerc committed Feb 2, 2024
1 parent e8a56d4 commit 4c892b6
Show file tree
Hide file tree
Showing 55 changed files with 551 additions and 284 deletions.
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,11 @@ S3method(extractVariableNames,DataSurvival)
S3method(generateQuantities,JointModelSamples)
S3method(getParameters,default)
S3method(initialValues,JointModel)
S3method(initialValues,Link)
S3method(initialValues,Parameter)
S3method(initialValues,ParameterList)
S3method(initialValues,Prior)
S3method(initialValues,StanModel)
S3method(names,Parameter)
S3method(names,ParameterList)
S3method(sampleStanModel,JointModel)
Expand Down Expand Up @@ -100,6 +102,7 @@ export(generateQuantities)
export(gsf_dsld)
export(gsf_sld)
export(gsf_ttg)
export(initialValues)
export(link_gsf_abstract)
export(link_gsf_dsld)
export(link_gsf_identity)
Expand Down
72 changes: 62 additions & 10 deletions R/JointModel.R
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ compileStanModel.JointModel <- function(object) {
#' @export
sampleStanModel.JointModel <- function(object, data, ...) {

assert_class(data, "DataJoint")

if (!is.null(object@survival)) {
assert_that(
!is.null(data@survival),
Expand All @@ -150,21 +152,36 @@ sampleStanModel.JointModel <- function(object, data, ...) {
}

args <- list(...)

args[["data"]] <- append(
as_stan_list(data),
as_stan_list(object@parameters)
)

if (!"init" %in% names(args)) {
values_initial <- initialValues(object)
values_sizes <- size(object@parameters)
values_sizes_complete <- replace_with_lookup(values_sizes, args[["data"]])
values_initial_expanded <- expand_initial_values(values_initial, values_sizes_complete)
args[["init"]] <- function() values_initial_expanded
args[["chains"]] <- if ("chains" %in% names(args)) {
args[["chains"]]
} else {
4
}

initial_values <- if ("init" %in% names(args)) {
args[["init"]]
} else {
initialValues(object, nchains = args[["chains"]])
}

args[["init"]] <- ensure_initial_values(
initial_values,
args[["data"]],
object@parameters
)

model <- compileStanModel(object)
results <- do.call(model$sample, args)

results <- do.call(
model$sample,
args
)

.JointModelSamples(
model = object,
Expand All @@ -174,12 +191,47 @@ sampleStanModel.JointModel <- function(object, data, ...) {
}


# initialValues-JointModel ----
#' Ensure that initial values are correctly specified
#'
#' @param initial_values (`list`)\cr A list of lists containing the initial values
#' must be 1 list per desired chain. All elements should have identical names
#' @param data (`list`)\cr specifies the size to expand each of our initial values to be.
#' That is elements of size 1 in `initial_values` will be expanded to be the same
#' size as the corresponding element in `data` by broadcasting the value.
#' @param parameters ([`ParameterList`])\cr the parameters object
#'
#' @details
#' This function is mostly a thin wrapper around `expand_initial_values` to
#' enable easier unit testing.
#'
#' @keywords internal
ensure_initial_values <- function(initial_values, data, parameters) {
if (is.function(initial_values)) {
return(initial_values)
}

assert_class(data, "list")
assert_class(parameters, "ParameterList")
assert_class(initial_values, "list")

values_sizes <- size(parameters)
values_sizes_complete <- replace_with_lookup(
values_sizes,
data
)
lapply(
initial_values,
expand_initial_values,
sizes = values_sizes_complete
)
}



#' @rdname initialValues
#' @export
initialValues.JointModel <- function(object) {
initialValues(object@parameters)
initialValues.JointModel <- function(object, nchains, ...) {
initialValues(object@parameters, nchains)
}


Expand Down
5 changes: 3 additions & 2 deletions R/Link.R
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,9 @@ setMethod(
# initialValues-Link ----

#' @rdname initialValues
initialValues.Link <- function(object) {
initialValues(object@parameters)
#' @export
initialValues.Link <- function(object, nchains, ...) {
initialValues(object@parameters, nchains)
}


Expand Down
6 changes: 3 additions & 3 deletions R/LinkGSF.R
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ link_gsf_abstract <- function(
#'
#' @export
link_gsf_ttg <- function(
gamma = prior_normal(0, 5, init = 0)
gamma = prior_normal(0, 5)
) {
.link_gsf_ttg(
name = "TTG",
Expand Down Expand Up @@ -182,7 +182,7 @@ link_gsf_ttg <- function(
#'
#' @export
link_gsf_dsld <- function(
beta = prior_normal(0, 5, init = 0)
beta = prior_normal(0, 5)
) {
.link_gsf_dsld(
name = "dSLD",
Expand Down Expand Up @@ -215,7 +215,7 @@ link_gsf_dsld <- function(
#' @param tau (`Prior`)\cr prior for the link coefficient `tau`.
#'
#' @export
link_gsf_identity <- function(tau = prior_normal(0, 5, init = 0)) {
link_gsf_identity <- function(tau = prior_normal(0, 5)) {
.link_gsf_identity(
name = "Identity",
stan = StanModule("lm-gsf/link_identity.stan"),
Expand Down
2 changes: 1 addition & 1 deletion R/LinkRandomSlope.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ NULL
#'
#' @export
LinkRandomSlope <- function(
link_lm_phi = prior_normal(0.2, 0.5, init = 0.02)
link_lm_phi = prior_normal(0.2, 0.5)
) {
.LinkRandomSlope(
Link(
Expand Down
40 changes: 13 additions & 27 deletions R/LongitudinalGSF.R
Original file line number Diff line number Diff line change
Expand Up @@ -36,37 +36,23 @@ NULL
#' @param a_phi (`Prior`)\cr for the alpha parameter for the fraction of cells that respond to treatment.
#' @param b_phi (`Prior`)\cr for the beta parameter for the fraction of cells that respond to treatment.
#'
#' @param psi_bsld (`Prior`)\cr for the baseline value random effect `psi_bsld`. Only used in the
#' centered parameterization to set the initial value.
#' @param psi_ks (`Prior`)\cr for the shrinkage rate random effect `psi_ks`. Only used in the
#' centered parameterization to set the initial value.
#' @param psi_kg (`Prior`)\cr for the growth rate random effect `psi_kg`. Only used in the
#' centered parameterization to set the initial value.
#' @param psi_phi (`Prior`)\cr for the shrinkage proportion random effect `psi_phi`. Only used in the
#' centered parameterization to set the initial value.
#'
#' @param centered (`logical`)\cr whether to use the centered parameterization.
#'
#' @export
LongitudinalGSF <- function(

mu_bsld = prior_normal(log(60), 1, init = 60),
mu_ks = prior_normal(log(0.5), 1, init = 0.5),
mu_kg = prior_normal(log(0.3), 1, init = 0.3),

omega_bsld = prior_lognormal(log(0.2), 1, init = 0.2),
omega_ks = prior_lognormal(log(0.2), 1, init = 0.2),
omega_kg = prior_lognormal(log(0.2), 1, init = 0.2),
mu_bsld = prior_normal(log(60), 1),
mu_ks = prior_normal(log(0.5), 1),
mu_kg = prior_normal(log(0.3), 1),

a_phi = prior_lognormal(log(5), 1, init = 5),
b_phi = prior_lognormal(log(5), 1, init = 5),
omega_bsld = prior_lognormal(log(0.2), 1),
omega_ks = prior_lognormal(log(0.2), 1),
omega_kg = prior_lognormal(log(0.2), 1),

sigma = prior_lognormal(log(0.1), 1, init = 0.1),
a_phi = prior_lognormal(log(5), 1),
b_phi = prior_lognormal(log(5), 1),

psi_bsld = prior_none(init = 60),
psi_ks = prior_none(init = 0.5),
psi_kg = prior_none(init = 0.5),
psi_phi = prior_none(init = 0.5),
sigma = prior_lognormal(log(0.1), 1),

centered = FALSE
) {
Expand All @@ -87,17 +73,17 @@ LongitudinalGSF <- function(

Parameter(name = "lm_gsf_a_phi", prior = a_phi, size = "n_arms"),
Parameter(name = "lm_gsf_b_phi", prior = b_phi, size = "n_arms"),
Parameter(name = "lm_gsf_psi_phi", prior = psi_phi, size = "Nind"),
Parameter(name = "lm_gsf_psi_phi", prior = prior_none(), size = "Nind"),

Parameter(name = "lm_gsf_sigma", prior = sigma, size = 1)
)

assert_flag(centered)
parameters_extra <- if (centered) {
list(
Parameter(name = "lm_gsf_psi_bsld", prior = psi_bsld, size = "Nind"),
Parameter(name = "lm_gsf_psi_ks", prior = psi_ks, size = "Nind"),
Parameter(name = "lm_gsf_psi_kg", prior = psi_kg, size = "Nind")
Parameter(name = "lm_gsf_psi_bsld", prior = prior_none(), size = "Nind"),
Parameter(name = "lm_gsf_psi_ks", prior = prior_none(), size = "Nind"),
Parameter(name = "lm_gsf_psi_kg", prior = prior_none(), size = "Nind")
)
} else {
list(
Expand Down
10 changes: 5 additions & 5 deletions R/LongitudinalRandomSlope.R
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ NULL
#'
#' @export
LongitudinalRandomSlope <- function(
intercept = prior_normal(30, 10, init = 30),
slope_mu = prior_normal(0, 15, init = 0.001),
slope_sigma = prior_lognormal(1, 5, init = 1),
sigma = prior_lognormal(1, 5, init = 1),
random_slope = prior_none(init = 0)
intercept = prior_normal(30, 10),
slope_mu = prior_normal(0, 15),
slope_sigma = prior_lognormal(0, 1.5),
sigma = prior_lognormal(0, 1.5),
random_slope = prior_none()
) {

stan <- StanModule(
Expand Down
2 changes: 1 addition & 1 deletion R/Parameter.R
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ names.Parameter <- function(x) x@name

#' @describeIn Parameter-Getter-Methods The parameter's initial values
#' @export
initialValues.Parameter <- function(object) initialValues(object@prior)
initialValues.Parameter <- function(object, ...) initialValues(object@prior)

#' @describeIn Parameter-Getter-Methods The parameter's dimensionality
#' @export
Expand Down
16 changes: 11 additions & 5 deletions R/ParameterList.R
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,17 @@ names.ParameterList <- function(x) {

#' @describeIn ParameterList-Getter-Methods The parameter-list's parameter initial values
#' @export
initialValues.ParameterList <- function(object) {
vals <- lapply(object@parameters, initialValues)
name <- vapply(object@parameters, names, character(1))
names(vals) <- name
return(vals)
initialValues.ParameterList <- function(object, nchains, ...) {
x <- lapply(
seq_len(nchains),
\(i) {
vals <- lapply(object@parameters, initialValues)
name <- vapply(object@parameters, names, character(1))
names(vals) <- name
vals
}
)
return(x)
}


Expand Down
Loading

0 comments on commit 4c892b6

Please sign in to comment.