From 648d6dd7bdf1baf5bf34064ac6ed4b28e823537b Mon Sep 17 00:00:00 2001 From: Zhi Lin Date: Sat, 22 Apr 2023 06:51:48 +0800 Subject: [PATCH] Add QuantileDMatrix support (#279) * add quantile matrix Signed-off-by: Zhi Lin * revert unrelated changes Signed-off-by: Zhi Lin * add helper function Signed-off-by: Zhi Lin * format Signed-off-by: Zhi Lin --------- Signed-off-by: Zhi Lin --- xgboost_ray/main.py | 37 +++++++++++++++++++++++-------------- xgboost_ray/matrix.py | 12 ++++++++++++ 2 files changed, 35 insertions(+), 14 deletions(-) diff --git a/xgboost_ray/main.py b/xgboost_ray/main.py index b6fdd966..039a7b35 100644 --- a/xgboost_ray/main.py +++ b/xgboost_ray/main.py @@ -75,7 +75,7 @@ def inner_f(*args, **kwargs): from xgboost_ray.matrix import RayDMatrix, combine_data, \ RayDeviceQuantileDMatrix, RayDataIter, concat_dataframes, \ - LEGACY_MATRIX + LEGACY_MATRIX, QUANTILE_AVAILABLE, RayQuantileDMatrix from xgboost_ray.session import init_session, put_queue, \ set_session_queue, get_rabit_rank @@ -320,7 +320,28 @@ def _set_omp_num_threads(): return int(float(os.environ.get("OMP_NUM_THREADS", "0.0"))) +def _prepare_dmatrix_params(param: Dict) -> Dict: + dm_param = { + "data": concat_dataframes(param["data"]), + "label": concat_dataframes(param["label"]), + "weight": concat_dataframes(param["weight"]), + "feature_weights": concat_dataframes(param["feature_weights"]), + "qid": concat_dataframes(param["qid"]), + "base_margin": concat_dataframes(param["base_margin"]), + "label_lower_bound": concat_dataframes(param["label_lower_bound"]), + "label_upper_bound": concat_dataframes(param["label_upper_bound"]), + } + return dm_param + + def _get_dmatrix(data: RayDMatrix, param: Dict) -> xgb.DMatrix: + if QUANTILE_AVAILABLE and isinstance(data, RayQuantileDMatrix): + if isinstance(param["data"], list): + qdm_param = _prepare_dmatrix_params(param) + param.update(qdm_param) + if data.enable_categorical is not None: + param["enable_categorical"] = data.enable_categorical + matrix = xgb.QuantileDMatrix(**param) if not LEGACY_MATRIX and isinstance(data, RayDeviceQuantileDMatrix): # If we only got a single data shard, create a list so we can # iterate over it @@ -355,18 +376,7 @@ def _get_dmatrix(data: RayDMatrix, param: Dict) -> xgb.DMatrix: matrix = xgb.DeviceQuantileDMatrix(it, **dm_param) else: if isinstance(param["data"], list): - dm_param = { - "data": concat_dataframes(param["data"]), - "label": concat_dataframes(param["label"]), - "weight": concat_dataframes(param["weight"]), - "feature_weights": concat_dataframes(param["feature_weights"]), - "qid": concat_dataframes(param["qid"]), - "base_margin": concat_dataframes(param["base_margin"]), - "label_lower_bound": concat_dataframes( - param["label_lower_bound"]), - "label_upper_bound": concat_dataframes( - param["label_upper_bound"]), - } + dm_param = _prepare_dmatrix_params(param) param.update(dm_param) ll = param.pop("label_lower_bound", None) @@ -669,7 +679,6 @@ def _train(): for deval, name in evals: local_evals.append((_get_dmatrix( deval, self._data[deval]), name)) - if LEGACY_CALLBACK: for xgb_callback in kwargs.get("callbacks", []): if isinstance(xgb_callback, TrainingCallback): diff --git a/xgboost_ray/matrix.py b/xgboost_ray/matrix.py index c2fbbfb4..5bf92838 100644 --- a/xgboost_ray/matrix.py +++ b/xgboost_ray/matrix.py @@ -37,6 +37,13 @@ class RayDataset: DataIter = object LEGACY_MATRIX = True +try: + from xgboost.core import QuantileDmatrix + QUANTILE_AVAILABLE = True +except ImportError: + QuantileDmatrix = object + QUANTILE_AVAILABLE = False + if TYPE_CHECKING: from xgboost_ray.xgb import xgboost as xgb @@ -875,6 +882,11 @@ def __eq__(self, other): return self.__hash__() == other.__hash__() +class RayQuantileDMatrix(RayDMatrix): + """Currently just a thin wrapper for type detection""" + pass + + class RayDeviceQuantileDMatrix(RayDMatrix): """Currently just a thin wrapper for type detection"""