diff --git a/examples/multiclass_sklearn.py b/examples/multiclass_sklearn.py index 9a37ea9..9013668 100644 --- a/examples/multiclass_sklearn.py +++ b/examples/multiclass_sklearn.py @@ -35,8 +35,6 @@ # Make predictions on the test data y_pred_focal = clf.predict(X_test) - -# Evaluate the model performance using accuracy, log loss, and ROC AUC # Evaluate models print("\nClassification Report:") print(classification_report(y_test, y_pred_focal)) diff --git a/imlightgbm/objective/core.py b/imlightgbm/objective/core.py index 8ef75a1..edad85e 100644 --- a/imlightgbm/objective/core.py +++ b/imlightgbm/objective/core.py @@ -13,40 +13,58 @@ def _safe_log(array: np.ndarray, min_value: float = 1e-6) -> np.ndarray: return np.log(np.clip(array, min_value, None)) -def sklearn_binary_focal_objective( - y_true: np.ndarray, y_pred: np.ndarray, gamma: float +def _weighted_grad_hess( + y_true: np.ndarray, pred_prob: np.ndarray, alpha: float ) -> tuple[np.ndarray, np.ndarray]: - """Return grad, hess for binary focal objective for sklearn API.""" - pred_prob = expit(y_pred) + """Return weighted grad hess.""" + grad = -(alpha**y_true) * (y_true - pred_prob) + hess = (alpha**y_true) * pred_prob * (1.0 - pred_prob) + return grad, hess + - # gradient - g1 = pred_prob * (1 - pred_prob) - g2 = y_true + ((-1) ** y_true) * pred_prob - g3 = pred_prob + y_true - 1 - g4 = 1 - y_true - ((-1) ** y_true) * pred_prob - g5 = y_true + ((-1) ** y_true) * pred_prob - grad = gamma * g3 * _safe_power(g2, gamma) * _safe_log(g4) + ( +def _focal_grad_hess( + y_true: np.ndarray, pred_prob: np.ndarray, gamma: float +) -> tuple[np.ndarray, np.ndarray]: + """Reurtn focal grad hess.""" + prob_product = pred_prob * (1 - pred_prob) + true_diff_pred = y_true + ((-1) ** y_true) * pred_prob + focal_grad_term = pred_prob + y_true - 1 + focal_log_term = 1 - y_true - ((-1) ** y_true) * pred_prob + focal_grad_base = y_true + ((-1) ** y_true) * pred_prob + grad = gamma * focal_grad_term * _safe_power(true_diff_pred, gamma) * _safe_log( + focal_log_term + ) + ((-1) ** y_true) * _safe_power(focal_grad_base, (gamma + 1)) + + hess_term1 = _safe_power(true_diff_pred, gamma) + gamma * ( (-1) ** y_true - ) * _safe_power(g5, (gamma + 1)) - # hess - h1 = _safe_power(g2, gamma) + gamma * ((-1) ** y_true) * g3 * _safe_power( - g2, (gamma - 1) + ) * focal_grad_term * _safe_power(true_diff_pred, (gamma - 1)) + hess_term2 = ( + ((-1) ** y_true) + * focal_grad_term + * _safe_power(true_diff_pred, gamma) + / focal_log_term ) - h2 = ((-1) ** y_true) * g3 * _safe_power(g2, gamma) / g4 hess = ( - (h1 * _safe_log(g4) - h2) * gamma + (gamma + 1) * _safe_power(g5, gamma) - ) * g1 + (hess_term1 * _safe_log(focal_log_term) - hess_term2) * gamma + + (gamma + 1) * _safe_power(focal_grad_base, gamma) + ) * prob_product return grad, hess +def sklearn_binary_focal_objective( + y_true: np.ndarray, y_pred: np.ndarray, gamma: float +) -> tuple[np.ndarray, np.ndarray]: + """Return grad, hess for binary focal objective for sklearn API.""" + pred_prob = expit(y_pred) + return _focal_grad_hess(y_true=y_true, pred_prob=pred_prob, gamma=gamma) + + def sklearn_binary_weighted_objective( y_true: np.ndarray, y_pred: np.ndarray, alpha: float ) -> tuple[np.ndarray, np.ndarray]: """Return grad, hess for binary weighted objective for sklearn API.""" pred_prob = expit(y_pred) - grad = -(alpha**y_true) * (y_true - pred_prob) - hess = (alpha**y_true) * pred_prob * (1.0 - pred_prob) - return grad, hess + return _weighted_grad_hess(y_true=y_true, pred_prob=pred_prob, alpha=alpha) def binary_focal_objective( @@ -54,11 +72,7 @@ def binary_focal_objective( ) -> tuple[np.ndarray, np.ndarray]: """Return grad, hess for binary focal objective for engine.""" label = train_data.get_label() - return sklearn_binary_focal_objective( - y_true=label, - y_pred=pred, - gamma=gamma, - ) + return sklearn_binary_focal_objective(y_true=label, y_pred=pred, gamma=gamma) def binary_weighted_objective( @@ -78,26 +92,7 @@ def sklearn_multiclass_focal_objective( """Return grad, hess for multclass focal objective for sklearn API..""" pred_prob = softmax(y_pred, axis=1) y_true_onehot = np.eye(num_class)[y_true.astype(int)] - - # gradient - g1 = pred_prob * (1 - pred_prob) - g2 = y_true_onehot + ((-1) ** y_true_onehot) * pred_prob - g3 = pred_prob + y_true_onehot - 1 - g4 = 1 - y_true_onehot - ((-1) ** y_true_onehot) * pred_prob - g5 = y_true_onehot + ((-1) ** y_true_onehot) * pred_prob - grad = gamma * g3 * _safe_power(g2, gamma) * _safe_log(g4) + ( - (-1) ** y_true_onehot - ) * _safe_power(g5, (gamma + 1)) - # hess - h1 = _safe_power(g2, gamma) + gamma * ((-1) ** y_true_onehot) * g3 * _safe_power( - g2, (gamma - 1) - ) - h2 = ((-1) ** y_true_onehot) * g3 * _safe_power(g2, gamma) / g4 - hess = ( - (h1 * _safe_log(g4) - h2) * gamma + (gamma + 1) * _safe_power(g5, gamma) - ) * g1 - - return grad, hess + return _focal_grad_hess(y_true=y_true_onehot, pred_prob=pred_prob, gamma=gamma) def sklearn_multiclass_weighted_objective( @@ -109,9 +104,7 @@ def sklearn_multiclass_weighted_objective( """Return grad, hess for multclass weighted objective for sklearn API.""" pred_prob = softmax(y_pred, axis=1) y_true_onehot = np.eye(num_class)[y_true.astype(int)] - grad = -(alpha**y_true_onehot) * (y_true_onehot - pred_prob) - hess = (alpha**y_true_onehot) * pred_prob * (1.0 - pred_prob) - return grad, hess + return _weighted_grad_hess(y_true=y_true_onehot, pred_prob=pred_prob, alpha=alpha) def multiclass_focal_objective( diff --git a/imlightgbm/objective/engine.py b/imlightgbm/objective/engine.py index 0654320..ad98a2e 100644 --- a/imlightgbm/objective/engine.py +++ b/imlightgbm/objective/engine.py @@ -33,7 +33,11 @@ def _get_metric(task_enum: SupportedTask, metric: str | None) -> str: Defaults to auc (binary), auc_mu (multiclass). """ metric_mapper: dict[SupportedTask, list[Metric]] = { - SupportedTask.binary: [Metric.auc, Metric.binary_error, Metric.binary_logloss], + SupportedTask.binary: [ + Metric.auc, + Metric.binary_error, + Metric.binary_logloss, + ], SupportedTask.multiclass: [ Metric.auc_mu, Metric.multi_logloss, @@ -61,8 +65,14 @@ def _get_objective( """Retrieve the appropriate objective function based on task and objective type.""" objective_mapper: dict[SupportedTask, dict[Objective, ObjLike]] = { SupportedTask.binary: { - Objective.binary_focal: partial(binary_focal_objective, gamma=gamma), - Objective.binary_weighted: partial(binary_weighted_objective, alpha=alpha), + Objective.binary_focal: partial( + binary_focal_objective, + gamma=gamma, + ), + Objective.binary_weighted: partial( + binary_weighted_objective, + alpha=alpha, + ), }, SupportedTask.multiclass: { Objective.multiclass_focal: partial( diff --git a/imlightgbm/sklearn.py b/imlightgbm/sklearn.py index 7b485a2..5a1bc92 100644 --- a/imlightgbm/sklearn.py +++ b/imlightgbm/sklearn.py @@ -26,8 +26,8 @@ def __init__( self, *, objective: str, - alpha: float = ALPHA_DEFAULT, - gamma: float = GAMMA_DEFAULT, + alpha: float | None = None, + gamma: float | None = None, boosting_type: str = "gbdt", num_leaves: int = 31, max_depth: int = -1, @@ -61,37 +61,11 @@ def __init__( other parameters: Check http://lightgbm.readthedocs.io/en/latest/Parameters.html for more details. """ - validate_positive_number(alpha) - validate_positive_number(gamma) - - self.alpha = alpha - self.gamma = gamma self.num_class = num_class - _objective = Objective.get(objective) - if _objective in { - Objective.multiclass_focal, - Objective.multiclass_weighted, - } and not isinstance(num_class, int): - raise ValueError("num_class must be provided") - - _objective_mapper: dict[Objective, _SklearnObjLike] = { - Objective.binary_focal: lambda y_true, - y_pred: sklearn_binary_focal_objective( - y_true=y_true, y_pred=y_pred, gamma=gamma - ), - Objective.binary_weighted: lambda y_true, - y_pred: sklearn_binary_weighted_objective( - y_true=y_true, y_pred=y_pred, alpha=alpha - ), - Objective.multiclass_focal: lambda y_true, - y_pred: sklearn_multiclass_focal_objective( - y_true=y_true, y_pred=y_pred, gamma=gamma, num_class=num_class - ), - Objective.multiclass_weighted: lambda y_true, - y_pred: sklearn_multiclass_weighted_objective( - y_true=y_true, y_pred=y_pred, alpha=alpha, num_class=num_class - ), - } + _objective_enum: Objective = Objective.get(objective) + self.__alpha_select(objective=_objective_enum, alpha=alpha) + self.__gamma_select(objective=_objective_enum, gamma=gamma) + _objective = self.__objective_select(objective_enum=_objective_enum) super().__init__( boosting_type=boosting_type, num_leaves=num_leaves, @@ -99,7 +73,7 @@ def __init__( learning_rate=learning_rate, n_estimators=n_estimators, subsample_for_bin=subsample_for_bin, - objective=_objective_mapper[_objective], + objective=_objective, class_weight=class_weight, min_split_gain=min_split_gain, min_child_weight=min_child_weight, @@ -152,3 +126,69 @@ def predict( return expit(_predict) predict.__doc__ = LGBMClassifier.predict.__doc__ + + def __objective_select(self, objective_enum: Objective) -> _SklearnObjLike: + """Select objective function.""" + if objective_enum in { + Objective.multiclass_focal, + Objective.multiclass_weighted, + } and not isinstance(self.num_class, int): + raise ValueError("num_class must be provided") + + _objective_mapper: dict[Objective, _SklearnObjLike] = { + Objective.binary_focal: lambda y_true, + y_pred: sklearn_binary_focal_objective( + y_true=y_true, y_pred=y_pred, gamma=self.gamma + ), + Objective.binary_weighted: lambda y_true, + y_pred: sklearn_binary_weighted_objective( + y_true=y_true, y_pred=y_pred, alpha=self.alpha + ), + Objective.multiclass_focal: lambda y_true, + y_pred: sklearn_multiclass_focal_objective( + y_true=y_true, y_pred=y_pred, gamma=self.gamma, num_class=self.num_class + ), + Objective.multiclass_weighted: lambda y_true, + y_pred: sklearn_multiclass_weighted_objective( + y_true=y_true, y_pred=y_pred, alpha=self.alpha, num_class=self.num_class + ), + } + return _objective_mapper[objective_enum] + + def __param_select( + self, + objective: Objective, + param: float | None, + valid_objectives: set[Objective], + default_value: float, + param_name: str, + ) -> None: + """General method to select appropriate parameter (alpha or gamma).""" + if objective not in valid_objectives: + setattr(self, param_name, None) + return + if param: + validate_positive_number(param) + setattr(self, param_name, param) + return + setattr(self, param_name, default_value) + + def __alpha_select(self, objective: Objective, alpha: float | None) -> None: + """Select appropriate alpha.""" + self.__param_select( + objective=objective, + param=alpha, + valid_objectives={Objective.binary_weighted, Objective.multiclass_weighted}, + default_value=ALPHA_DEFAULT, + param_name="alpha", + ) + + def __gamma_select(self, objective: Objective, gamma: float | None) -> None: + """Select appropriate gamma.""" + self.__param_select( + objective=objective, + param=gamma, + valid_objectives={Objective.binary_focal, Objective.multiclass_focal}, + default_value=GAMMA_DEFAULT, + param_name="gamma", + )