Skip to content

Commit

Permalink
enable BetaCalibrator to handle null values;add test
Browse files Browse the repository at this point in the history
  • Loading branch information
LukasDrews97 committed Apr 17, 2024
1 parent ad2f735 commit 6f9a1f5
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 0 deletions.
8 changes: 8 additions & 0 deletions hiclass/_calibration/BetaCalibrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,15 @@ def fit(self, y: np.ndarray, scores: np.ndarray, X: np.ndarray = None):
return self

scores_1 = np.log(scores)
# replace negative infinity with limit for log(n), n -> -inf
replace_negative_inf = np.log(1e-300)
scores_1 = np.nan_to_num(scores_1, neginf=replace_negative_inf)

scores_2 = -np.log(1 - scores)
# replace positive infinity with limit for log(n), n -> inf
replace_positive_inf = np.log(1e300)
scores_2 = np.nan_to_num(scores_2, posinf=replace_positive_inf)

feature_matrix = np.column_stack((scores_1, scores_2))

lr = LogisticRegression()
Expand Down
36 changes: 36 additions & 0 deletions tests/test_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,42 @@ def test_beta_calibration(binary_calibration_data, binary_test_scores):
)


def test_calibration_methods_can_handle_zeros(binary_test_scores):
cal_scores = np.array(
[
[0, 1],
[0, 1],
[0, 1],
[1, 0],
[1, 0],
[0, 1],
[0, 1],
[1, 0],
[1, 0],
[1, 0],
[1, 0],
[1, 0],
],
dtype=np.float32,
)

cal_labels = np.array([1, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0])
assert_array_equal(np.sum(cal_scores, axis=1), np.ones(len(cal_scores)))

for calibrator in [
_PlattScaling(),
_IsotonicRegression(),
_BetaCalibrator(),
_InductiveVennAbersCalibrator(),
_CrossVennAbersCalibrator(LogisticRegression()),
]:
try:
calibrator.fit(cal_labels, cal_scores[:, 1], cal_scores)
calibrator.predict_proba(binary_test_scores[:, 1])
except ValueError as e:
pytest.fail(repr(e))


def test_illegal_calibration_method_raises_error(binary_mock_estimator):
with pytest.raises(ValueError, match="abc is not a valid calibration method."):
_Calibrator(binary_mock_estimator, method="abc")
Expand Down

0 comments on commit 6f9a1f5

Please sign in to comment.