From 678860e09d239faa80337e84ac05cc77c2f694c1 Mon Sep 17 00:00:00 2001 From: david-cortes Date: Wed, 11 Dec 2024 18:36:59 +0100 Subject: [PATCH 1/4] ensure data is kept alive until next set or reset --- R-package/R/xgb.DMatrix.R | 1 - R-package/src/xgboost_R.cc | 3 +++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/R-package/R/xgb.DMatrix.R b/R-package/R/xgb.DMatrix.R index ce06f5c50370..66bd7205570b 100644 --- a/R-package/R/xgb.DMatrix.R +++ b/R-package/R/xgb.DMatrix.R @@ -569,7 +569,6 @@ xgb.ProxyDMatrix <- function(proxy_handle, data_iterator) { tmp <- .process.df.for.dmatrix(lst$data, lst$feature_types) lst$feature_types <- tmp$feature_types .Call(XGProxyDMatrixSetDataColumnar_R, proxy_handle, tmp$lst) - rm(tmp) } else if (is.matrix(lst$data)) { .Call(XGProxyDMatrixSetDataDense_R, proxy_handle, lst$data) } else if (inherits(lst$data, "dgRMatrix")) { diff --git a/R-package/src/xgboost_R.cc b/R-package/src/xgboost_R.cc index 0e7234a18708..2aa22b68a3c0 100644 --- a/R-package/src/xgboost_R.cc +++ b/R-package/src/xgboost_R.cc @@ -682,6 +682,7 @@ XGB_DLL SEXP XGProxyDMatrixCreate_R() { XGB_DLL SEXP XGProxyDMatrixSetDataDense_R(SEXP handle, SEXP R_mat) { R_API_BEGIN(); + R_SetExternalPtrProtected(handle, R_mat); DMatrixHandle proxy_dmat = R_ExternalPtrAddr(handle); int res_code; { @@ -695,6 +696,7 @@ XGB_DLL SEXP XGProxyDMatrixSetDataDense_R(SEXP handle, SEXP R_mat) { XGB_DLL SEXP XGProxyDMatrixSetDataCSR_R(SEXP handle, SEXP lst) { R_API_BEGIN(); + R_SetExternalPtrProtected(handle, lst); DMatrixHandle proxy_dmat = R_ExternalPtrAddr(handle); int res_code; { @@ -715,6 +717,7 @@ XGB_DLL SEXP XGProxyDMatrixSetDataCSR_R(SEXP handle, SEXP lst) { XGB_DLL SEXP XGProxyDMatrixSetDataColumnar_R(SEXP handle, SEXP lst) { R_API_BEGIN(); + R_SetExternalPtrProtected(handle, lst); DMatrixHandle proxy_dmat = R_ExternalPtrAddr(handle); int res_code; { From f0de83eeac914c92605ad766e3af54ad9a1d03db Mon Sep 17 00:00:00 2001 From: david-cortes Date: Wed, 11 Dec 2024 19:25:11 +0100 Subject: [PATCH 2/4] unprotect only after setting next batch --- R-package/src/xgboost_R.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/R-package/src/xgboost_R.cc b/R-package/src/xgboost_R.cc index 2aa22b68a3c0..9b8730da9d46 100644 --- a/R-package/src/xgboost_R.cc +++ b/R-package/src/xgboost_R.cc @@ -682,12 +682,12 @@ XGB_DLL SEXP XGProxyDMatrixCreate_R() { XGB_DLL SEXP XGProxyDMatrixSetDataDense_R(SEXP handle, SEXP R_mat) { R_API_BEGIN(); - R_SetExternalPtrProtected(handle, R_mat); DMatrixHandle proxy_dmat = R_ExternalPtrAddr(handle); int res_code; { std::string array_str = MakeArrayInterfaceFromRMat(R_mat); res_code = XGProxyDMatrixSetDataDense(proxy_dmat, array_str.c_str()); + R_SetExternalPtrProtected(handle, R_mat); } CHECK_CALL(res_code); R_API_END(); @@ -696,7 +696,6 @@ XGB_DLL SEXP XGProxyDMatrixSetDataDense_R(SEXP handle, SEXP R_mat) { XGB_DLL SEXP XGProxyDMatrixSetDataCSR_R(SEXP handle, SEXP lst) { R_API_BEGIN(); - R_SetExternalPtrProtected(handle, lst); DMatrixHandle proxy_dmat = R_ExternalPtrAddr(handle); int res_code; { @@ -709,6 +708,7 @@ XGB_DLL SEXP XGProxyDMatrixSetDataCSR_R(SEXP handle, SEXP lst) { array_str_indices.c_str(), array_str_data.c_str(), ncol); + R_SetExternalPtrProtected(handle, lst); } CHECK_CALL(res_code); R_API_END(); @@ -717,12 +717,12 @@ XGB_DLL SEXP XGProxyDMatrixSetDataCSR_R(SEXP handle, SEXP lst) { XGB_DLL SEXP XGProxyDMatrixSetDataColumnar_R(SEXP handle, SEXP lst) { R_API_BEGIN(); - R_SetExternalPtrProtected(handle, lst); DMatrixHandle proxy_dmat = R_ExternalPtrAddr(handle); int res_code; { std::string sinterface = MakeArrayInterfaceFromRDataFrame(lst); res_code = XGProxyDMatrixSetDataColumnar(proxy_dmat, sinterface.c_str()); + R_SetExternalPtrProtected(handle, lst); } CHECK_CALL(res_code); R_API_END(); From 34110fec3daecc5f6b273ffc9fc30b53b124f969 Mon Sep 17 00:00:00 2001 From: david-cortes Date: Wed, 11 Dec 2024 19:59:16 +0100 Subject: [PATCH 3/4] unprotect data before generating the next batch --- R-package/src/xgboost_R.cc | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/R-package/src/xgboost_R.cc b/R-package/src/xgboost_R.cc index 9b8730da9d46..1aca29b4a029 100644 --- a/R-package/src/xgboost_R.cc +++ b/R-package/src/xgboost_R.cc @@ -736,17 +736,20 @@ struct _RDataIterator { SEXP f_reset; SEXP calling_env; SEXP continuation_token; + SEXP proxy_dmat; _RDataIterator( - SEXP f_next, SEXP f_reset, SEXP calling_env, SEXP continuation_token) : + SEXP f_next, SEXP f_reset, SEXP calling_env, SEXP continuation_token, SEP proxy_dmat) : f_next(f_next), f_reset(f_reset), calling_env(calling_env), - continuation_token(continuation_token) {} + continuation_token(continuation_token), proxy_dmat(proxy_dmat) {} void reset() { + R_SetExternalPtrProtected(this->proxy_dmat, R_NilValue); SafeExecFun(this->f_reset, this->calling_env, this->continuation_token); } int next() { + R_SetExternalPtrProtected(this->proxy_dmat, R_NilValue); SEXP R_res = Rf_protect( SafeExecFun(this->f_next, this->calling_env, this->continuation_token)); int res = Rf_asInteger(R_res); @@ -774,7 +777,7 @@ SEXP XGDMatrixCreateFromCallbackGeneric_R( int res_code; try { - _RDataIterator data_iterator(f_next, f_reset, calling_env, continuation_token); + _RDataIterator data_iterator(f_next, f_reset, calling_env, continuation_token, proxy_dmat); std::string str_cache_prefix; xgboost::Json jconfig{xgboost::Object{}}; From 98bdc7fd330ecb1fd0e899637fd48a9a3d4de802 Mon Sep 17 00:00:00 2001 From: david-cortes Date: Wed, 11 Dec 2024 20:04:15 +0100 Subject: [PATCH 4/4] typo --- R-package/src/xgboost_R.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R-package/src/xgboost_R.cc b/R-package/src/xgboost_R.cc index 1aca29b4a029..adb9649bf33d 100644 --- a/R-package/src/xgboost_R.cc +++ b/R-package/src/xgboost_R.cc @@ -739,7 +739,7 @@ struct _RDataIterator { SEXP proxy_dmat; _RDataIterator( - SEXP f_next, SEXP f_reset, SEXP calling_env, SEXP continuation_token, SEP proxy_dmat) : + SEXP f_next, SEXP f_reset, SEXP calling_env, SEXP continuation_token, SEXP proxy_dmat) : f_next(f_next), f_reset(f_reset), calling_env(calling_env), continuation_token(continuation_token), proxy_dmat(proxy_dmat) {}