Skip to content

Commit

Permalink
fix: Make reports pickable out-of-the-box (#1179)
Browse files Browse the repository at this point in the history
This PR makes sure that the `EstimatorReport` and
`CrossValidationReport` are pickable object.

Right now, because we kept the step of the `Progress` from rich that has
a lock, the object are not pickable. I reset the `_parent_progress` and
`_progress_info` to `None` in the finally of the decorator.

An alternative is to change the `set_state` but here I think that we can
just applied this clean-up instead.
  • Loading branch information
glemaitre authored Jan 21, 2025
1 parent 4ddf226 commit 7ff192d
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 53 deletions.
4 changes: 3 additions & 1 deletion skore/src/skore/sklearn/_cross_validation/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,9 @@ def __init__(
cv_splitter=None,
n_jobs=None,
):
self._parent_progress = None # used for the different progress bars
# used to know if a parent launch a progress bar manager
self._parent_progress = None

self._estimator = clone(estimator)

# private storage to be able to invalidate the cache when the user alters
Expand Down
5 changes: 2 additions & 3 deletions skore/src/skore/sklearn/_estimator/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ def __init__(
X_test=None,
y_test=None,
):
self._parent_progress = None # used to display progress bar
# used to know if a parent launch a progress bar manager
self._parent_progress = None

if fit == "auto":
try:
Expand All @@ -129,8 +130,6 @@ def __init__(
self._X_test = X_test
self._y_test = y_test

self._parent_progress = None

self._initialize_state()

def _initialize_state(self):
Expand Down
3 changes: 3 additions & 0 deletions skore/src/skore/utils/_progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ def wrapper(*args, **kwargs):
task, completed=progress.tasks[task].total, refresh=True
)
progress.stop()
# clean up to make object pickable
self_obj._parent_progress = None
self_obj._progress_info = None

return wrapper

Expand Down
21 changes: 17 additions & 4 deletions skore/tests/unit/sklearn/test_cross_validation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import re

import joblib
import numpy as np
import pandas as pd
import pytest
Expand Down Expand Up @@ -182,7 +183,7 @@ def test_cross_validation_report_repr(binary_classification_data):
],
)
@pytest.mark.parametrize("n_jobs", [None, 1, 2])
def test_estimator_report_cache_predictions(
def test_cross_validation_report_cache_predictions(
request, fixture_name, expected_n_keys, n_jobs
):
"""Check that calling cache_predictions fills the cache."""
Expand All @@ -202,6 +203,18 @@ def test_estimator_report_cache_predictions(
assert estimator_report._cache == {}


def test_cross_validation_report_pickle(tmp_path, binary_classification_data):
"""Check that we can pickle an cross-validation report.
In particular, the progress bar from rich are pickable, therefore we trigger
the progress bar to be able to test that the progress bar is pickable.
"""
estimator, X, y = binary_classification_data
report = CrossValidationReport(estimator, X, y, cv_splitter=2)
report.cache_predictions()
joblib.dump(report, tmp_path / "report.joblib")


########################################################################################
# Check the plot methods
########################################################################################
Expand Down Expand Up @@ -265,7 +278,7 @@ def test_cross_validation_report_display_regression(pyplot, regression_data, dis
########################################################################################


def test_estimator_report_metrics_help(capsys, binary_classification_data):
def test_cross_validation_report_metrics_help(capsys, binary_classification_data):
"""Check that the help method writes to the console."""
estimator, X, y = binary_classification_data
report = CrossValidationReport(estimator, X, y, cv_splitter=2)
Expand All @@ -275,7 +288,7 @@ def test_estimator_report_metrics_help(capsys, binary_classification_data):
assert "Available metrics methods" in captured.out


def test_estimator_report_metrics_repr(binary_classification_data):
def test_cross_validation_report_metrics_repr(binary_classification_data):
"""Check that __repr__ returns a string starting with the expected prefix."""
estimator, X, y = binary_classification_data
report = CrossValidationReport(estimator, X, y, cv_splitter=2)
Expand Down Expand Up @@ -618,7 +631,7 @@ def test_cross_validation_report_report_metrics_error_scoring_strings(
report.metrics.report_metrics(scoring=[scoring])


def test_estimator_report_report_metrics_with_scorer(regression_data):
def test_cross_validation_report_report_metrics_with_scorer(regression_data):
"""Check that we can pass scikit-learn scorer with different parameters to
the `report_metrics` method."""
estimator, X, y = regression_data
Expand Down
12 changes: 12 additions & 0 deletions skore/tests/unit/sklearn/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,18 @@ def test_estimator_report_cache_predictions(
assert report._cache.keys() == stored_cache.keys()


def test_estimator_report_pickle(tmp_path, binary_classification_data):
"""Check that we can pickle an estimator report.
In particular, the progress bar from rich are pickable, therefore we trigger
the progress bar to be able to test that the progress bar is pickable.
"""
estimator, X_test, y_test = binary_classification_data
report = EstimatorReport(estimator, X_test=X_test, y_test=y_test)
report.cache_predictions()
joblib.dump(report, tmp_path / "report.joblib")


########################################################################################
# Check the plot methods
########################################################################################
Expand Down
71 changes: 26 additions & 45 deletions skore/tests/unit/utils/test_progress_bar.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import pytest
from rich.progress import Progress
from skore.utils._progress_bar import progress_decorator


Expand All @@ -17,23 +16,18 @@ def run(self, iterations=5):
task = self._progress_info["current_task"]
progress.update(task, total=iterations)

for _ in range(iterations):
for i in range(iterations):
progress.update(task, advance=1)
self._standalone_n_calls = i
return "done"

task = StandaloneTask()
result = task.run()

assert result == "done"
assert task._progress_info is not None
assert isinstance(task._progress_info["current_progress"], Progress)
assert task._standalone_n_calls == 4
assert task._progress_info is None
assert task._parent_progress is None
assert (
task._progress_info["current_progress"]
.tasks[task._progress_info["current_task"]]
.completed
== 5
)


def test_nested_progress():
Expand All @@ -50,15 +44,17 @@ def run(self, iterations=3):
task = self._progress_info["current_task"]
progress.update(task, total=iterations)

child = ChildTask(self._progress_info["current_progress"])
for _ in range(iterations):
child.run()
self._child = ChildTask()
for i in range(iterations):
self._child._parent_progress = progress
self._child.run()
progress.update(task, advance=1)
self._parent_n_calls = i
return "done"

class ChildTask:
def __init__(self, parent_progress):
self._parent_progress = parent_progress
def __init__(self):
self._parent_progress = None
self._progress_info = None

@progress_decorator("Child Task")
Expand All @@ -67,22 +63,21 @@ def run(self, iterations=2):
task = self._progress_info["current_task"]
progress.update(task, total=iterations)

for _ in range(iterations):
for i in range(iterations):
progress.update(task, advance=1)
self._child_n_calls = i
return "done"

parent = ParentTask()
result = parent.run()

assert result == "done"
assert parent._progress_info is not None
assert isinstance(parent._progress_info["current_progress"], Progress)
assert (
parent._progress_info["current_progress"]
.tasks[parent._progress_info["current_task"]]
.completed
== 3
)
assert parent._progress_info is None
assert parent._parent_progress is None
assert parent._parent_n_calls == 2
assert parent._child._child_n_calls == 1
assert parent._child._parent_progress is None
assert parent._child._progress_info is None


def test_dynamic_description():
Expand All @@ -101,27 +96,18 @@ def run(self, iterations=4):
task = self._progress_info["current_task"]
progress.update(task, total=iterations)

for _ in range(iterations):
for i in range(iterations):
progress.update(task, advance=1)
self._dynamic_n_calls = i
return self.name

task = DynamicTask("test_task")
result = task.run()

assert result == "test_task"
assert task._progress_info is not None
assert (
task._progress_info["current_progress"]
.tasks[task._progress_info["current_task"]]
.description
== "Processing test_task"
)
assert (
task._progress_info["current_progress"]
.tasks[task._progress_info["current_task"]]
.completed
== 4
)
assert task._progress_info is None
assert task._parent_progress is None
assert task._dynamic_n_calls == 3


def test_exception_handling():
Expand All @@ -146,10 +132,5 @@ def run(self):
task.run()

# Verify progress bar was cleaned up
assert task._progress_info is not None
assert (
task._progress_info["current_progress"]
.tasks[task._progress_info["current_task"]]
.completed
== 1
)
assert task._progress_info is None
assert task._parent_progress is None

0 comments on commit 7ff192d

Please sign in to comment.