Skip to content

Commit

Permalink
[Feature] add set params function (#3)
Browse files Browse the repository at this point in the history
  • Loading branch information
RektPunk authored Sep 15, 2024
1 parent 1922fe2 commit 6745204
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 30 deletions.
30 changes: 5 additions & 25 deletions imlightgbm/engine.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from collections.abc import Iterable
from copy import deepcopy
from typing import Any, Callable, Literal

import lightgbm as lgb
import numpy as np
from sklearn.model_selection import BaseCrossValidator

from imlightgbm.objective import set_fobj_feval
from imlightgbm.utils import docstring, logger
from imlightgbm.objective import set_params
from imlightgbm.utils import docstring


@docstring(lgb.train.__doc__)
Expand All @@ -23,17 +22,7 @@ def train(
keep_training_booster: bool = False,
callbacks: list[Callable] | None = None,
) -> lgb.Booster:
_params = deepcopy(params)
if "objective" in _params:
logger.warning("'objective' exists in params will not used.")
del _params["objective"]

_alpha = _params.pop("alpha", 0.05)
_gamma = _params.pop("gamma", 0.01)

fobj, feval = set_fobj_feval(train_set=train_set, alpha=_alpha, gamma=_gamma)
_params.update({"objective": fobj})

_params, feval = set_params(params=params, train_set=train_set)
return lgb.train(
params=_params,
train_set=train_set,
Expand Down Expand Up @@ -72,16 +61,7 @@ def cv(
eval_train_metric: bool = False,
return_cvbooster: bool = False,
) -> dict[str, list[float] | lgb.CVBooster]:
_params = deepcopy(params)
if "objective" in _params:
logger.warning("'objective' exists in params will not used.")
del _params["objective"]

_alpha = _params.pop("alpha", 0.05)
_gamma = _params.pop("gamma", 0.01)

fobj, feval = set_fobj_feval(train_set=train_set, alpha=_alpha, gamma=_gamma)
_params.update({"objective": fobj})
_params, feval = set_params(params=params, train_set=train_set)
return lgb.cv(
params=_params,
train_set=train_set,
Expand All @@ -91,7 +71,7 @@ def cv(
stratified=stratified,
shuffle=shuffle,
metrics=metrics,
feavl=feval,
feval=feval,
init_model=init_model,
feature_name=feature_name,
categorical_feature=categorical_feature,
Expand Down
32 changes: 27 additions & 5 deletions imlightgbm/objective.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,44 @@
from copy import deepcopy
from functools import partial
from typing import Callable
from typing import Any, Callable

import numpy as np
from lightgbm import Dataset
from sklearn.utils.multiclass import type_of_target

from imlightgbm.utils import logger

EvalLike = Callable[[np.ndarray, Dataset], tuple[str, float, bool]]
ObjLike = Callable[[np.ndarray, Dataset], tuple[np.ndarray, np.ndarray]]
ALPHA_DEFAULT: float = 0.05
GAMMA_DEFAULT: float = 0.05
OBJECTIVE_STR: str = "objective"


def binary_focal_eval(
pred: np.ndarray, train_data: Dataset, alpha: float, gamma: float
):
) -> tuple[str, float, bool]:
is_higher_better = False
return "binary_focal", ..., is_higher_better


def binary_focal_objective(
pred: np.ndarray, train_data: Dataset, alpha: float, gamma: float
):
) -> tuple[np.ndarray, np.ndarray]:
# TODO
return ...


def multiclass_focal_eval(
pred: np.ndarray, train_data: Dataset, alpha: float, gamma: float
):
) -> tuple[str, float, bool]:
# TODO
return


def multiclass_focal_objective(
pred: np.ndarray, train_data: Dataset, alpha: float, gamma: float
):
) -> tuple[np.ndarray, np.ndarray]:
# TODO
return

Expand All @@ -57,3 +63,19 @@ def set_fobj_feval(
feval = eval_mapper[inferred_task]

return fobj, feval


def set_params(
params: dict[str, Any], train_set: Dataset
) -> tuple[dict[str, Any], EvalLike]:
_params = deepcopy(params)
if OBJECTIVE_STR in _params:
logger.warning(f"'{OBJECTIVE_STR}' exists in params will not used.")
del _params[OBJECTIVE_STR]

_alpha = _params.pop("alpha", ALPHA_DEFAULT)
_gamma = _params.pop("gamma", GAMMA_DEFAULT)

fobj, feval = set_fobj_feval(train_set=train_set, alpha=_alpha, gamma=_gamma)
_params.update({OBJECTIVE_STR: fobj})
return _params, feval

0 comments on commit 6745204

Please sign in to comment.