From 6f9a1f588c7aee325ba3e7f42799c02737ae4622 Mon Sep 17 00:00:00 2001 From: Lukas Drews Date: Wed, 17 Apr 2024 12:50:05 +0200 Subject: [PATCH] enable BetaCalibrator to handle null values;add test --- hiclass/_calibration/BetaCalibrator.py | 8 ++++++ tests/test_calibration.py | 36 ++++++++++++++++++++++++++ 2 files changed, 44 insertions(+) diff --git a/hiclass/_calibration/BetaCalibrator.py b/hiclass/_calibration/BetaCalibrator.py index 22c36dde..00bbb6e9 100644 --- a/hiclass/_calibration/BetaCalibrator.py +++ b/hiclass/_calibration/BetaCalibrator.py @@ -19,7 +19,15 @@ def fit(self, y: np.ndarray, scores: np.ndarray, X: np.ndarray = None): return self scores_1 = np.log(scores) + # replace negative infinity with limit for log(n), n -> -inf + replace_negative_inf = np.log(1e-300) + scores_1 = np.nan_to_num(scores_1, neginf=replace_negative_inf) + scores_2 = -np.log(1 - scores) + # replace positive infinity with limit for log(n), n -> inf + replace_positive_inf = np.log(1e300) + scores_2 = np.nan_to_num(scores_2, posinf=replace_positive_inf) + feature_matrix = np.column_stack((scores_1, scores_2)) lr = LogisticRegression() diff --git a/tests/test_calibration.py b/tests/test_calibration.py index 6f4440f2..6e419895 100644 --- a/tests/test_calibration.py +++ b/tests/test_calibration.py @@ -284,6 +284,42 @@ def test_beta_calibration(binary_calibration_data, binary_test_scores): ) +def test_calibration_methods_can_handle_zeros(binary_test_scores): + cal_scores = np.array( + [ + [0, 1], + [0, 1], + [0, 1], + [1, 0], + [1, 0], + [0, 1], + [0, 1], + [1, 0], + [1, 0], + [1, 0], + [1, 0], + [1, 0], + ], + dtype=np.float32, + ) + + cal_labels = np.array([1, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0]) + assert_array_equal(np.sum(cal_scores, axis=1), np.ones(len(cal_scores))) + + for calibrator in [ + _PlattScaling(), + _IsotonicRegression(), + _BetaCalibrator(), + _InductiveVennAbersCalibrator(), + _CrossVennAbersCalibrator(LogisticRegression()), + ]: + try: + calibrator.fit(cal_labels, cal_scores[:, 1], cal_scores) + calibrator.predict_proba(binary_test_scores[:, 1]) + except ValueError as e: + pytest.fail(repr(e)) + + def test_illegal_calibration_method_raises_error(binary_mock_estimator): with pytest.raises(ValueError, match="abc is not a valid calibration method."): _Calibrator(binary_mock_estimator, method="abc")