Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] refactor select objective and objective core #25

Merged
merged 8 commits into from
Sep 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions examples/multiclass_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@
# Make predictions on the test data
y_pred_focal = clf.predict(X_test)


# Evaluate the model performance using accuracy, log loss, and ROC AUC
# Evaluate models
print("\nClassification Report:")
print(classification_report(y_test, y_pred_focal))
91 changes: 42 additions & 49 deletions imlightgbm/objective/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,52 +13,66 @@ def _safe_log(array: np.ndarray, min_value: float = 1e-6) -> np.ndarray:
return np.log(np.clip(array, min_value, None))


def sklearn_binary_focal_objective(
y_true: np.ndarray, y_pred: np.ndarray, gamma: float
def _weighted_grad_hess(
y_true: np.ndarray, pred_prob: np.ndarray, alpha: float
) -> tuple[np.ndarray, np.ndarray]:
"""Return grad, hess for binary focal objective for sklearn API."""
pred_prob = expit(y_pred)
"""Return weighted grad hess."""
grad = -(alpha**y_true) * (y_true - pred_prob)
hess = (alpha**y_true) * pred_prob * (1.0 - pred_prob)
return grad, hess


# gradient
g1 = pred_prob * (1 - pred_prob)
g2 = y_true + ((-1) ** y_true) * pred_prob
g3 = pred_prob + y_true - 1
g4 = 1 - y_true - ((-1) ** y_true) * pred_prob
g5 = y_true + ((-1) ** y_true) * pred_prob
grad = gamma * g3 * _safe_power(g2, gamma) * _safe_log(g4) + (
def _focal_grad_hess(
y_true: np.ndarray, pred_prob: np.ndarray, gamma: float
) -> tuple[np.ndarray, np.ndarray]:
"""Reurtn focal grad hess."""
prob_product = pred_prob * (1 - pred_prob)
true_diff_pred = y_true + ((-1) ** y_true) * pred_prob
focal_grad_term = pred_prob + y_true - 1
focal_log_term = 1 - y_true - ((-1) ** y_true) * pred_prob
focal_grad_base = y_true + ((-1) ** y_true) * pred_prob
grad = gamma * focal_grad_term * _safe_power(true_diff_pred, gamma) * _safe_log(
focal_log_term
) + ((-1) ** y_true) * _safe_power(focal_grad_base, (gamma + 1))

hess_term1 = _safe_power(true_diff_pred, gamma) + gamma * (
(-1) ** y_true
) * _safe_power(g5, (gamma + 1))
# hess
h1 = _safe_power(g2, gamma) + gamma * ((-1) ** y_true) * g3 * _safe_power(
g2, (gamma - 1)
) * focal_grad_term * _safe_power(true_diff_pred, (gamma - 1))
hess_term2 = (
((-1) ** y_true)
* focal_grad_term
* _safe_power(true_diff_pred, gamma)
/ focal_log_term
)
h2 = ((-1) ** y_true) * g3 * _safe_power(g2, gamma) / g4
hess = (
(h1 * _safe_log(g4) - h2) * gamma + (gamma + 1) * _safe_power(g5, gamma)
) * g1
(hess_term1 * _safe_log(focal_log_term) - hess_term2) * gamma
+ (gamma + 1) * _safe_power(focal_grad_base, gamma)
) * prob_product
return grad, hess


def sklearn_binary_focal_objective(
y_true: np.ndarray, y_pred: np.ndarray, gamma: float
) -> tuple[np.ndarray, np.ndarray]:
"""Return grad, hess for binary focal objective for sklearn API."""
pred_prob = expit(y_pred)
return _focal_grad_hess(y_true=y_true, pred_prob=pred_prob, gamma=gamma)


def sklearn_binary_weighted_objective(
y_true: np.ndarray, y_pred: np.ndarray, alpha: float
) -> tuple[np.ndarray, np.ndarray]:
"""Return grad, hess for binary weighted objective for sklearn API."""
pred_prob = expit(y_pred)
grad = -(alpha**y_true) * (y_true - pred_prob)
hess = (alpha**y_true) * pred_prob * (1.0 - pred_prob)
return grad, hess
return _weighted_grad_hess(y_true=y_true, pred_prob=pred_prob, alpha=alpha)


def binary_focal_objective(
pred: np.ndarray, train_data: Dataset, gamma: float
) -> tuple[np.ndarray, np.ndarray]:
"""Return grad, hess for binary focal objective for engine."""
label = train_data.get_label()
return sklearn_binary_focal_objective(
y_true=label,
y_pred=pred,
gamma=gamma,
)
return sklearn_binary_focal_objective(y_true=label, y_pred=pred, gamma=gamma)


def binary_weighted_objective(
Expand All @@ -78,26 +92,7 @@ def sklearn_multiclass_focal_objective(
"""Return grad, hess for multclass focal objective for sklearn API.."""
pred_prob = softmax(y_pred, axis=1)
y_true_onehot = np.eye(num_class)[y_true.astype(int)]

# gradient
g1 = pred_prob * (1 - pred_prob)
g2 = y_true_onehot + ((-1) ** y_true_onehot) * pred_prob
g3 = pred_prob + y_true_onehot - 1
g4 = 1 - y_true_onehot - ((-1) ** y_true_onehot) * pred_prob
g5 = y_true_onehot + ((-1) ** y_true_onehot) * pred_prob
grad = gamma * g3 * _safe_power(g2, gamma) * _safe_log(g4) + (
(-1) ** y_true_onehot
) * _safe_power(g5, (gamma + 1))
# hess
h1 = _safe_power(g2, gamma) + gamma * ((-1) ** y_true_onehot) * g3 * _safe_power(
g2, (gamma - 1)
)
h2 = ((-1) ** y_true_onehot) * g3 * _safe_power(g2, gamma) / g4
hess = (
(h1 * _safe_log(g4) - h2) * gamma + (gamma + 1) * _safe_power(g5, gamma)
) * g1

return grad, hess
return _focal_grad_hess(y_true=y_true_onehot, pred_prob=pred_prob, gamma=gamma)


def sklearn_multiclass_weighted_objective(
Expand All @@ -109,9 +104,7 @@ def sklearn_multiclass_weighted_objective(
"""Return grad, hess for multclass weighted objective for sklearn API."""
pred_prob = softmax(y_pred, axis=1)
y_true_onehot = np.eye(num_class)[y_true.astype(int)]
grad = -(alpha**y_true_onehot) * (y_true_onehot - pred_prob)
hess = (alpha**y_true_onehot) * pred_prob * (1.0 - pred_prob)
return grad, hess
return _weighted_grad_hess(y_true=y_true_onehot, pred_prob=pred_prob, alpha=alpha)


def multiclass_focal_objective(
Expand Down
16 changes: 13 additions & 3 deletions imlightgbm/objective/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@ def _get_metric(task_enum: SupportedTask, metric: str | None) -> str:
Defaults to auc (binary), auc_mu (multiclass).
"""
metric_mapper: dict[SupportedTask, list[Metric]] = {
SupportedTask.binary: [Metric.auc, Metric.binary_error, Metric.binary_logloss],
SupportedTask.binary: [
Metric.auc,
Metric.binary_error,
Metric.binary_logloss,
],
SupportedTask.multiclass: [
Metric.auc_mu,
Metric.multi_logloss,
Expand Down Expand Up @@ -61,8 +65,14 @@ def _get_objective(
"""Retrieve the appropriate objective function based on task and objective type."""
objective_mapper: dict[SupportedTask, dict[Objective, ObjLike]] = {
SupportedTask.binary: {
Objective.binary_focal: partial(binary_focal_objective, gamma=gamma),
Objective.binary_weighted: partial(binary_weighted_objective, alpha=alpha),
Objective.binary_focal: partial(
binary_focal_objective,
gamma=gamma,
),
Objective.binary_weighted: partial(
binary_weighted_objective,
alpha=alpha,
),
},
SupportedTask.multiclass: {
Objective.multiclass_focal: partial(
Expand Down
106 changes: 73 additions & 33 deletions imlightgbm/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ def __init__(
self,
*,
objective: str,
alpha: float = ALPHA_DEFAULT,
gamma: float = GAMMA_DEFAULT,
alpha: float | None = None,
gamma: float | None = None,
boosting_type: str = "gbdt",
num_leaves: int = 31,
max_depth: int = -1,
Expand Down Expand Up @@ -61,45 +61,19 @@ def __init__(
other parameters:
Check http://lightgbm.readthedocs.io/en/latest/Parameters.html for more details.
"""
validate_positive_number(alpha)
validate_positive_number(gamma)

self.alpha = alpha
self.gamma = gamma
self.num_class = num_class
_objective = Objective.get(objective)
if _objective in {
Objective.multiclass_focal,
Objective.multiclass_weighted,
} and not isinstance(num_class, int):
raise ValueError("num_class must be provided")

_objective_mapper: dict[Objective, _SklearnObjLike] = {
Objective.binary_focal: lambda y_true,
y_pred: sklearn_binary_focal_objective(
y_true=y_true, y_pred=y_pred, gamma=gamma
),
Objective.binary_weighted: lambda y_true,
y_pred: sklearn_binary_weighted_objective(
y_true=y_true, y_pred=y_pred, alpha=alpha
),
Objective.multiclass_focal: lambda y_true,
y_pred: sklearn_multiclass_focal_objective(
y_true=y_true, y_pred=y_pred, gamma=gamma, num_class=num_class
),
Objective.multiclass_weighted: lambda y_true,
y_pred: sklearn_multiclass_weighted_objective(
y_true=y_true, y_pred=y_pred, alpha=alpha, num_class=num_class
),
}
_objective_enum: Objective = Objective.get(objective)
self.__alpha_select(objective=_objective_enum, alpha=alpha)
self.__gamma_select(objective=_objective_enum, gamma=gamma)
_objective = self.__objective_select(objective_enum=_objective_enum)
super().__init__(
boosting_type=boosting_type,
num_leaves=num_leaves,
max_depth=max_depth,
learning_rate=learning_rate,
n_estimators=n_estimators,
subsample_for_bin=subsample_for_bin,
objective=_objective_mapper[_objective],
objective=_objective,
class_weight=class_weight,
min_split_gain=min_split_gain,
min_child_weight=min_child_weight,
Expand Down Expand Up @@ -152,3 +126,69 @@ def predict(
return expit(_predict)

predict.__doc__ = LGBMClassifier.predict.__doc__

def __objective_select(self, objective_enum: Objective) -> _SklearnObjLike:
"""Select objective function."""
if objective_enum in {
Objective.multiclass_focal,
Objective.multiclass_weighted,
} and not isinstance(self.num_class, int):
raise ValueError("num_class must be provided")

_objective_mapper: dict[Objective, _SklearnObjLike] = {
Objective.binary_focal: lambda y_true,
y_pred: sklearn_binary_focal_objective(
y_true=y_true, y_pred=y_pred, gamma=self.gamma
),
Objective.binary_weighted: lambda y_true,
y_pred: sklearn_binary_weighted_objective(
y_true=y_true, y_pred=y_pred, alpha=self.alpha
),
Objective.multiclass_focal: lambda y_true,
y_pred: sklearn_multiclass_focal_objective(
y_true=y_true, y_pred=y_pred, gamma=self.gamma, num_class=self.num_class
),
Objective.multiclass_weighted: lambda y_true,
y_pred: sklearn_multiclass_weighted_objective(
y_true=y_true, y_pred=y_pred, alpha=self.alpha, num_class=self.num_class
),
}
return _objective_mapper[objective_enum]

def __param_select(
self,
objective: Objective,
param: float | None,
valid_objectives: set[Objective],
default_value: float,
param_name: str,
) -> None:
"""General method to select appropriate parameter (alpha or gamma)."""
if objective not in valid_objectives:
setattr(self, param_name, None)
return
if param:
validate_positive_number(param)
setattr(self, param_name, param)
return
setattr(self, param_name, default_value)

def __alpha_select(self, objective: Objective, alpha: float | None) -> None:
"""Select appropriate alpha."""
self.__param_select(
objective=objective,
param=alpha,
valid_objectives={Objective.binary_weighted, Objective.multiclass_weighted},
default_value=ALPHA_DEFAULT,
param_name="alpha",
)

def __gamma_select(self, objective: Objective, gamma: float | None) -> None:
"""Select appropriate gamma."""
self.__param_select(
objective=objective,
param=gamma,
valid_objectives={Objective.binary_focal, Objective.multiclass_focal},
default_value=GAMMA_DEFAULT,
param_name="gamma",
)