Skip to content

Commit c42b49f

Browse files
committed
cmdstanr support & target += robustness
1 parent 7a5af8d commit c42b49f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

76 files changed

+311
-251
lines changed

.Rbuildignore

100644100755
File mode changed.

.gitignore

100644100755
File mode changed.

.travis.yml

100644100755
File mode changed.

DESCRIPTION

100644100755
+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ Package: adjustr
22
Encoding: UTF-8
33
Type: Package
44
Title: Stan Model Adjustments and Sensitivity Analyses using Importance Sampling
5-
Version: 0.1.1
5+
Version: 0.1.2
66
Authors@R: person("Cory", "McCartan", email = "[email protected]",
77
role = c("aut", "cre"))
88
Description: Functions to help assess the sensitivity of a Bayesian model

LICENSE

100644100755
File mode changed.

NAMESPACE

100644100755
File mode changed.

NEWS.md

100644100755
File mode changed.

R/adjust_weights.R

100644100755
+19-9
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@
1919
#' \code{\link{make_spec}}, containing the new sampling sampling statements
2020
#' to replace their counterparts in the original Stan model, and the data,
2121
#' if any, by which these sampling statements are parametrized.
22-
#' @param object A \code{\link[rstan]{stanfit}} model object.
22+
#' @param object A model object, either of type \code{\link[rstan]{stanfit}},
23+
#' \code{\link[rstanarm]{stanreg}}, \code{\link[brms]{brmsfit}}, or
24+
#' a list with two elements: \code{model} containing a
25+
#' \code{\link[cmdstanr]{CmdStanModel}}, and \code{fit} containing a
26+
#' \code{\link[cmdstanr]{CmdStanMCMC}} object.
2327
#' @param data The data that was used to fit the model in \code{object}.
2428
#' Required only if one of the new sampling specifications involves Stan data
2529
#' variables.
@@ -71,14 +75,14 @@ adjust_weights = function(spec, object, data=NULL, keep_bad=FALSE, incl_orig=TRU
7175
if (is.null(data) & is(object, "brmsfit"))
7276
data = object$data
7377
object = get_fit_obj(object)
74-
model_code = object@stanmodel@model_code
7578
stopifnot(is.adjustr_spec(spec))
7679

77-
parsed = parse_model(model_code)
80+
parsed = parse_model(object@stanmodel@model_code)
7881

7982
# if no model data provided, we can only resample distributions of parameters
8083
if (is.null(data)) {
81-
samp_vars = map_chr(parsed$samp, ~ as.character(f_lhs(.)))
84+
samp_vars = map(parsed$samp, ~ as.character(f_lhs(.))) %>%
85+
purrr::as_vector()
8286
prior_vars = parsed$vars[samp_vars] != "data"
8387
parsed$samp = parsed$samp[prior_vars]
8488
data = list()
@@ -170,11 +174,11 @@ pull.adjustr_weighted = function(.data, var=".weights", name=NULL, ...) {
170174
#' @export
171175
extract_samp_stmts = function(object) {
172176
object = get_fit_obj(object)
173-
model_code = object@stanmodel@model_code
174177

175-
parsed = parse_model(model_code)
178+
parsed = parse_model(object@stanmodel@model_code)
176179

177-
samp_vars = map_chr(parsed$samp, ~ as.character(f_lhs(.)))
180+
samp_vars = map(parsed$samp, ~ as.character(f_lhs(.))) %>%
181+
purrr::as_vector()
178182
type = map_chr(samp_vars, function(var) {
179183
if (stringr::str_ends(parsed$vars[var], "data")) "data" else "parameter"
180184
})
@@ -185,15 +189,21 @@ extract_samp_stmts = function(object) {
185189
invisible(parsed$samp)
186190
}
187191

188-
# Check that the model object is correct, and extract its Stan code
192+
# Check that the model object is correct, and put it into a convenient format
189193
get_fit_obj = function(object) {
190194
if (is(object, "stanfit")) {
191195
object
192196
} else if (is(object, "stanreg")) {
193197
object$stanfit
194198
} else if (is(object, "brmsfit")) {
195199
object$fit
200+
} else if (is(object, "list") && all(c("fit", "model") %in% names(object))
201+
&& is(object$fit, "CmdStanMCMC") && is(object$model, "CmdStanModel")) {
202+
out = rstan::read_stan_csv(object$fit$output_files())
203+
out@stanmodel@model_code = paste0(object$model$code(), collapse="\n")
204+
out
196205
} else {
197-
stop("`object` must be of class `stanfit`, `stanreg`, or `brmsfit`.")
206+
stop("`object` must be of class `stanfit`, `stanreg`, `brmsfit`, or ",
207+
"a list with `CmdStanModel` and `CmdStanMCMC` objects.")
198208
}
199209
}

R/adjustr-package.R

100644100755
File mode changed.

R/logprob.R

100644100755
File mode changed.

R/make_spec.R

100644100755
File mode changed.

R/parsing.R

100644100755
+6-3
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,8 @@ parse_model = function(model_code) {
6161
samps = flatten(samps)
6262

6363
parameters = names(vars)[vars == "parameters"]
64-
sampled_pars = map_chr(samps, ~ as.character(f_lhs(.)))
64+
sampled_pars = map(samps, ~ as.character(f_lhs(.))) %>%
65+
purrr::as_vector()
6566
uniform_pars = setdiff(parameters, sampled_pars)
6667
uniform_samp = paste0(uniform_pars, " ~ uniform(-1e100, 1e100)")
6768
uniform_samp = map(uniform_samp, ~ stats::as.formula(., env=empty_env()))
@@ -73,8 +74,10 @@ parse_model = function(model_code) {
7374
# Take a list of provided sampling formulas and return a matching list of
7475
# sampling statements from a reference list
7576
match_sampling_stmts = function(new_samp, ref_samp) {
76-
ref_vars = map_chr(ref_samp, ~ as.character(f_lhs(.)))
77-
new_vars = map_chr(new_samp, ~ as.character(f_lhs(.)))
77+
ref_vars = map(ref_samp, ~ as.character(f_lhs(.))) %>%
78+
purrr::as_vector()
79+
new_vars = map(new_samp, ~ as.character(f_lhs(.))) %>%
80+
purrr::as_vector()
7881
indices = match(new_vars, ref_vars)
7982
# check that every prior was matched
8083
if (any(is.na(indices))) {

R/use_weights.R

100644100755
File mode changed.

README.md

100644100755
File mode changed.

_pkgdown.yml

100644100755
File mode changed.

adjustr.Rproj

100644100755
File mode changed.

codecov.yml

100644100755
File mode changed.

docs/404.html

100644100755
+5-5
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

docs/LICENSE-text.html

100644100755
+5-5
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)