From 003f3b7124f3e4f81a2eb55ada8d22e2223d1a63 Mon Sep 17 00:00:00 2001 From: RektPunk Date: Sat, 14 Sep 2024 19:11:30 +0900 Subject: [PATCH] add objective initial --- imlightgbm/engine.py | 38 ++++++++++++++++++++++++++++++++------ imlightgbm/objective.py | 30 ++++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+), 6 deletions(-) diff --git a/imlightgbm/engine.py b/imlightgbm/engine.py index c3b6286..9eef31e 100644 --- a/imlightgbm/engine.py +++ b/imlightgbm/engine.py @@ -1,8 +1,15 @@ +from functools import partial from typing import Any, Callable, Literal import lightgbm as lgb from sklearn.utils.multiclass import type_of_target +from imlightgbm.objective import ( + binary_focal_eval, + binary_focal_objective, + multiclass_focal_eval, + multiclass_focal_objective, +) from imlightgbm.utils import logger, modify_docstring @@ -22,15 +29,34 @@ def train( logger.warning("'objective' exists in params will not used.") params.pop("objective") - inferred_type = type_of_target(train_set.get_label()) - if inferred_type not in {"binary", "multiclass"}: + params.setdefault("alpha", 0.05) + params.setdefault("gamma", 0.01) + + inferred_task = type_of_target(train_set.get_label()) + if inferred_task not in {"binary", "multiclass"}: raise ValueError( - f"Invalid target type: {inferred_type}. Supported types are 'binary' or 'multiclass'." + f"Invalid target type: {inferred_task}. Supported types are 'binary' or 'multiclass'." ) - # MAPPER # TODO - fobj = ... # Focal eval function # TODO - feval = ... # Focal objective function, # TODO + eval_mapper = { + "binary": partial( + binary_focal_eval, alpha=params["alpha"], gamma=params["gamma"] + ), + "multiclass": partial( + multiclass_focal_eval, alpha=params["alpha"], gamma=params["gamma"] + ), + } + objective_mapper = { + "binary": partial( + binary_focal_objective, alpha=params["alpha"], gamma=params["gamma"] + ), + "multiclass": partial( + multiclass_focal_objective, alpha=params["alpha"], gamma=params["gamma"] + ), + } + + fobj = objective_mapper.get(inferred_task) + feval = eval_mapper.get(inferred_task) params.update({"objective": fobj}) return lgb.train( diff --git a/imlightgbm/objective.py b/imlightgbm/objective.py index e69de29..ed87b74 100644 --- a/imlightgbm/objective.py +++ b/imlightgbm/objective.py @@ -0,0 +1,30 @@ +import numpy as np +from lightgbm import Dataset + + +def binary_focal_eval( + pred: np.ndarray, train_data: Dataset, alpha: float, gamma: float +): + is_higher_better = False + return "binary_focal", ..., is_higher_better + + +def binary_focal_objective( + pred: np.ndarray, train_data: Dataset, alpha: float, gamma: float +): + # TODO + return ... + + +def multiclass_focal_eval( + pred: np.ndarray, train_data: Dataset, alpha: float, gamma: float +): + # TODO + return + + +def multiclass_focal_objective( + pred: np.ndarray, train_data: Dataset, alpha: float, gamma: float +): + # TODO + return