From 55474dafb1d15d8e0211971a58d9fba0f9b9a033 Mon Sep 17 00:00:00 2001 From: Auguste Baum Date: Fri, 4 Oct 2024 12:29:54 +0200 Subject: [PATCH] refactor: Factorize `project` fixture to `conftest.py` --- tests/conftest.py | 14 +++ tests/integration/ui/test_ui.py | 32 ++--- tests/unit/test_project.py | 201 +++++++++++++++----------------- 3 files changed, 119 insertions(+), 128 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index a43043d9..fba13182 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,10 @@ from datetime import UTC, datetime import pytest +from skore.item.item_repository import ItemRepository +from skore.persistence.in_memory_storage import InMemoryStorage +from skore.project import Project +from skore.view.view_repository import ViewRepository @pytest.fixture @@ -23,3 +27,13 @@ def now(*args, **kwargs): return mock_now return MockDatetime + + +@pytest.fixture +def in_memory_project(): + item_repository = ItemRepository(storage=InMemoryStorage()) + view_repository = ViewRepository(storage=InMemoryStorage()) + return Project( + item_repository=item_repository, + view_repository=view_repository, + ) diff --git a/tests/integration/ui/test_ui.py b/tests/integration/ui/test_ui.py index a1dc11e1..0212ce21 100644 --- a/tests/integration/ui/test_ui.py +++ b/tests/integration/ui/test_ui.py @@ -1,26 +1,12 @@ import pytest from fastapi.testclient import TestClient -from skore.item.item_repository import ItemRepository -from skore.persistence.in_memory_storage import InMemoryStorage -from skore.project import Project from skore.ui.app import create_app from skore.view.view import View -from skore.view.view_repository import ViewRepository @pytest.fixture -def project(): - item_repository = ItemRepository(storage=InMemoryStorage()) - view_repository = ViewRepository(storage=InMemoryStorage()) - return Project( - item_repository=item_repository, - view_repository=view_repository, - ) - - -@pytest.fixture -def client(project): - return TestClient(app=create_app(project=project)) +def client(in_memory_project): + return TestClient(app=create_app(project=in_memory_project)) def test_app_state(client): @@ -34,14 +20,14 @@ def test_skore_ui_index(client): assert b"" in response.content -def test_get_items(client, project): +def test_get_items(client, in_memory_project): response = client.get("/api/project/items") assert response.status_code == 200 assert response.json() == {"views": {}, "items": {}} - project.put("test", "test") - item = project.get_item("test") + in_memory_project.put("test", "test") + item = in_memory_project.get_item("test") response = client.get("/api/project/items") assert response.status_code == 200 @@ -58,8 +44,8 @@ def test_get_items(client, project): } -def test_share_view(client, project): - project.put_view("hello", View(layout=[])) +def test_share_view(client, in_memory_project): + in_memory_project.put_view("hello", View(layout=[])) response = client.post("/api/project/views/share?key=hello") assert response.status_code == 200 @@ -76,8 +62,8 @@ def test_put_view_layout(client): assert response.status_code == 201 -def test_delete_view(client, project): - project.put_view("hello", View(layout=[])) +def test_delete_view(client, in_memory_project): + in_memory_project.put_view("hello", View(layout=[])) response = client.delete("/api/project/views?key=hello") assert response.status_code == 202 diff --git a/tests/unit/test_project.py b/tests/unit/test_project.py index b65ff17b..a60ae202 100644 --- a/tests/unit/test_project.py +++ b/tests/unit/test_project.py @@ -10,71 +10,62 @@ from matplotlib import pyplot as plt from PIL import Image from sklearn.ensemble import RandomForestClassifier -from skore.item import ItemRepository -from skore.persistence.in_memory_storage import InMemoryStorage from skore.project import Project, ProjectLoadError, ProjectPutError, load from skore.view.view import View -from skore.view.view_repository import ViewRepository -@pytest.fixture -def project(): - return Project( - item_repository=ItemRepository(InMemoryStorage()), - view_repository=ViewRepository(InMemoryStorage()), - ) - +def test_put_string_item(in_memory_project): + in_memory_project.put("string_item", "Hello, World!") + assert in_memory_project.get("string_item") == "Hello, World!" -def test_put_string_item(project): - project.put("string_item", "Hello, World!") - assert project.get("string_item") == "Hello, World!" +def test_put_int_item(in_memory_project): + in_memory_project.put("int_item", 42) + assert in_memory_project.get("int_item") == 42 -def test_put_int_item(project): - project.put("int_item", 42) - assert project.get("int_item") == 42 +def test_put_float_item(in_memory_project): + in_memory_project.put("float_item", 3.14) + assert in_memory_project.get("float_item") == 3.14 -def test_put_float_item(project): - project.put("float_item", 3.14) - assert project.get("float_item") == 3.14 +def test_put_bool_item(in_memory_project): + in_memory_project.put("bool_item", True) + assert in_memory_project.get("bool_item") is True -def test_put_bool_item(project): - project.put("bool_item", True) - assert project.get("bool_item") is True +def test_put_list_item(in_memory_project): + in_memory_project.put("list_item", [1, 2, 3]) + assert in_memory_project.get("list_item") == [1, 2, 3] -def test_put_list_item(project): - project.put("list_item", [1, 2, 3]) - assert project.get("list_item") == [1, 2, 3] +def test_put_dict_item(in_memory_project): + in_memory_project.put("dict_item", {"key": "value"}) + assert in_memory_project.get("dict_item") == {"key": "value"} -def test_put_dict_item(project): - project.put("dict_item", {"key": "value"}) - assert project.get("dict_item") == {"key": "value"} - -def test_put_pandas_dataframe(project): +def test_put_pandas_dataframe(in_memory_project): dataframe = pandas.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) - project.put("pandas_dataframe", dataframe) - pandas.testing.assert_frame_equal(project.get("pandas_dataframe"), dataframe) + in_memory_project.put("pandas_dataframe", dataframe) + pandas.testing.assert_frame_equal( + in_memory_project.get("pandas_dataframe"), dataframe + ) -def test_put_pandas_series(project): +def test_put_pandas_series(in_memory_project): series = pandas.Series([0, 1, 2]) - project.put("pandas_series", series) - pandas.testing.assert_series_equal(project.get("pandas_series"), series) + in_memory_project.put("pandas_series", series) + pandas.testing.assert_series_equal(in_memory_project.get("pandas_series"), series) -def test_put_numpy_array(project): +def test_put_numpy_array(in_memory_project): # Add a Numpy array arr = numpy.array([1, 2, 3, 4, 5]) - project.put("numpy_array", arr) # NumpyArrayItem - numpy.testing.assert_array_equal(project.get("numpy_array"), arr) + in_memory_project.put("numpy_array", arr) # NumpyArrayItem + numpy.testing.assert_array_equal(in_memory_project.get("numpy_array"), arr) -def test_put_mpl_figure(project, monkeypatch): +def test_put_mpl_figure(in_memory_project, monkeypatch): # Add a Matplotlib figure def savefig(*args, **kwargs): return "" @@ -83,35 +74,35 @@ def savefig(*args, **kwargs): fig, ax = plt.subplots() ax.plot([1, 2, 3, 4]) - project.put("mpl_figure", fig) # MediaItem (SVG) - assert isinstance(project.get("mpl_figure"), bytes) + in_memory_project.put("mpl_figure", fig) # MediaItem (SVG) + assert isinstance(in_memory_project.get("mpl_figure"), bytes) -def test_put_vega_chart(project): +def test_put_vega_chart(in_memory_project): # Add an Altair chart altair_chart = altair.Chart().mark_point() - project.put("vega_chart", altair_chart) - assert isinstance(project.get("vega_chart"), bytes) + in_memory_project.put("vega_chart", altair_chart) + assert isinstance(in_memory_project.get("vega_chart"), bytes) -def test_put_pil_image(project): +def test_put_pil_image(in_memory_project): # Add a PIL Image pil_image = Image.new("RGB", (100, 100), color="red") with BytesIO() as output: # FIXME: Not JPEG! pil_image.save(output, format="jpeg") - project.put("pil_image", pil_image) # MediaItem (PNG) - assert isinstance(project.get("pil_image"), bytes) + in_memory_project.put("pil_image", pil_image) # MediaItem (PNG) + assert isinstance(in_memory_project.get("pil_image"), bytes) -def test_put_rf_model(project, monkeypatch): +def test_put_rf_model(in_memory_project, monkeypatch): # Add a scikit-learn model monkeypatch.setattr("sklearn.utils.estimator_html_repr", lambda _: "") model = RandomForestClassifier() model.fit(numpy.array([[1, 2], [3, 4]]), [0, 1]) - project.put("rf_model", model) # ScikitLearnModelItem - assert isinstance(project.get("rf_model"), RandomForestClassifier) + in_memory_project.put("rf_model", model) # ScikitLearnModelItem + assert isinstance(in_memory_project.get("rf_model"), RandomForestClassifier) def test_load(tmp_path): @@ -127,82 +118,82 @@ def test_load(tmp_path): assert isinstance(p, Project) -def test_put(project): - project.put("key1", 1) - project.put("key2", 2) - project.put("key3", 3) - project.put("key4", 4) +def test_put(in_memory_project): + in_memory_project.put("key1", 1) + in_memory_project.put("key2", 2) + in_memory_project.put("key3", 3) + in_memory_project.put("key4", 4) - assert project.list_item_keys() == ["key1", "key2", "key3", "key4"] + assert in_memory_project.list_item_keys() == ["key1", "key2", "key3", "key4"] -def test_put_twice(project): - project.put("key2", 2) - project.put("key2", 5) +def test_put_twice(in_memory_project): + in_memory_project.put("key2", 2) + in_memory_project.put("key2", 5) - assert project.get("key2") == 5 + assert in_memory_project.get("key2") == 5 -def test_put_int_key(project, caplog): +def test_put_int_key(in_memory_project, caplog): # Warns that 0 is not a string, but doesn't raise - project.put(0, "hello") + in_memory_project.put(0, "hello") assert len(caplog.record_tuples) == 1 - assert project.list_item_keys() == [] + assert in_memory_project.list_item_keys() == [] -def test_get(project): - project.put("key1", 1) - assert project.get("key1") == 1 +def test_get(in_memory_project): + in_memory_project.put("key1", 1) + assert in_memory_project.get("key1") == 1 with pytest.raises(KeyError): - project.get("key2") + in_memory_project.get("key2") -def test_delete(project): - project.put("key1", 1) - project.delete_item("key1") +def test_delete(in_memory_project): + in_memory_project.put("key1", 1) + in_memory_project.delete_item("key1") - assert project.list_item_keys() == [] + assert in_memory_project.list_item_keys() == [] with pytest.raises(KeyError): - project.delete_item("key2") + in_memory_project.delete_item("key2") -def test_keys(project): - project.put("key1", 1) - project.put("key2", 2) - assert project.list_item_keys() == ["key1", "key2"] +def test_keys(in_memory_project): + in_memory_project.put("key1", 1) + in_memory_project.put("key2", 2) + assert in_memory_project.list_item_keys() == ["key1", "key2"] -def test_view(project): +def test_view(in_memory_project): layout = ["key1", "key2"] view = View(layout=layout) - project.put_view("view", view) - assert project.get_view("view") == view + in_memory_project.put_view("view", view) + assert in_memory_project.get_view("view") == view -def test_list_view_keys(project): +def test_list_view_keys(in_memory_project): view = View(layout=[]) - project.put_view("view", view) - assert project.list_view_keys() == ["view"] + in_memory_project.put_view("view", view) + assert in_memory_project.list_view_keys() == ["view"] -def test_put_several_happy_path(project): - project.put({"a": "foo", "b": "bar"}) - assert project.list_item_keys() == ["a", "b"] +def test_put_several_happy_path(in_memory_project): + in_memory_project.put({"a": "foo", "b": "bar"}) + assert in_memory_project.list_item_keys() == ["a", "b"] -def test_put_several_canonical(project): +def test_put_several_canonical(in_memory_project): """Use `put_several` instead of the `put` alias.""" - project.put_several({"a": "foo", "b": "bar"}) - assert project.list_item_keys() == ["a", "b"] + in_memory_project.put_several({"a": "foo", "b": "bar"}) + assert in_memory_project.list_item_keys() == ["a", "b"] -def test_put_several_some_errors(project, caplog): - project.put( +def test_put_several_some_errors(in_memory_project, caplog): + in_memory_project.put( { 0: "hello", 1: "hello", @@ -210,34 +201,34 @@ def test_put_several_some_errors(project, caplog): } ) assert len(caplog.record_tuples) == 3 - assert project.list_item_keys() == [] + assert in_memory_project.list_item_keys() == [] -def test_put_several_nested(project): - project.put({"a": {"b": "baz"}}) - assert project.list_item_keys() == ["a"] - assert project.get("a") == {"b": "baz"} +def test_put_several_nested(in_memory_project): + in_memory_project.put({"a": {"b": "baz"}}) + assert in_memory_project.list_item_keys() == ["a"] + assert in_memory_project.get("a") == {"b": "baz"} -def test_put_several_error(project): +def test_put_several_error(in_memory_project): """If some key-value pairs are wrong, add all that are valid and print a warning.""" - project.put({"a": "foo", "b": (lambda: "unsupported object")}) - assert project.list_item_keys() == ["a"] + in_memory_project.put({"a": "foo", "b": (lambda: "unsupported object")}) + assert in_memory_project.list_item_keys() == ["a"] -def test_put_key_is_a_tuple(project): +def test_put_key_is_a_tuple(in_memory_project): """If key is not a string, warn.""" - project.put(("a", "foo"), ("b", "bar")) - assert project.list_item_keys() == [] + in_memory_project.put(("a", "foo"), ("b", "bar")) + assert in_memory_project.list_item_keys() == [] -def test_put_key_is_a_set(project): +def test_put_key_is_a_set(in_memory_project): """Cannot use an unhashable type as a key.""" with pytest.raises(ProjectPutError): - project.put(set(), "hello", on_error="raise") + in_memory_project.put(set(), "hello", on_error="raise") -def test_put_wrong_key_and_value_raise(project): +def test_put_wrong_key_and_value_raise(in_memory_project): """When `on_error` is "raise", raise the first error that occurs.""" with pytest.raises(ProjectPutError): - project.put(0, (lambda: "unsupported object"), on_error="raise") + in_memory_project.put(0, (lambda: "unsupported object"), on_error="raise")