|
18 | 18 | #' posterior, and which as a result cannot be reliably estimated using
|
19 | 19 | #' importance sampling (i.e., if the Pareto shape parameter is larger than
|
20 | 20 | #' 0.7), have their weights discarded.
|
| 21 | +#' @param incl_orig When \code{TRUE}, include a row for the original |
| 22 | +#' model specification, with all weights equal. Can facilitate comaprison |
| 23 | +#' and plotting later. |
21 | 24 | #'
|
22 | 25 | #' @return A tibble, produced by converting the provided \code{specs} to a
|
23 | 26 | #' tibble (see \code{\link{as.data.frame.adjustr_spec}}), and adding columns
|
24 | 27 | #' \code{.weights}, containing vectors of weights for each draw, and
|
25 | 28 | #' \code{.pareto_k}, containing the diagnostic Pareto shape parameters. Values
|
26 | 29 | #' greater than 0.7 indicate that importance sampling is not reliable.
|
27 |
| -#' Weights can be extracted with the \code{\link{pull.adjustr_weighted}} |
28 |
| -#' method. The returned object also includes the model sample draws, in the |
29 |
| -#' \code{draws} attribute. |
| 30 | +#' If \code{incl_orig} is \code{TRUE}, a row is added for the original model |
| 31 | +#' specification. Weights can be extracted with the |
| 32 | +#' \code{\link{pull.adjustr_weighted}} method. The returned object also |
| 33 | +#' includes the model sample draws, in the \code{draws} attribute. |
30 | 34 | #'
|
31 | 35 | #' @examples \dontrun{
|
32 | 36 | #' model_data = list(
|
|
46 | 50 | #' }
|
47 | 51 | #'
|
48 | 52 | #' @export
|
49 |
| -adjust_weights = function(spec, object, data=NULL, keep_bad=FALSE) { |
| 53 | +adjust_weights = function(spec, object, data=NULL, keep_bad=FALSE, incl_orig=TRUE) { |
50 | 54 | # CHECK ARGUMENTS
|
51 | 55 | object = get_fit_obj(object)
|
52 | 56 | model_code = object@stanmodel@model_code
|
53 | 57 | stopifnot(is.adjustr_spec(spec))
|
54 | 58 |
|
55 |
| - parsed_model = parse_model(model_code) |
56 |
| - parsed_vars = get_variables(parsed_model) |
57 |
| - parsed_samp = get_sampling_stmts(parsed_model) |
| 59 | + parsed = parse_model(model_code) |
58 | 60 |
|
59 | 61 | # if no model data provided, we can only resample distributions of parameters
|
60 | 62 | if (is.null(data)) {
|
61 |
| - samp_vars = map_chr(parsed_samp, ~ as.character(f_lhs(.))) |
62 |
| - prior_vars = parsed_vars[samp_vars] != "data" |
63 |
| - parsed_samp = parsed_samp[prior_vars] |
| 63 | + samp_vars = map_chr(parsed$samp, ~ as.character(f_lhs(.))) |
| 64 | + prior_vars = parsed$vars[samp_vars] != "data" |
| 65 | + parsed$samp = parsed$samp[prior_vars] |
64 | 66 | data = list()
|
65 | 67 | }
|
66 | 68 |
|
67 |
| - matched_samp = match_sampling_stmts(spec$samp, parsed_samp) |
68 |
| - original_lp = calc_original_lp(object, matched_samp, parsed_vars, data) |
69 |
| - specs_lp = calc_specs_lp(object, spec$samp, parsed_vars, data, spec$params) |
| 69 | + matched_samp = match_sampling_stmts(spec$samp, parsed$samp) |
| 70 | + original_lp = calc_original_lp(object, matched_samp, parsed$vars, data) |
| 71 | + specs_lp = calc_specs_lp(object, spec$samp, parsed$vars, data, spec$params) |
70 | 72 |
|
71 | 73 | # compute weights
|
72 | 74 | wgts = map(specs_lp, function(spec_lp) {
|
@@ -95,6 +97,14 @@ adjust_weights = function(spec, object, data=NULL, keep_bad=FALSE) {
|
95 | 97 | attr(adjust_obj, "draws") = rstan::extract(object)
|
96 | 98 | attr(adjust_obj, "data") = data
|
97 | 99 | attr(adjust_obj, "iter") = object@sim$chains * (object@sim$iter - object@sim$warmup)
|
| 100 | + if (incl_orig) { |
| 101 | + adjust_obj = bind_rows(adjust_obj, tibble( |
| 102 | + .weights=list(rep(1, attr(adjust_obj, "iter"))), |
| 103 | + .pareto_k = -Inf)) |
| 104 | + samp_cols = stringr::str_detect(names(adjust_obj), "\\.samp") |
| 105 | + adjust_obj[nrow(adjust_obj), samp_cols] = "<original model>" |
| 106 | + } |
| 107 | + |
98 | 108 | adjust_obj
|
99 | 109 | }
|
100 | 110 |
|
@@ -141,19 +151,17 @@ pull.adjustr_weighted = function(.data, var=".weights") {
|
141 | 151 | extract_samp_stmts = function(object) {
|
142 | 152 | model_code = get_fit_obj(object)@stanmodel@model_code
|
143 | 153 |
|
144 |
| - parsed_model = parse_model(model_code) |
145 |
| - parsed_vars = get_variables(parsed_model) |
146 |
| - parsed_samp = get_sampling_stmts(parsed_model) |
| 154 | + parsed = parse_model(model_code) |
147 | 155 |
|
148 |
| - samp_vars = map_chr(parsed_samp, ~ as.character(f_lhs(.))) |
| 156 | + samp_vars = map_chr(parsed$samp, ~ as.character(f_lhs(.))) |
149 | 157 | type = map_chr(samp_vars, function(var) {
|
150 |
| - if (stringr::str_ends(parsed_vars[var], "data")) "data" else "parameter" |
| 158 | + if (stringr::str_ends(parsed$vars[var], "data")) "data" else "parameter" |
151 | 159 | })
|
152 | 160 | print_order = order(type, samp_vars, decreasing=c(T, F))
|
153 | 161 |
|
154 | 162 | cat(paste0("Sampling statements for model ", object@model_name, ":\n"))
|
155 |
| - purrr::walk(print_order, ~ cat(sprintf(" %-9s %s\n", type[.], as.character(parsed_samp[.])))) |
156 |
| - invisible(parsed_samp) |
| 163 | + purrr::walk(print_order, ~ cat(sprintf(" %-9s %s\n", type[.], as.character(parsed$samp[.])))) |
| 164 | + invisible(parsed$samp) |
157 | 165 | }
|
158 | 166 |
|
159 | 167 | # Check that the model object is correct, and extract its Stan code
|
|
0 commit comments