Skip to content

Commit

Permalink
don't allow unrecognized arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
david-cortes committed Dec 7, 2024
1 parent 2bff29d commit ea59be5
Show file tree
Hide file tree
Showing 60 changed files with 1,110 additions and 687 deletions.
14 changes: 7 additions & 7 deletions R-package/R/callbacks.R
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
#' will be the same as parameter `begin_iteration`, then next one will add +1, and so on).
#'
#' - iter_feval Evaluation metrics for `evals` that were supplied, either
#' determined by the objective, or by parameter `feval`.
#' determined by the objective, or by parameter `custom_metric`.
#'
#' For [xgb.train()], this will be a named vector with one entry per element in
#' `evals`, where the names are determined as 'evals name' + '-' + 'metric name' - for
Expand Down Expand Up @@ -451,7 +451,7 @@ xgb.cb.print.evaluation <- function(period = 1, showsd = TRUE) {
#' Callback for logging the evaluation history
#'
#' @details This callback creates a table with per-iteration evaluation metrics (see parameters
#' `evals` and `feval` in [xgb.train()]).
#' `evals` and `custom_metric` in [xgb.train()]).
#'
#' Note: in the column names of the final data.table, the dash '-' character is replaced with
#' the underscore '_' in order to make the column names more like regular R identifiers.
Expand Down Expand Up @@ -957,7 +957,7 @@ xgb.cb.cv.predict <- function(save_models = FALSE, outputmargin = FALSE) {
#' label = 1 * (iris$Species == "versicolor"),
#' nthread = nthread
#' )
#' param <- list(
#' param <- xgb.params(
#' booster = "gblinear",
#' objective = "reg:logistic",
#' eval_metric = "auc",
Expand All @@ -973,7 +973,7 @@ xgb.cb.cv.predict <- function(save_models = FALSE, outputmargin = FALSE) {
#' bst <- xgb.train(
#' param,
#' dtrain,
#' list(tr = dtrain),
#' evals = list(tr = dtrain),
#' nrounds = 200,
#' eta = 1.,
#' callbacks = list(xgb.cb.gblinear.history())
Expand All @@ -988,7 +988,7 @@ xgb.cb.cv.predict <- function(save_models = FALSE, outputmargin = FALSE) {
#' bst <- xgb.train(
#' param,
#' dtrain,
#' list(tr = dtrain),
#' evals = list(tr = dtrain),
#' nrounds = 200,
#' eta = 0.8,
#' updater = "coord_descent",
Expand Down Expand Up @@ -1017,7 +1017,7 @@ xgb.cb.cv.predict <- function(save_models = FALSE, outputmargin = FALSE) {
#' #### Multiclass classification:
#' dtrain <- xgb.DMatrix(scale(x), label = as.numeric(iris$Species) - 1, nthread = nthread)
#'
#' param <- list(
#' param <- xgb.params(
#' booster = "gblinear",
#' objective = "multi:softprob",
#' num_class = 3,
Expand All @@ -1031,7 +1031,7 @@ xgb.cb.cv.predict <- function(save_models = FALSE, outputmargin = FALSE) {
#' bst <- xgb.train(
#' param,
#' dtrain,
#' list(tr = dtrain),
#' evals = list(tr = dtrain),
#' nrounds = 50,
#' eta = 0.5,
#' callbacks = list(xgb.cb.gblinear.history())
Expand Down
195 changes: 115 additions & 80 deletions R-package/R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -71,21 +71,13 @@ NVL <- function(x, val) {

# Merges booster params with whatever is provided in ...
# plus runs some checks
check.booster.params <- function(params, ...) {
check.booster.params <- function(params) {
if (!identical(class(params), "list"))
stop("params must be a list")

# in R interface, allow for '.' instead of '_' in parameter names
names(params) <- gsub(".", "_", names(params), fixed = TRUE)

# merge parameters from the params and the dots-expansion
dot_params <- list(...)
names(dot_params) <- gsub(".", "_", names(dot_params), fixed = TRUE)
if (length(intersect(names(params),
names(dot_params))) > 0)
stop("Same parameters in 'params' and in the call are not allowed. Please check your 'params' list.")
params <- c(params, dot_params)

# providing a parameter multiple times makes sense only for 'eval_metric'
name_freqs <- table(names(params))
multi_names <- setdiff(names(name_freqs[name_freqs > 1]), 'eval_metric')
Expand All @@ -110,7 +102,6 @@ check.booster.params <- function(params, ...) {
}

# monotone_constraints parser

if (!is.null(params[['monotone_constraints']]) &&
typeof(params[['monotone_constraints']]) != "character") {
vec2str <- paste(params[['monotone_constraints']], collapse = ',')
Expand Down Expand Up @@ -144,55 +135,56 @@ check.booster.params <- function(params, ...) {


# Performs some checks related to custom objective function.
# WARNING: has side-effects and can modify 'params' and 'obj' in its calling frame
check.custom.obj <- function(env = parent.frame()) {
if (!is.null(env$params[['objective']]) && !is.null(env$obj))
stop("Setting objectives in 'params' and 'obj' at the same time is not allowed")
check.custom.obj <- function(params, objective) {
if (!is.null(params[['objective']]) && !is.null(objective))
stop("Setting objectives in 'params' and 'objective' at the same time is not allowed")

if (!is.null(env$obj) && typeof(env$obj) != 'closure')
stop("'obj' must be a function")
if (!is.null(objective) && typeof(objective) != 'closure')
stop("'objective' must be a function")

# handle the case when custom objective function was provided through params
if (!is.null(env$params[['objective']]) &&
typeof(env$params$objective) == 'closure') {
env$obj <- env$params$objective
env$params$objective <- NULL
if (!is.null(params[['objective']]) &&
typeof(params$objective) == 'closure') {
objective <- params$objective
params$objective <- NULL
}
return(list(params = params, objective = objective))
}

# Performs some checks related to custom evaluation function.
# WARNING: has side-effects and can modify 'params' and 'feval' in its calling frame
check.custom.eval <- function(env = parent.frame()) {
if (!is.null(env$params[['eval_metric']]) && !is.null(env$feval))
stop("Setting evaluation metrics in 'params' and 'feval' at the same time is not allowed")
check.custom.eval <- function(params, custom_metric, maximize, early_stopping_rounds, callbacks) {
if (!is.null(params[['eval_metric']]) && !is.null(custom_metric))
stop("Setting evaluation metrics in 'params' and 'custom_metric' at the same time is not allowed")

if (!is.null(env$feval) && typeof(env$feval) != 'closure')
stop("'feval' must be a function")
if (!is.null(custom_metric) && typeof(custom_metric) != 'closure')
stop("'custom_metric' must be a function")

# handle a situation when custom eval function was provided through params
if (!is.null(env$params[['eval_metric']]) &&
typeof(env$params$eval_metric) == 'closure') {
env$feval <- env$params$eval_metric
env$params$eval_metric <- NULL
if (!is.null(params[['eval_metric']]) &&
typeof(params$eval_metric) == 'closure') {
custom_metric <- params$eval_metric
params$eval_metric <- NULL
}

# require maximize to be set when custom feval and early stopping are used together
if (!is.null(env$feval) &&
is.null(env$maximize) && (
!is.null(env$early_stopping_rounds) ||
has.callbacks(env$callbacks, "early_stop")))
# require maximize to be set when custom metric and early stopping are used together
if (!is.null(custom_metric) &&
is.null(maximize) && (
!is.null(early_stopping_rounds) ||
has.callbacks(callbacks, "early_stop")))
stop("Please set 'maximize' to indicate whether the evaluation metric needs to be maximized or not")

return(list(params = params, custom_metric = custom_metric))
}


# Update a booster handle for an iteration with dtrain data
xgb.iter.update <- function(bst, dtrain, iter, obj) {
xgb.iter.update <- function(bst, dtrain, iter, objective) {
if (!inherits(dtrain, "xgb.DMatrix")) {
stop("dtrain must be of xgb.DMatrix class")
}
handle <- xgb.get.handle(bst)

if (is.null(obj)) {
if (is.null(objective)) {
.Call(XGBoosterUpdateOneIter_R, handle, as.integer(iter), dtrain)
} else {
pred <- predict(
Expand All @@ -201,12 +193,12 @@ xgb.iter.update <- function(bst, dtrain, iter, obj) {
outputmargin = TRUE,
training = TRUE
)
gpair <- obj(pred, dtrain)
n_samples <- dim(dtrain)[1]
gpair <- objective(pred, dtrain)
n_samples <- dim(dtrain)[1L]
grad <- gpair$grad
hess <- gpair$hess

if ((is.matrix(grad) && dim(grad)[1] != n_samples) ||
if ((is.matrix(grad) && dim(grad)[1L] != n_samples) ||
(is.vector(grad) && length(grad) != n_samples) ||
(is.vector(grad) != is.vector(hess))) {
warning(paste(
Expand All @@ -230,14 +222,14 @@ xgb.iter.update <- function(bst, dtrain, iter, obj) {
# Evaluate one iteration.
# Returns a named vector of evaluation metrics
# with the names in a 'datasetname-metricname' format.
xgb.iter.eval <- function(bst, evals, iter, feval) {
xgb.iter.eval <- function(bst, evals, iter, custom_metric) {
handle <- xgb.get.handle(bst)

if (length(evals) == 0)
return(NULL)

evnames <- names(evals)
if (is.null(feval)) {
if (is.null(custom_metric)) {
msg <- .Call(XGBoosterEvalOneIter_R, handle, as.integer(iter), evals, as.list(evnames))
mat <- matrix(strsplit(msg, '\\s+|:')[[1]][-1], nrow = 2)
res <- structure(as.numeric(mat[2, ]), names = mat[1, ])
Expand All @@ -246,7 +238,7 @@ xgb.iter.eval <- function(bst, evals, iter, feval) {
w <- evals[[j]]
## predict using all trees
preds <- predict(bst, w, outputmargin = TRUE, iterationrange = "all")
eval_res <- feval(preds, w)
eval_res <- custom_metric(preds, w)
out <- eval_res$value
names(out) <- paste0(evnames[j], "-", eval_res$metric)
out
Expand Down Expand Up @@ -498,11 +490,13 @@ NULL
#'
#' bst <- xgb.train(
#' data = xgb.DMatrix(agaricus.train$data, label = agaricus.train$label),
#' max_depth = 2,
#' eta = 1,
#' nthread = 2,
#' nrounds = 2,
#' objective = "binary:logistic"
#' params = xgb.params(
#' max_depth = 2,
#' eta = 1,
#' nthread = 2,
#' objective = "binary:logistic"
#' )
#' )
#'
#' # Save as a stand-alone file; load it with xgb.load()
Expand Down Expand Up @@ -535,44 +529,85 @@ NULL
NULL

# Lookup table for the deprecated parameters bookkeeping
depr_par_lut <- matrix(c(
'print.every.n', 'print_every_n',
'early.stop.round', 'early_stopping_rounds',
'training.data', 'data',
'with.stats', 'with_stats',
'numberOfClusters', 'n_clusters',
'features.keep', 'features_keep',
'plot.height', 'plot_height',
'plot.width', 'plot_width',
'n_first_tree', 'trees',
'dummy', 'DUMMY',
'watchlist', 'evals'
), ncol = 2, byrow = TRUE)
colnames(depr_par_lut) <- c('old', 'new')
deprecated_train_params <- list(
'print.every.n' = 'print_every_n',
'early.stop.round' = 'early_stopping_rounds',
'training.data' = 'data',
'dtrain' = 'data',
'watchlist' = 'evals',
'feval' = 'custom_metric'
)
deprecated_dttree_params <- list(
'n_first_tree' = 'trees'
)
deprecated_plot_params <- list(
'plot.height' = 'plot_height',
'plot.width' = 'plot_width'
)
deprecated_multitrees_params <- c(
deprecated_plot_params,
list('features.keep' = 'features_keep')
)
deprecated_dump_params <- list(
'with.stats' = 'with_stats'
)
deprecated_plottree_params <- c(
deprecated_plot_params,
deprecated_dump_params
)

# Checks the dot-parameters for deprecated names
# (including partial matching), gives a deprecation warning,
# and sets new parameters to the old parameters' values within its parent frame.
# WARNING: has side-effects
check.deprecation <- function(..., env = parent.frame()) {
pars <- list(...)
# exact and partial matches
all_match <- pmatch(names(pars), depr_par_lut[, 1])
# indices of matched pars' names
idx_pars <- which(!is.na(all_match))
if (length(idx_pars) == 0) return()
# indices of matched LUT rows
idx_lut <- all_match[idx_pars]
# which of idx_lut were the exact matches?
ex_match <- depr_par_lut[idx_lut, 1] %in% names(pars)
for (i in seq_along(idx_pars)) {
pars_par <- names(pars)[idx_pars[i]]
old_par <- depr_par_lut[idx_lut[i], 1]
new_par <- depr_par_lut[idx_lut[i], 2]
if (!ex_match[i]) {
warning("'", pars_par, "' was partially matched to '", old_par, "'")
check.deprecation <- function(
deprecated_list,
fn_call,
...,
env = parent.frame(),
allow_unrecognized = FALSE
) {
params <- list(...)
if (length(params) == 0) {
return(NULL)
}
if (is.null(names(params)) || min(nchar(names(params))) == 0L) {
stop("Passed invalid positional arguments")
}
all_match <- pmatch(names(params), names(deprecated_list))
# throw error on unrecognized parameters
if (!allow_unrecognized && anyNA(all_match)) {
names_unrecognized <- names(params)[is.na(all_match)]
# make it informative if they match something that goes under 'params'
if (deprecated_list[[1L]] == deprecated_train_params[[1L]]) {
names_params <- formalArgs(xgb.params)
names_params <- c(names_params, gsub("_", ".", names_params))
names_under_params <- intersect(names_unrecognized, names_params)
if (length(names_under_params)) {
stop(
"Passed invalid function arguments: ",
paste(head(names_under_params), collapse = ", "),
". These should be passed as a list to argument 'params'."
)
}
}
# otherwise throw a generic error
stop(
"Passed unrecognized parameters: ",
paste(head(names_unrecognized), collapse = ", ")
)
}

matched_params <- deprecated_list[all_match[!is.na(all_match)]]
idx_orig <- seq_along(params)[!is.na(all_match)]
function_args_passed <- names(as.list(fn_call))[-1L]
for (idx in seq_along(matched_params)) {
match_old <- names(matched_params)[[idx]]
match_new <- matched_params[[idx]]
warning("Parameter '", match_old, "' has been renamed to '", match_new, "'.")
if (match_new %in% function_args_passed) {
stop("Passed both '", match_new, "' and '", match_old, "'.")
}
.Deprecated(new_par, old = old_par, package = 'xgboost')
stop()
env[[match_new]] <- params[[idx_orig[idx]]]
}
}
Loading

0 comments on commit ea59be5

Please sign in to comment.