diff --git a/R-package/src/xgboost_R.cc b/R-package/src/xgboost_R.cc index 8742a2271353..a82913819565 100644 --- a/R-package/src/xgboost_R.cc +++ b/R-package/src/xgboost_R.cc @@ -59,6 +59,32 @@ namespace { return ""; } +[[nodiscard]] std::string MakeArrayInterfaceFromRVector(SEXP R_vec) { + const size_t vec_len = Rf_xlength(R_vec); + + // Lambda for type dispatch. + auto make_vec = [=](auto const *ptr) { + using namespace xgboost; // NOLINT + auto v = linalg::MakeVec(ptr, vec_len); + return linalg::ArrayInterfaceStr(v); + }; + + const SEXPTYPE arr_type = TYPEOF(R_vec); + switch (arr_type) { + case REALSXP: + return make_vec(REAL(R_vec)); + case INTSXP: + return make_vec(INTEGER(R_vec)); + case LGLSXP: + return make_vec(LOGICAL(R_vec)); + default: + LOG(FATAL) << "Array or matrix has unsupported type."; + } + + LOG(FATAL) << "Not reachable"; + return ""; +} + [[nodiscard]] std::string MakeJsonConfigForArray(SEXP missing, SEXP n_threads, SEXPTYPE arr_type) { using namespace ::xgboost; // NOLINT Json jconfig{Object{}}; @@ -159,12 +185,15 @@ XGB_DLL SEXP XGDMatrixCreateFromMat_R(SEXP mat, SEXP missing, SEXP n_threads) { SEXP ret = PROTECT(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue)); R_API_BEGIN(); - auto array_str = MakeArrayInterfaceFromRMat(mat); - auto config_str = MakeJsonConfigForArray(missing, n_threads, TYPEOF(mat)); - DMatrixHandle handle; - CHECK_CALL(XGDMatrixCreateFromDense(array_str.c_str(), config_str.c_str(), &handle)); + int res_code; + { + auto array_str = MakeArrayInterfaceFromRMat(mat); + auto config_str = MakeJsonConfigForArray(missing, n_threads, TYPEOF(mat)); + res_code = XGDMatrixCreateFromDense(array_str.c_str(), config_str.c_str(), &handle); + } + CHECK_CALL(res_code); R_SetExternalPtrAddr(ret, handle); R_RegisterCFinalizerEx(ret, _DMatrixFinalizer, TRUE); R_API_END(); @@ -279,23 +308,15 @@ XGB_DLL SEXP XGDMatrixSaveBinary_R(SEXP handle, SEXP fname, SEXP silent) { XGB_DLL SEXP XGDMatrixSetInfo_R(SEXP handle, SEXP field, SEXP array) { R_API_BEGIN(); - int len = length(array); - const char *name = CHAR(asChar(field)); - auto ctx = DMatrixCtx(R_ExternalPtrAddr(handle)); - if (!strcmp("group", name)) { - std::vector vec(len); - xgboost::common::ParallelFor(len, ctx->Threads(), [&](xgboost::omp_ulong i) { - vec[i] = static_cast(INTEGER(array)[i]); - }); - CHECK_CALL( - XGDMatrixSetUIntInfo(R_ExternalPtrAddr(handle), CHAR(asChar(field)), BeginPtr(vec), len)); - } else { - std::vector vec(len); - xgboost::common::ParallelFor(len, ctx->Threads(), - [&](xgboost::omp_ulong i) { vec[i] = REAL(array)[i]; }); - CHECK_CALL( - XGDMatrixSetFloatInfo(R_ExternalPtrAddr(handle), CHAR(asChar(field)), BeginPtr(vec), len)); + SEXP field_ = PROTECT(Rf_asChar(field)); + int res_code; + { + const std::string array_str = MakeArrayInterfaceFromRVector(array); + res_code = XGDMatrixSetInfoFromInterface( + R_ExternalPtrAddr(handle), CHAR(field_), array_str.c_str()); } + CHECK_CALL(res_code); + UNPROTECT(1); R_API_END(); return R_NilValue; }