From 674520409cf1ad76c7658ce3389afe7492c717e1 Mon Sep 17 00:00:00 2001 From: RektPunk <110188257+RektPunk@users.noreply.github.com> Date: Sun, 15 Sep 2024 13:44:39 +0900 Subject: [PATCH] [Feature] add set params function (#3) --- imlightgbm/engine.py | 30 +++++------------------------- imlightgbm/objective.py | 32 +++++++++++++++++++++++++++----- 2 files changed, 32 insertions(+), 30 deletions(-) diff --git a/imlightgbm/engine.py b/imlightgbm/engine.py index dcc0e8e..df5a3b7 100644 --- a/imlightgbm/engine.py +++ b/imlightgbm/engine.py @@ -1,13 +1,12 @@ from collections.abc import Iterable -from copy import deepcopy from typing import Any, Callable, Literal import lightgbm as lgb import numpy as np from sklearn.model_selection import BaseCrossValidator -from imlightgbm.objective import set_fobj_feval -from imlightgbm.utils import docstring, logger +from imlightgbm.objective import set_params +from imlightgbm.utils import docstring @docstring(lgb.train.__doc__) @@ -23,17 +22,7 @@ def train( keep_training_booster: bool = False, callbacks: list[Callable] | None = None, ) -> lgb.Booster: - _params = deepcopy(params) - if "objective" in _params: - logger.warning("'objective' exists in params will not used.") - del _params["objective"] - - _alpha = _params.pop("alpha", 0.05) - _gamma = _params.pop("gamma", 0.01) - - fobj, feval = set_fobj_feval(train_set=train_set, alpha=_alpha, gamma=_gamma) - _params.update({"objective": fobj}) - + _params, feval = set_params(params=params, train_set=train_set) return lgb.train( params=_params, train_set=train_set, @@ -72,16 +61,7 @@ def cv( eval_train_metric: bool = False, return_cvbooster: bool = False, ) -> dict[str, list[float] | lgb.CVBooster]: - _params = deepcopy(params) - if "objective" in _params: - logger.warning("'objective' exists in params will not used.") - del _params["objective"] - - _alpha = _params.pop("alpha", 0.05) - _gamma = _params.pop("gamma", 0.01) - - fobj, feval = set_fobj_feval(train_set=train_set, alpha=_alpha, gamma=_gamma) - _params.update({"objective": fobj}) + _params, feval = set_params(params=params, train_set=train_set) return lgb.cv( params=_params, train_set=train_set, @@ -91,7 +71,7 @@ def cv( stratified=stratified, shuffle=shuffle, metrics=metrics, - feavl=feval, + feval=feval, init_model=init_model, feature_name=feature_name, categorical_feature=categorical_feature, diff --git a/imlightgbm/objective.py b/imlightgbm/objective.py index 62f6f6e..66dd63c 100644 --- a/imlightgbm/objective.py +++ b/imlightgbm/objective.py @@ -1,38 +1,44 @@ +from copy import deepcopy from functools import partial -from typing import Callable +from typing import Any, Callable import numpy as np from lightgbm import Dataset from sklearn.utils.multiclass import type_of_target +from imlightgbm.utils import logger + EvalLike = Callable[[np.ndarray, Dataset], tuple[str, float, bool]] ObjLike = Callable[[np.ndarray, Dataset], tuple[np.ndarray, np.ndarray]] +ALPHA_DEFAULT: float = 0.05 +GAMMA_DEFAULT: float = 0.05 +OBJECTIVE_STR: str = "objective" def binary_focal_eval( pred: np.ndarray, train_data: Dataset, alpha: float, gamma: float -): +) -> tuple[str, float, bool]: is_higher_better = False return "binary_focal", ..., is_higher_better def binary_focal_objective( pred: np.ndarray, train_data: Dataset, alpha: float, gamma: float -): +) -> tuple[np.ndarray, np.ndarray]: # TODO return ... def multiclass_focal_eval( pred: np.ndarray, train_data: Dataset, alpha: float, gamma: float -): +) -> tuple[str, float, bool]: # TODO return def multiclass_focal_objective( pred: np.ndarray, train_data: Dataset, alpha: float, gamma: float -): +) -> tuple[np.ndarray, np.ndarray]: # TODO return @@ -57,3 +63,19 @@ def set_fobj_feval( feval = eval_mapper[inferred_task] return fobj, feval + + +def set_params( + params: dict[str, Any], train_set: Dataset +) -> tuple[dict[str, Any], EvalLike]: + _params = deepcopy(params) + if OBJECTIVE_STR in _params: + logger.warning(f"'{OBJECTIVE_STR}' exists in params will not used.") + del _params[OBJECTIVE_STR] + + _alpha = _params.pop("alpha", ALPHA_DEFAULT) + _gamma = _params.pop("gamma", GAMMA_DEFAULT) + + fobj, feval = set_fobj_feval(train_set=train_set, alpha=_alpha, gamma=_gamma) + _params.update({OBJECTIVE_STR: fobj}) + return _params, feval