Skip to content

Commit

Permalink
more usage of array interface, fix potential memory leaks of std::str…
Browse files Browse the repository at this point in the history
…ing (#9824)
  • Loading branch information
david-cortes authored Nov 30, 2023
1 parent 37da66f commit 95af5c0
Showing 1 changed file with 41 additions and 20 deletions.
61 changes: 41 additions & 20 deletions R-package/src/xgboost_R.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,32 @@ namespace {
return "";
}

[[nodiscard]] std::string MakeArrayInterfaceFromRVector(SEXP R_vec) {
const size_t vec_len = Rf_xlength(R_vec);

// Lambda for type dispatch.
auto make_vec = [=](auto const *ptr) {
using namespace xgboost; // NOLINT
auto v = linalg::MakeVec(ptr, vec_len);
return linalg::ArrayInterfaceStr(v);
};

const SEXPTYPE arr_type = TYPEOF(R_vec);
switch (arr_type) {
case REALSXP:
return make_vec(REAL(R_vec));
case INTSXP:
return make_vec(INTEGER(R_vec));
case LGLSXP:
return make_vec(LOGICAL(R_vec));
default:
LOG(FATAL) << "Array or matrix has unsupported type.";
}

LOG(FATAL) << "Not reachable";
return "";
}

[[nodiscard]] std::string MakeJsonConfigForArray(SEXP missing, SEXP n_threads, SEXPTYPE arr_type) {
using namespace ::xgboost; // NOLINT
Json jconfig{Object{}};
Expand Down Expand Up @@ -159,12 +185,15 @@ XGB_DLL SEXP XGDMatrixCreateFromMat_R(SEXP mat, SEXP missing, SEXP n_threads) {
SEXP ret = PROTECT(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
R_API_BEGIN();

auto array_str = MakeArrayInterfaceFromRMat(mat);
auto config_str = MakeJsonConfigForArray(missing, n_threads, TYPEOF(mat));

DMatrixHandle handle;
CHECK_CALL(XGDMatrixCreateFromDense(array_str.c_str(), config_str.c_str(), &handle));
int res_code;
{
auto array_str = MakeArrayInterfaceFromRMat(mat);
auto config_str = MakeJsonConfigForArray(missing, n_threads, TYPEOF(mat));

res_code = XGDMatrixCreateFromDense(array_str.c_str(), config_str.c_str(), &handle);
}
CHECK_CALL(res_code);
R_SetExternalPtrAddr(ret, handle);
R_RegisterCFinalizerEx(ret, _DMatrixFinalizer, TRUE);
R_API_END();
Expand Down Expand Up @@ -279,23 +308,15 @@ XGB_DLL SEXP XGDMatrixSaveBinary_R(SEXP handle, SEXP fname, SEXP silent) {

XGB_DLL SEXP XGDMatrixSetInfo_R(SEXP handle, SEXP field, SEXP array) {
R_API_BEGIN();
int len = length(array);
const char *name = CHAR(asChar(field));
auto ctx = DMatrixCtx(R_ExternalPtrAddr(handle));
if (!strcmp("group", name)) {
std::vector<unsigned> vec(len);
xgboost::common::ParallelFor(len, ctx->Threads(), [&](xgboost::omp_ulong i) {
vec[i] = static_cast<unsigned>(INTEGER(array)[i]);
});
CHECK_CALL(
XGDMatrixSetUIntInfo(R_ExternalPtrAddr(handle), CHAR(asChar(field)), BeginPtr(vec), len));
} else {
std::vector<float> vec(len);
xgboost::common::ParallelFor(len, ctx->Threads(),
[&](xgboost::omp_ulong i) { vec[i] = REAL(array)[i]; });
CHECK_CALL(
XGDMatrixSetFloatInfo(R_ExternalPtrAddr(handle), CHAR(asChar(field)), BeginPtr(vec), len));
SEXP field_ = PROTECT(Rf_asChar(field));
int res_code;
{
const std::string array_str = MakeArrayInterfaceFromRVector(array);
res_code = XGDMatrixSetInfoFromInterface(
R_ExternalPtrAddr(handle), CHAR(field_), array_str.c_str());
}
CHECK_CALL(res_code);
UNPROTECT(1);
R_API_END();
return R_NilValue;
}
Expand Down

0 comments on commit 95af5c0

Please sign in to comment.