diff --git a/R-package/R/xgb.DMatrix.R b/R-package/R/xgb.DMatrix.R index ce06f5c50370..66bd7205570b 100644 --- a/R-package/R/xgb.DMatrix.R +++ b/R-package/R/xgb.DMatrix.R @@ -569,7 +569,6 @@ xgb.ProxyDMatrix <- function(proxy_handle, data_iterator) { tmp <- .process.df.for.dmatrix(lst$data, lst$feature_types) lst$feature_types <- tmp$feature_types .Call(XGProxyDMatrixSetDataColumnar_R, proxy_handle, tmp$lst) - rm(tmp) } else if (is.matrix(lst$data)) { .Call(XGProxyDMatrixSetDataDense_R, proxy_handle, lst$data) } else if (inherits(lst$data, "dgRMatrix")) { diff --git a/R-package/src/xgboost_R.cc b/R-package/src/xgboost_R.cc index 0e7234a18708..adb9649bf33d 100644 --- a/R-package/src/xgboost_R.cc +++ b/R-package/src/xgboost_R.cc @@ -687,6 +687,7 @@ XGB_DLL SEXP XGProxyDMatrixSetDataDense_R(SEXP handle, SEXP R_mat) { { std::string array_str = MakeArrayInterfaceFromRMat(R_mat); res_code = XGProxyDMatrixSetDataDense(proxy_dmat, array_str.c_str()); + R_SetExternalPtrProtected(handle, R_mat); } CHECK_CALL(res_code); R_API_END(); @@ -707,6 +708,7 @@ XGB_DLL SEXP XGProxyDMatrixSetDataCSR_R(SEXP handle, SEXP lst) { array_str_indices.c_str(), array_str_data.c_str(), ncol); + R_SetExternalPtrProtected(handle, lst); } CHECK_CALL(res_code); R_API_END(); @@ -720,6 +722,7 @@ XGB_DLL SEXP XGProxyDMatrixSetDataColumnar_R(SEXP handle, SEXP lst) { { std::string sinterface = MakeArrayInterfaceFromRDataFrame(lst); res_code = XGProxyDMatrixSetDataColumnar(proxy_dmat, sinterface.c_str()); + R_SetExternalPtrProtected(handle, lst); } CHECK_CALL(res_code); R_API_END(); @@ -733,17 +736,20 @@ struct _RDataIterator { SEXP f_reset; SEXP calling_env; SEXP continuation_token; + SEXP proxy_dmat; _RDataIterator( - SEXP f_next, SEXP f_reset, SEXP calling_env, SEXP continuation_token) : + SEXP f_next, SEXP f_reset, SEXP calling_env, SEXP continuation_token, SEXP proxy_dmat) : f_next(f_next), f_reset(f_reset), calling_env(calling_env), - continuation_token(continuation_token) {} + continuation_token(continuation_token), proxy_dmat(proxy_dmat) {} void reset() { + R_SetExternalPtrProtected(this->proxy_dmat, R_NilValue); SafeExecFun(this->f_reset, this->calling_env, this->continuation_token); } int next() { + R_SetExternalPtrProtected(this->proxy_dmat, R_NilValue); SEXP R_res = Rf_protect( SafeExecFun(this->f_next, this->calling_env, this->continuation_token)); int res = Rf_asInteger(R_res); @@ -771,7 +777,7 @@ SEXP XGDMatrixCreateFromCallbackGeneric_R( int res_code; try { - _RDataIterator data_iterator(f_next, f_reset, calling_env, continuation_token); + _RDataIterator data_iterator(f_next, f_reset, calling_env, continuation_token, proxy_dmat); std::string str_cache_prefix; xgboost::Json jconfig{xgboost::Object{}};