Skip to content

Commit

Permalink
[Feature] remove feval (#5)
Browse files Browse the repository at this point in the history
  • Loading branch information
RektPunk authored Sep 17, 2024
1 parent 91bcdf7 commit 4e19e88
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 30 deletions.
6 changes: 2 additions & 4 deletions imlightgbm/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,12 @@ def train(
keep_training_booster: bool = False,
callbacks: list[Callable] | None = None,
) -> lgb.Booster:
_params, feval = set_params(params=params, train_set=train_set)
_params = set_params(params=params, train_set=train_set)
return lgb.train(
params=_params,
train_set=train_set,
valid_sets=valid_sets,
valid_names=valid_names,
feval=feval,
num_boost_round=num_boost_round,
init_model=init_model,
feature_name=feature_name,
Expand Down Expand Up @@ -61,7 +60,7 @@ def cv(
eval_train_metric: bool = False,
return_cvbooster: bool = False,
) -> dict[str, list[float] | lgb.CVBooster]:
_params, feval = set_params(params=params, train_set=train_set)
_params = set_params(params=params, train_set=train_set)
return lgb.cv(
params=_params,
train_set=train_set,
Expand All @@ -71,7 +70,6 @@ def cv(
stratified=stratified,
shuffle=shuffle,
metrics=metrics,
feval=feval,
init_model=init_model,
feature_name=feature_name,
categorical_feature=categorical_feature,
Expand Down
57 changes: 31 additions & 26 deletions imlightgbm/objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,38 +13,30 @@
ALPHA_DEFAULT: float = 0.25
GAMMA_DEFAULT: float = 2.0
OBJECTIVE_STR: str = "objective"
IS_HIGHER_BETTER = False
METRIC_STR: str = "metric"
IS_HIGHER_BETTER: bool = False


def _power(num_base: np.ndarray, num_pow: float):
"""Safe power."""
return np.sign(num_base) * (np.abs(num_base)) ** (num_pow)


def _log(array: np.ndarray, is_prob: bool = False) -> np.ndarray:
"""Safe log."""
_upper = 1 if is_prob else None
return np.log(np.clip(array, 1e-6, _upper))


def _sigmoid(x: np.ndarray) -> np.ndarray:
"""Convert raw predictions to probabilities in binary task"""
"""Convert raw predictions to probabilities in binary task."""
return 1 / (1 + np.exp(-x))


def binary_focal_eval(
pred: np.ndarray, train_data: Dataset, alpha: float, gamma: float
) -> tuple[str, float, bool]:
label = train_data.get_label()
pred_prob = _sigmoid(pred)
p_t = np.where(label == 1, pred_prob, 1 - pred_prob)
loss = -alpha * ((1 - p_t) ** gamma) * _log(p_t, True)

focal_loss = np.mean(loss)
return "binary_focal", focal_loss, IS_HIGHER_BETTER


def binary_focal_objective(
pred: np.ndarray, train_data: Dataset, gamma: float
) -> tuple[np.ndarray, np.ndarray]:
"""Return binary focal objective."""
label = train_data.get_label()
pred_prob = _sigmoid(pred)

Expand All @@ -65,11 +57,17 @@ def binary_focal_objective(
return grad, hess


def multiclass_focal_eval(
def binary_focal_eval(
pred: np.ndarray, train_data: Dataset, alpha: float, gamma: float
) -> tuple[str, float, bool]:
# TODO
return
"""Return binary focal eval."""
label = train_data.get_label()
pred_prob = _sigmoid(pred)
p_t = np.where(label == 1, pred_prob, 1 - pred_prob)
loss = -alpha * ((1 - p_t) ** gamma) * _log(p_t, True)

focal_loss = np.mean(loss)
return "focal", focal_loss, IS_HIGHER_BETTER


def multiclass_focal_objective(
Expand All @@ -79,9 +77,17 @@ def multiclass_focal_objective(
return


def set_fobj_feval(
def multiclass_focal_eval(
pred: np.ndarray, train_data: Dataset, alpha: float, gamma: float
) -> tuple[str, float, bool]:
# TODO
return


def _set_fobj_feval(
train_set: Dataset, alpha: float, gamma: float
) -> tuple[ObjLike, EvalLike]:
"""Return obj and eval with respect to task type."""
inferred_task = type_of_target(train_set.get_label())
if inferred_task not in {"binary", "multiclass"}:
raise ValueError(
Expand All @@ -92,18 +98,17 @@ def set_fobj_feval(
"multiclass": partial(multiclass_focal_objective, alpha=alpha, gamma=gamma),
}
eval_mapper: dict[str, EvalLike] = {
"binary": partial(binary_focal_eval, alpha=alpha, gamma=gamma),
"multiclass": partial(multiclass_focal_eval, alpha=alpha, gamma=gamma),
"binary": "binary_logloss",
"multiclass": "multi_logloss",
}
fobj = objective_mapper[inferred_task]
feval = eval_mapper[inferred_task]

return fobj, feval


def set_params(
params: dict[str, Any], train_set: Dataset
) -> tuple[dict[str, Any], EvalLike]:
def set_params(params: dict[str, Any], train_set: Dataset) -> dict[str, Any]:
"""Set params and eval finction, objective in params."""
_params = deepcopy(params)
if OBJECTIVE_STR in _params:
logger.warning(f"'{OBJECTIVE_STR}' exists in params will not used.")
Expand All @@ -112,6 +117,6 @@ def set_params(
_alpha = _params.pop("alpha", ALPHA_DEFAULT)
_gamma = _params.pop("gamma", GAMMA_DEFAULT)

fobj, feval = set_fobj_feval(train_set=train_set, alpha=_alpha, gamma=_gamma)
_params.update({OBJECTIVE_STR: fobj})
return _params, feval
fobj, feval = _set_fobj_feval(train_set=train_set, alpha=_alpha, gamma=_gamma)
_params.update({OBJECTIVE_STR: fobj, METRIC_STR: feval})
return _params

0 comments on commit 4e19e88

Please sign in to comment.