diff --git a/imlightgbm/objective.py b/imlightgbm/objective.py index 5c3e965..e750786 100644 --- a/imlightgbm/objective.py +++ b/imlightgbm/objective.py @@ -6,6 +6,8 @@ from lightgbm import Dataset from sklearn.utils.multiclass import type_of_target +from imlightgbm.base import Metric, Objective, SupportedTask + ObjLike = Callable[[np.ndarray, Dataset], tuple[np.ndarray, np.ndarray]] ALPHA_DEFAULT: float = 0.25 GAMMA_DEFAULT: float = 2.0 @@ -54,8 +56,8 @@ def binary_focal_objective( return grad, hess -def weighted_binary_cross_entropy(pred: np.ndarray, train_data: Dataset, alpha: float): - """Return grad, hess for binary focal objective.""" +def binary_weighted_objective(pred: np.ndarray, train_data: Dataset, alpha: float): + """Return grad, hess for binary weighted objective.""" label = train_data.get_label() pred_prob = _sigmoid(pred) grad = -(alpha**label) * (label - pred_prob) @@ -70,41 +72,73 @@ def multiclass_focal_objective( return -def multiclass_focal_eval( +def multiclass_weighted_objective( pred: np.ndarray, train_data: Dataset, alpha: float, gamma: float ) -> tuple[str, float, bool]: # TODO return -def _set_fobj_feval( +def _get_metric(task_enum: SupportedTask, metric: str | None) -> str: + """Retrieve the appropriate metric function based on task.""" + metric_mapper: dict[SupportedTask, list[Metric]] = { + SupportedTask.binary: [Metric.auc, Metric.binary_error, Metric.binary_logloss], + SupportedTask.multiclass: [ + Metric.auc_mu, + Metric.multi_logloss, + Metric.multi_error, + ], + } + if metric: + metric_enum = Metric.get(metric) + metric_enums = metric_mapper[task_enum] + if metric_enum not in metric_enums: + valid_metrics = ", ".join([m.value for m in metric_enums]) + raise ValueError(f"Invalid metric: Supported metrics are {valid_metrics}") + return metric_enum.value + + return metric_mapper[task_enum][0].value + + +def _get_objective( + task_enum: SupportedTask, objective: str | None, alpha: float, gamma: float +) -> ObjLike: + """Retrieve the appropriate objective function based on task and objective type.""" + objective_mapper: dict[SupportedTask, dict[Objective, ObjLike]] = { + SupportedTask.binary: { + Objective.focal: partial(binary_focal_objective, gamma=gamma), + Objective.weighted: partial(binary_weighted_objective, alpha=alpha), + }, + SupportedTask.multiclass: { + Objective.focal: partial( + multiclass_focal_objective, alpha=alpha, gamma=gamma + ), + Objective.weighted: partial( + multiclass_weighted_objective, alpha=alpha, gamma=gamma + ), + }, + } + if objective: + objective_enum = Objective.get(objective) + return objective_mapper[task_enum][objective_enum] + + return objective_mapper[task_enum][Objective.focal] + + +def _get_fobj_feval( train_set: Dataset, alpha: float, gamma: float, - objective: str | None = None, - metric: str | None = None, -) -> tuple[ObjLike, list[str]]: + objective: str | None, + metric: str | None, +) -> tuple[ObjLike, str]: """Return obj and eval with respect to task type.""" - inferred_task = type_of_target(train_set.get_label()) - if inferred_task not in {"binary"}: # TODO: multiclass - raise ValueError( - f"Invalid target type: {inferred_task}. Supported types is 'binary'." - ) - - feval = metric if metric else "auc" - if objective: - if objective not in {"focal", "weighted"}: - raise ValueError( - f"Invalid objective: {objective}. Supported types is 'focal' and 'weighted'." - ) - objective_mapper: dict[str, ObjLike] = { - "focal": partial(binary_focal_objective, gamma=gamma), - "weighted": partial(weighted_binary_cross_entropy, alpha=alpha), - } - fobj = objective_mapper[objective] - else: - fobj: ObjLike = partial(binary_focal_objective, gamma=gamma) - + _task = type_of_target(train_set.get_label()) + task_enum = SupportedTask.get(_task) + feval = _get_metric(task_enum=task_enum, metric=metric) + fobj = _get_objective( + task_enum=task_enum, objective=objective, alpha=alpha, gamma=gamma + ) return fobj, feval @@ -117,13 +151,10 @@ def set_params(params: dict[str, Any], train_set: Dataset) -> dict[str, Any]: if _metric and not isinstance(_metric, str): raise ValueError("metric must be str") - _alpha = _params.pop("alpha", ALPHA_DEFAULT) - _gamma = _params.pop("gamma", GAMMA_DEFAULT) - - fobj, feval = _set_fobj_feval( + fobj, feval = _get_fobj_feval( train_set=train_set, - alpha=_alpha, - gamma=_gamma, + alpha=_params.pop("alpha", ALPHA_DEFAULT), + gamma=_params.pop("gamma", GAMMA_DEFAULT), objective=_objective, metric=_metric, )