Skip to content

Commit

Permalink
Harmonise link methods (#308)
Browse files Browse the repository at this point in the history
  • Loading branch information
gowerc authored Apr 15, 2024
1 parent 8e45d0b commit ba5a4a9
Show file tree
Hide file tree
Showing 46 changed files with 838 additions and 492 deletions.
2 changes: 2 additions & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ Collate:
'LongitudinalQuantities.R'
'LongitudinalRandomSlope.R'
'LongitudinalSteinFojo.R'
'Promise.R'
'SimGroup.R'
'SimJointData.R'
'SimLongitudinal.R'
Expand All @@ -104,6 +105,7 @@ Collate:
'defaults.R'
'external-exports.R'
'jmpost-package.R'
'link_generics.R'
'settings.R'
'zzz.R'
VignetteBuilder: knitr
Expand Down
25 changes: 18 additions & 7 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ S3method(as.StanModule,LinkComponent)
S3method(as.StanModule,Parameter)
S3method(as.StanModule,ParameterList)
S3method(as.StanModule,Prior)
S3method(as.StanModule,PromiseLinkComponent)
S3method(as.character,JointModel)
S3method(as.character,Parameter)
S3method(as.character,Prior)
Expand Down Expand Up @@ -89,18 +90,25 @@ S3method(initialValues,StanModel)
S3method(length,Link)
S3method(length,QuantityCollapser)
S3method(linkDSLD,LongitudinalGSF)
S3method(linkDSLD,LongitudinalModel)
S3method(linkDSLD,LongitudinalRandomSlope)
S3method(linkDSLD,LongitudinalSteinFojo)
S3method(linkDSLD,PromiseLongitudinalModel)
S3method(linkDSLD,default)
S3method(linkIdentity,LongitudinalGSF)
S3method(linkIdentity,LongitudinalModel)
S3method(linkIdentity,LongitudinalRandomSlope)
S3method(linkIdentity,LongitudinalSteinFojo)
S3method(linkIdentity,PromiseLongitudinalModel)
S3method(linkIdentity,default)
S3method(linkTTG,LongitudinalGSF)
S3method(linkTTG,LongitudinalModel)
S3method(linkTTG,LongitudinalSteinFojo)
S3method(linkTTG,PromiseLongitudinalModel)
S3method(linkTTG,default)
S3method(names,LinkComponent)
S3method(names,Parameter)
S3method(names,ParameterList)
S3method(resolvePromise,Link)
S3method(resolvePromise,PromiseLinkComponent)
S3method(resolvePromise,default)
S3method(sampleObservations,SimLongitudinalGSF)
S3method(sampleObservations,SimLongitudinalRandomSlope)
S3method(sampleObservations,SimLongitudinalSteinFojo)
Expand Down Expand Up @@ -138,6 +146,8 @@ export(LongitudinalSteinFojo)
export(Parameter)
export(ParameterList)
export(Prior)
export(PromiseLinkComponent)
export(PromiseLongitudinalModel)
export(STAN_BLOCKS)
export(SimGroup)
export(SimJointData)
Expand Down Expand Up @@ -169,11 +179,8 @@ export(generateQuantities)
export(initialValues)
export(linkDSLD)
export(linkIdentity)
export(linkNone)
export(linkTTG)
export(link_dsld)
export(link_identity)
export(link_none)
export(link_ttg)
export(merge)
export(prior_beta)
export(prior_cauchy)
Expand All @@ -186,6 +193,7 @@ export(prior_normal)
export(prior_std_normal)
export(prior_student_t)
export(prior_uniform)
export(resolvePromise)
export(sampleObservations)
export(sampleStanModel)
export(sampleSubjects)
Expand All @@ -205,6 +213,9 @@ exportClasses(LongitudinalSteinFojo)
exportClasses(Parameter)
exportClasses(ParameterList)
exportClasses(Prior)
exportClasses(Promise)
exportClasses(PromiseLinkComponent)
exportClasses(PromiseLongitudinalModel)
exportClasses(SimGroup)
exportClasses(SimJointData)
exportClasses(SimLongitudinal)
Expand Down
7 changes: 3 additions & 4 deletions R/JointModel.R
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,10 @@ setClassUnion("SurvivalModel_OR_NULL", c("SurvivalModel", "NULL"))
JointModel <- function(
longitudinal = NULL,
survival = NULL,
link = link_none()
link = Link()
) {

# Ensure that it is a link object (e.g. wrap link components in a Link object)
link <- Link(link)
link <- resolvePromise(Link(link), longitudinal)

if (length(link) > 0) {
longitudinal <- enableLink(longitudinal)
Expand All @@ -81,7 +80,7 @@ JointModel <- function(
.x = base_model,
longitudinal = add_missing_stan_blocks(as.list(longitudinal)),
survival = add_missing_stan_blocks(as.list(survival)),
link = add_missing_stan_blocks(as.list(link, model = longitudinal)),
link = add_missing_stan_blocks(as.list(link)),
priors = add_missing_stan_blocks(as.list(parameters))
)
# Unresolved Jinja code within the longitudinal / Survival / Link
Expand Down
81 changes: 61 additions & 20 deletions R/Link.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#' @include LongitudinalModel.R
#' @include ParameterList.R
#' @include LinkComponent.R
#' @include Prior.R
NULL


Expand All @@ -19,21 +20,22 @@ NULL




#' `Link`
#'
#' @slot components (`list`)\cr a list of [`LinkComponent`] objects.
#' @slot components (`list`)\cr a list of [`LinkComponent`] or [`PromiseLinkComponent`] objects.
#' @slot resolved (`logical`)\cr indicates if all the `components` have been resolved.
#'
#' @param ... ([`LinkComponent`])\cr an arbitrary number of link components.
#' @param ... ([`LinkComponent`] or [`PromiseLinkComponent`])\cr
#' an arbitrary number of link components.
#'
#' @description
#' Simple container class to enable the use of multiple link components in a joint model.
#' Note that the constructor of this object is idempotent e.g. `Link(Link(x)) == Link(x)`
#'
#' @examples
#' Link(
#' link_dsld(),
#' link_ttg()
#' linkDSLD(),
#' linkTTG()
#' )
#'
#' @family Link
Expand All @@ -42,7 +44,8 @@ NULL
.Link <- setClass(
Class = "Link",
slots = list(
components = "list"
components = "list",
resolved = "logical"
)
)

Expand All @@ -51,26 +54,65 @@ NULL
Link <- function(...) {
components <- list(...)

# Enable copy constructor e.g. if passed a Link just return the Link
# If the input is already a Link object, return it (e.g. implement
# a constructor that is idempotent)
if (length(components) == 1 && is(components[[1]], "Link")) {
return(components[[1]])
}
.Link(components = components)

.Link(
components = components,
resolved = !any(vapply(components, \(x) is(x, "PromiseLinkComponent"), logical(1)))
)
}


#' Resolve any promises
#'
#' Loops over all components and ensures that any [`PromiseLinkComponent`] objects
#' are resolved to [`LinkComponent`] objects.
#'
#' @param object ([`Link`])\cr a link object.
#' @param model ([`LongitudinalModel`])\cr the model object.
#' @param ... Not Used.
#'
#' @export
resolvePromise.Link <- function(object, model, ...) {
if (length(object) == 0) {
return(object)
}
assert_that(
is(model, "LongitudinalModel"),
msg = "model must be of class `LongitudinalModel`"
)
do.call(Link, lapply(object@components, resolvePromise, model = model))
}


setValidity(
Class = "Link",
method = function(object) {
if (length(object@components) == 0) {
return(TRUE)
}
for (component in object@components) {
if (!is(component, "LinkComponent")) {
return("Link components must be of class `LinkComponent`.")

for (i in object@components) {
if (!(is(i, "LinkComponent") || is(i, "PromiseLinkComponent"))) {
return("All components must be of class `LinkComponent` or `PromiseLinkComponent`")
}
}

contains_promise <- any(
vapply(
object@components,
\(x) is(x, "PromiseLinkComponent"),
logical(1)
)
)
if (contains_promise & object@resolved) {
return("Object cannot be resolved if it contains promises")
}

if (length(object@resolved) > 1) {
return("The `resolved` slot must be a logical scalar")
}
return(TRUE)
}
)
Expand Down Expand Up @@ -108,8 +150,7 @@ as.StanModule.Link <- function(object, ...) {

stan_list <- lapply(
object@components,
as.StanModule,
...
as.StanModule
)

stan <- Reduce(
Expand Down Expand Up @@ -182,13 +223,13 @@ as_print_string.Link <- function(object, ...) {
if (length(object) == 0) {
return("\nNo Link")
}

strings <- vapply(object@components, as_print_string, character(1))

paste(
c(
"\nLink with the following components/parameters:",
paste0(
" ",
vapply(object@components, as_print_string, character(1))
)
paste0(" ", strings)
),
collapse = "\n"
)
Expand Down
Loading

0 comments on commit ba5a4a9

Please sign in to comment.