Skip to content

Commit

Permalink
[Feature] add custom booster (#20)
Browse files Browse the repository at this point in the history
  • Loading branch information
RektPunk authored Sep 25, 2024
1 parent 9e2f382 commit f5d01bd
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 3 deletions.
46 changes: 44 additions & 2 deletions imlightgbm/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,52 @@

import lightgbm as lgb
import numpy as np
from scipy.sparse import spmatrix
from scipy.special import expit
from sklearn.model_selection import BaseCrossValidator

from imlightgbm.docstring import add_docstring
from imlightgbm.objective.engine import set_params


class ImbalancedBooster(lgb.Booster):
def predict(
self,
data: lgb.basic._LGBM_PredictDataType,
start_iteration: int = 0,
num_iteration: int | None = None,
raw_score: bool = False,
pred_leaf: bool = False,
pred_contrib: bool = False,
data_has_header: bool = False,
validate_features: bool = False,
**kwargs: Any,
) -> np.ndarray | spmatrix | list[spmatrix]:
_predict = super().predict(
data=data,
start_iteration=start_iteration,
num_iteration=num_iteration,
raw_score=raw_score,
pred_leaf=pred_leaf,
pred_contrib=pred_contrib,
data_has_header=data_has_header,
validate_features=validate_features,
**kwargs,
)
if (
raw_score
or pred_leaf
or pred_contrib
or isinstance(_predict, spmatrix | list[spmatrix])
):
return _predict

if len(_predict.shape) == 1:
return expit(_predict)

return _predict # TODO: multiclass


@add_docstring("train")
def train(
params: dict[str, Any],
Expand All @@ -19,9 +59,9 @@ def train(
init_model: str | lgb.Path | lgb.Booster | None = None,
keep_training_booster: bool = False,
callbacks: list[Callable] | None = None,
) -> lgb.Booster:
) -> ImbalancedBooster:
_params = set_params(params=params, train_set=train_set)
return lgb.train(
_booster = lgb.train(
params=_params,
train_set=train_set,
num_boost_round=num_boost_round,
Expand All @@ -31,6 +71,8 @@ def train(
keep_training_booster=keep_training_booster,
callbacks=callbacks,
)
_booster_str = _booster.model_to_string()
return ImbalancedBooster(model_str=_booster_str)


@add_docstring("cv")
Expand Down
3 changes: 2 additions & 1 deletion imlightgbm/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np
from lightgbm.sklearn import LGBMClassifier, _LGBM_ScikitMatrixLike
from scipy.sparse import spmatrix
from scipy.special import expit

from imlightgbm.base import ALPHA_DEFAULT, GAMMA_DEFAULT, Objective
Expand Down Expand Up @@ -99,7 +100,7 @@ def predict(
pred_contrib: bool = False,
validate_features: bool = False,
**kwargs: Any,
):
) -> np.ndarray | spmatrix | list[spmatrix]:
"""Docstring is inherited from the LGBMClassifier."""
result = super().predict(
X=X,
Expand Down

0 comments on commit f5d01bd

Please sign in to comment.