Skip to content

Commit

Permalink
[Feature] introduce scipy.special.expit (#19)
Browse files Browse the repository at this point in the history
  • Loading branch information
RektPunk authored Sep 24, 2024
1 parent a647887 commit 9e2f382
Show file tree
Hide file tree
Showing 8 changed files with 246 additions and 225 deletions.
9 changes: 7 additions & 2 deletions imlightgbm/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from enum import Enum

ALPHA_DEFAULT: float = 0.25
GAMMA_DEFAULT: float = 2.0


class BaseEnum(str, Enum):
@classmethod
Expand Down Expand Up @@ -31,5 +34,7 @@ class Metric(BaseEnum):


class Objective(BaseEnum):
focal: str = "focal"
weighted: str = "weighted"
binary_focal: str = "binary_focal"
binary_weighted: str = "binary_weighted"
multiclass_focal: str = "multiclass_focal"
multiclass_weighted: str = "multiclass_weighted"
2 changes: 1 addition & 1 deletion imlightgbm/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from sklearn.model_selection import BaseCrossValidator

from imlightgbm.docstring import add_docstring
from imlightgbm.objective import set_params
from imlightgbm.objective.engine import set_params


@add_docstring("train")
Expand Down
204 changes: 0 additions & 204 deletions imlightgbm/objective.py

This file was deleted.

Empty file.
85 changes: 85 additions & 0 deletions imlightgbm/objective/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import numpy as np
from lightgbm import Dataset
from scipy.special import expit


def _safe_power(num_base: np.ndarray, num_pow: float):
"""Safe power."""
return np.sign(num_base) * (np.abs(num_base)) ** (num_pow)


def _safe_log(array: np.ndarray, min_value: float = 1e-6) -> np.ndarray:
"""Safe log."""
return np.log(np.clip(array, min_value))


def sklearn_binary_focal_objective(
y_true: np.ndarray, y_pred: np.ndarray, gamma: float
) -> tuple[np.ndarray, np.ndarray]:
"""Return grad, hess for binary focal objective."""
pred_prob = expit(y_pred)

# gradient
g1 = pred_prob * (1 - pred_prob)
g2 = y_true + ((-1) ** y_true) * pred_prob
g3 = pred_prob + y_true - 1
g4 = 1 - y_true - ((-1) ** y_true) * pred_prob
g5 = y_true + ((-1) ** y_true) * pred_prob
grad = gamma * g3 * _safe_power(g2, gamma) * _safe_log(g4) + (
(-1) ** y_true
) * _safe_power(g5, (gamma + 1))
# hess
h1 = _safe_power(g2, gamma) + gamma * ((-1) ** y_true) * g3 * _safe_power(
g2, (gamma - 1)
)
h2 = ((-1) ** y_true) * g3 * _safe_power(g2, gamma) / g4
hess = (
(h1 * _safe_log(g4) - h2) * gamma + (gamma + 1) * _safe_power(g5, gamma)
) * g1
return grad, hess


def sklearn_binary_weighted_objective(
y_true: np.ndarray, y_pred: np.ndarray, alpha: float
) -> tuple[np.ndarray, np.ndarray]:
"""Return grad, hess for binary weighted objective."""
pred_prob = expit(y_pred)
grad = -(alpha**y_true) * (y_true - pred_prob)
hess = (alpha**y_true) * pred_prob * (1.0 - pred_prob)
return grad, hess


def binary_focal_objective(
pred: np.ndarray, train_data: Dataset, gamma: float
) -> tuple[np.ndarray, np.ndarray]:
"""Return grad, hess for binary focal objective."""
label = train_data.get_label()
grad, hess = sklearn_binary_focal_objective(
y_true=label,
y_pred=pred,
gamma=gamma,
)
return grad, hess


def binary_weighted_objective(pred: np.ndarray, train_data: Dataset, alpha: float):
"""Return grad, hess for binary weighted objective."""
label = train_data.get_label()
grad, hess = sklearn_binary_weighted_objective(
y_true=label, y_pred=pred, alpha=alpha
)
return grad, hess


def multiclass_focal_objective(
pred: np.ndarray, train_data: Dataset, alpha: float, gamma: float
) -> tuple[np.ndarray, np.ndarray]:
# TODO
return


def multiclass_weighted_objective(
pred: np.ndarray, train_data: Dataset, alpha: float, gamma: float
) -> tuple[str, float, bool]:
# TODO
return
Loading

0 comments on commit 9e2f382

Please sign in to comment.