Skip to content

Commit

Permalink
r2
Browse files Browse the repository at this point in the history
  • Loading branch information
strengejacke committed Oct 4, 2023
1 parent 07d5b3a commit c1cc3df
Show file tree
Hide file tree
Showing 7 changed files with 135 additions and 16 deletions.
5 changes: 5 additions & 0 deletions R/model_performance.lm.R
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,11 @@ model_performance.nestedLogit <- function(model, metrics = "all", verbose = TRUE
data.frame(Response = names(mp), stringsAsFactors = FALSE),
do.call(rbind, mp)
)
# need to handle R2 separately
if (any(c("ALL", "R2") %in% toupper(metrics))) {
out$R2 <- unlist(r2_tjur(model))
}

row.names(out) <- NULL
class(out) <- unique(c("performance_model", class(out)))
out
Expand Down
19 changes: 7 additions & 12 deletions R/r2.R
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,8 @@ r2 <- function(model, ...) {
}




# Default models -----------------------------------------------


#' @rdname r2
#' @export
r2.default <- function(model, ci = NULL, verbose = TRUE, ...) {
Expand Down Expand Up @@ -115,6 +112,8 @@ r2.lm <- function(model, ci = NULL, ...) {
#' @export
r2.phylolm <- r2.lm

# helper -------------

.r2_lm <- function(model_summary, ci = NULL) {
out <- list(
R2 = model_summary$r.squared,
Expand All @@ -140,7 +139,6 @@ r2.phylolm <- r2.lm
}



#' @export
r2.summary.lm <- function(model, ci = NULL, ...) {
if (!is.null(ci) && !is.na(ci)) {
Expand All @@ -150,7 +148,6 @@ r2.summary.lm <- function(model, ci = NULL, ...) {
}



#' @export
r2.systemfit <- function(model, ...) {
out <- lapply(summary(model)$eq, function(model_summary) {
Expand Down Expand Up @@ -198,16 +195,13 @@ r2.ols <- function(model, ...) {
structure(class = "r2_generic", out)
}



#' @export
r2.lrm <- r2.ols

#' @export
r2.cph <- r2.ols



#' @export
r2.mhurdle <- function(model, ...) {
resp <- insight::get_response(model, verbose = FALSE)
Expand All @@ -230,7 +224,6 @@ r2.mhurdle <- function(model, ...) {
}



#' @export
r2.aov <- function(model, ci = NULL, ...) {
if (!is.null(ci) && !is.na(ci)) {
Expand All @@ -252,7 +245,6 @@ r2.aov <- function(model, ci = NULL, ...) {
}



#' @export
r2.mlm <- function(model, ...) {
model_summary <- summary(model)
Expand All @@ -276,7 +268,6 @@ r2.mlm <- function(model, ...) {
}



#' @export
r2.glm <- function(model, ci = NULL, verbose = TRUE, ...) {
if (!is.null(ci) && !is.na(ci)) {
Expand Down Expand Up @@ -312,9 +303,13 @@ r2.glm <- function(model, ci = NULL, verbose = TRUE, ...) {
#' @export
r2.glmx <- r2.glm


#' @export
r2.nestedLogit <- function(model, ci = NULL, verbose = TRUE, ...) {
lapply(r2, model$models, ci = ci, verbose = verbose, ...)
out <- list("R2_Tjur" = r2_tjur(model, ...))

Check warning on line 309 in R/r2.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/r2.R,line=309,col=15,[keyword_quote_linter] Only quote named arguments to functions if necessary, i.e., if the name is not a valid R symbol (see ?make.names).
attr(out, "model_type") <- "Logistic"
class(out) <- c("r2_pseudo", class(out))
out
}


Expand Down
20 changes: 20 additions & 0 deletions R/r2_coxsnell.R
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,26 @@ r2_coxsnell.glm <- function(model, verbose = TRUE, ...) {
#' @export
r2_coxsnell.BBreg <- r2_coxsnell.glm


#' @export
r2_coxsnell.nestedLogit <- function(model, ...) {
n <- insight::n_obs(model, disaggregate = TRUE)
stats::setNames(
lapply(names(model$models), function(i) {
m <- model$models[[i]]
# if no deviance, return NA
if (is.null(m$deviance)) {
return(NA)
}
r2_coxsnell <- (1 - exp((m$deviance - m$null.deviance) / n[[i]]))
names(r2_coxsnell) <- "Cox & Snell's R2"
r2_coxsnell
}),
names(model$models)
)
}


#' @export
r2_coxsnell.mclogit <- function(model, ...) {
insight::check_if_installed("mclogit", reason = "to calculate R2")
Expand Down
20 changes: 20 additions & 0 deletions R/r2_nagelkerke.R
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,26 @@ r2_nagelkerke.glm <- function(model, verbose = TRUE, ...) {
#' @export
r2_nagelkerke.BBreg <- r2_nagelkerke.glm


#' @export
r2_nagelkerke.nestedLogit <- function(model, ...) {
n <- insight::n_obs(model, disaggregate = TRUE)
stats::setNames(
lapply(names(model$models), function(i) {
m <- model$models[[i]]
# if no deviance, return NA
if (is.null(m$deviance)) {
return(NA)
}
r2_nagelkerke <- (1 - exp((m$deviance - m$null.deviance) / n[[i]])) / (1 - exp(-m$null.deviance / n[[i]]))
names(r2_nagelkerke) <- "Nagelkerke's R2"
r2_nagelkerke
}),
names(model$models)
)
}


#' @export
r2_nagelkerke.bife <- function(model, ...) {
r2_nagelkerke <- r2_coxsnell(model) / (1 - exp(-model$null_deviance / insight::n_obs(model)))
Expand Down
32 changes: 32 additions & 0 deletions R/r2_tjur.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@
#'
#' @export
r2_tjur <- function(model, ...) {
UseMethod("r2_tjur")
}

#' @export
r2_tjur.default <- function(model, ...) {
info <- list(...)$model_info
if (is.null(info)) {
info <- suppressWarnings(insight::model_info(model, verbose = FALSE))
Expand Down Expand Up @@ -50,3 +55,30 @@ r2_tjur <- function(model, ...) {
names(tjur_d) <- "Tjur's R2"
tjur_d
}

#' @export
r2_tjur.nestedLogit <- function(model, ...) {
resp <- insight::get_response(model, dichotomies = TRUE, verbose = FALSE)

stats::setNames(
lapply(names(model$models), function(i) {
y <- resp[[i]]
m <- model$models[[i]]
pred <- stats::predict(m, type = "response")
# delete pred for cases with missing residuals
if (anyNA(stats::residuals(m))) {
pred <- pred[!is.na(stats::residuals(m))]
}
categories <- unique(y)
m1 <- mean(pred[which(y == categories[1])], na.rm = TRUE)
m2 <- mean(pred[which(y == categories[2])], na.rm = TRUE)

tjur_d <- abs(m2 - m1)

names(tjur_d) <- "Tjur's R2"
tjur_d
}
),
names(model$models)
)
}
8 changes: 4 additions & 4 deletions tests/testthat/_snaps/nestedLogit.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
Output
# Indices of model performance
Response | AIC | BIC | RMSE | Sigma
--------------------------------------------
work | 325.733 | 336.449 | 0.456 | 1.000
full | 110.495 | 118.541 | 0.398 | 1.000
Response | AIC | BIC | RMSE | Sigma | R2
----------------------------------------------------
work | 325.733 | 336.449 | 0.456 | 1.000 | 0.138
full | 110.495 | 118.541 | 0.398 | 1.000 | 0.333

47 changes: 47 additions & 0 deletions tests/testthat/test-nestedLogit.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,53 @@ mnl <- nestedLogit::nestedLogit(
data = Womenlf
)

test_that("r2", {
out <- r2(mnl)
expect_equal(
out,
list(R2_Tjur = list(
work = c(`Tjur's R2` = 0.137759452521642),
full = c(`Tjur's R2` = 0.332536937208286)
)),
ignore_attr = TRUE,
tolerance = 1e-4
)

out <- r2_tjur(mnl)
expect_equal(
out,
list(R2_Tjur = list(
work = c(`Tjur's R2` = 0.137759452521642),
full = c(`Tjur's R2` = 0.332536937208286)
)),
ignore_attr = TRUE,
tolerance = 1e-4
)

out <- r2_coxsnell(mnl)
expect_equal(
out,
list(
work = c(`Cox & Snell's R2` = 0.129313084315599),
full = c(`Cox & Snell's R2` = 0.308541455410686)
),
ignore_attr = TRUE,
tolerance = 1e-4
)

out <- r2_nagelkerke(mnl)
expect_equal(
out,
list(
work = c(`Nagelkerke's R2` = 0.174313365512442),
full = c(`Nagelkerke's R2` = 0.418511411473948)
),
ignore_attr = TRUE,
tolerance = 1e-4
)
})


test_that("model_performance", {
expect_snapshot(model_performance(mnl))
})

0 comments on commit c1cc3df

Please sign in to comment.