Skip to content

Commit 86859a5

Browse files
committed
all but vignette
1 parent 9d60a4b commit 86859a5

40 files changed

+1160
-70
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@
77
*.bak
88

99
.Rproj.user
10+
inst/doc

DESCRIPTION

+10-4
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
Package: adjustr
2+
Encoding: UTF-8
3+
Type: Package
24
Title: Stan Model Adjustments and Sensitivity Analyses using Importance Sampling
35
Version: 0.0.0.9000
46
Authors@R: person("Cory", "McCartan", email = "[email protected]",
@@ -12,22 +14,26 @@ License: BSD_3_clause + file LICENSE
1214
Depends: R (>= 3.6.0)
1315
Imports:
1416
tibble,
17+
tidyselect,
1518
dplyr,
1619
purrr,
1720
methods,
1821
utils,
1922
stats,
2023
rlang,
2124
rstan,
22-
ggplot2,
2325
stringr,
2426
dparser,
27+
ggplot2,
2528
loo
2629
Suggests:
27-
tidyr,
2830
extraDistr,
31+
tidyr,
2932
testthat,
30-
covr
31-
Encoding: UTF-8
33+
covr,
34+
knitr,
35+
rmarkdown
36+
URL: https://corymccartan.github.io/adjustr/
3237
LazyData: true
3338
RoxygenNote: 7.1.0
39+
VignetteBuilder: knitr

NAMESPACE

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
S3method(arrange,adjustr_spec)
44
S3method(as.data.frame,adjustr_spec)
55
S3method(length,adjustr_spec)
6+
S3method(plot,adjustr_weighted)
67
S3method(print,adjustr_spec)
78
S3method(pull,adjustr_weighted)
89
S3method(rename,adjustr_spec)

R/adjust_weights.R

+29-5
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,25 @@
2828
#' method. The returned object also includes the model sample draws, in the
2929
#' \code{draws} attribute.
3030
#'
31+
#' @examples \dontrun{
32+
#' model_data = list(
33+
#' J = 8,
34+
#' y = c(28, 8, -3, 7, -1, 1, 18, 12),
35+
#' sigma = c(15, 10, 16, 11, 9, 11, 10, 18)
36+
#' )
37+
#'
38+
#' spec = make_spec(eta ~ student_t(df, 0, 1), df=1:10)
39+
#' adjust_weights(spec, eightschools_m)
40+
#' adjust_weights(spec, eightschools_m, keep_bad=TRUE)
41+
#'
42+
#' spec = make_spec(y ~ student_t(df, theta, sigma), df=1:10)
43+
#' adjust_weights(spec, eightschools_m, data=model_data)
44+
#' # will throw an error because `y` and `sigma` aren't provided
45+
#' adjust_weights(spec, eightschools_m)
46+
#' }
47+
#'
3148
#' @export
32-
adjust_weights = function(spec, object, data=NULL, keep_bad=F) {
49+
adjust_weights = function(spec, object, data=NULL, keep_bad=FALSE) {
3350
# CHECK ARGUMENTS
3451
object = get_fit_obj(object)
3552
model_code = object@stanmodel@model_code
@@ -59,12 +76,12 @@ adjust_weights = function(spec, object, data=NULL, keep_bad=F) {
5976
psis_wgt = suppressWarnings(loo::psis(lratio, r_eff=r_eff))
6077
pareto_k = loo::pareto_k_values(psis_wgt)
6178
if (all(psis_wgt$log_weights == psis_wgt$log_weights[1])) {
62-
warning("New specification equal to old specification.", call.=F)
79+
warning("New specification equal to old specification.", call.=FALSE)
6380
pareto_k = -Inf
6481
}
6582

6683
list(
67-
weights = loo::weights.importance_sampling(psis_wgt, log=F),
84+
weights = loo::weights.importance_sampling(psis_wgt, log=FALSE),
6885
pareto_k = pareto_k
6986
)
7087
})
@@ -84,7 +101,7 @@ adjust_weights = function(spec, object, data=NULL, keep_bad=F) {
84101

85102
# Generic methods
86103
is.adjustr_weighted = function(x) inherits(x, "adjustr_weighted")
87-
#' Extract Weights From an \code{adjustr_spec_weighted} Object
104+
#' Extract Weights From an \code{adjustr_weighted} Object
88105
#'
89106
#' This function modifies the default behavior of \code{dplyr::pull} to extract
90107
#' the \code{.weights} column.
@@ -96,10 +113,11 @@ is.adjustr_weighted = function(x) inherits(x, "adjustr_weighted")
96113
#'
97114
#' @export
98115
pull.adjustr_weighted = function(.data, var=".weights") {
116+
var = tidyselect::vars_pull(names(.data), !!enquo(var))
99117
if (nrow(.data) == 1 && var == ".weights") {
100118
.data$.weights[[1]]
101119
} else {
102-
NextMethod(.data, var=var)
120+
.data[[var]]
103121
}
104122
}
105123

@@ -113,6 +131,12 @@ pull.adjustr_weighted = function(.data, var=".weights") {
113131
#'
114132
#' @return Invisbly returns a list of sampling formulas.
115133
#'
134+
#' @examples \dontrun{
135+
#' extract_samp_stmts(eightschools_m)
136+
#' #> Sampling statements for model 2c8d1d8a30137533422c438f23b83428:
137+
#' #> parameter eta ~ std_normal()
138+
#' #> data y ~ normal(theta, sigma)
139+
#' }
116140
#' @export
117141
extract_samp_stmts = function(object) {
118142
model_code = get_fit_obj(object)@stanmodel@model_code

R/adjustr-package.R

+3-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#' \item \code{\link{make_spec}}
1717
#' \item \code{\link{adjust_weights}}
1818
#' \item \code{\link{summarize}}
19+
#' \item \code{\link{plot}}
1920
#' }
2021
#'
2122
#' @importFrom methods is
@@ -35,6 +36,7 @@ pkg_env = new_environment()
3536
# create the Stan parser
3637
tryCatch(get_parser(), error = function(e) {})
3738

38-
utils::globalVariables(c("name", "pos", "value"))
39+
utils::globalVariables(c("name", "pos", "value", ".y", ".y_ol", ".y_ou",
40+
".y_il", ".y_iu", ".y_med"))
3941
} # nocov end
4042
#> NULL

R/logprob.R

+1-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ get_base_data = function(object, samps, parsed_vars, data, extra_names=NULL) {
5151
if (!all(found)) stop(paste(vars_indata[!found], collapse=", "), " not found")
5252
# combine draws and data
5353
base_data = append(
54-
map(vars_indraws, ~ rstan::extract(object, ., permuted=F)) %>%
54+
map(vars_indraws, ~ rstan::extract(object, ., permuted=FALSE)) %>%
5555
set_names(vars_indraws),
5656
map(vars_indata, ~ reshape_data(data[[.]])) %>%
5757
set_names(vars_indata),

R/make_spec.R

+20-2
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,16 @@
4242
#' \code{\link[dplyr]{rename}}, and \code{\link[dplyr]{slice}}) are
4343
#' supported and operate on the underlying table of specification parameters.
4444
#'
45+
#' @examples
46+
#' make_spec(eta ~ cauchy(0, 1))
47+
#'
48+
#' make_spec(eta ~ student_t(df, 0, 1), df=1:10)
49+
#'
50+
#' params = tidyr::crossing(df=1:10, infl=c(1, 1.5, 2))
51+
#' make_spec(eta ~ student_t(df, 0, 1),
52+
#' y ~ normal(theta, infl*sigma),
53+
#' params)
54+
#'
4555
#' @export
4656
make_spec = function(...) {
4757
args = dots_list(..., .check_assign=T)
@@ -151,6 +161,14 @@ as.data.frame.adjustr_spec = function(x, ...) {
151161
#' @param ... additional arguments to underlying method
152162
#' @param .preserve as in \code{filter} and \code{slice}
153163
#' @name dplyr.adjustr_spec
164+
#'
165+
#' @examples \dontrun{
166+
#' spec = make_spec(eta ~ student_t(df, 0, 1), df=1:10)
167+
#'
168+
#' arrange(spec, desc(df))
169+
#' slice(spec, 4:7)
170+
#' filter(spec, df == 2)
171+
#' }
154172
NULL
155173
# dplyr generics
156174
dplyr_handler = function(dplyr_func, x, ...) {
@@ -165,7 +183,7 @@ dplyr_handler = function(dplyr_func, x, ...) {
165183

166184
# no @export because R CMD CHECK didn't like it
167185
#' @rdname dplyr.adjustr_spec
168-
filter.adjustr_spec = function(.data, ..., .preserve=F) {
186+
filter.adjustr_spec = function(.data, ..., .preserve=FALSE) {
169187
dplyr_handler(dplyr::filter, .data, ..., .preserve=.preserve)
170188
}
171189
#' @rdname dplyr.adjustr_spec
@@ -185,7 +203,7 @@ select.adjustr_spec = function(.data, ...) {
185203
}
186204
#' @rdname dplyr.adjustr_spec
187205
#' @export
188-
slice.adjustr_spec = function(.data, ..., .preserve=F) {
206+
slice.adjustr_spec = function(.data, ..., .preserve=FALSE) {
189207
dplyr_handler(dplyr::slice, .data, ..., .preserve=.preserve)
190208
}
191209

R/use_weights.R

+101-3
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,14 @@
1717
#' containing the sampled indices. If any weights are \code{NA}, the indices
1818
#' will also be \code{NA}.
1919
#'
20+
#' @examples \dontrun{
21+
#' spec = make_spec(eta ~ student_t(df, 0, 1), df=1:10)
22+
#' adjusted = adjust_weights(spec, eightschools_m)
23+
#'
24+
#' get_resampling_idxs(adjusted)
25+
#' get_resampling_idxs(adjusted, frac=0.1, replace=FALSE)
26+
#' }
27+
#'
2028
#' @export
2129
get_resampling_idxs = function(x, frac=1, replace=T) {
2230
if (frac < 0) stop("`frac` parameter must be nonnegative")
@@ -48,7 +56,10 @@ get_resampling_idxs = function(x, frac=1, replace=T) {
4856
#' posterior distribution of eight alternative specification. For example,
4957
#' a value of \code{mean(theta)} will compute the posterior mean of
5058
#' \code{theta} for each alternative specification.
51-
#' @param .resampling Wether to compute summary statistics by first resampling
59+
#'
60+
#' The arguments in \code{...} are automatically quoted and evaluated in the
61+
#' context of \code{.data}. They support unquoting and splicing.
62+
#' @param .resampling Whether to compute summary statistics by first resampling
5263
#' the data according to the weights. Defaults to \code{FALSE}, but will be
5364
#' used for any summary statistic that is not \code{mean}, \code{var} or
5465
#' \code{sd}.
@@ -58,9 +69,24 @@ get_resampling_idxs = function(x, frac=1, replace=T) {
5869
#' @return An \code{adjustr_weighted} object, wth the new columns specified in
5970
#' \code{...} added.
6071
#'
72+
#' @examples \dontrun{
73+
#' model_data = list(
74+
#' J = 8,
75+
#' y = c(28, 8, -3, 7, -1, 1, 18, 12),
76+
#' sigma = c(15, 10, 16, 11, 9, 11, 10, 18)
77+
#' )
78+
#'
79+
#' spec = make_spec(eta ~ student_t(df, 0, 1), df=1:10)
80+
#' adjusted = adjust_weights(spec, eightschools_m)
81+
#'
82+
#' summarize(adjusted, mean(mu), var(mu))
83+
#' summarize(adjusted, diff_1 = mean(y[1] - theta[1]), .model_data=model_data)
84+
#' summarize(adjusted, quantile(tau, probs=c(0.05, 0.5, 0.95)))
85+
#' }
86+
#'
6187
#' @rdname summarize.adjustr_weighted
6288
#' @export
63-
summarise.adjustr_weighted = function(.data, ..., .resampling=F, .model_data=NULL) {
89+
summarise.adjustr_weighted = function(.data, ..., .resampling=FALSE, .model_data=NULL) {
6490
stopifnot(is.adjustr_weighted(.data)) # just in case called manually
6591
args = enexprs(...)
6692

@@ -89,7 +115,7 @@ summarise.adjustr_weighted = function(.data, ..., .resampling=F, .model_data=NUL
89115
expr = expr_deparse(call_args(call)[[1]])
90116
expr = stringr::str_replace_all(expr, "\\[(\\d)", "[,\\1")
91117
expr = stringr::str_replace_all(expr, "(?<![a-zA-Z0-9._])mean\\(", "rowMeans(")
92-
expr = stringr::str_replace_all(expr, "(?<![a-zA-Z0-9._])sum\\(", "rowSum(")
118+
expr = stringr::str_replace_all(expr, "(?<![a-zA-Z0-9._])sum\\(", "rowSums(")
93119
computed = as.array(eval_tidy(parse_expr(expr), data))
94120
if (length(dim(computed)) == 1) dim(computed) = c(dim(computed), 1)
95121

@@ -124,3 +150,75 @@ funs_env = new_environment(list(
124150
var = wtd_array_var,
125151
sd = wtd_array_sd
126152
))
153+
154+
155+
#' Plot Posterior Quantities of Interest Under Alternative Model Specifications
156+
#'
157+
#' Uses weights computed in \code{\link{adjust_weights}} to plot posterior
158+
#' quantities of interest versus
159+
#'
160+
#' @param x An \code{adjustr_weighted} object.
161+
#' @param by The x-axis variable, which is usually one of the specification
162+
#' parameters. Can be set to \code{1} if there is only one specification.
163+
#' Automatically quoted and evaluated in the context of \code{x}.
164+
#' @param post The posterior quantity of interest, to be computed for each
165+
#' resampled draw of each specificaiton. Should evaluate to a single number
166+
#' for each draw. Automatically quoted and evaluated in the context of \code{x}.
167+
#' @param only_mean Whether to only plot the posterior mean. May be more stable.
168+
#' @param ci_level The inner credible interval to plot. Central
169+
#' 100*ci_level% intervals are computed from the quantiles of the resampled
170+
#' posterior draws.
171+
#' @param outer_level The outer credible interval to plot.
172+
#' @param ... Ignored.
173+
#'
174+
#' @return A \code{\link[ggplot2]{ggplot}} object which can be further
175+
#' customized with the \strong{ggplot2} package.
176+
#'
177+
#' @examples \dontrun{
178+
#' spec = make_spec(eta ~ student_t(df, 0, scale),
179+
#' df=1:10, scale=seq(2, 1, -1/9))
180+
#' adjusted = adjust_weights(spec, eightschools_m)
181+
#'
182+
#' plot(adjusted, df, theta[1])
183+
#' plot(adjusted, df, mu, only_mean=TRUE)
184+
#' plot(adjusted, scale, tau)
185+
#' }
186+
#'
187+
#' @export
188+
plot.adjustr_weighted = function(x, by, post, only_mean=FALSE, ci_level=0.8,
189+
outer_level=0.95, ...) {
190+
if (!requireNamespace("ggplot2", quietly=TRUE)) { # nocov start
191+
stop("Package `ggplot2` must be installed to plot posterior quantities of interest.")
192+
} # nocov end
193+
if (ci_level > outer_level) stop("`ci_level` should be less than `outer_level`")
194+
195+
post = enexpr(post)
196+
if (!only_mean) {
197+
outer = (1 - outer_level) / 2
198+
inner = (1 - ci_level) / 2
199+
q_probs = c(outer, inner, 0.5, 1-inner, 1-outer)
200+
sum_arg = quo(stats::quantile(!!post, probs = !!q_probs))
201+
202+
summarise.adjustr_weighted(x, .y = !!sum_arg) %>%
203+
rowwise() %>%
204+
mutate(.y_ol = .y[1],
205+
.y_il = .y[2],
206+
.y_med = .y[3],
207+
.y_iu = .y[4],
208+
.y_ou = .y[5]) %>%
209+
ggplot(aes({{ by }}, .y_med)) +
210+
geom_ribbon(aes(ymin=.y_ol, ymax=.y_ou), alpha=0.4) +
211+
geom_ribbon(aes(ymin=.y_il, ymax=.y_iu), alpha=0.5) +
212+
geom_line() +
213+
geom_point(size=3) +
214+
theme_minimal() +
215+
labs(y= expr_name(post))
216+
} else {
217+
summarise.adjustr_weighted(x, .y = mean(!!post)) %>%
218+
ggplot(aes({{ by }}, .y)) +
219+
geom_line() +
220+
geom_point(size=3) +
221+
theme_minimal() +
222+
labs(y = expr_name(post))
223+
}
224+
}

_pkgdown.yml

+5-4
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@ template:
1313
navbar:
1414
title: "adjustr"
1515
left:
16-
#- text: "Vignettes"
17-
# href: articles/
16+
- text: "Vignettes"
17+
href: articles/index.html
1818
- text: "Functions"
19-
href: reference/
19+
href: reference/index.html
2020
- text: "Other Packages"
2121
menu:
2222
- text: "rstan"
@@ -72,7 +72,8 @@ reference:
7272
contents:
7373
- make_spec
7474
- adjust_weights
75-
- summarise.adjustr_weighted
75+
- summarize.adjustr_weighted
76+
- plot.adjustr_weighted
7677
- title: "Helper Functions"
7778
desc: >
7879
Various helper functions for examining a model or building sampling

docs/404.html

+4-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

docs/LICENSE-text.html

+4-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)