diff --git a/python-package/lightgbm/sklearn.py b/python-package/lightgbm/sklearn.py index 7e909342c01f..c71c233df908 100644 --- a/python-package/lightgbm/sklearn.py +++ b/python-package/lightgbm/sklearn.py @@ -1103,6 +1103,8 @@ def fit( # type: ignore[override] self._classes = self._le.classes_ self._n_classes = len(self._classes) # type: ignore[arg-type] + if self.objective is None: + self._objective = None # adjust eval metrics to match whether binary or multiclass # classification is being performed diff --git a/tests/python_package_test/test_sklearn.py b/tests/python_package_test/test_sklearn.py index e41719845c0a..2247c9a512d2 100644 --- a/tests/python_package_test/test_sklearn.py +++ b/tests/python_package_test/test_sklearn.py @@ -1561,3 +1561,20 @@ def test_ranking_minimally_works_with_all_all_accepted_data_types(X_type, y_type ) preds = model.predict(X) assert spearmanr(preds, y).correlation >= 0.99 + + +def test_classifier_fit_detects_classes_every_time(): + rng = np.random.default_rng(seed=123) + nrows = 1000 + ncols = 20 + + X = rng.standard_normal(size=(nrows, ncols)) + y_bin = (rng.random(size=nrows) <= .3).astype(np.float64) + y_multi = rng.integers(4, size=nrows) + + model = lgb.LGBMClassifier(verbose=-1) + for _ in range(2): + model.fit(X, y_multi) + assert model.objective_ == "multiclass" + model.fit(X, y_bin) + assert model.objective_ == "binary"