Skip to content

Commit eb4fafa

Browse files
authored
Improve binned_residuals() (#641)
1 parent 3408a7c commit eb4fafa

File tree

8 files changed

+285
-95
lines changed

8 files changed

+285
-95
lines changed

NEWS.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,15 @@
11
# performance 0.10.7
22

3+
## Breaking changes
4+
5+
* `binned_residuals()` gains a few new arguments to control the residuals used
6+
for the test, as well as different options to calculate confidence intervals
7+
(namely, `ci_type`, `residuals`, `ci` and `iterations`). The default values
8+
to compute binned residuals have changed. Default residuals are now "deviance"
9+
residuals (and no longer "response" residuals). Default confidence intervals
10+
are now "exact" intervals (and no longer based on Gaussian approximation).
11+
Use `ci_type = "gaussian"` and `residuals = "response"` to get the old defaults.
12+
313
## Changes to functions
414

515
* `binned_residuals()` - like `check_model()` - gains a `show_dots` argument to

R/binned_residuals.R

Lines changed: 72 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,19 @@
1111
#' @param n_bins Numeric, the number of bins to divide the data. If
1212
#' `n_bins = NULL`, the square root of the number of observations is
1313
#' taken.
14+
#' @param ci Numeric, the confidence level for the error bounds.
15+
#' @param ci_type Character, the type of error bounds to calculate. Can be
16+
#' `"exact"` (default), `"gaussian"` or `"boot"`. `"exact"` calculates the
17+
#' error bounds based on the exact binomial distribution, using [`binom.test()`].
18+
#' `"gaussian"` uses the Gaussian approximation, while `"boot"` uses a simple
19+
#' bootstrap method, where confidence intervals are calculated based on the
20+
#' quantiles of the bootstrap distribution.
21+
#' @param residuals Character, the type of residuals to calculate. Can be
22+
#' `"deviance"` (default), `"pearson"` or `"response"`. It is recommended to
23+
#' use `"response"` only for those models where other residuals are not
24+
#' available.
25+
#' @param iterations Integer, the number of iterations to use for the
26+
#' bootstrap method. Only used if `ci_type = "boot"`.
1427
#' @param show_dots Logical, if `TRUE`, will show data points in the plot. Set
1528
#' to `FALSE` for models with many observations, if generating the plot is too
1629
#' time-consuming. By default, `show_dots = NULL`. In this case `binned_residuals()`
@@ -62,12 +75,24 @@
6275
#' }
6376
#'
6477
#' @export
65-
binned_residuals <- function(model, term = NULL, n_bins = NULL, show_dots = NULL, ...) {
66-
fv <- stats::fitted(model)
78+
binned_residuals <- function(model,
79+
term = NULL,
80+
n_bins = NULL,
81+
show_dots = NULL,
82+
ci = 0.95,
83+
ci_type = c("exact", "gaussian", "boot"),
84+
residuals = c("deviance", "pearson", "response"),
85+
iterations = 1000,
86+
...) {
87+
# match arguments
88+
ci_type <- match.arg(ci_type)
89+
residuals <- match.arg(residuals)
90+
91+
fitted_values <- stats::fitted(model)
6792
mf <- insight::get_data(model, verbose = FALSE)
6893

6994
if (is.null(term)) {
70-
pred <- fv
95+
pred <- fitted_values
7196
} else {
7297
pred <- mf[[term]]
7398
}
@@ -78,7 +103,20 @@ binned_residuals <- function(model, term = NULL, n_bins = NULL, show_dots = NULL
78103
show_dots <- is.null(n) || n <= 1e5
79104
}
80105

81-
y <- .recode_to_zero(insight::get_response(model, verbose = FALSE)) - fv
106+
# make sure response is 0/1 (and numeric)
107+
y0 <- .recode_to_zero(insight::get_response(model, verbose = FALSE))
108+
109+
# calculate residuals
110+
y <- switch(residuals,
111+
response = y0 - fitted_values,
112+
pearson = .safe((y0 - fitted_values) / sqrt(fitted_values * (1 - fitted_values))),
113+
deviance = .safe(stats::residuals(model, type = "deviance"))
114+
)
115+
116+
# make sure we really have residuals
117+
if (is.null(y)) {
118+
insight::format_error("Could not calculate residuals. Try using `residuals = \"response\"`.")
119+
}
82120

83121
if (is.null(n_bins)) n_bins <- round(sqrt(length(pred)))
84122

@@ -95,24 +133,32 @@ binned_residuals <- function(model, term = NULL, n_bins = NULL, show_dots = NULL
95133
n <- length(items)
96134
sdev <- stats::sd(y[items], na.rm = TRUE)
97135

98-
data.frame(
136+
conf_int <- switch(ci_type,
137+
gaussian = stats::qnorm(c((1 - ci) / 2, (1 + ci) / 2), mean = ybar, sd = sdev / sqrt(n)),
138+
exact = {
139+
out <- stats::binom.test(sum(y0[items]), n)$conf.int
140+
# center CIs around point estimate
141+
out <- out - (min(out) - ybar) - (diff(out) / 2)
142+
out
143+
},
144+
boot = .boot_binned_ci(y[items], ci, iterations)
145+
)
146+
names(conf_int) <- c("CI_low", "CI_high")
147+
148+
d0 <- data.frame(
99149
xbar = xbar,
100150
ybar = ybar,
101151
n = n,
102152
x.lo = model.range[1],
103153
x.hi = model.range[2],
104-
se = stats::qnorm(0.975) * sdev / sqrt(n),
105-
ci_range = sdev / sqrt(n)
154+
se = stats::qnorm((1 + ci) / 2) * sdev / sqrt(n)
106155
)
156+
cbind(d0, rbind(conf_int))
107157
}))
108158

109159
d <- do.call(rbind, d)
110160
d <- d[stats::complete.cases(d), ]
111161

112-
# CIs
113-
d$CI_low <- d$ybar - stats::qnorm(0.975) * d$ci_range
114-
d$CI_high <- d$ybar + stats::qnorm(0.975) * d$ci_range
115-
116162
gr <- abs(d$ybar) > abs(d$se)
117163
d$group <- "yes"
118164
d$group[gr] <- "no"
@@ -129,6 +175,21 @@ binned_residuals <- function(model, term = NULL, n_bins = NULL, show_dots = NULL
129175
}
130176

131177

178+
# utilities ---------------------------
179+
180+
.boot_binned_ci <- function(x, ci = 0.95, iterations = 1000) {
181+
x <- x[!is.na(x)]
182+
n <- length(x)
183+
out <- vector("numeric", iterations)
184+
for (i in seq_len(iterations)) {
185+
out[i] <- sum(x[sample.int(n, n, replace = TRUE)])
186+
}
187+
out <- out / n
188+
189+
quant <- stats::quantile(out, c((1 - ci) / 2, (1 + ci) / 2))
190+
c(CI_low = quant[1L], CI_high = quant[2L])
191+
}
192+
132193

133194
# methods -----------------------------
134195

R/check_model.R

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@
3535
#' tries to guess whether performance will be poor due to a very large model
3636
#' and thus automatically shows or hides dots.
3737
#' @param verbose If `FALSE` (default), suppress most warning messages.
38-
#' @param ... Currently not used.
38+
#' @param ... Arguments passed down to the individual check functions, especially
39+
#' to `check_predictions()` and `binned_residuals()`.
3940
#' @inheritParams check_predictions
4041
#'
4142
#' @return The data frame that is used for plotting.
@@ -185,11 +186,11 @@ check_model.default <- function(x,
185186
ca <- tryCatch(
186187
{
187188
if (minfo$is_bayesian) {
188-
suppressWarnings(.check_assumptions_stan(x))
189+
suppressWarnings(.check_assumptions_stan(x, ...))
189190
} else if (minfo$is_linear) {
190-
suppressWarnings(.check_assumptions_linear(x, minfo, verbose))
191+
suppressWarnings(.check_assumptions_linear(x, minfo, verbose, ...))
191192
} else {
192-
suppressWarnings(.check_assumptions_glm(x, minfo, verbose))
193+
suppressWarnings(.check_assumptions_glm(x, minfo, verbose, ...))
193194
}
194195
},
195196
error = function(e) {
@@ -202,7 +203,7 @@ check_model.default <- function(x,
202203
}
203204

204205
# try to find sensible default for "type" argument
205-
suggest_dots <- (minfo$is_bernoulli || minfo$is_count || minfo$is_ordinal || minfo$is_categorical || minfo$is_multinomial)
206+
suggest_dots <- (minfo$is_bernoulli || minfo$is_count || minfo$is_ordinal || minfo$is_categorical || minfo$is_multinomial) # nolint
206207
if (missing(type) && suggest_dots) {
207208
type <- "discrete_interval"
208209
}
@@ -330,7 +331,7 @@ check_model.model_fit <- function(x,
330331

331332
# compile plots for checks of linear models ------------------------
332333

333-
.check_assumptions_linear <- function(model, model_info, verbose = TRUE) {
334+
.check_assumptions_linear <- function(model, model_info, verbose = TRUE, ...) {
334335
dat <- list()
335336

336337
dat$VIF <- .diag_vif(model, verbose = verbose)
@@ -340,13 +341,13 @@ check_model.model_fit <- function(x,
340341
dat$NCV <- .diag_ncv(model, verbose = verbose)
341342
dat$HOMOGENEITY <- .diag_homogeneity(model, verbose = verbose)
342343
dat$OUTLIERS <- check_outliers(model, method = "cook")
343-
if (!is.null(dat$OUTLIERS)) {
344-
threshold <- attributes(dat$OUTLIERS)$threshold$cook
345-
} else {
344+
if (is.null(dat$OUTLIERS)) {
346345
threshold <- NULL
346+
} else {
347+
threshold <- attributes(dat$OUTLIERS)$threshold$cook
347348
}
348349
dat$INFLUENTIAL <- .influential_obs(model, threshold = threshold)
349-
dat$PP_CHECK <- .safe(check_predictions(model))
350+
dat$PP_CHECK <- .safe(check_predictions(model, ...))
350351

351352
dat <- insight::compact_list(dat)
352353
class(dat) <- c("check_model", "see_check_model")
@@ -357,23 +358,23 @@ check_model.model_fit <- function(x,
357358

358359
# compile plots for checks of generalized linear models ------------------------
359360

360-
.check_assumptions_glm <- function(model, model_info, verbose = TRUE) {
361+
.check_assumptions_glm <- function(model, model_info, verbose = TRUE, ...) {
361362
dat <- list()
362363

363364
dat$VIF <- .diag_vif(model, verbose = verbose)
364365
dat$QQ <- .diag_qq(model, verbose = verbose)
365366
dat$HOMOGENEITY <- .diag_homogeneity(model, verbose = verbose)
366367
dat$REQQ <- .diag_reqq(model, level = 0.95, model_info = model_info, verbose = verbose)
367368
dat$OUTLIERS <- check_outliers(model, method = "cook")
368-
if (!is.null(dat$OUTLIERS)) {
369-
threshold <- attributes(dat$OUTLIERS)$threshold$cook
370-
} else {
369+
if (is.null(dat$OUTLIERS)) {
371370
threshold <- NULL
371+
} else {
372+
threshold <- attributes(dat$OUTLIERS)$threshold$cook
372373
}
373374
dat$INFLUENTIAL <- .influential_obs(model, threshold = threshold)
374-
dat$PP_CHECK <- .safe(check_predictions(model))
375+
dat$PP_CHECK <- .safe(check_predictions(model, ...))
375376
if (isTRUE(model_info$is_binomial)) {
376-
dat$BINNED_RESID <- binned_residuals(model)
377+
dat$BINNED_RESID <- binned_residuals(model, ...)
377378
}
378379
if (isTRUE(model_info$is_count)) {
379380
dat$OVERDISPERSION <- .diag_overdispersion(model)
@@ -388,7 +389,7 @@ check_model.model_fit <- function(x,
388389

389390
# compile plots for checks of Bayesian models ------------------------
390391

391-
.check_assumptions_stan <- function(model) {
392+
.check_assumptions_stan <- function(model, ...) {
392393
if (inherits(model, "brmsfit")) {
393394
# check if brms can be loaded
394395

0 commit comments

Comments
 (0)