diff --git a/e2e_tests/test_server.py b/e2e_tests/test_server.py index 03083c610..b3d7ed496 100644 --- a/e2e_tests/test_server.py +++ b/e2e_tests/test_server.py @@ -8,8 +8,6 @@ from optuna_dashboard import wsgi import pytest -from .utils import clear_inmemory_cache - def get_free_port() -> int: tcp = socket.socket(socket.AF_INET, socket.SOCK_STREAM) @@ -30,7 +28,7 @@ def make_test_server( thread.start() def stop_server() -> None: - clear_inmemory_cache() + app._inmemory_cache.clear() httpd.shutdown() httpd.server_close() thread.join() diff --git a/e2e_tests/utils.py b/e2e_tests/utils.py index ecd7d2cc1..a7b22e5f8 100644 --- a/e2e_tests/utils.py +++ b/e2e_tests/utils.py @@ -1,19 +1,6 @@ -from optuna_dashboard._cached_extra_study_property import cached_extra_study_property_cache -from optuna_dashboard._cached_extra_study_property import cached_extra_study_property_cache_lock -from optuna_dashboard._storage import trials_cache -from optuna_dashboard._storage import trials_cache_lock -from optuna_dashboard._storage import trials_last_fetched_at from playwright.sync_api import Page -def clear_inmemory_cache() -> None: - with trials_cache_lock: - trials_cache.clear() - trials_last_fetched_at.clear() - with cached_extra_study_property_cache_lock: - cached_extra_study_property_cache.clear() - - def count_components(page: Page, component_name: str): component_count = page.evaluate( f"""() => {{ diff --git a/optuna_dashboard/_app.py b/optuna_dashboard/_app.py index e1c040ac2..f7dbc17c1 100644 --- a/optuna_dashboard/_app.py +++ b/optuna_dashboard/_app.py @@ -27,9 +27,11 @@ from . import _note as note from ._bottle_util import BottleViewReturn from ._bottle_util import json_api_view -from ._cached_extra_study_property import get_cached_extra_study_property from ._custom_plot_data import get_plotly_graph_objects from ._importance import get_param_importance_from_trials_cache +from ._inmemory_cache import get_cached_extra_study_property +from ._inmemory_cache import get_trials +from ._inmemory_cache import InMemoryCache from ._pareto_front import get_pareto_front_trials from ._preference_setting import _register_preference_feedback_component from ._preferential_history import NewHistory @@ -43,7 +45,6 @@ from ._storage import create_new_study from ._storage import get_studies from ._storage import get_study -from ._storage import get_trials from ._storage_url import get_storage from .artifact._backend import delete_all_artifacts from .artifact._backend import register_artifact_route @@ -80,6 +81,7 @@ def create_app( debug: bool = False, ) -> Bottle: app = Bottle() + app._inmemory_cache = InMemoryCache() @app.hook("before_request") def remove_trailing_slashes_hook() -> None: @@ -214,7 +216,7 @@ def get_study_detail(study_id: int) -> dict[str, Any]: if study is None: response.status = 404 # Not found return {"reason": f"study_id={study_id} is not found"} - trials = get_trials(storage, study_id) + trials = get_trials(app._inmemory_cache, storage, study_id) system_attrs = getattr(study, "system_attrs", {}) is_preferential = system_attrs.get(_SYSTEM_ATTR_PREFERENTIAL_STUDY, False) @@ -235,7 +237,7 @@ def get_study_detail(study_id: int) -> dict[str, Any]: union, union_user_attrs, has_intermediate_values, - ) = get_cached_extra_study_property(study_id, trials) + ) = get_cached_extra_study_property(app._inmemory_cache, study_id, trials) plotly_graph_objects = get_plotly_graph_objects(system_attrs) skipped_trial_ids = get_skipped_trial_ids(system_attrs) @@ -261,10 +263,12 @@ def get_param_importances(study_id: int) -> dict[str, Any]: response.status = 404 # Study is not found return {"reason": f"study_id={study_id} is not found"} - trials = get_trials(storage, study_id) + trials = get_trials(app._inmemory_cache, storage, study_id) try: importances = [ - get_param_importance_from_trials_cache(storage, study_id, objective_id, trials) + get_param_importance_from_trials_cache( + storage, study_id, objective_id, trials, app._inmemory_cache + ) for objective_id in range(n_directions) ] return {"param_importances": importances} diff --git a/optuna_dashboard/_importance.py b/optuna_dashboard/_importance.py index b71894120..da1be75d1 100644 --- a/optuna_dashboard/_importance.py +++ b/optuna_dashboard/_importance.py @@ -11,7 +11,8 @@ from optuna.study import Study from optuna.trial import FrozenTrial from optuna.trial import TrialState -from optuna_dashboard._cached_extra_study_property import get_cached_extra_study_property +from optuna_dashboard._inmemory_cache import get_cached_extra_study_property +from optuna_dashboard._inmemory_cache import InMemoryCache _logger = logging.getLogger(__name__) @@ -97,7 +98,11 @@ def _get_param_importances( def get_param_importance_from_trials_cache( - storage: BaseStorage, study_id: int, objective_id: int, trials: list[FrozenTrial] + storage: BaseStorage, + study_id: int, + objective_id: int, + trials: list[FrozenTrial], + inmemory_cache: InMemoryCache, ) -> list[ImportanceType]: completed_trials = [t for t in trials if t.state == TrialState.COMPLETE] n_completed_trials = len(completed_trials) @@ -118,7 +123,9 @@ def get_param_importance_from_trials_cache( except RuntimeError: # RuntimeError("Encountered zero total variance in all trees.") may be raised # when all objective values are same. - _, union_search_space, _, _ = get_cached_extra_study_property(study_id, trials) + _, union_search_space, _, _ = get_cached_extra_study_property( + inmemory_cache, study_id, trials + ) importance_value = 1 / len(union_search_space) importance = { param_name: importance_value for param_name, distribution in union_search_space diff --git a/optuna_dashboard/_cached_extra_study_property.py b/optuna_dashboard/_inmemory_cache.py similarity index 65% rename from optuna_dashboard/_cached_extra_study_property.py rename to optuna_dashboard/_inmemory_cache.py index 24e223723..367d46c34 100644 --- a/optuna_dashboard/_cached_extra_study_property.py +++ b/optuna_dashboard/_inmemory_cache.py @@ -1,6 +1,8 @@ from __future__ import annotations import copy +from datetime import datetime +from datetime import timedelta import numbers import threading from typing import List @@ -10,29 +12,27 @@ from typing import TYPE_CHECKING from optuna.distributions import BaseDistribution +from optuna.storages import BaseStorage from optuna.trial import FrozenTrial from optuna.trial import TrialState -# In-memory cache -cached_extra_study_property_cache_lock = threading.Lock() -cached_extra_study_property_cache: dict[int, "_CachedExtraStudyProperty"] = {} - - if TYPE_CHECKING: SearchSpaceSetT = Set[Tuple[str, BaseDistribution]] SearchSpaceListT = List[Tuple[str, BaseDistribution]] def get_cached_extra_study_property( - study_id: int, trials: list[FrozenTrial] + in_memory_cache: InMemoryCache, study_id: int, trials: list[FrozenTrial] ) -> tuple[SearchSpaceListT, SearchSpaceListT, list[tuple[str, bool]], bool]: - with cached_extra_study_property_cache_lock: - cached_extra_study_property = cached_extra_study_property_cache.get(study_id, None) + with in_memory_cache._cached_extra_study_property_cache_lock: + cached_extra_study_property = in_memory_cache._cached_extra_study_property_cache.get( + study_id, None + ) if cached_extra_study_property is None: cached_extra_study_property = _CachedExtraStudyProperty() cached_extra_study_property.update(trials) - cached_extra_study_property_cache[study_id] = cached_extra_study_property + in_memory_cache._cached_extra_study_property_cache[study_id] = cached_extra_study_property return ( cached_extra_study_property.intersection_search_space, cached_extra_study_property.union_search_space, @@ -41,6 +41,48 @@ def get_cached_extra_study_property( ) +def get_trials( + in_memory_cache: InMemoryCache, storage: BaseStorage, study_id: int +) -> list[FrozenTrial]: + with in_memory_cache._trials_cache_lock: + trials = in_memory_cache._trials_cache.get(study_id, None) + + # Not a big fan of the heuristic, but I can't think of anything better. + if trials is None or len(trials) < 100: + ttl_seconds = 2 + elif len(trials) < 500: + ttl_seconds = 5 + else: + ttl_seconds = 10 + + last_fetched_at = in_memory_cache._trials_last_fetched_at.get(study_id, None) + if ( + trials is not None + and last_fetched_at is not None + and datetime.now() - last_fetched_at < timedelta(seconds=ttl_seconds) + ): + return trials + trials = storage.get_all_trials(study_id, deepcopy=False) + + with in_memory_cache._trials_cache_lock: + in_memory_cache._trials_last_fetched_at[study_id] = datetime.now() + in_memory_cache._trials_cache[study_id] = trials + return trials + + +class InMemoryCache: + def __init__(self): + self._cached_extra_study_property_cache: dict[int, "_CachedExtraStudyProperty"] = {} + self._cached_extra_study_property_cache_lock = threading.Lock() + self._trials_cache: dict[int, list[FrozenTrial]] = {} + self._trials_cache_lock = threading.Lock() + self._trials_last_fetched_at: dict[int, datetime] = {} + + def clear(self): + self._cached_extra_study_property_cache.clear() + self._trials_cache.clear() + + class _CachedExtraStudyProperty: def __init__(self) -> None: self._cursor: int = -1 diff --git a/optuna_dashboard/_storage.py b/optuna_dashboard/_storage.py index 8fe10e0a0..303704296 100644 --- a/optuna_dashboard/_storage.py +++ b/optuna_dashboard/_storage.py @@ -1,7 +1,6 @@ from __future__ import annotations from datetime import datetime -from datetime import timedelta import threading from optuna.storages import BaseStorage @@ -17,33 +16,6 @@ trials_last_fetched_at: dict[int, datetime] = {} -def get_trials(storage: BaseStorage, study_id: int) -> list[FrozenTrial]: - with trials_cache_lock: - trials = trials_cache.get(study_id, None) - - # Not a big fan of the heuristic, but I can't think of anything better. - if trials is None or len(trials) < 100: - ttl_seconds = 2 - elif len(trials) < 500: - ttl_seconds = 5 - else: - ttl_seconds = 10 - - last_fetched_at = trials_last_fetched_at.get(study_id, None) - if ( - trials is not None - and last_fetched_at is not None - and datetime.now() - last_fetched_at < timedelta(seconds=ttl_seconds) - ): - return trials - trials = storage.get_all_trials(study_id, deepcopy=False) - - with trials_cache_lock: - trials_last_fetched_at[study_id] = datetime.now() - trials_cache[study_id] = trials - return trials - - def get_studies(storage: BaseStorage) -> list[FrozenStudy]: frozen_studies = storage.get_all_studies() if isinstance(storage, RDBStorage): diff --git a/python_tests/test_cached_extra_study_property.py b/python_tests/test_cached_extra_study_property.py index 381cbcca1..8e316719a 100644 --- a/python_tests/test_cached_extra_study_property.py +++ b/python_tests/test_cached_extra_study_property.py @@ -11,7 +11,7 @@ from optuna.distributions import FloatDistribution from optuna.exceptions import ExperimentalWarning from optuna.trial import TrialState -from optuna_dashboard._cached_extra_study_property import _CachedExtraStudyProperty +from optuna_dashboard._inmemory_cache import _CachedExtraStudyProperty class _CachedExtraStudyPropertySearchSpaceTestCase(TestCase): diff --git a/python_tests/wsgi_client.py b/python_tests/wsgi_client.py index d0abef9a7..0141dde4f 100644 --- a/python_tests/wsgi_client.py +++ b/python_tests/wsgi_client.py @@ -4,9 +4,6 @@ import typing from bottle import Bottle -from optuna_dashboard._storage import trials_cache -from optuna_dashboard._storage import trials_cache_lock -from optuna_dashboard._storage import trials_last_fetched_at if typing.TYPE_CHECKING: @@ -14,12 +11,6 @@ from _typeshed.wsgi import WSGIEnvironment -def clear_inmemory_cache() -> None: - with trials_cache_lock: - trials_cache.clear() - trials_last_fetched_at.clear() - - def create_wsgi_env( path: str, method: str, @@ -77,7 +68,7 @@ def start_response( queries = queries or {} env = create_wsgi_env(path, method, content_type, bytes_body, queries, headers) - clear_inmemory_cache() + app._inmemory_cache.clear() response_body = b"" iterable_body = app(env, start_response) for b in iterable_body: