Skip to content

Commit

Permalink
[R] Ensure ProxyDMatrix creation keeps data until next iteration (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
david-cortes authored Dec 11, 2024
1 parent 376009c commit 3162e0d
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 4 deletions.
1 change: 0 additions & 1 deletion R-package/R/xgb.DMatrix.R
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,6 @@ xgb.ProxyDMatrix <- function(proxy_handle, data_iterator) {
tmp <- .process.df.for.dmatrix(lst$data, lst$feature_types)
lst$feature_types <- tmp$feature_types
.Call(XGProxyDMatrixSetDataColumnar_R, proxy_handle, tmp$lst)
rm(tmp)
} else if (is.matrix(lst$data)) {
.Call(XGProxyDMatrixSetDataDense_R, proxy_handle, lst$data)
} else if (inherits(lst$data, "dgRMatrix")) {
Expand Down
12 changes: 9 additions & 3 deletions R-package/src/xgboost_R.cc
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,7 @@ XGB_DLL SEXP XGProxyDMatrixSetDataDense_R(SEXP handle, SEXP R_mat) {
{
std::string array_str = MakeArrayInterfaceFromRMat(R_mat);
res_code = XGProxyDMatrixSetDataDense(proxy_dmat, array_str.c_str());
R_SetExternalPtrProtected(handle, R_mat);
}
CHECK_CALL(res_code);
R_API_END();
Expand All @@ -707,6 +708,7 @@ XGB_DLL SEXP XGProxyDMatrixSetDataCSR_R(SEXP handle, SEXP lst) {
array_str_indices.c_str(),
array_str_data.c_str(),
ncol);
R_SetExternalPtrProtected(handle, lst);
}
CHECK_CALL(res_code);
R_API_END();
Expand All @@ -720,6 +722,7 @@ XGB_DLL SEXP XGProxyDMatrixSetDataColumnar_R(SEXP handle, SEXP lst) {
{
std::string sinterface = MakeArrayInterfaceFromRDataFrame(lst);
res_code = XGProxyDMatrixSetDataColumnar(proxy_dmat, sinterface.c_str());
R_SetExternalPtrProtected(handle, lst);
}
CHECK_CALL(res_code);
R_API_END();
Expand All @@ -733,17 +736,20 @@ struct _RDataIterator {
SEXP f_reset;
SEXP calling_env;
SEXP continuation_token;
SEXP proxy_dmat;

_RDataIterator(
SEXP f_next, SEXP f_reset, SEXP calling_env, SEXP continuation_token) :
SEXP f_next, SEXP f_reset, SEXP calling_env, SEXP continuation_token, SEXP proxy_dmat) :
f_next(f_next), f_reset(f_reset), calling_env(calling_env),
continuation_token(continuation_token) {}
continuation_token(continuation_token), proxy_dmat(proxy_dmat) {}

void reset() {
R_SetExternalPtrProtected(this->proxy_dmat, R_NilValue);
SafeExecFun(this->f_reset, this->calling_env, this->continuation_token);
}

int next() {
R_SetExternalPtrProtected(this->proxy_dmat, R_NilValue);
SEXP R_res = Rf_protect(
SafeExecFun(this->f_next, this->calling_env, this->continuation_token));
int res = Rf_asInteger(R_res);
Expand Down Expand Up @@ -771,7 +777,7 @@ SEXP XGDMatrixCreateFromCallbackGeneric_R(

int res_code;
try {
_RDataIterator data_iterator(f_next, f_reset, calling_env, continuation_token);
_RDataIterator data_iterator(f_next, f_reset, calling_env, continuation_token, proxy_dmat);

std::string str_cache_prefix;
xgboost::Json jconfig{xgboost::Object{}};
Expand Down

0 comments on commit 3162e0d

Please sign in to comment.