diff --git a/e2e_tests/test_server.py b/e2e_tests/test_server.py index 03083c610..b3d7ed496 100644 --- a/e2e_tests/test_server.py +++ b/e2e_tests/test_server.py @@ -8,8 +8,6 @@ from optuna_dashboard import wsgi import pytest -from .utils import clear_inmemory_cache - def get_free_port() -> int: tcp = socket.socket(socket.AF_INET, socket.SOCK_STREAM) @@ -30,7 +28,7 @@ def make_test_server( thread.start() def stop_server() -> None: - clear_inmemory_cache() + app._inmemory_cache.clear() httpd.shutdown() httpd.server_close() thread.join() diff --git a/e2e_tests/utils.py b/e2e_tests/utils.py index ecd7d2cc1..a7b22e5f8 100644 --- a/e2e_tests/utils.py +++ b/e2e_tests/utils.py @@ -1,19 +1,6 @@ -from optuna_dashboard._cached_extra_study_property import cached_extra_study_property_cache -from optuna_dashboard._cached_extra_study_property import cached_extra_study_property_cache_lock -from optuna_dashboard._storage import trials_cache -from optuna_dashboard._storage import trials_cache_lock -from optuna_dashboard._storage import trials_last_fetched_at from playwright.sync_api import Page -def clear_inmemory_cache() -> None: - with trials_cache_lock: - trials_cache.clear() - trials_last_fetched_at.clear() - with cached_extra_study_property_cache_lock: - cached_extra_study_property_cache.clear() - - def count_components(page: Page, component_name: str): component_count = page.evaluate( f"""() => {{ diff --git a/optuna_dashboard/_app.py b/optuna_dashboard/_app.py index e1c040ac2..d06f13e32 100644 --- a/optuna_dashboard/_app.py +++ b/optuna_dashboard/_app.py @@ -27,9 +27,10 @@ from . import _note as note from ._bottle_util import BottleViewReturn from ._bottle_util import json_api_view -from ._cached_extra_study_property import get_cached_extra_study_property from ._custom_plot_data import get_plotly_graph_objects from ._importance import get_param_importance_from_trials_cache +from ._inmemory_cache import get_cached_extra_study_property +from ._inmemory_cache import InMemoryCache from ._pareto_front import get_pareto_front_trials from ._preference_setting import _register_preference_feedback_component from ._preferential_history import NewHistory @@ -80,6 +81,7 @@ def create_app( debug: bool = False, ) -> Bottle: app = Bottle() + app._inmemory_cache = InMemoryCache() @app.hook("before_request") def remove_trailing_slashes_hook() -> None: @@ -214,7 +216,7 @@ def get_study_detail(study_id: int) -> dict[str, Any]: if study is None: response.status = 404 # Not found return {"reason": f"study_id={study_id} is not found"} - trials = get_trials(storage, study_id) + trials = get_trials(app._inmemory_cache, storage, study_id) system_attrs = getattr(study, "system_attrs", {}) is_preferential = system_attrs.get(_SYSTEM_ATTR_PREFERENTIAL_STUDY, False) @@ -235,7 +237,7 @@ def get_study_detail(study_id: int) -> dict[str, Any]: union, union_user_attrs, has_intermediate_values, - ) = get_cached_extra_study_property(study_id, trials) + ) = get_cached_extra_study_property(app._inmemory_cache, study_id, trials) plotly_graph_objects = get_plotly_graph_objects(system_attrs) skipped_trial_ids = get_skipped_trial_ids(system_attrs) @@ -261,10 +263,16 @@ def get_param_importances(study_id: int) -> dict[str, Any]: response.status = 404 # Study is not found return {"reason": f"study_id={study_id} is not found"} - trials = get_trials(storage, study_id) + trials = get_trials(app._inmemory_cache, storage, study_id) try: importances = [ - get_param_importance_from_trials_cache(storage, study_id, objective_id, trials) + get_param_importance_from_trials_cache( + app._inmemory_cache, + storage, + study_id, + objective_id, + trials, + ) for objective_id in range(n_directions) ] return {"param_importances": importances} diff --git a/optuna_dashboard/_cli.py b/optuna_dashboard/_cli.py index e0df75e5e..518836b70 100644 --- a/optuna_dashboard/_cli.py +++ b/optuna_dashboard/_cli.py @@ -74,7 +74,7 @@ def run_debug_server(app: Bottle, host: str, port: int, quiet: bool) -> None: def auto_select_server( - server_arg: Literal["auto", "gunicorn", "wsgiref"] + server_arg: Literal["auto", "gunicorn", "wsgiref"], ) -> Literal["gunicorn", "wsgiref"]: if server_arg != "auto": return server_arg diff --git a/optuna_dashboard/_importance.py b/optuna_dashboard/_importance.py index b71894120..8ef96c9d4 100644 --- a/optuna_dashboard/_importance.py +++ b/optuna_dashboard/_importance.py @@ -11,7 +11,8 @@ from optuna.study import Study from optuna.trial import FrozenTrial from optuna.trial import TrialState -from optuna_dashboard._cached_extra_study_property import get_cached_extra_study_property +from optuna_dashboard._inmemory_cache import get_cached_extra_study_property +from optuna_dashboard._inmemory_cache import InMemoryCache _logger = logging.getLogger(__name__) @@ -97,7 +98,11 @@ def _get_param_importances( def get_param_importance_from_trials_cache( - storage: BaseStorage, study_id: int, objective_id: int, trials: list[FrozenTrial] + inmemory_cache: InMemoryCache, + storage: BaseStorage, + study_id: int, + objective_id: int, + trials: list[FrozenTrial], ) -> list[ImportanceType]: completed_trials = [t for t in trials if t.state == TrialState.COMPLETE] n_completed_trials = len(completed_trials) @@ -118,7 +123,9 @@ def get_param_importance_from_trials_cache( except RuntimeError: # RuntimeError("Encountered zero total variance in all trees.") may be raised # when all objective values are same. - _, union_search_space, _, _ = get_cached_extra_study_property(study_id, trials) + _, union_search_space, _, _ = get_cached_extra_study_property( + inmemory_cache, study_id, trials + ) importance_value = 1 / len(union_search_space) importance = { param_name: importance_value for param_name, distribution in union_search_space diff --git a/optuna_dashboard/_cached_extra_study_property.py b/optuna_dashboard/_inmemory_cache.py similarity index 80% rename from optuna_dashboard/_cached_extra_study_property.py rename to optuna_dashboard/_inmemory_cache.py index 24e223723..356972ca6 100644 --- a/optuna_dashboard/_cached_extra_study_property.py +++ b/optuna_dashboard/_inmemory_cache.py @@ -1,6 +1,7 @@ from __future__ import annotations import copy +from datetime import datetime import numbers import threading from typing import List @@ -14,25 +15,22 @@ from optuna.trial import TrialState -# In-memory cache -cached_extra_study_property_cache_lock = threading.Lock() -cached_extra_study_property_cache: dict[int, "_CachedExtraStudyProperty"] = {} - - if TYPE_CHECKING: SearchSpaceSetT = Set[Tuple[str, BaseDistribution]] SearchSpaceListT = List[Tuple[str, BaseDistribution]] def get_cached_extra_study_property( - study_id: int, trials: list[FrozenTrial] + in_memory_cache: InMemoryCache, study_id: int, trials: list[FrozenTrial] ) -> tuple[SearchSpaceListT, SearchSpaceListT, list[tuple[str, bool]], bool]: - with cached_extra_study_property_cache_lock: - cached_extra_study_property = cached_extra_study_property_cache.get(study_id, None) + with in_memory_cache._cached_extra_study_property_cache_lock: + cached_extra_study_property = in_memory_cache._cached_extra_study_property_cache.get( + study_id, None + ) if cached_extra_study_property is None: cached_extra_study_property = _CachedExtraStudyProperty() cached_extra_study_property.update(trials) - cached_extra_study_property_cache[study_id] = cached_extra_study_property + in_memory_cache._cached_extra_study_property_cache[study_id] = cached_extra_study_property return ( cached_extra_study_property.intersection_search_space, cached_extra_study_property.union_search_space, @@ -41,6 +39,20 @@ def get_cached_extra_study_property( ) +class InMemoryCache: + def __init__(self) -> None: + self._cached_extra_study_property_cache: dict[int, "_CachedExtraStudyProperty"] = {} + self._cached_extra_study_property_cache_lock = threading.Lock() + self._trials_cache: dict[int, list[FrozenTrial]] = {} + self._trials_cache_lock = threading.Lock() + self._trials_last_fetched_at: dict[int, datetime] = {} + + def clear(self) -> None: + self._cached_extra_study_property_cache.clear() + self._trials_cache.clear() + self._trials_last_fetched_at.clear() + + class _CachedExtraStudyProperty: def __init__(self) -> None: self._cursor: int = -1 diff --git a/optuna_dashboard/_serializer.py b/optuna_dashboard/_serializer.py index 55eb19c32..2f4050b4a 100644 --- a/optuna_dashboard/_serializer.py +++ b/optuna_dashboard/_serializer.py @@ -315,7 +315,7 @@ def serialize_distribution(distribution: BaseDistribution) -> DistributionJSON: def serialize_search_space( - search_space: list[tuple[str, BaseDistribution]] + search_space: list[tuple[str, BaseDistribution]], ) -> list[dict[str, Any]]: serialized = [] for param_name, distribution in search_space: diff --git a/optuna_dashboard/_storage.py b/optuna_dashboard/_storage.py index 8fe10e0a0..80094b0d9 100644 --- a/optuna_dashboard/_storage.py +++ b/optuna_dashboard/_storage.py @@ -2,7 +2,6 @@ from datetime import datetime from datetime import timedelta -import threading from optuna.storages import BaseStorage from optuna.storages import RDBStorage @@ -10,16 +9,14 @@ from optuna.study._frozen import FrozenStudy from optuna.trial import FrozenTrial +from ._inmemory_cache import InMemoryCache -# In-memory trials cache -trials_cache_lock = threading.Lock() -trials_cache: dict[int, list[FrozenTrial]] = {} -trials_last_fetched_at: dict[int, datetime] = {} - -def get_trials(storage: BaseStorage, study_id: int) -> list[FrozenTrial]: - with trials_cache_lock: - trials = trials_cache.get(study_id, None) +def get_trials( + in_memory_cache: InMemoryCache, storage: BaseStorage, study_id: int +) -> list[FrozenTrial]: + with in_memory_cache._trials_cache_lock: + trials = in_memory_cache._trials_cache.get(study_id, None) # Not a big fan of the heuristic, but I can't think of anything better. if trials is None or len(trials) < 100: @@ -29,7 +26,7 @@ def get_trials(storage: BaseStorage, study_id: int) -> list[FrozenTrial]: else: ttl_seconds = 10 - last_fetched_at = trials_last_fetched_at.get(study_id, None) + last_fetched_at = in_memory_cache._trials_last_fetched_at.get(study_id, None) if ( trials is not None and last_fetched_at is not None @@ -38,9 +35,9 @@ def get_trials(storage: BaseStorage, study_id: int) -> list[FrozenTrial]: return trials trials = storage.get_all_trials(study_id, deepcopy=False) - with trials_cache_lock: - trials_last_fetched_at[study_id] = datetime.now() - trials_cache[study_id] = trials + with in_memory_cache._trials_cache_lock: + in_memory_cache._trials_last_fetched_at[study_id] = datetime.now() + in_memory_cache._trials_cache[study_id] = trials return trials diff --git a/python_tests/test_cached_extra_study_property.py b/python_tests/test_cached_extra_study_property.py index 381cbcca1..8e316719a 100644 --- a/python_tests/test_cached_extra_study_property.py +++ b/python_tests/test_cached_extra_study_property.py @@ -11,7 +11,7 @@ from optuna.distributions import FloatDistribution from optuna.exceptions import ExperimentalWarning from optuna.trial import TrialState -from optuna_dashboard._cached_extra_study_property import _CachedExtraStudyProperty +from optuna_dashboard._inmemory_cache import _CachedExtraStudyProperty class _CachedExtraStudyPropertySearchSpaceTestCase(TestCase): diff --git a/python_tests/wsgi_client.py b/python_tests/wsgi_client.py index d0abef9a7..0141dde4f 100644 --- a/python_tests/wsgi_client.py +++ b/python_tests/wsgi_client.py @@ -4,9 +4,6 @@ import typing from bottle import Bottle -from optuna_dashboard._storage import trials_cache -from optuna_dashboard._storage import trials_cache_lock -from optuna_dashboard._storage import trials_last_fetched_at if typing.TYPE_CHECKING: @@ -14,12 +11,6 @@ from _typeshed.wsgi import WSGIEnvironment -def clear_inmemory_cache() -> None: - with trials_cache_lock: - trials_cache.clear() - trials_last_fetched_at.clear() - - def create_wsgi_env( path: str, method: str, @@ -77,7 +68,7 @@ def start_response( queries = queries or {} env = create_wsgi_env(path, method, content_type, bytes_body, queries, headers) - clear_inmemory_cache() + app._inmemory_cache.clear() response_body = b"" iterable_body = app(env, start_response) for b in iterable_body: