From c1cc3dfe81120e3d6b68e17be063370ea7666e3f Mon Sep 17 00:00:00 2001 From: Daniel Date: Wed, 4 Oct 2023 10:37:17 +0200 Subject: [PATCH] r2 --- R/model_performance.lm.R | 5 +++ R/r2.R | 19 +++++------ R/r2_coxsnell.R | 20 ++++++++++++ R/r2_nagelkerke.R | 20 ++++++++++++ R/r2_tjur.R | 32 +++++++++++++++++++ tests/testthat/_snaps/nestedLogit.md | 8 ++--- tests/testthat/test-nestedLogit.R | 47 ++++++++++++++++++++++++++++ 7 files changed, 135 insertions(+), 16 deletions(-) diff --git a/R/model_performance.lm.R b/R/model_performance.lm.R index f03ee726e..e8e5b07cf 100644 --- a/R/model_performance.lm.R +++ b/R/model_performance.lm.R @@ -260,6 +260,11 @@ model_performance.nestedLogit <- function(model, metrics = "all", verbose = TRUE data.frame(Response = names(mp), stringsAsFactors = FALSE), do.call(rbind, mp) ) + # need to handle R2 separately + if (any(c("ALL", "R2") %in% toupper(metrics))) { + out$R2 <- unlist(r2_tjur(model)) + } + row.names(out) <- NULL class(out) <- unique(c("performance_model", class(out))) out diff --git a/R/r2.R b/R/r2.R index c0a53af64..0e9760bba 100644 --- a/R/r2.R +++ b/R/r2.R @@ -49,11 +49,8 @@ r2 <- function(model, ...) { } - - # Default models ----------------------------------------------- - #' @rdname r2 #' @export r2.default <- function(model, ci = NULL, verbose = TRUE, ...) { @@ -115,6 +112,8 @@ r2.lm <- function(model, ci = NULL, ...) { #' @export r2.phylolm <- r2.lm +# helper ------------- + .r2_lm <- function(model_summary, ci = NULL) { out <- list( R2 = model_summary$r.squared, @@ -140,7 +139,6 @@ r2.phylolm <- r2.lm } - #' @export r2.summary.lm <- function(model, ci = NULL, ...) { if (!is.null(ci) && !is.na(ci)) { @@ -150,7 +148,6 @@ r2.summary.lm <- function(model, ci = NULL, ...) { } - #' @export r2.systemfit <- function(model, ...) { out <- lapply(summary(model)$eq, function(model_summary) { @@ -198,8 +195,6 @@ r2.ols <- function(model, ...) { structure(class = "r2_generic", out) } - - #' @export r2.lrm <- r2.ols @@ -207,7 +202,6 @@ r2.lrm <- r2.ols r2.cph <- r2.ols - #' @export r2.mhurdle <- function(model, ...) { resp <- insight::get_response(model, verbose = FALSE) @@ -230,7 +224,6 @@ r2.mhurdle <- function(model, ...) { } - #' @export r2.aov <- function(model, ci = NULL, ...) { if (!is.null(ci) && !is.na(ci)) { @@ -252,7 +245,6 @@ r2.aov <- function(model, ci = NULL, ...) { } - #' @export r2.mlm <- function(model, ...) { model_summary <- summary(model) @@ -276,7 +268,6 @@ r2.mlm <- function(model, ...) { } - #' @export r2.glm <- function(model, ci = NULL, verbose = TRUE, ...) { if (!is.null(ci) && !is.na(ci)) { @@ -312,9 +303,13 @@ r2.glm <- function(model, ci = NULL, verbose = TRUE, ...) { #' @export r2.glmx <- r2.glm + #' @export r2.nestedLogit <- function(model, ci = NULL, verbose = TRUE, ...) { - lapply(r2, model$models, ci = ci, verbose = verbose, ...) + out <- list("R2_Tjur" = r2_tjur(model, ...)) + attr(out, "model_type") <- "Logistic" + class(out) <- c("r2_pseudo", class(out)) + out } diff --git a/R/r2_coxsnell.R b/R/r2_coxsnell.R index c151136d4..cdb73dea7 100644 --- a/R/r2_coxsnell.R +++ b/R/r2_coxsnell.R @@ -88,6 +88,26 @@ r2_coxsnell.glm <- function(model, verbose = TRUE, ...) { #' @export r2_coxsnell.BBreg <- r2_coxsnell.glm + +#' @export +r2_coxsnell.nestedLogit <- function(model, ...) { + n <- insight::n_obs(model, disaggregate = TRUE) + stats::setNames( + lapply(names(model$models), function(i) { + m <- model$models[[i]] + # if no deviance, return NA + if (is.null(m$deviance)) { + return(NA) + } + r2_coxsnell <- (1 - exp((m$deviance - m$null.deviance) / n[[i]])) + names(r2_coxsnell) <- "Cox & Snell's R2" + r2_coxsnell + }), + names(model$models) + ) +} + + #' @export r2_coxsnell.mclogit <- function(model, ...) { insight::check_if_installed("mclogit", reason = "to calculate R2") diff --git a/R/r2_nagelkerke.R b/R/r2_nagelkerke.R index bb6230f22..85bcc6e8f 100644 --- a/R/r2_nagelkerke.R +++ b/R/r2_nagelkerke.R @@ -77,6 +77,26 @@ r2_nagelkerke.glm <- function(model, verbose = TRUE, ...) { #' @export r2_nagelkerke.BBreg <- r2_nagelkerke.glm + +#' @export +r2_nagelkerke.nestedLogit <- function(model, ...) { + n <- insight::n_obs(model, disaggregate = TRUE) + stats::setNames( + lapply(names(model$models), function(i) { + m <- model$models[[i]] + # if no deviance, return NA + if (is.null(m$deviance)) { + return(NA) + } + r2_nagelkerke <- (1 - exp((m$deviance - m$null.deviance) / n[[i]])) / (1 - exp(-m$null.deviance / n[[i]])) + names(r2_nagelkerke) <- "Nagelkerke's R2" + r2_nagelkerke + }), + names(model$models) + ) +} + + #' @export r2_nagelkerke.bife <- function(model, ...) { r2_nagelkerke <- r2_coxsnell(model) / (1 - exp(-model$null_deviance / insight::n_obs(model))) diff --git a/R/r2_tjur.R b/R/r2_tjur.R index e0aa6117c..efdca98c2 100644 --- a/R/r2_tjur.R +++ b/R/r2_tjur.R @@ -23,6 +23,11 @@ #' #' @export r2_tjur <- function(model, ...) { + UseMethod("r2_tjur") +} + +#' @export +r2_tjur.default <- function(model, ...) { info <- list(...)$model_info if (is.null(info)) { info <- suppressWarnings(insight::model_info(model, verbose = FALSE)) @@ -50,3 +55,30 @@ r2_tjur <- function(model, ...) { names(tjur_d) <- "Tjur's R2" tjur_d } + +#' @export +r2_tjur.nestedLogit <- function(model, ...) { + resp <- insight::get_response(model, dichotomies = TRUE, verbose = FALSE) + + stats::setNames( + lapply(names(model$models), function(i) { + y <- resp[[i]] + m <- model$models[[i]] + pred <- stats::predict(m, type = "response") + # delete pred for cases with missing residuals + if (anyNA(stats::residuals(m))) { + pred <- pred[!is.na(stats::residuals(m))] + } + categories <- unique(y) + m1 <- mean(pred[which(y == categories[1])], na.rm = TRUE) + m2 <- mean(pred[which(y == categories[2])], na.rm = TRUE) + + tjur_d <- abs(m2 - m1) + + names(tjur_d) <- "Tjur's R2" + tjur_d + } + ), + names(model$models) + ) +} diff --git a/tests/testthat/_snaps/nestedLogit.md b/tests/testthat/_snaps/nestedLogit.md index 907ae3fa5..f5c9a0bdd 100644 --- a/tests/testthat/_snaps/nestedLogit.md +++ b/tests/testthat/_snaps/nestedLogit.md @@ -5,8 +5,8 @@ Output # Indices of model performance - Response | AIC | BIC | RMSE | Sigma - -------------------------------------------- - work | 325.733 | 336.449 | 0.456 | 1.000 - full | 110.495 | 118.541 | 0.398 | 1.000 + Response | AIC | BIC | RMSE | Sigma | R2 + ---------------------------------------------------- + work | 325.733 | 336.449 | 0.456 | 1.000 | 0.138 + full | 110.495 | 118.541 | 0.398 | 1.000 | 0.333 diff --git a/tests/testthat/test-nestedLogit.R b/tests/testthat/test-nestedLogit.R index beaa701d4..78a8db3c7 100644 --- a/tests/testthat/test-nestedLogit.R +++ b/tests/testthat/test-nestedLogit.R @@ -13,6 +13,53 @@ mnl <- nestedLogit::nestedLogit( data = Womenlf ) +test_that("r2", { + out <- r2(mnl) + expect_equal( + out, + list(R2_Tjur = list( + work = c(`Tjur's R2` = 0.137759452521642), + full = c(`Tjur's R2` = 0.332536937208286) + )), + ignore_attr = TRUE, + tolerance = 1e-4 + ) + + out <- r2_tjur(mnl) + expect_equal( + out, + list(R2_Tjur = list( + work = c(`Tjur's R2` = 0.137759452521642), + full = c(`Tjur's R2` = 0.332536937208286) + )), + ignore_attr = TRUE, + tolerance = 1e-4 + ) + + out <- r2_coxsnell(mnl) + expect_equal( + out, + list( + work = c(`Cox & Snell's R2` = 0.129313084315599), + full = c(`Cox & Snell's R2` = 0.308541455410686) + ), + ignore_attr = TRUE, + tolerance = 1e-4 + ) + + out <- r2_nagelkerke(mnl) + expect_equal( + out, + list( + work = c(`Nagelkerke's R2` = 0.174313365512442), + full = c(`Nagelkerke's R2` = 0.418511411473948) + ), + ignore_attr = TRUE, + tolerance = 1e-4 + ) +}) + + test_that("model_performance", { expect_snapshot(model_performance(mnl)) })