Skip to content

Commit

Permalink
[R] Use inplace predict (#9829)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Hyunsu Cho <[email protected]>
  • Loading branch information
david-cortes and hcho3 authored Feb 23, 2024
1 parent 729fd97 commit f7005d3
Show file tree
Hide file tree
Showing 7 changed files with 450 additions and 46 deletions.
137 changes: 119 additions & 18 deletions R-package/R/xgb.Booster.R
Original file line number Diff line number Diff line change
Expand Up @@ -77,26 +77,45 @@ xgb.get.handle <- function(object) {

#' Predict method for XGBoost model
#'
#' Predicted values based on either xgboost model or model handle object.
#' Predict values on data based on xgboost model.
#'
#' @param object Object of class `xgb.Booster`.
#' @param newdata Takes `matrix`, `dgCMatrix`, `dgRMatrix`, `dsparseVector`,
#' @param newdata Takes `data.frame`, `matrix`, `dgCMatrix`, `dgRMatrix`, `dsparseVector`,
#' local data file, or `xgb.DMatrix`.
#' For single-row predictions on sparse data, it is recommended to use the CSR format.
#' If passing a sparse vector, it will take it as a row vector.
#' @param missing Only used when input is a dense matrix. Pick a float value that represents
#' missing values in data (e.g., 0 or some other extreme value).
#'
#' For single-row predictions on sparse data, it's recommended to use CSR format. If passing
#' a sparse vector, it will take it as a row vector.
#'
#' Note that, for repeated predictions on the same data, one might want to create a DMatrix to
#' pass here instead of passing R types like matrices or data frames, as predictions will be
#' faster on DMatrix.
#'
#' If `newdata` is a `data.frame`, be aware that:\itemize{
#' \item Columns will be converted to numeric if they aren't already, which could potentially make
#' the operation slower than in an equivalent `matrix` object.
#' \item The order of the columns must match with that of the data from which the model was fitted
#' (i.e. columns will not be referenced by their names, just by their order in the data).
#' \item If the model was fitted to data with categorical columns, these columns must be of
#' `factor` type here, and must use the same encoding (i.e. have the same levels).
#' \item If `newdata` contains any `factor` columns, they will be converted to base-0
#' encoding (same as during DMatrix creation) - hence, one should not pass a `factor`
#' under a column which during training had a different type.
#' }
#' @param missing Float value that represents missing values in data (e.g., 0 or some other extreme value).
#'
#' This parameter is not used when `newdata` is an `xgb.DMatrix` - in such cases, should pass
#' this as an argument to the DMatrix constructor instead.
#' @param outputmargin Whether the prediction should be returned in the form of original untransformed
#' sum of predictions from boosting iterations' results. E.g., setting `outputmargin=TRUE` for
#' logistic regression would return log-odds instead of probabilities.
#' @param predleaf Whether to predict pre-tree leaf indices.
#' @param predleaf Whether to predict per-tree leaf indices.
#' @param predcontrib Whether to return feature contributions to individual predictions (see Details).
#' @param approxcontrib Whether to use a fast approximation for feature contributions (see Details).
#' @param predinteraction Whether to return contributions of feature interactions to individual predictions (see Details).
#' @param reshape Whether to reshape the vector of predictions to matrix form when there are several
#' prediction outputs per case. No effect if `predleaf`, `predcontrib`,
#' or `predinteraction` is `TRUE`.
#' @param training Whether the predictions are used for training. For dart booster,
#' @param training Whether the prediction result is used for training. For dart booster,
#' training predicting will perform dropout.
#' @param iterationrange Sequence of rounds/iterations from the model to use for prediction, specified by passing
#' a two-dimensional vector with the start and end numbers in the sequence (same format as R's `seq` - i.e.
Expand All @@ -111,6 +130,12 @@ xgb.get.handle <- function(object) {
#' If passing "all", will use all of the rounds regardless of whether the model had early stopping or not.
#' @param strict_shape Default is `FALSE`. When set to `TRUE`, the output
#' type and shape of predictions are invariant to the model type.
#' @param base_margin Base margin used for boosting from existing model.
#'
#' Note that, if `newdata` is an `xgb.DMatrix` object, this argument will
#' be ignored as it needs to be added to the DMatrix instead (e.g. by passing it as
#' an argument in its constructor, or by calling \link{setinfo.xgb.DMatrix}).
#'
#' @param validate_features When `TRUE`, validate that the Booster's and newdata's feature_names
#' match (only applicable when both `object` and `newdata` have feature names).
#'
Expand Down Expand Up @@ -287,16 +312,80 @@ xgb.get.handle <- function(object) {
predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FALSE,
predleaf = FALSE, predcontrib = FALSE, approxcontrib = FALSE, predinteraction = FALSE,
reshape = FALSE, training = FALSE, iterationrange = NULL, strict_shape = FALSE,
validate_features = FALSE, ...) {
validate_features = FALSE, base_margin = NULL, ...) {
if (validate_features) {
newdata <- validate.features(object, newdata)
}
if (!inherits(newdata, "xgb.DMatrix")) {
is_dmatrix <- inherits(newdata, "xgb.DMatrix")
if (is_dmatrix && !is.null(base_margin)) {
stop(
"'base_margin' is not supported when passing 'xgb.DMatrix' as input.",
" Should be passed as argument to 'xgb.DMatrix' constructor."
)
}

use_as_df <- FALSE
use_as_dense_matrix <- FALSE
use_as_csr_matrix <- FALSE
n_row <- NULL
if (!is_dmatrix) {

inplace_predict_supported <- !predcontrib && !predinteraction && !predleaf
if (inplace_predict_supported) {
booster_type <- xgb.booster_type(object)
if (booster_type == "gblinear" || (booster_type == "dart" && training)) {
inplace_predict_supported <- FALSE
}
}
if (inplace_predict_supported) {

if (is.matrix(newdata)) {
use_as_dense_matrix <- TRUE
} else if (is.data.frame(newdata)) {
# note: since here it turns it into a non-data-frame list,
# needs to keep track of the number of rows it had for later
n_row <- nrow(newdata)
newdata <- lapply(
newdata,
function(x) {
if (is.factor(x)) {
return(as.numeric(x) - 1)
} else {
return(as.numeric(x))
}
}
)
use_as_df <- TRUE
} else if (inherits(newdata, "dgRMatrix")) {
use_as_csr_matrix <- TRUE
csr_data <- list(newdata@p, newdata@j, newdata@x, ncol(newdata))
} else if (inherits(newdata, "dsparseVector")) {
use_as_csr_matrix <- TRUE
n_row <- 1L
i <- newdata@i - 1L
if (storage.mode(i) != "integer") {
storage.mode(i) <- "integer"
}
csr_data <- list(c(0L, length(i)), i, newdata@x, length(newdata))
}

}

} # if (!is_dmatrix)

if (!is_dmatrix && !use_as_dense_matrix && !use_as_csr_matrix && !use_as_df) {
nthread <- xgb.nthread(object)
newdata <- xgb.DMatrix(
newdata,
missing = missing, nthread = NVL(nthread, -1)
missing = missing,
base_margin = base_margin,
nthread = NVL(nthread, -1)
)
is_dmatrix <- TRUE
}

if (is.null(n_row)) {
n_row <- nrow(newdata)
}


Expand Down Expand Up @@ -354,18 +443,30 @@ predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FA
args$type <- set_type(6)
}

predts <- .Call(
XGBoosterPredictFromDMatrix_R,
xgb.get.handle(object),
newdata,
jsonlite::toJSON(args, auto_unbox = TRUE)
)
json_conf <- jsonlite::toJSON(args, auto_unbox = TRUE)
if (is_dmatrix) {
predts <- .Call(
XGBoosterPredictFromDMatrix_R, xgb.get.handle(object), newdata, json_conf
)
} else if (use_as_dense_matrix) {
predts <- .Call(
XGBoosterPredictFromDense_R, xgb.get.handle(object), newdata, missing, json_conf, base_margin
)
} else if (use_as_csr_matrix) {
predts <- .Call(
XGBoosterPredictFromCSR_R, xgb.get.handle(object), csr_data, missing, json_conf, base_margin
)
} else if (use_as_df) {
predts <- .Call(
XGBoosterPredictFromColumnar_R, xgb.get.handle(object), newdata, missing, json_conf, base_margin
)
}

names(predts) <- c("shape", "results")
shape <- predts$shape
arr <- predts$results

n_ret <- length(arr)
n_row <- nrow(newdata)
if (n_row != shape[1]) {
stop("Incorrect predict shape.")
}
Expand Down
45 changes: 37 additions & 8 deletions R-package/man/predict.xgb.Booster.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions R-package/src/init.c
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ extern SEXP XGBoosterLoadJsonConfig_R(SEXP handle, SEXP value);
extern SEXP XGBoosterSerializeToBuffer_R(SEXP handle);
extern SEXP XGBoosterUnserializeFromBuffer_R(SEXP handle, SEXP raw);
extern SEXP XGBoosterPredictFromDMatrix_R(SEXP, SEXP, SEXP);
extern SEXP XGBoosterPredictFromDense_R(SEXP, SEXP, SEXP, SEXP, SEXP);
extern SEXP XGBoosterPredictFromCSR_R(SEXP, SEXP, SEXP, SEXP, SEXP);
extern SEXP XGBoosterPredictFromColumnar_R(SEXP, SEXP, SEXP, SEXP, SEXP);
extern SEXP XGBoosterSaveModel_R(SEXP, SEXP);
extern SEXP XGBoosterSetAttr_R(SEXP, SEXP, SEXP);
extern SEXP XGBoosterSetParam_R(SEXP, SEXP, SEXP);
Expand Down Expand Up @@ -96,6 +99,9 @@ static const R_CallMethodDef CallEntries[] = {
{"XGBoosterSerializeToBuffer_R", (DL_FUNC) &XGBoosterSerializeToBuffer_R, 1},
{"XGBoosterUnserializeFromBuffer_R", (DL_FUNC) &XGBoosterUnserializeFromBuffer_R, 2},
{"XGBoosterPredictFromDMatrix_R", (DL_FUNC) &XGBoosterPredictFromDMatrix_R, 3},
{"XGBoosterPredictFromDense_R", (DL_FUNC) &XGBoosterPredictFromDense_R, 5},
{"XGBoosterPredictFromCSR_R", (DL_FUNC) &XGBoosterPredictFromCSR_R, 5},
{"XGBoosterPredictFromColumnar_R", (DL_FUNC) &XGBoosterPredictFromColumnar_R, 5},
{"XGBoosterSaveModel_R", (DL_FUNC) &XGBoosterSaveModel_R, 2},
{"XGBoosterSetAttr_R", (DL_FUNC) &XGBoosterSetAttr_R, 3},
{"XGBoosterSetParam_R", (DL_FUNC) &XGBoosterSetParam_R, 3},
Expand Down
Loading

0 comments on commit f7005d3

Please sign in to comment.