From 237dec9edf1c1d6d4e661f3d9ac50529f226b526 Mon Sep 17 00:00:00 2001 From: RektPunk Date: Sat, 5 Oct 2024 22:15:41 +0900 Subject: [PATCH 1/2] label mean if reference --- mqboost/dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mqboost/dataset.py b/mqboost/dataset.py index 5dea1e7..e43a5e9 100644 --- a/mqboost/dataset.py +++ b/mqboost/dataset.py @@ -33,7 +33,7 @@ class MQDataset: data (pd.DataFrame | pd.Series | np.ndarray): The input features. label (pd.Series | np.ndarray): The target labels (if provided). model (str): The model type (LightGBM or XGBoost). - reference (MQBoost | None): Reference dataset for label encoding. + reference (MQBoost | None): Reference dataset for label encoding and label mean. Property: train_dtype: Returns the data type function for training data. @@ -81,7 +81,7 @@ def __init__( self._data = prepare_x(x=_data, alphas=self._alphas) self._columns = self._data.columns if label is not None: - self._label_mean = label.mean() + self._label_mean = reference.label_mean if reference else label.mean() self._label = prepare_y(y=label - self._label_mean, alphas=self._alphas) self._is_none_label = False From 67d3f2bb183c2422bb0dcc11d4612ef5093d9cb2 Mon Sep 17 00:00:00 2001 From: RektPunk Date: Sat, 5 Oct 2024 22:30:17 +0900 Subject: [PATCH 2/2] add kwargs in fit method --- mqboost/regressor.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/mqboost/regressor.py b/mqboost/regressor.py index 461178b..16c3d29 100644 --- a/mqboost/regressor.py +++ b/mqboost/regressor.py @@ -58,6 +58,7 @@ def fit( self, dataset: MQDataset, eval_set: MQDataset | None = None, + **kwargs, ) -> None: """ Fit the regressor to the dataset. @@ -65,6 +66,8 @@ def fit( dataset (MQDataset): The dataset to fit the model on. eval_set (Optional[MQDataset]): The validation dataset. If None, the dataset is used for evaluation. + **kwargs: + train parameters. """ if eval_set: _eval_set = eval_set.dtrain @@ -92,6 +95,7 @@ def fit( params=params, feval=self._MQObj.feval, valid_sets=[_eval_set], + **kwargs, ) elif self.__is_xgb: self.model = xgb.train( @@ -101,6 +105,7 @@ def fit( obj=self._MQObj.fobj, custom_metric=self._MQObj.feval, evals=[(_eval_set, "eval")], + **kwargs, ) self._colnames = dataset.columns.to_list() self._fitted = True