Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[R] Ensure ProxyDMatrix creation keeps data until next iteration #11092

Merged
merged 4 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion R-package/R/xgb.DMatrix.R
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,6 @@ xgb.ProxyDMatrix <- function(proxy_handle, data_iterator) {
tmp <- .process.df.for.dmatrix(lst$data, lst$feature_types)
lst$feature_types <- tmp$feature_types
.Call(XGProxyDMatrixSetDataColumnar_R, proxy_handle, tmp$lst)
rm(tmp)
} else if (is.matrix(lst$data)) {
.Call(XGProxyDMatrixSetDataDense_R, proxy_handle, lst$data)
} else if (inherits(lst$data, "dgRMatrix")) {
Expand Down
12 changes: 9 additions & 3 deletions R-package/src/xgboost_R.cc
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,7 @@ XGB_DLL SEXP XGProxyDMatrixSetDataDense_R(SEXP handle, SEXP R_mat) {
{
std::string array_str = MakeArrayInterfaceFromRMat(R_mat);
res_code = XGProxyDMatrixSetDataDense(proxy_dmat, array_str.c_str());
R_SetExternalPtrProtected(handle, R_mat);
}
CHECK_CALL(res_code);
R_API_END();
Expand All @@ -707,6 +708,7 @@ XGB_DLL SEXP XGProxyDMatrixSetDataCSR_R(SEXP handle, SEXP lst) {
array_str_indices.c_str(),
array_str_data.c_str(),
ncol);
R_SetExternalPtrProtected(handle, lst);
}
CHECK_CALL(res_code);
R_API_END();
Expand All @@ -720,6 +722,7 @@ XGB_DLL SEXP XGProxyDMatrixSetDataColumnar_R(SEXP handle, SEXP lst) {
{
std::string sinterface = MakeArrayInterfaceFromRDataFrame(lst);
res_code = XGProxyDMatrixSetDataColumnar(proxy_dmat, sinterface.c_str());
R_SetExternalPtrProtected(handle, lst);
}
CHECK_CALL(res_code);
R_API_END();
Expand All @@ -733,17 +736,20 @@ struct _RDataIterator {
SEXP f_reset;
SEXP calling_env;
SEXP continuation_token;
SEXP proxy_dmat;

_RDataIterator(
SEXP f_next, SEXP f_reset, SEXP calling_env, SEXP continuation_token) :
SEXP f_next, SEXP f_reset, SEXP calling_env, SEXP continuation_token, SEXP proxy_dmat) :
f_next(f_next), f_reset(f_reset), calling_env(calling_env),
continuation_token(continuation_token) {}
continuation_token(continuation_token), proxy_dmat(proxy_dmat) {}

void reset() {
R_SetExternalPtrProtected(this->proxy_dmat, R_NilValue);
SafeExecFun(this->f_reset, this->calling_env, this->continuation_token);
}

int next() {
R_SetExternalPtrProtected(this->proxy_dmat, R_NilValue);
SEXP R_res = Rf_protect(
SafeExecFun(this->f_next, this->calling_env, this->continuation_token));
int res = Rf_asInteger(R_res);
Expand Down Expand Up @@ -771,7 +777,7 @@ SEXP XGDMatrixCreateFromCallbackGeneric_R(

int res_code;
try {
_RDataIterator data_iterator(f_next, f_reset, calling_env, continuation_token);
_RDataIterator data_iterator(f_next, f_reset, calling_env, continuation_token, proxy_dmat);

std::string str_cache_prefix;
xgboost::Json jconfig{xgboost::Object{}};
Expand Down
Loading