From f5d01bd8af84068db9f442654e2898fd0e0a85e6 Mon Sep 17 00:00:00 2001 From: RektPunk <110188257+RektPunk@users.noreply.github.com> Date: Wed, 25 Sep 2024 11:00:56 +0900 Subject: [PATCH] [Feature] add custom booster (#20) --- imlightgbm/engine.py | 46 +++++++++++++++++++++++++++++++++++++++++-- imlightgbm/sklearn.py | 3 ++- 2 files changed, 46 insertions(+), 3 deletions(-) diff --git a/imlightgbm/engine.py b/imlightgbm/engine.py index 2c69494..b9c57dd 100644 --- a/imlightgbm/engine.py +++ b/imlightgbm/engine.py @@ -3,12 +3,52 @@ import lightgbm as lgb import numpy as np +from scipy.sparse import spmatrix +from scipy.special import expit from sklearn.model_selection import BaseCrossValidator from imlightgbm.docstring import add_docstring from imlightgbm.objective.engine import set_params +class ImbalancedBooster(lgb.Booster): + def predict( + self, + data: lgb.basic._LGBM_PredictDataType, + start_iteration: int = 0, + num_iteration: int | None = None, + raw_score: bool = False, + pred_leaf: bool = False, + pred_contrib: bool = False, + data_has_header: bool = False, + validate_features: bool = False, + **kwargs: Any, + ) -> np.ndarray | spmatrix | list[spmatrix]: + _predict = super().predict( + data=data, + start_iteration=start_iteration, + num_iteration=num_iteration, + raw_score=raw_score, + pred_leaf=pred_leaf, + pred_contrib=pred_contrib, + data_has_header=data_has_header, + validate_features=validate_features, + **kwargs, + ) + if ( + raw_score + or pred_leaf + or pred_contrib + or isinstance(_predict, spmatrix | list[spmatrix]) + ): + return _predict + + if len(_predict.shape) == 1: + return expit(_predict) + + return _predict # TODO: multiclass + + @add_docstring("train") def train( params: dict[str, Any], @@ -19,9 +59,9 @@ def train( init_model: str | lgb.Path | lgb.Booster | None = None, keep_training_booster: bool = False, callbacks: list[Callable] | None = None, -) -> lgb.Booster: +) -> ImbalancedBooster: _params = set_params(params=params, train_set=train_set) - return lgb.train( + _booster = lgb.train( params=_params, train_set=train_set, num_boost_round=num_boost_round, @@ -31,6 +71,8 @@ def train( keep_training_booster=keep_training_booster, callbacks=callbacks, ) + _booster_str = _booster.model_to_string() + return ImbalancedBooster(model_str=_booster_str) @add_docstring("cv") diff --git a/imlightgbm/sklearn.py b/imlightgbm/sklearn.py index 8d896af..c9aa1d5 100644 --- a/imlightgbm/sklearn.py +++ b/imlightgbm/sklearn.py @@ -2,6 +2,7 @@ import numpy as np from lightgbm.sklearn import LGBMClassifier, _LGBM_ScikitMatrixLike +from scipy.sparse import spmatrix from scipy.special import expit from imlightgbm.base import ALPHA_DEFAULT, GAMMA_DEFAULT, Objective @@ -99,7 +100,7 @@ def predict( pred_contrib: bool = False, validate_features: bool = False, **kwargs: Any, - ): + ) -> np.ndarray | spmatrix | list[spmatrix]: """Docstring is inherited from the LGBMClassifier.""" result = super().predict( X=X,