From 1922fe2213c9adcc07966bf4715fdd816fb21d1e Mon Sep 17 00:00:00 2001 From: RektPunk <110188257+RektPunk@users.noreply.github.com> Date: Sun, 15 Sep 2024 13:25:16 +0900 Subject: [PATCH] [Feature] add `cv` function (#2) --- imlightgbm/__init__.py | 2 +- imlightgbm/engine.py | 108 +++++++++++++++++++++++++--------------- imlightgbm/objective.py | 29 +++++++++++ imlightgbm/utils.py | 11 +++- 4 files changed, 107 insertions(+), 43 deletions(-) diff --git a/imlightgbm/__init__.py b/imlightgbm/__init__.py index aae3d35..447fb82 100644 --- a/imlightgbm/__init__.py +++ b/imlightgbm/__init__.py @@ -1,2 +1,2 @@ # ruff: noqa -from imlightgbm.engine import train +from imlightgbm.engine import cv, train diff --git a/imlightgbm/engine.py b/imlightgbm/engine.py index 9eef31e..dcc0e8e 100644 --- a/imlightgbm/engine.py +++ b/imlightgbm/engine.py @@ -1,18 +1,16 @@ -from functools import partial +from collections.abc import Iterable +from copy import deepcopy from typing import Any, Callable, Literal import lightgbm as lgb -from sklearn.utils.multiclass import type_of_target +import numpy as np +from sklearn.model_selection import BaseCrossValidator -from imlightgbm.objective import ( - binary_focal_eval, - binary_focal_objective, - multiclass_focal_eval, - multiclass_focal_objective, -) -from imlightgbm.utils import logger, modify_docstring +from imlightgbm.objective import set_fobj_feval +from imlightgbm.utils import docstring, logger +@docstring(lgb.train.__doc__) def train( params: dict[str, Any], train_set: lgb.Dataset, @@ -25,42 +23,19 @@ def train( keep_training_booster: bool = False, callbacks: list[Callable] | None = None, ) -> lgb.Booster: - if "objective" in params: + _params = deepcopy(params) + if "objective" in _params: logger.warning("'objective' exists in params will not used.") - params.pop("objective") + del _params["objective"] - params.setdefault("alpha", 0.05) - params.setdefault("gamma", 0.01) + _alpha = _params.pop("alpha", 0.05) + _gamma = _params.pop("gamma", 0.01) - inferred_task = type_of_target(train_set.get_label()) - if inferred_task not in {"binary", "multiclass"}: - raise ValueError( - f"Invalid target type: {inferred_task}. Supported types are 'binary' or 'multiclass'." - ) - - eval_mapper = { - "binary": partial( - binary_focal_eval, alpha=params["alpha"], gamma=params["gamma"] - ), - "multiclass": partial( - multiclass_focal_eval, alpha=params["alpha"], gamma=params["gamma"] - ), - } - objective_mapper = { - "binary": partial( - binary_focal_objective, alpha=params["alpha"], gamma=params["gamma"] - ), - "multiclass": partial( - multiclass_focal_objective, alpha=params["alpha"], gamma=params["gamma"] - ), - } - - fobj = objective_mapper.get(inferred_task) - feval = eval_mapper.get(inferred_task) - params.update({"objective": fobj}) + fobj, feval = set_fobj_feval(train_set=train_set, alpha=_alpha, gamma=_gamma) + _params.update({"objective": fobj}) return lgb.train( - params=params, + params=_params, train_set=train_set, valid_sets=valid_sets, valid_names=valid_names, @@ -74,4 +49,55 @@ def train( ) -train.__doc__ = modify_docstring(lgb.train.__doc__) +@docstring(lgb.cv.__doc__) +def cv( + params: dict[str, Any], + train_set: lgb.Dataset, + num_boost_round: int = 100, + folds: Iterable[tuple[np.ndarray, np.ndarray]] | BaseCrossValidator | None = None, + nfold: int = 5, + stratified: bool = True, + shuffle: bool = True, + metrics: str | list[str] | None = None, + init_model: str | lgb.Path | lgb.Booster | None = None, + feature_name: list[str] | Literal["auto"] = "auto", + categorical_feature: list[str] | list[int] | Literal["auto"] = "auto", + fpreproc: Callable[ + [lgb.Dataset, lgb.Dataset, dict[str, Any]], + tuple[lgb.Dataset, lgb.Dataset, dict[str, Any]], + ] + | None = None, + seed: int = 0, + callbacks: list[Callable] | None = None, + eval_train_metric: bool = False, + return_cvbooster: bool = False, +) -> dict[str, list[float] | lgb.CVBooster]: + _params = deepcopy(params) + if "objective" in _params: + logger.warning("'objective' exists in params will not used.") + del _params["objective"] + + _alpha = _params.pop("alpha", 0.05) + _gamma = _params.pop("gamma", 0.01) + + fobj, feval = set_fobj_feval(train_set=train_set, alpha=_alpha, gamma=_gamma) + _params.update({"objective": fobj}) + return lgb.cv( + params=_params, + train_set=train_set, + num_boost_round=num_boost_round, + folds=folds, + nfold=nfold, + stratified=stratified, + shuffle=shuffle, + metrics=metrics, + feavl=feval, + init_model=init_model, + feature_name=feature_name, + categorical_feature=categorical_feature, + fpreproc=fpreproc, + seed=seed, + callbacks=callbacks, + eval_train_metric=eval_train_metric, + return_cvbooster=return_cvbooster, + ) diff --git a/imlightgbm/objective.py b/imlightgbm/objective.py index ed87b74..62f6f6e 100644 --- a/imlightgbm/objective.py +++ b/imlightgbm/objective.py @@ -1,5 +1,12 @@ +from functools import partial +from typing import Callable + import numpy as np from lightgbm import Dataset +from sklearn.utils.multiclass import type_of_target + +EvalLike = Callable[[np.ndarray, Dataset], tuple[str, float, bool]] +ObjLike = Callable[[np.ndarray, Dataset], tuple[np.ndarray, np.ndarray]] def binary_focal_eval( @@ -28,3 +35,25 @@ def multiclass_focal_objective( ): # TODO return + + +def set_fobj_feval( + train_set: Dataset, alpha: float, gamma: float +) -> tuple[ObjLike, EvalLike]: + inferred_task = type_of_target(train_set.get_label()) + if inferred_task not in {"binary", "multiclass"}: + raise ValueError( + f"Invalid target type: {inferred_task}. Supported types are 'binary' or 'multiclass'." + ) + objective_mapper: dict[str, ObjLike] = { + "binary": partial(binary_focal_objective, alpha=alpha, gamma=gamma), + "multiclass": partial(multiclass_focal_objective, alpha=alpha, gamma=gamma), + } + eval_mapper: dict[str, EvalLike] = { + "binary": partial(binary_focal_eval, alpha=alpha, gamma=gamma), + "multiclass": partial(multiclass_focal_eval, alpha=alpha, gamma=gamma), + } + fobj = objective_mapper[inferred_task] + feval = eval_mapper[inferred_task] + + return fobj, feval diff --git a/imlightgbm/utils.py b/imlightgbm/utils.py index a052c68..2cc1ad2 100644 --- a/imlightgbm/utils.py +++ b/imlightgbm/utils.py @@ -1,7 +1,8 @@ import logging +from typing import Callable -def modify_docstring(docstring: str) -> str: +def _modify_docstring(docstring: str) -> str: lines = docstring.splitlines() feval_start = next(i for i, line in enumerate(lines) if "feval" in line) @@ -14,6 +15,14 @@ def modify_docstring(docstring: str) -> str: return "\n".join(lines) +def docstring(doc: str) -> Callable[[Callable], Callable]: + def decorator(func: Callable) -> Callable: + func.__doc__ = _modify_docstring(doc) + return func + + return decorator + + def init_logger() -> logging.Logger: logger = logging.getLogger("imlightgbm") logger.setLevel(logging.INFO)