Skip to content

Commit

Permalink
[Feature] add cv function (#2)
Browse files Browse the repository at this point in the history
  • Loading branch information
RektPunk authored Sep 15, 2024
1 parent e8f1947 commit 1922fe2
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 43 deletions.
2 changes: 1 addition & 1 deletion imlightgbm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# ruff: noqa
from imlightgbm.engine import train
from imlightgbm.engine import cv, train
108 changes: 67 additions & 41 deletions imlightgbm/engine.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
from functools import partial
from collections.abc import Iterable
from copy import deepcopy
from typing import Any, Callable, Literal

import lightgbm as lgb
from sklearn.utils.multiclass import type_of_target
import numpy as np
from sklearn.model_selection import BaseCrossValidator

from imlightgbm.objective import (
binary_focal_eval,
binary_focal_objective,
multiclass_focal_eval,
multiclass_focal_objective,
)
from imlightgbm.utils import logger, modify_docstring
from imlightgbm.objective import set_fobj_feval
from imlightgbm.utils import docstring, logger


@docstring(lgb.train.__doc__)
def train(
params: dict[str, Any],
train_set: lgb.Dataset,
Expand All @@ -25,42 +23,19 @@ def train(
keep_training_booster: bool = False,
callbacks: list[Callable] | None = None,
) -> lgb.Booster:
if "objective" in params:
_params = deepcopy(params)
if "objective" in _params:
logger.warning("'objective' exists in params will not used.")
params.pop("objective")
del _params["objective"]

params.setdefault("alpha", 0.05)
params.setdefault("gamma", 0.01)
_alpha = _params.pop("alpha", 0.05)
_gamma = _params.pop("gamma", 0.01)

inferred_task = type_of_target(train_set.get_label())
if inferred_task not in {"binary", "multiclass"}:
raise ValueError(
f"Invalid target type: {inferred_task}. Supported types are 'binary' or 'multiclass'."
)

eval_mapper = {
"binary": partial(
binary_focal_eval, alpha=params["alpha"], gamma=params["gamma"]
),
"multiclass": partial(
multiclass_focal_eval, alpha=params["alpha"], gamma=params["gamma"]
),
}
objective_mapper = {
"binary": partial(
binary_focal_objective, alpha=params["alpha"], gamma=params["gamma"]
),
"multiclass": partial(
multiclass_focal_objective, alpha=params["alpha"], gamma=params["gamma"]
),
}

fobj = objective_mapper.get(inferred_task)
feval = eval_mapper.get(inferred_task)
params.update({"objective": fobj})
fobj, feval = set_fobj_feval(train_set=train_set, alpha=_alpha, gamma=_gamma)
_params.update({"objective": fobj})

return lgb.train(
params=params,
params=_params,
train_set=train_set,
valid_sets=valid_sets,
valid_names=valid_names,
Expand All @@ -74,4 +49,55 @@ def train(
)


train.__doc__ = modify_docstring(lgb.train.__doc__)
@docstring(lgb.cv.__doc__)
def cv(
params: dict[str, Any],
train_set: lgb.Dataset,
num_boost_round: int = 100,
folds: Iterable[tuple[np.ndarray, np.ndarray]] | BaseCrossValidator | None = None,
nfold: int = 5,
stratified: bool = True,
shuffle: bool = True,
metrics: str | list[str] | None = None,
init_model: str | lgb.Path | lgb.Booster | None = None,
feature_name: list[str] | Literal["auto"] = "auto",
categorical_feature: list[str] | list[int] | Literal["auto"] = "auto",
fpreproc: Callable[
[lgb.Dataset, lgb.Dataset, dict[str, Any]],
tuple[lgb.Dataset, lgb.Dataset, dict[str, Any]],
]
| None = None,
seed: int = 0,
callbacks: list[Callable] | None = None,
eval_train_metric: bool = False,
return_cvbooster: bool = False,
) -> dict[str, list[float] | lgb.CVBooster]:
_params = deepcopy(params)
if "objective" in _params:
logger.warning("'objective' exists in params will not used.")
del _params["objective"]

_alpha = _params.pop("alpha", 0.05)
_gamma = _params.pop("gamma", 0.01)

fobj, feval = set_fobj_feval(train_set=train_set, alpha=_alpha, gamma=_gamma)
_params.update({"objective": fobj})
return lgb.cv(
params=_params,
train_set=train_set,
num_boost_round=num_boost_round,
folds=folds,
nfold=nfold,
stratified=stratified,
shuffle=shuffle,
metrics=metrics,
feavl=feval,
init_model=init_model,
feature_name=feature_name,
categorical_feature=categorical_feature,
fpreproc=fpreproc,
seed=seed,
callbacks=callbacks,
eval_train_metric=eval_train_metric,
return_cvbooster=return_cvbooster,
)
29 changes: 29 additions & 0 deletions imlightgbm/objective.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
from functools import partial
from typing import Callable

import numpy as np
from lightgbm import Dataset
from sklearn.utils.multiclass import type_of_target

EvalLike = Callable[[np.ndarray, Dataset], tuple[str, float, bool]]
ObjLike = Callable[[np.ndarray, Dataset], tuple[np.ndarray, np.ndarray]]


def binary_focal_eval(
Expand Down Expand Up @@ -28,3 +35,25 @@ def multiclass_focal_objective(
):
# TODO
return


def set_fobj_feval(
train_set: Dataset, alpha: float, gamma: float
) -> tuple[ObjLike, EvalLike]:
inferred_task = type_of_target(train_set.get_label())
if inferred_task not in {"binary", "multiclass"}:
raise ValueError(
f"Invalid target type: {inferred_task}. Supported types are 'binary' or 'multiclass'."
)
objective_mapper: dict[str, ObjLike] = {
"binary": partial(binary_focal_objective, alpha=alpha, gamma=gamma),
"multiclass": partial(multiclass_focal_objective, alpha=alpha, gamma=gamma),
}
eval_mapper: dict[str, EvalLike] = {
"binary": partial(binary_focal_eval, alpha=alpha, gamma=gamma),
"multiclass": partial(multiclass_focal_eval, alpha=alpha, gamma=gamma),
}
fobj = objective_mapper[inferred_task]
feval = eval_mapper[inferred_task]

return fobj, feval
11 changes: 10 additions & 1 deletion imlightgbm/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import logging
from typing import Callable


def modify_docstring(docstring: str) -> str:
def _modify_docstring(docstring: str) -> str:
lines = docstring.splitlines()

feval_start = next(i for i, line in enumerate(lines) if "feval" in line)
Expand All @@ -14,6 +15,14 @@ def modify_docstring(docstring: str) -> str:
return "\n".join(lines)


def docstring(doc: str) -> Callable[[Callable], Callable]:
def decorator(func: Callable) -> Callable:
func.__doc__ = _modify_docstring(doc)
return func

return decorator


def init_logger() -> logging.Logger:
logger = logging.getLogger("imlightgbm")
logger.setLevel(logging.INFO)
Expand Down

0 comments on commit 1922fe2

Please sign in to comment.