From efcf52263f5113b9e528c1a4b80b4afcbd3d9207 Mon Sep 17 00:00:00 2001 From: Aki Vehtari Date: Thu, 29 Feb 2024 20:11:22 +0200 Subject: [PATCH] make E_loo Pareto-k diagnostic more robust --- R/E_loo.R | 20 +++++++++++++------- tests/testthat/test_E_loo.R | 7 ++++++- 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/R/E_loo.R b/R/E_loo.R index 2508162f..a05052f1 100644 --- a/R/E_loo.R +++ b/R/E_loo.R @@ -48,11 +48,11 @@ #' Pareto-k's, which may produce optimistic estimates. #' #' For `type="mean"`, `type="var"`, and `type="sd"`, the returned Pareto-k is -#' the maximum of the Pareto-k's for the left and right tail of \eqn{hr} and -#' the right tail of \eqn{r}, where \eqn{r} is the importance ratio and -#' \eqn{h=x} for `type="mean"` and \eqn{h=x^2} for `type="var"` and -#' `type="sd"`. For `type="quantile"`, the returned Pareto-k is the Pareto-k -#' for the right tail of \eqn{r}. +#' usually the maximum of the Pareto-k's for the left and right tail of \eqn{hr} +#' and the right tail of \eqn{r}, where \eqn{r} is the importance ratio and +#' \eqn{h=x} for `type="mean"` and \eqn{h=x^2} for `type="var"` and `type="sd"`. +#' If \eqn{h} is binary, constant, or not finite, or if type="quantile"`, the +#' returned Pareto-k is the Pareto-k for the right tail of \eqn{r}. #' } #' } #' @@ -291,10 +291,16 @@ E_loo_khat.matrix <- function(x, psis_object, log_ratios, ...) { h_theta <- x_i r_theta <- exp(log_ratios_i - max(log_ratios_i)) khat_r <- posterior::pareto_khat(r_theta, tail = "right", ndraws_tail = tail_len_i)$khat - if (is.null(x_i)) { + if (is.null(x_i) || is_constant(x_i) || length(unique(x_i))==2 || + anyNA(x_i) || any(is.infinite(x_i))) { khat_r } else { khat_hr <- posterior::pareto_khat(h_theta * r_theta, tail = "both", ndraws_tail = tail_len_i)$khat - max(khat_hr, khat_r) + if (is.na(khat_hr) && is.na(khat_r)) { + k <- NA + } else { + k <- max(khat_hr, khat_r, na.rm=TRUE) + } + k } } diff --git a/tests/testthat/test_E_loo.R b/tests/testthat/test_E_loo.R index 89fb0d24..678b9ed3 100644 --- a/tests/testthat/test_E_loo.R +++ b/tests/testthat/test_E_loo.R @@ -115,6 +115,12 @@ test_that("E_loo.matrix equal to reference", { test_that("E_loo throws correct errors and warnings", { # warnings expect_no_warning(E_loo.matrix(x, psis_mat)) + # now warnings if x is constant, binary, NA, NaN, Inf + expect_no_warning(E_loo.matrix(x*0, psis_mat)) + expect_no_warning(E_loo.matrix(0+(x>0), psis_mat)) + expect_no_warning(E_loo.matrix(x+NA, psis_mat)) + expect_no_warning(E_loo.matrix(x*NaN, psis_mat)) + expect_no_warning(E_loo.matrix(x*Inf, psis_mat)) expect_no_warning(E_test <- E_loo.default(x[, 1], psis_vec)) expect_length(E_test$pareto_k, 1) @@ -191,4 +197,3 @@ test_that("weighted variance works", { w <- c(rep(0.1, 10), rep(0, 90)) expect_equal(.wvar(x, w), var(x[w > 0])) }) -