Skip to content

Commit d24848b

Browse files
committed
vignette + improved parser + wasserstein + more
1 parent 86859a5 commit d24848b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+1186
-470
lines changed

.Rbuildignore

+2
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,5 @@
77
^pkgdown$
88
^codecov\.yml$
99
^\.travis\.yml$
10+
^doc$
11+
^Meta$

.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,5 @@
88

99
.Rproj.user
1010
inst/doc
11+
doc
12+
Meta

DESCRIPTION

+5-6
Original file line numberDiff line numberDiff line change
@@ -2,28 +2,27 @@ Package: adjustr
22
Encoding: UTF-8
33
Type: Package
44
Title: Stan Model Adjustments and Sensitivity Analyses using Importance Sampling
5-
Version: 0.0.0.9000
5+
Version: 0.1.0
66
Authors@R: person("Cory", "McCartan", email = "[email protected]",
77
role = c("aut", "cre"))
88
Description: Functions to help assess the sensitivity of a Bayesian model
99
(fitted using the rstan pakcage) to the specification of its likelihood and
10-
priors.Users provide a series of alternate sampling specifications, and the
10+
priors. Users provide a series of alternate sampling specifications, and the
1111
package uses Pareto-smoothed importance sampling to estimate posterior
1212
quantities of interest under each specification.
1313
License: BSD_3_clause + file LICENSE
1414
Depends: R (>= 3.6.0)
1515
Imports:
1616
tibble,
1717
tidyselect,
18-
dplyr,
18+
dplyr (>= 1.0.0),
1919
purrr,
20+
stringr,
2021
methods,
2122
utils,
2223
stats,
2324
rlang,
2425
rstan,
25-
stringr,
26-
dparser,
2726
ggplot2,
2827
loo
2928
Suggests:
@@ -35,5 +34,5 @@ Suggests:
3534
rmarkdown
3635
URL: https://corymccartan.github.io/adjustr/
3736
LazyData: true
38-
RoxygenNote: 7.1.0
37+
RoxygenNote: 7.1.1
3938
VignetteBuilder: knitr

NAMESPACE

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
S3method(arrange,adjustr_spec)
44
S3method(as.data.frame,adjustr_spec)
55
S3method(length,adjustr_spec)
6-
S3method(plot,adjustr_weighted)
76
S3method(print,adjustr_spec)
87
S3method(pull,adjustr_weighted)
98
S3method(rename,adjustr_spec)
@@ -15,6 +14,7 @@ export(adjust_weights)
1514
export(extract_samp_stmts)
1615
export(get_resampling_idxs)
1716
export(make_spec)
17+
export(spec_plot)
1818
import(dplyr)
1919
import(ggplot2)
2020
import(rlang)

NEWS.md

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# adjustr 0.1.0
2+
3+
* Initial release.
4+
5+
* Basic workflow implemented: `make_spec()`, `adjust_weights()`, and `summarize()`/`spec_plot()`.

R/adjust_weights.R

+28-20
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,19 @@
1818
#' posterior, and which as a result cannot be reliably estimated using
1919
#' importance sampling (i.e., if the Pareto shape parameter is larger than
2020
#' 0.7), have their weights discarded.
21+
#' @param incl_orig When \code{TRUE}, include a row for the original
22+
#' model specification, with all weights equal. Can facilitate comaprison
23+
#' and plotting later.
2124
#'
2225
#' @return A tibble, produced by converting the provided \code{specs} to a
2326
#' tibble (see \code{\link{as.data.frame.adjustr_spec}}), and adding columns
2427
#' \code{.weights}, containing vectors of weights for each draw, and
2528
#' \code{.pareto_k}, containing the diagnostic Pareto shape parameters. Values
2629
#' greater than 0.7 indicate that importance sampling is not reliable.
27-
#' Weights can be extracted with the \code{\link{pull.adjustr_weighted}}
28-
#' method. The returned object also includes the model sample draws, in the
29-
#' \code{draws} attribute.
30+
#' If \code{incl_orig} is \code{TRUE}, a row is added for the original model
31+
#' specification. Weights can be extracted with the
32+
#' \code{\link{pull.adjustr_weighted}} method. The returned object also
33+
#' includes the model sample draws, in the \code{draws} attribute.
3034
#'
3135
#' @examples \dontrun{
3236
#' model_data = list(
@@ -46,27 +50,25 @@
4650
#' }
4751
#'
4852
#' @export
49-
adjust_weights = function(spec, object, data=NULL, keep_bad=FALSE) {
53+
adjust_weights = function(spec, object, data=NULL, keep_bad=FALSE, incl_orig=TRUE) {
5054
# CHECK ARGUMENTS
5155
object = get_fit_obj(object)
5256
model_code = object@stanmodel@model_code
5357
stopifnot(is.adjustr_spec(spec))
5458

55-
parsed_model = parse_model(model_code)
56-
parsed_vars = get_variables(parsed_model)
57-
parsed_samp = get_sampling_stmts(parsed_model)
59+
parsed = parse_model(model_code)
5860

5961
# if no model data provided, we can only resample distributions of parameters
6062
if (is.null(data)) {
61-
samp_vars = map_chr(parsed_samp, ~ as.character(f_lhs(.)))
62-
prior_vars = parsed_vars[samp_vars] != "data"
63-
parsed_samp = parsed_samp[prior_vars]
63+
samp_vars = map_chr(parsed$samp, ~ as.character(f_lhs(.)))
64+
prior_vars = parsed$vars[samp_vars] != "data"
65+
parsed$samp = parsed$samp[prior_vars]
6466
data = list()
6567
}
6668

67-
matched_samp = match_sampling_stmts(spec$samp, parsed_samp)
68-
original_lp = calc_original_lp(object, matched_samp, parsed_vars, data)
69-
specs_lp = calc_specs_lp(object, spec$samp, parsed_vars, data, spec$params)
69+
matched_samp = match_sampling_stmts(spec$samp, parsed$samp)
70+
original_lp = calc_original_lp(object, matched_samp, parsed$vars, data)
71+
specs_lp = calc_specs_lp(object, spec$samp, parsed$vars, data, spec$params)
7072

7173
# compute weights
7274
wgts = map(specs_lp, function(spec_lp) {
@@ -95,6 +97,14 @@ adjust_weights = function(spec, object, data=NULL, keep_bad=FALSE) {
9597
attr(adjust_obj, "draws") = rstan::extract(object)
9698
attr(adjust_obj, "data") = data
9799
attr(adjust_obj, "iter") = object@sim$chains * (object@sim$iter - object@sim$warmup)
100+
if (incl_orig) {
101+
adjust_obj = bind_rows(adjust_obj, tibble(
102+
.weights=list(rep(1, attr(adjust_obj, "iter"))),
103+
.pareto_k = -Inf))
104+
samp_cols = stringr::str_detect(names(adjust_obj), "\\.samp")
105+
adjust_obj[nrow(adjust_obj), samp_cols] = "<original model>"
106+
}
107+
98108
adjust_obj
99109
}
100110

@@ -141,19 +151,17 @@ pull.adjustr_weighted = function(.data, var=".weights") {
141151
extract_samp_stmts = function(object) {
142152
model_code = get_fit_obj(object)@stanmodel@model_code
143153

144-
parsed_model = parse_model(model_code)
145-
parsed_vars = get_variables(parsed_model)
146-
parsed_samp = get_sampling_stmts(parsed_model)
154+
parsed = parse_model(model_code)
147155

148-
samp_vars = map_chr(parsed_samp, ~ as.character(f_lhs(.)))
156+
samp_vars = map_chr(parsed$samp, ~ as.character(f_lhs(.)))
149157
type = map_chr(samp_vars, function(var) {
150-
if (stringr::str_ends(parsed_vars[var], "data")) "data" else "parameter"
158+
if (stringr::str_ends(parsed$vars[var], "data")) "data" else "parameter"
151159
})
152160
print_order = order(type, samp_vars, decreasing=c(T, F))
153161

154162
cat(paste0("Sampling statements for model ", object@model_name, ":\n"))
155-
purrr::walk(print_order, ~ cat(sprintf(" %-9s %s\n", type[.], as.character(parsed_samp[.]))))
156-
invisible(parsed_samp)
163+
purrr::walk(print_order, ~ cat(sprintf(" %-9s %s\n", type[.], as.character(parsed$samp[.]))))
164+
invisible(parsed$samp)
157165
}
158166

159167
# Check that the model object is correct, and extract its Stan code

R/adjustr-package.R

+2-2
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ pkg_env = new_environment()
3434

3535
.onLoad = function(libname, pkgname) { # nocov start
3636
# create the Stan parser
37-
tryCatch(get_parser(), error = function(e) {})
37+
#tryCatch(get_parser(), error = function(e) {})
3838

3939
utils::globalVariables(c("name", "pos", "value", ".y", ".y_ol", ".y_ou",
40-
".y_il", ".y_iu", ".y_med"))
40+
".y_il", ".y_iu", ".y_med", "quantile", "median"))
4141
} # nocov end
4242
#> NULL

R/make_spec.R

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
#' frame, each entry in each column will be substituted into the corresponding
2727
#' parameter in the sampling statements.
2828
#'
29-
#' List arguments are coerced to data frame. They can either be lists of named
29+
#' List arguments are coerced to data frames. They can either be lists of named
3030
#' vectors, or lists of lists of single-element named vector.
3131
#'
3232
#' The lengths of all parameter arguments must be consistent. Named vectors

R/mockup.R

-111
This file was deleted.

0 commit comments

Comments
 (0)