Skip to content

Commit

Permalink
plot.check_predictions for Stan models (#336)
Browse files Browse the repository at this point in the history
* plot.check_predictions for Stan models

* remove NA here

* fix
  • Loading branch information
strengejacke authored Mar 30, 2024
1 parent 1a090e1 commit fea309a
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 1 deletion.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Type: Package
Package: see
Title: Model Visualisation Toolbox for 'easystats' and 'ggplot2'
Version: 0.8.3.3
Version: 0.8.3.4
Authors@R:
c(person(given = "Daniel",
family = "Lüdecke",
Expand Down
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
downstream plot-functions (i.e., `plot()` for `check_model()` passes arguments
to change geom sizes to the underlying plot-functions).

* `plot()` for `check_predictions()` now supports Bayesian regression models from
*brms* and *rstanarm*.

# see 0.8.3

## Major changes
Expand Down
90 changes: 90 additions & 0 deletions R/plot.check_predictions.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,19 @@
#' @export
data_plot.performance_pp_check <- function(x, type = "density", ...) {
# for data from "bayesplot::pp_check()", data is already in shape
if (isTRUE(attributes(x)$is_stan) && type != "density") {
class(x) <- c("data_plot", "see_performance_pp_check", "data.frame")
attr(x, "info") <- list(
xlab = attr(x, "response_name"),
ylab = ifelse(identical(type, "density"), "Density", "Counts"),
title = "Posterior Predictive Check",
check_range = attr(x, "check_range"),
bandwidth = attr(x, "bandwidth"),
model_info = attr(x, "model_info")
)
return(x)
}

columns <- colnames(x)
dataplot <- stats::reshape(
x,
Expand Down Expand Up @@ -88,6 +102,7 @@ print.see_performance_pp_check <- function(x,
orig_x <- x
check_range <- isTRUE(attributes(x)$check_range)
plot_type <- attributes(x)$type
is_stan <- attributes(x)$is_stan

if (missing(type) && !is.null(plot_type) && plot_type %in% c("density", "discrete_dots", "discrete_interval", "discrete_both")) {

Check warning on line 107 in R/plot.check_predictions.R

View workflow job for this annotation

GitHub Actions / lint / lint

file=R/plot.check_predictions.R,line=107,col=121,[line_length_linter] Lines should not be more than 120 characters. This line is 131 characters.
type <- plot_type
Expand All @@ -111,6 +126,7 @@ print.see_performance_pp_check <- function(x,
size_axis_title = size_axis_title,
type = type,
x_limits = x_limits,
is_stan = is_stan,
...
)

Expand Down Expand Up @@ -143,6 +159,7 @@ plot.see_performance_pp_check <- function(x,
orig_x <- x
check_range <- isTRUE(attributes(x)$check_range)
plot_type <- attributes(x)$type
is_stan <- attributes(x)$is_stan

if (missing(type) && !is.null(plot_type) && plot_type %in% c("density", "discrete_dots", "discrete_interval", "discrete_both")) { # nolint
type <- plot_type
Expand All @@ -166,6 +183,7 @@ plot.see_performance_pp_check <- function(x,
colors = colors,
type = type,
x_limits = x_limits,
is_stan = is_stan,
...
)

Expand All @@ -190,9 +208,16 @@ plot.see_performance_pp_check <- function(x,
colors,

Check warning on line 208 in R/plot.check_predictions.R

View workflow job for this annotation

GitHub Actions / lint / lint

file=R/plot.check_predictions.R,line=208,col=28,[function_argument_linter] Arguments without defaults should come before arguments with defaults.
type = "density",
x_limits = NULL,
is_stan = NULL,
...) {
info <- attr(x, "info")

# discrete plot type from "bayesplot::pp_check()" returns a different data
# structure, so we need to handle it differently
if (isTRUE(is_stan) && type != "density") {
return(.plot_check_predictions_stan_dots(x, colors, info, size_line, size_point, line_alpha, ...))
}

# default bandwidth, for smooting
bandwidth <- info$bandwidth
if (is.null(bandwidth)) {
Expand Down Expand Up @@ -450,6 +475,71 @@ plot.see_performance_pp_check <- function(x,
}


.plot_check_predictions_stan_dots <- function(x,
colors,
info,
size_line,
size_point,
line_alpha,
...) {
# make sure we have a factor, so "table()" generates frequencies for all levels
# for each group - we need tables of same size to bind data frames
x$Group[x$Group == "y"] <- "Observed data"
x$Group[x$Group == "Mean"] <- "Model-predicted data"

# sanity check, remove NA rows
x <- x[!is.na(x$Count), ]

p <- ggplot2::ggplot() +
ggplot2::geom_pointrange(
data = x[x$Group == "Model-predicted data", ],
mapping = ggplot2::aes(
x = .data$x,
y = .data$Count,
ymin = .data$CI_low,
ymax = .data$CI_high,
color = .data$Group
),
position = ggplot2::position_nudge(x = 0.2),
size = 0.4 * size_point,
linewidth = size_line,
stroke = 0,
shape = 16
) +
ggplot2::geom_point(
data = x[x$Group == "Observed data", ],
mapping = ggplot2::aes(
x = .data$x,
y = .data$Count,
color = .data$Group
),
size = 1.5 * size_point,
stroke = 0,
shape = 16
) +
ggplot2::scale_y_continuous() +
ggplot2::scale_color_manual(values = c(
"Observed data" = colors[1],
"Model-predicted data" = colors[2]
)) +
ggplot2::labs(
x = info$xlab,
y = info$ylab,
color = "",
size = "",
alpha = "",
title = "Posterior Predictive Check",
subtitle = "Model-predicted intervals should include observed data points"
) +
ggplot2::guides(
color = ggplot2::guide_legend(reverse = TRUE),
size = ggplot2::guide_legend(reverse = TRUE)
)

return(p)
}


.plot_pp_check_range <- function(x,
size_bar = 0.7,
colors = unname(social_colors(c("green", "blue")))) {
Expand Down

0 comments on commit fea309a

Please sign in to comment.