Skip to content

Commit

Permalink
UPD: add DeprecationWarning for partial_fit
Browse files Browse the repository at this point in the history
Co-authored-by: Vincent Blot <[email protected]>
  • Loading branch information
thibaultcordier and vincentblot28 authored Dec 21, 2023
1 parent fc0522f commit 5a1c3ab
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 0 deletions.
7 changes: 7 additions & 0 deletions mapie/regression/time_series_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,13 @@ def partial_fit(
If the length of ``y`` is greater than
the length of the training set.
"""
warnings.warn(
"WARNING: Deprecated method. "
+ "The method \"partial_fit\" is outdated. "
+ "Prefer to use \"update\" instead to keep "
+ "the same behavior in the future.",
DeprecationWarning
)
check_is_fitted(self, self.fit_attributes)
X, y = cast(NDArray, X), cast(NDArray, y)
m, n = len(X), len(self.conformity_scores_)
Expand Down
10 changes: 10 additions & 0 deletions mapie/tests/test_time_series_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,16 @@ def test_aci__get_alpha_with_unknown_alpha() -> None:
np.testing.assert_allclose(mapie_ts_reg.current_alpha[0.2], 0.3, rtol=1e-3)


def test_deprecated_partial_fit_warning(method: str) -> None:
"""Test that a warning is raised if use partial_fit"""
mapie_ts_reg = MapieTimeSeriesRegressor(method='enbpi', cv=-1)
mapie_ts_reg.fit(X_toy, y_toy)
with pytest.warns(
DeprecationWarning, match=r".*WARNING: Deprecated method.*"
):
mapie_ts_reg = mapie_ts_reg.partial_fit(X_toy, y_toy)


@pytest.mark.parametrize("method", ["wrong_method"])
def test_method_error_in_update(monkeypatch: Any, method: str) -> None:
"""Test else condition for the method in .update"""
Expand Down

0 comments on commit 5a1c3ab

Please sign in to comment.