Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[R] Don't cap global number of threads #10028

Merged
merged 4 commits into from
Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion R-package/DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ Suggests:
testthat,
igraph (>= 1.0.1),
float,
titanic
titanic,
RhpcBLASctl
Depends:
R (>= 4.3.0)
Imports:
Expand Down
1 change: 1 addition & 0 deletions R-package/R/xgb.DMatrix.save.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#' @param fname the name of the file to write.
#'
#' @examples
#' \dontshow{RhpcBLASctl::omp_set_num_threads(1)}
#' data(agaricus.train, package='xgboost')
#' dtrain <- with(agaricus.train, xgb.DMatrix(data, label = label, nthread = 2))
#' fname <- file.path(tempdir(), "xgb.DMatrix.data")
Expand Down
7 changes: 7 additions & 0 deletions R-package/R/xgb.config.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,14 @@
#' values of one or more global-scope parameters. Use \code{xgb.get.config} to fetch the current
#' values of all global-scope parameters (listed in
#' \url{https://xgboost.readthedocs.io/en/stable/parameter.html}).
#' @details
#' Note that serialization-related functions might use a globally-configured number of threads,
#' which is managed by the system's OpenMP (OMP) configuration instead. Typically, XGBoost methods
#' accept an `nthreads` parameter, but some methods like `readRDS` might get executed before such
#' parameter can be supplied.
#'
#' The number of OMP threads can in turn be configured for example through an environment variable
#' `OMP_NUM_THREADS` (needs to be set before R is started), or through `RhpcBLASctl::omp_set_num_threads`.
#' @rdname xgbConfig
#' @title Set and get global configuration
#' @name xgb.set.config, xgb.get.config
Expand Down
1 change: 1 addition & 0 deletions R-package/R/xgb.dump.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#' as a \code{character} vector. Otherwise it will return \code{TRUE}.
#'
#' @examples
#' \dontshow{RhpcBLASctl::omp_set_num_threads(1)}
#' data(agaricus.train, package='xgboost')
#' data(agaricus.test, package='xgboost')
#' train <- agaricus.train
Expand Down
1 change: 1 addition & 0 deletions R-package/R/xgb.load.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#' \code{\link{xgb.save}}
#'
#' @examples
#' \dontshow{RhpcBLASctl::omp_set_num_threads(1)}
#' data(agaricus.train, package='xgboost')
#' data(agaricus.test, package='xgboost')
#'
Expand Down
1 change: 1 addition & 0 deletions R-package/R/xgb.save.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#' \code{\link{xgb.load}}
#'
#' @examples
#' \dontshow{RhpcBLASctl::omp_set_num_threads(1)}
#' data(agaricus.train, package='xgboost')
#' data(agaricus.test, package='xgboost')
#'
Expand Down
1 change: 1 addition & 0 deletions R-package/R/xgb.save.raw.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#' }
#'
#' @examples
#' \dontshow{RhpcBLASctl::omp_set_num_threads(1)}
#' data(agaricus.train, package='xgboost')
#' data(agaricus.test, package='xgboost')
#'
Expand Down
2 changes: 2 additions & 0 deletions R-package/demo/basic_walkthrough.R
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ print(paste("test-error=", err))
# save model to binary local file
xgb.save(bst, "xgboost.model")
# load binary model to R
# Function doesn't take 'nthreads', but can be set like this:
RhpcBLASctl::omp_set_num_threads(1)
trivialfis marked this conversation as resolved.
Show resolved Hide resolved
bst2 <- xgb.load("xgboost.model")
pred2 <- predict(bst2, test$data)
# pred2 should be identical to pred
Expand Down
1 change: 1 addition & 0 deletions R-package/man/xgb.DMatrix.save.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions R-package/man/xgb.dump.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions R-package/man/xgb.load.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions R-package/man/xgb.save.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions R-package/man/xgb.save.raw.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 9 additions & 0 deletions R-package/man/xgbConfig.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions R-package/tests/helper_scripts/install_deps.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ pkgs <- c(
"igraph",
"float",
"titanic",
"RhpcBLASctl",
## imports
"Matrix",
"methods",
Expand Down
1 change: 1 addition & 0 deletions R-package/tests/testthat.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ library(testthat)
library(xgboost)

test_check("xgboost", reporter = ProgressReporter)
RhpcBLASctl::omp_set_num_threads(1)
3 changes: 3 additions & 0 deletions R-package/vignettes/xgboostPresentation.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,9 @@ An interesting test to see how identical our saved model is to the original one

```{r loadModel, message=F, warning=F}
# load binary model to R
# Note that the number of threads for 'xgb.load' is taken from global config,
# can be modified like this:
RhpcBLASctl::omp_set_num_threads(1)
bst2 <- xgb.load(fname)
xgb.parameters(bst2) <- list(nthread = 2)
pred2 <- predict(bst2, test$data)
Expand Down
21 changes: 2 additions & 19 deletions src/gbm/gbtree_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,30 +106,13 @@ void GBTreeModel::Load(dmlc::Stream* fi) {
Validate(*this);
}

namespace {
std::int32_t IOThreads(Context const* ctx) {
CHECK(ctx);
std::int32_t n_threads = ctx->Threads();
// CRAN checks for number of threads used by examples, but we might not have the right
// number of threads when serializing/unserializing models as nthread is a booster
// parameter, which is only effective after booster initialization.
//
// The threshold ratio of CPU time to user time for R is 2.5, we set the number of
// threads to 2.
#if defined(XGBOOST_STRICT_R_MODE) && XGBOOST_STRICT_R_MODE == 1
n_threads = std::min(2, n_threads);
#endif
return n_threads;
}
} // namespace

void GBTreeModel::SaveModel(Json* p_out) const {
auto& out = *p_out;
CHECK_EQ(param.num_trees, static_cast<int>(trees.size()));
out["gbtree_model_param"] = ToJson(param);
std::vector<Json> trees_json(trees.size());

common::ParallelFor(trees.size(), IOThreads(ctx_), [&](auto t) {
common::ParallelFor(trees.size(), ctx_->Threads(), [&](auto t) {
auto const& tree = trees[t];
Json jtree{Object{}};
tree->SaveModel(&jtree);
Expand Down Expand Up @@ -167,7 +150,7 @@ void GBTreeModel::LoadModel(Json const& in) {
CHECK_EQ(tree_info_json.size(), param.num_trees);
tree_info.resize(param.num_trees);

common::ParallelFor(param.num_trees, IOThreads(ctx_), [&](auto t) {
common::ParallelFor(param.num_trees, ctx_->Threads(), [&](auto t) {
auto tree_id = get<Integer const>(trees_json[t]["id"]);
trees.at(tree_id).reset(new RegTree{});
trees[tree_id]->LoadModel(trees_json[t]);
Expand Down
Loading