Skip to content

Commit

Permalink
add objective initial
Browse files Browse the repository at this point in the history
  • Loading branch information
RektPunk committed Sep 14, 2024
1 parent 8da3f98 commit 003f3b7
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 6 deletions.
38 changes: 32 additions & 6 deletions imlightgbm/engine.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
from functools import partial
from typing import Any, Callable, Literal

import lightgbm as lgb
from sklearn.utils.multiclass import type_of_target

from imlightgbm.objective import (
binary_focal_eval,
binary_focal_objective,
multiclass_focal_eval,
multiclass_focal_objective,
)
from imlightgbm.utils import logger, modify_docstring


Expand All @@ -22,15 +29,34 @@ def train(
logger.warning("'objective' exists in params will not used.")
params.pop("objective")

inferred_type = type_of_target(train_set.get_label())
if inferred_type not in {"binary", "multiclass"}:
params.setdefault("alpha", 0.05)
params.setdefault("gamma", 0.01)

inferred_task = type_of_target(train_set.get_label())
if inferred_task not in {"binary", "multiclass"}:
raise ValueError(
f"Invalid target type: {inferred_type}. Supported types are 'binary' or 'multiclass'."
f"Invalid target type: {inferred_task}. Supported types are 'binary' or 'multiclass'."
)

# MAPPER # TODO
fobj = ... # Focal eval function # TODO
feval = ... # Focal objective function, # TODO
eval_mapper = {
"binary": partial(
binary_focal_eval, alpha=params["alpha"], gamma=params["gamma"]
),
"multiclass": partial(
multiclass_focal_eval, alpha=params["alpha"], gamma=params["gamma"]
),
}
objective_mapper = {
"binary": partial(
binary_focal_objective, alpha=params["alpha"], gamma=params["gamma"]
),
"multiclass": partial(
multiclass_focal_objective, alpha=params["alpha"], gamma=params["gamma"]
),
}

fobj = objective_mapper.get(inferred_task)
feval = eval_mapper.get(inferred_task)
params.update({"objective": fobj})

return lgb.train(
Expand Down
30 changes: 30 additions & 0 deletions imlightgbm/objective.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import numpy as np
from lightgbm import Dataset


def binary_focal_eval(
pred: np.ndarray, train_data: Dataset, alpha: float, gamma: float
):
is_higher_better = False
return "binary_focal", ..., is_higher_better


def binary_focal_objective(
pred: np.ndarray, train_data: Dataset, alpha: float, gamma: float
):
# TODO
return ...


def multiclass_focal_eval(
pred: np.ndarray, train_data: Dataset, alpha: float, gamma: float
):
# TODO
return


def multiclass_focal_objective(
pred: np.ndarray, train_data: Dataset, alpha: float, gamma: float
):
# TODO
return

0 comments on commit 003f3b7

Please sign in to comment.