Skip to content

Commit

Permalink
refactor objective
Browse files Browse the repository at this point in the history
  • Loading branch information
RektPunk committed Sep 18, 2024
1 parent 4eb7d84 commit dc54077
Showing 1 changed file with 64 additions and 33 deletions.
97 changes: 64 additions & 33 deletions imlightgbm/objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from lightgbm import Dataset
from sklearn.utils.multiclass import type_of_target

from imlightgbm.base import Metric, Objective, SupportedTask

ObjLike = Callable[[np.ndarray, Dataset], tuple[np.ndarray, np.ndarray]]
ALPHA_DEFAULT: float = 0.25
GAMMA_DEFAULT: float = 2.0
Expand Down Expand Up @@ -54,8 +56,8 @@ def binary_focal_objective(
return grad, hess


def weighted_binary_cross_entropy(pred: np.ndarray, train_data: Dataset, alpha: float):
"""Return grad, hess for binary focal objective."""
def binary_weighted_objective(pred: np.ndarray, train_data: Dataset, alpha: float):
"""Return grad, hess for binary weighted objective."""
label = train_data.get_label()
pred_prob = _sigmoid(pred)
grad = -(alpha**label) * (label - pred_prob)
Expand All @@ -70,41 +72,73 @@ def multiclass_focal_objective(
return


def multiclass_focal_eval(
def multiclass_weighted_objective(
pred: np.ndarray, train_data: Dataset, alpha: float, gamma: float
) -> tuple[str, float, bool]:
# TODO
return


def _set_fobj_feval(
def _get_metric(task_enum: SupportedTask, metric: str | None) -> str:
"""Retrieve the appropriate metric function based on task."""
metric_mapper: dict[SupportedTask, list[Metric]] = {
SupportedTask.binary: [Metric.auc, Metric.binary_error, Metric.binary_logloss],
SupportedTask.multiclass: [
Metric.auc_mu,
Metric.multi_logloss,
Metric.multi_error,
],
}
if metric:
metric_enum = Metric.get(metric)
metric_enums = metric_mapper[task_enum]
if metric_enum not in metric_enums:
valid_metrics = ", ".join([m.value for m in metric_enums])
raise ValueError(f"Invalid metric: Supported metrics are {valid_metrics}")
return metric_enum.value

return metric_mapper[task_enum][0].value


def _get_objective(
task_enum: SupportedTask, objective: str | None, alpha: float, gamma: float
) -> ObjLike:
"""Retrieve the appropriate objective function based on task and objective type."""
objective_mapper: dict[SupportedTask, dict[Objective, ObjLike]] = {
SupportedTask.binary: {
Objective.focal: partial(binary_focal_objective, gamma=gamma),
Objective.weighted: partial(binary_weighted_objective, alpha=alpha),
},
SupportedTask.multiclass: {
Objective.focal: partial(
multiclass_focal_objective, alpha=alpha, gamma=gamma
),
Objective.weighted: partial(
multiclass_weighted_objective, alpha=alpha, gamma=gamma
),
},
}
if objective:
objective_enum = Objective.get(objective)
return objective_mapper[task_enum][objective_enum]

return objective_mapper[task_enum][Objective.focal]


def _get_fobj_feval(
train_set: Dataset,
alpha: float,
gamma: float,
objective: str | None = None,
metric: str | None = None,
) -> tuple[ObjLike, list[str]]:
objective: str | None,
metric: str | None,
) -> tuple[ObjLike, str]:
"""Return obj and eval with respect to task type."""
inferred_task = type_of_target(train_set.get_label())
if inferred_task not in {"binary"}: # TODO: multiclass
raise ValueError(
f"Invalid target type: {inferred_task}. Supported types is 'binary'."
)

feval = metric if metric else "auc"
if objective:
if objective not in {"focal", "weighted"}:
raise ValueError(
f"Invalid objective: {objective}. Supported types is 'focal' and 'weighted'."
)
objective_mapper: dict[str, ObjLike] = {
"focal": partial(binary_focal_objective, gamma=gamma),
"weighted": partial(weighted_binary_cross_entropy, alpha=alpha),
}
fobj = objective_mapper[objective]
else:
fobj: ObjLike = partial(binary_focal_objective, gamma=gamma)

_task = type_of_target(train_set.get_label())
task_enum = SupportedTask.get(_task)
feval = _get_metric(task_enum=task_enum, metric=metric)
fobj = _get_objective(
task_enum=task_enum, objective=objective, alpha=alpha, gamma=gamma
)
return fobj, feval


Expand All @@ -117,13 +151,10 @@ def set_params(params: dict[str, Any], train_set: Dataset) -> dict[str, Any]:
if _metric and not isinstance(_metric, str):
raise ValueError("metric must be str")

_alpha = _params.pop("alpha", ALPHA_DEFAULT)
_gamma = _params.pop("gamma", GAMMA_DEFAULT)

fobj, feval = _set_fobj_feval(
fobj, feval = _get_fobj_feval(
train_set=train_set,
alpha=_alpha,
gamma=_gamma,
alpha=_params.pop("alpha", ALPHA_DEFAULT),
gamma=_params.pop("gamma", GAMMA_DEFAULT),
objective=_objective,
metric=_metric,
)
Expand Down

0 comments on commit dc54077

Please sign in to comment.