Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use cmdstanr generic #397

Merged
merged 8 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ Collate:
'jmpost-package.R'
'link_generics.R'
'settings.R'
'standalone-s3-register.R'
'zzz.R'
VignetteBuilder: knitr
RdMacros: Rdpack
2 changes: 0 additions & 2 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Generated by roxygen2: do not edit by hand

S3method(as.CmdStanMCMC,JointModelSamples)
S3method(as.QuantityCollapser,GridEven)
S3method(as.QuantityCollapser,GridEvent)
S3method(as.QuantityCollapser,GridFixed)
Expand Down Expand Up @@ -220,7 +219,6 @@ export(SurvivalLogLogistic)
export(SurvivalModel)
export(SurvivalQuantities)
export(SurvivalWeibullPH)
export(as.CmdStanMCMC)
export(as.QuantityCollapser)
export(as.QuantityGenerator)
export(as_formula)
Expand Down
9 changes: 4 additions & 5 deletions R/JointModelSamples.R
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ as.StanModule.JointModelSamples <- function(object, generator, type, ...) {
#' @export
as_print_string.JointModelSamples <- function(object, indent = 1, ...) {
sizes <- vapply(
as.CmdStanMCMC(object)$metadata()[["stan_variable_sizes"]],
cmdstanr::as.CmdStanMCMC(object)$metadata()[["stan_variable_sizes"]],
\(x) {
if (length(x) == 1 && x == 1) return("")
paste0("[", paste(x, collapse = ", "), "]")
Expand All @@ -107,7 +107,7 @@ as_print_string.JointModelSamples <- function(object, indent = 1, ...) {
)
variable_string <- paste0(
" ",
as.CmdStanMCMC(object)$metadata()[["stan_variables"]],
cmdstanr::as.CmdStanMCMC(object)$metadata()[["stan_variables"]],
sizes
)
template <- c(
Expand All @@ -123,8 +123,8 @@ as_print_string.JointModelSamples <- function(object, indent = 1, ...) {
template_padded <- paste(pad, template)
sprintf(
paste(template_padded, collapse = "\n"),
as.CmdStanMCMC(object)$metadata()$iter_sampling,
as.CmdStanMCMC(object)$num_chains()
cmdstanr::as.CmdStanMCMC(object)$metadata()$iter_sampling,
cmdstanr::as.CmdStanMCMC(object)$num_chains()
)
}

Expand All @@ -141,7 +141,6 @@ setMethod(


#' @rdname as.CmdStanMCMC
#' @export
as.CmdStanMCMC.JointModelSamples <- function(object, ...) {
return(object@results)
}
13 changes: 13 additions & 0 deletions R/external-exports.R
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,16 @@ NULL
#' @export autoplot
#' @family autoplot
NULL


#' Coerce to `CmdStanMCMC`
#'
#' @param object to be converted
#' @param ... additional options
#'
#' @description
#' Coerces an object to a [`cmdstanr::CmdStanMCMC`] object
#'
#' @name as.CmdStanMCMC
#'
NULL
14 changes: 0 additions & 14 deletions R/generics.R
Original file line number Diff line number Diff line change
Expand Up @@ -332,20 +332,6 @@ hazardWindows <- function(object, ...) {
UseMethod("hazardWindows")
}


#' Coerce to `CmdStanMCMC`
#'
#' @param object to be converted
#' @param ... additional options
#'
#' @description
#' Coerces an object to a [`cmdstanr::CmdStanMCMC`] object
#'
#' @export
as.CmdStanMCMC <- function(object, ...) {
UseMethod("as.CmdStanMCMC")
}

#' @rdname Quant-Dev
#' @export
as.QuantityGenerator <- function(object, ...) {
Expand Down
162 changes: 162 additions & 0 deletions R/standalone-s3-register.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@


# Original source repo = r-lib/rlang
# file: standalone-s3-register.R
# last-updated: 2024-05-14
# license: https://unlicense.org
# ---
#
# ## Changelog
#
# 2024-05-14:
#
# * Mentioned `usethis::use_standalone()`.
#
# nocov start

#' Register a method for a suggested dependency
#'
#' Generally, the recommended way to register an S3 method is to use the
#' `S3Method()` namespace directive (often generated automatically by the
#' `@export` roxygen2 tag). However, this technique requires that the generic
#' be in an imported package, and sometimes you want to suggest a package,
#' and only provide a method when that package is loaded. `s3_register()`
#' can be called from your package's `.onLoad()` to dynamically register
#' a method only if the generic's package is loaded.
#'
#' For R 3.5.0 and later, `s3_register()` is also useful when demonstrating
#' class creation in a vignette, since method lookup no longer always involves
#' the lexical scope. For R 3.6.0 and later, you can achieve a similar effect
#' by using "delayed method registration", i.e. placing the following in your
#' `NAMESPACE` file:
#'
#' ```
#' if (getRversion() >= "3.6.0") {
#' S3method(package::generic, class)
#' }
#' ```
#'
#' @section Usage in other packages:
#' To avoid taking a dependency on rlang, you copy the source of
#' [`s3_register()`](https://github.com/r-lib/rlang/blob/main/R/standalone-s3-register.R)
#' into your own package or with
#' `usethis::use_standalone("r-lib/rlang", "s3-register")`. It is licensed under
#' the permissive [unlicense](https://choosealicense.com/licenses/unlicense/) to
#' make it crystal clear that we're happy for you to do this. There's no need to
#' include the license or even credit us when using this function.
#'
#' @param generic Name of the generic in the form `"pkg::generic"`.
#' @param class Name of the class
#' @param method Optionally, the implementation of the method. By default,
#' this will be found by looking for a function called `generic.class`
#' in the package environment.
#' @examples
#' # A typical use case is to dynamically register tibble/pillar methods
#' # for your class. That way you avoid creating a hard dependency on packages
#' # that are not essential, while still providing finer control over
#' # printing when they are used.
#'
#' .onLoad <- function(...) {
#' s3_register("pillar::pillar_shaft", "vctrs_vctr")
#' s3_register("tibble::type_sum", "vctrs_vctr")
#' }
#' @keywords internal
#' @noRd
s3_register <- function(generic, class, method = NULL) {
stopifnot(is.character(generic), length(generic) == 1)
stopifnot(is.character(class), length(class) == 1)

pieces <- strsplit(generic, "::")[[1]]
stopifnot(length(pieces) == 2)
package <- pieces[[1]]
generic <- pieces[[2]]

caller <- parent.frame()

get_method_env <- function() {
top <- topenv(caller)
if (isNamespace(top)) {
asNamespace(environmentName(top))
} else {
caller
}
}
get_method <- function(method) {
if (is.null(method)) {
get(paste0(generic, ".", class), envir = get_method_env())
} else {
method
}
}

register <- function(...) {
envir <- asNamespace(package)

# Refresh the method each time, it might have been updated by devtools
method_fn <- get_method(method)
stopifnot(is.function(method_fn))


# Only register if generic can be accessed
if (exists(generic, envir)) {
registerS3method(generic, class, method_fn, envir = envir)
} else if (identical(Sys.getenv("NOT_CRAN"), "true")) {
warn <- .rlang_s3_register_compat("warn")

warn(c(
sprintf(
"Can't find generic `%s` in package %s to register S3 method.",
generic,
package
),
"i" = "This message is only shown to developers using devtools.",
"i" = sprintf("Do you need to update %s to the latest version?", package)
))
}
}

# Always register hook in case package is later unloaded & reloaded
setHook(packageEvent(package, "onLoad"), function(...) {
register()
})

# For compatibility with R < 4.1.0 where base isn't locked
is_sealed <- function(pkg) {
identical(pkg, "base") || environmentIsLocked(asNamespace(pkg))
}

# Avoid registration failures during loading (pkgload or regular).
# Check that environment is locked because the registering package
# might be a dependency of the package that exports the generic. In
# that case, the exports (and the generic) might not be populated
# yet (#1225).
if (isNamespaceLoaded(package) && is_sealed(package)) {
register()
}

invisible()
}

.rlang_s3_register_compat <- function(fn) {
# Compats that behave the same independently of rlang's presence
out <- switch(fn,
is_installed = return(function(pkg) requireNamespace(pkg, quietly = TRUE))
)

is_interactive_compat <- function() {
interactive()
}

format_msg <- function(x) paste(x, collapse = "\n")
switch(
fn,
is_interactive = return(is_interactive_compat),
abort = return(function(msg) stop(format_msg(msg), call. = FALSE)),
warn = return(function(msg) warning(format_msg(msg), call. = FALSE)),
inform = return(function(msg) message(format_msg(msg)))
)

stop(sprintf("Internal error in rlang shims: Unknown function `%s()`.", fn))
}

# nocov end
8 changes: 4 additions & 4 deletions R/zzz.R
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@

.onLoad <- function(libname, pkgname) {
set_options()
}

.onAttach <- function(libname, pkgname) {
if (!is_cmdstanr_available()) {
packageStartupMessage(
Expand Down Expand Up @@ -38,6 +34,10 @@
return(invisible(NULL))
}

.onLoad <- function(...) {
set_options()
s3_register("cmdstanr::as.CmdStanMCMC", "JointModelSamples")
}

# This only exists to silence the false positive R CMD CHECK warning about
# importing but not using the posterior package. posterior is a dependency
Expand Down
10 changes: 4 additions & 6 deletions man/as.CmdStanMCMC.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions tests/testthat/test-Grid.R
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ test_that("GridObservered + Constructs correct quantities", {



pred_mat <- as.CmdStanMCMC(fixtures_gsf$mp)$draws("Ypred", format = "draws_matrix")
pred_mat <- cmdstanr::as.CmdStanMCMC(fixtures_gsf$mp)$draws("Ypred", format = "draws_matrix")

fdat <- fixtures_gsf$dat_lm |>
dplyr::arrange(subject, time, sld) |>
Expand Down Expand Up @@ -512,7 +512,7 @@ test_that("GridObservered + Constructs correct quantities", {
#
design <- model.matrix(~ cov_cat + cov_cont, data = fixtures_gsf$dat_os)

beta_coefs <- as.CmdStanMCMC(fixtures_gsf$mp)$draws(
beta_coefs <- cmdstanr::as.CmdStanMCMC(fixtures_gsf$mp)$draws(
c("sm_exp_lambda", "beta_os_cov"),
format = "draws_matrix"
)
Expand Down Expand Up @@ -636,7 +636,7 @@ test_that("GridPopulation() works as expected for GSF models", {
b * (phi * exp(-s * time) + (1 - phi) * exp(g * time))
}

samples_df <- as.CmdStanMCMC(fixtures_gsf$mp)$draws(
samples_df <- cmdstanr::as.CmdStanMCMC(fixtures_gsf$mp)$draws(
c("lm_gsf_mu_ks", "lm_gsf_mu_kg", "lm_gsf_mu_bsld", "lm_gsf_mu_phi"),
format = "draws_df"
) |>
Expand Down Expand Up @@ -705,7 +705,7 @@ test_that("GridPopulation() works as expected for Longitudinal models", {
#
# Derive values by hand
#
samples_df <- as.CmdStanMCMC(fixtures_rs$mp)$draws(
samples_df <- cmdstanr::as.CmdStanMCMC(fixtures_rs$mp)$draws(
c("lm_rs_intercept", "lm_rs_slope_mu"),
format = "draws_df"
) |>
Expand Down
4 changes: 2 additions & 2 deletions tests/testthat/test-GridPrediction.R
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ test_that("GridPrediction() works as expected for Survival models", {
#
# Derive values by hand
#
samples_df <- as.CmdStanMCMC(fixtures_gsf_link$mp)$draws(
samples_df <- cmdstanr::as.CmdStanMCMC(fixtures_gsf_link$mp)$draws(
c("beta_os_cov", "link_dsld", "link_ttg", "sm_exp_lambda"),
format = "draws_df"
) |>
Expand Down Expand Up @@ -399,7 +399,7 @@ test_that("GridPrediction() works for survival only models", {
dplyr::as_tibble()

# Calculate expected values by hand
pars_dat <- as.CmdStanMCMC(fixtures_weibull_only$mp)$draws(
pars_dat <- cmdstanr::as.CmdStanMCMC(fixtures_weibull_only$mp)$draws(
c("beta_os_cov", "sm_weibull_ph_lambda", "sm_weibull_ph_gamma"),
format = "draws_df"
) |>
Expand Down
6 changes: 3 additions & 3 deletions tests/testthat/test-LongitudinalClaretBruno.R
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ test_that("Can recover known distributional parameters from a SF joint model", {
}

dat <- summary_post(
as.CmdStanMCMC(mp),
cmdstanr::as.CmdStanMCMC(mp),
c("lm_clbr_mu_b", "lm_clbr_mu_g", "lm_clbr_mu_c", "lm_clbr_mu_p"),
TRUE
)
Expand All @@ -257,7 +257,7 @@ test_that("Can recover known distributional parameters from a SF joint model", {


dat <- summary_post(
as.CmdStanMCMC(mp),
cmdstanr::as.CmdStanMCMC(mp),
c("beta_os_cov", "sm_exp_lambda", "link_dsld", "link_growth", "link_ttg")
)
true_values <- c(
Expand Down Expand Up @@ -496,7 +496,7 @@ test_that("Can recover known distributional parameters from unscaled variance Cl
}

dat <- summary_post(
as.CmdStanMCMC(mp),
cmdstanr::as.CmdStanMCMC(mp),
c(
"lm_clbr_mu_b", "lm_clbr_mu_g", "lm_clbr_mu_c", "lm_clbr_mu_p",
"lm_clbr_omega_b", "lm_clbr_omega_g", "lm_clbr_omega_c", "lm_clbr_omega_p",
Expand Down
Loading
Loading