From bad5d771ed56348e735df4f021c27112cbf35a75 Mon Sep 17 00:00:00 2001 From: david-cortes Date: Sat, 14 Dec 2024 17:19:35 +0100 Subject: [PATCH] use class names in importance outputs --- R-package/R/xgb.importance.R | 16 +++++++++++++--- R-package/man/xgb.importance.Rd | 4 +++- R-package/tests/testthat/test_xgboost.R | 17 +++++++++++++++++ 3 files changed, 33 insertions(+), 4 deletions(-) diff --git a/R-package/R/xgb.importance.R b/R-package/R/xgb.importance.R index c1b45e81bb8c..be02b8cb0498 100644 --- a/R-package/R/xgb.importance.R +++ b/R-package/R/xgb.importance.R @@ -33,7 +33,9 @@ #' For a linear model: #' - `Features`: Names of the features used in the model. #' - `Weight`: Linear coefficient of this feature. -#' - `Class`: Class label (only for multiclass models). +#' - `Class`: Class label (only for multiclass models). For objects of class `xgboost` (as +#' produced by [xgboost()]), it will be a `factor`, while for objects of class `xgb.Booster` +#' (as produced by [xgb.train()]), it will be a zero-based integer vector. #' #' If `feature_names` is not provided and `model` doesn't have `feature_names`, #' the index of the features will be used instead. Because the index is extracted from the model dump @@ -144,11 +146,19 @@ xgb.importance <- function(model = NULL, feature_names = getinfo(model, "feature n_classes <- 0 } importance <- if (n_classes == 0) { - data.table(Feature = results$features, Weight = results$weight)[order(-abs(Weight))] + return(data.table(Feature = results$features, Weight = results$weight)[order(-abs(Weight))]) } else { - data.table( + out <- data.table( Feature = rep(results$features, each = n_classes), Weight = results$weight, Class = seq_len(n_classes) - 1 )[order(Class, -abs(Weight))] + if (inherits(model, "xgboost") && NROW(attributes(model)$metadata$y_levels)) { + class_vec <- out$Class + class_vec <- as.integer(class_vec) + 1L + attributes(class_vec)$levels <- attributes(model)$metadata$y_levels + attributes(class_vec)$class <- "factor" + out[, Class := class_vec] + } + return(out[]) } } else { concatenated <- list() diff --git a/R-package/man/xgb.importance.Rd b/R-package/man/xgb.importance.Rd index f26067d7fef9..503056737dd1 100644 --- a/R-package/man/xgb.importance.Rd +++ b/R-package/man/xgb.importance.Rd @@ -48,7 +48,9 @@ For a linear model: \itemize{ \item \code{Features}: Names of the features used in the model. \item \code{Weight}: Linear coefficient of this feature. -\item \code{Class}: Class label (only for multiclass models). +\item \code{Class}: Class label (only for multiclass models). For objects of class \code{xgboost} (as +produced by \code{\link[=xgboost]{xgboost()}}), it will be a \code{factor}, while for objects of class \code{xgb.Booster} +(as produced by \code{\link[=xgb.train]{xgb.train()}}), it will be a zero-based integer vector. } If \code{feature_names} is not provided and \code{model} doesn't have \code{feature_names}, diff --git a/R-package/tests/testthat/test_xgboost.R b/R-package/tests/testthat/test_xgboost.R index f3278cf37ea2..f7d9c367b8b9 100644 --- a/R-package/tests/testthat/test_xgboost.R +++ b/R-package/tests/testthat/test_xgboost.R @@ -1013,3 +1013,20 @@ test_that("'eval_set' as fraction works", { expect_true(hasName(evaluation_log, "eval_mlogloss")) expect_equal(length(attributes(model)$metadata$y_levels), 3L) }) + +test_that("Linear booster importance uses class names", { + y <- iris$Species + x <- iris[, -5L] + model <- xgboost( + x, + y, + nthreads = 1L, + nrounds = 4L, + verbosity = 0L, + booster = "gblinear", + learning_rate = 0.2 + ) + imp <- xgb.importance(model) + expect_true(is.factor(imp$Class)) + expect_equal(levels(imp$Class), levels(y)) +})