diff --git a/supervised/tuner/optuna/lightgbm.py b/supervised/tuner/optuna/lightgbm.py index c12f5681..6e015b17 100644 --- a/supervised/tuner/optuna/lightgbm.py +++ b/supervised/tuner/optuna/lightgbm.py @@ -139,16 +139,17 @@ def __call__(self, trial): pruning_callback = optuna.integration.LightGBMPruningCallback( trial, metric_name, "validation" ) + early_stopping_callback = lgb.early_stopping( + self.early_stopping_rounds, verbose=False + ) gbm = lgb.train( param, self.dtrain, valid_sets=[self.dvalid], valid_names=["validation"], - verbose_eval=False, - callbacks=[pruning_callback], + callbacks=[pruning_callback, early_stopping_callback], num_boost_round=self.rounds, - early_stopping_rounds=self.early_stopping_rounds, feval=self.custom_eval_metric, )