Skip to content

Commit

Permalink
test cache
Browse files Browse the repository at this point in the history
  • Loading branch information
fusawayugo committed Jan 29, 2025
1 parent 765735c commit 7a92876
Show file tree
Hide file tree
Showing 8 changed files with 74 additions and 73 deletions.
4 changes: 1 addition & 3 deletions e2e_tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
from optuna_dashboard import wsgi
import pytest

from .utils import clear_inmemory_cache


def get_free_port() -> int:
tcp = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
Expand All @@ -30,7 +28,7 @@ def make_test_server(
thread.start()

def stop_server() -> None:
clear_inmemory_cache()
app._inmemory_cache.clear()
httpd.shutdown()
httpd.server_close()
thread.join()
Expand Down
13 changes: 0 additions & 13 deletions e2e_tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,6 @@
from optuna_dashboard._cached_extra_study_property import cached_extra_study_property_cache
from optuna_dashboard._cached_extra_study_property import cached_extra_study_property_cache_lock
from optuna_dashboard._storage import trials_cache
from optuna_dashboard._storage import trials_cache_lock
from optuna_dashboard._storage import trials_last_fetched_at
from playwright.sync_api import Page


def clear_inmemory_cache() -> None:
with trials_cache_lock:
trials_cache.clear()
trials_last_fetched_at.clear()
with cached_extra_study_property_cache_lock:
cached_extra_study_property_cache.clear()


def count_components(page: Page, component_name: str):
component_count = page.evaluate(
f"""() => {{
Expand Down
16 changes: 10 additions & 6 deletions optuna_dashboard/_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@
from . import _note as note
from ._bottle_util import BottleViewReturn
from ._bottle_util import json_api_view
from ._cached_extra_study_property import get_cached_extra_study_property
from ._custom_plot_data import get_plotly_graph_objects
from ._importance import get_param_importance_from_trials_cache
from ._inmemory_cache import get_cached_extra_study_property
from ._inmemory_cache import get_trials
from ._inmemory_cache import InMemoryCache
from ._pareto_front import get_pareto_front_trials
from ._preference_setting import _register_preference_feedback_component
from ._preferential_history import NewHistory
Expand All @@ -43,7 +45,6 @@
from ._storage import create_new_study
from ._storage import get_studies
from ._storage import get_study
from ._storage import get_trials
from ._storage_url import get_storage
from .artifact._backend import delete_all_artifacts
from .artifact._backend import register_artifact_route
Expand Down Expand Up @@ -80,6 +81,7 @@ def create_app(
debug: bool = False,
) -> Bottle:
app = Bottle()
app._inmemory_cache = InMemoryCache()

@app.hook("before_request")
def remove_trailing_slashes_hook() -> None:
Expand Down Expand Up @@ -214,7 +216,7 @@ def get_study_detail(study_id: int) -> dict[str, Any]:
if study is None:
response.status = 404 # Not found
return {"reason": f"study_id={study_id} is not found"}
trials = get_trials(storage, study_id)
trials = get_trials(app._inmemory_cache, storage, study_id)

system_attrs = getattr(study, "system_attrs", {})
is_preferential = system_attrs.get(_SYSTEM_ATTR_PREFERENTIAL_STUDY, False)
Expand All @@ -235,7 +237,7 @@ def get_study_detail(study_id: int) -> dict[str, Any]:
union,
union_user_attrs,
has_intermediate_values,
) = get_cached_extra_study_property(study_id, trials)
) = get_cached_extra_study_property(app._inmemory_cache, study_id, trials)

plotly_graph_objects = get_plotly_graph_objects(system_attrs)
skipped_trial_ids = get_skipped_trial_ids(system_attrs)
Expand All @@ -261,10 +263,12 @@ def get_param_importances(study_id: int) -> dict[str, Any]:
response.status = 404 # Study is not found
return {"reason": f"study_id={study_id} is not found"}

trials = get_trials(storage, study_id)
trials = get_trials(app._inmemory_cache, storage, study_id)
try:
importances = [
get_param_importance_from_trials_cache(storage, study_id, objective_id, trials)
get_param_importance_from_trials_cache(
storage, study_id, objective_id, trials, app._inmemory_cache
)
for objective_id in range(n_directions)
]
return {"param_importances": importances}
Expand Down
13 changes: 10 additions & 3 deletions optuna_dashboard/_importance.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from optuna.study import Study
from optuna.trial import FrozenTrial
from optuna.trial import TrialState
from optuna_dashboard._cached_extra_study_property import get_cached_extra_study_property
from optuna_dashboard._inmemory_cache import get_cached_extra_study_property
from optuna_dashboard._inmemory_cache import InMemoryCache


_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -97,7 +98,11 @@ def _get_param_importances(


def get_param_importance_from_trials_cache(
storage: BaseStorage, study_id: int, objective_id: int, trials: list[FrozenTrial]
storage: BaseStorage,
study_id: int,
objective_id: int,
trials: list[FrozenTrial],
inmemory_cache: InMemoryCache,
) -> list[ImportanceType]:
completed_trials = [t for t in trials if t.state == TrialState.COMPLETE]
n_completed_trials = len(completed_trials)
Expand All @@ -118,7 +123,9 @@ def get_param_importance_from_trials_cache(
except RuntimeError:
# RuntimeError("Encountered zero total variance in all trees.") may be raised
# when all objective values are same.
_, union_search_space, _, _ = get_cached_extra_study_property(study_id, trials)
_, union_search_space, _, _ = get_cached_extra_study_property(
inmemory_cache, study_id, trials
)
importance_value = 1 / len(union_search_space)
importance = {
param_name: importance_value for param_name, distribution in union_search_space
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import copy
from datetime import datetime
from datetime import timedelta
import numbers
import threading
from typing import List
Expand All @@ -10,29 +12,27 @@
from typing import TYPE_CHECKING

from optuna.distributions import BaseDistribution
from optuna.storages import BaseStorage
from optuna.trial import FrozenTrial
from optuna.trial import TrialState


# In-memory cache
cached_extra_study_property_cache_lock = threading.Lock()
cached_extra_study_property_cache: dict[int, "_CachedExtraStudyProperty"] = {}


if TYPE_CHECKING:
SearchSpaceSetT = Set[Tuple[str, BaseDistribution]]
SearchSpaceListT = List[Tuple[str, BaseDistribution]]


def get_cached_extra_study_property(
study_id: int, trials: list[FrozenTrial]
in_memory_cache: InMemoryCache, study_id: int, trials: list[FrozenTrial]
) -> tuple[SearchSpaceListT, SearchSpaceListT, list[tuple[str, bool]], bool]:
with cached_extra_study_property_cache_lock:
cached_extra_study_property = cached_extra_study_property_cache.get(study_id, None)
with in_memory_cache._cached_extra_study_property_cache_lock:
cached_extra_study_property = in_memory_cache._cached_extra_study_property_cache.get(
study_id, None
)
if cached_extra_study_property is None:
cached_extra_study_property = _CachedExtraStudyProperty()
cached_extra_study_property.update(trials)
cached_extra_study_property_cache[study_id] = cached_extra_study_property
in_memory_cache._cached_extra_study_property_cache[study_id] = cached_extra_study_property
return (
cached_extra_study_property.intersection_search_space,
cached_extra_study_property.union_search_space,
Expand All @@ -41,6 +41,48 @@ def get_cached_extra_study_property(
)


def get_trials(
in_memory_cache: InMemoryCache, storage: BaseStorage, study_id: int
) -> list[FrozenTrial]:
with in_memory_cache._trials_cache_lock:
trials = in_memory_cache._trials_cache.get(study_id, None)

# Not a big fan of the heuristic, but I can't think of anything better.
if trials is None or len(trials) < 100:
ttl_seconds = 2
elif len(trials) < 500:
ttl_seconds = 5
else:
ttl_seconds = 10

last_fetched_at = in_memory_cache._trials_last_fetched_at.get(study_id, None)
if (
trials is not None
and last_fetched_at is not None
and datetime.now() - last_fetched_at < timedelta(seconds=ttl_seconds)
):
return trials
trials = storage.get_all_trials(study_id, deepcopy=False)

with in_memory_cache._trials_cache_lock:
in_memory_cache._trials_last_fetched_at[study_id] = datetime.now()
in_memory_cache._trials_cache[study_id] = trials
return trials


class InMemoryCache:
def __init__(self):
self._cached_extra_study_property_cache: dict[int, "_CachedExtraStudyProperty"] = {}
self._cached_extra_study_property_cache_lock = threading.Lock()
self._trials_cache: dict[int, list[FrozenTrial]] = {}
self._trials_cache_lock = threading.Lock()
self._trials_last_fetched_at: dict[int, datetime] = {}

def clear(self):
self._cached_extra_study_property_cache.clear()
self._trials_cache.clear()


class _CachedExtraStudyProperty:
def __init__(self) -> None:
self._cursor: int = -1
Expand Down
28 changes: 0 additions & 28 deletions optuna_dashboard/_storage.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

from datetime import datetime
from datetime import timedelta
import threading

from optuna.storages import BaseStorage
Expand All @@ -17,33 +16,6 @@
trials_last_fetched_at: dict[int, datetime] = {}


def get_trials(storage: BaseStorage, study_id: int) -> list[FrozenTrial]:
with trials_cache_lock:
trials = trials_cache.get(study_id, None)

# Not a big fan of the heuristic, but I can't think of anything better.
if trials is None or len(trials) < 100:
ttl_seconds = 2
elif len(trials) < 500:
ttl_seconds = 5
else:
ttl_seconds = 10

last_fetched_at = trials_last_fetched_at.get(study_id, None)
if (
trials is not None
and last_fetched_at is not None
and datetime.now() - last_fetched_at < timedelta(seconds=ttl_seconds)
):
return trials
trials = storage.get_all_trials(study_id, deepcopy=False)

with trials_cache_lock:
trials_last_fetched_at[study_id] = datetime.now()
trials_cache[study_id] = trials
return trials


def get_studies(storage: BaseStorage) -> list[FrozenStudy]:
frozen_studies = storage.get_all_studies()
if isinstance(storage, RDBStorage):
Expand Down
2 changes: 1 addition & 1 deletion python_tests/test_cached_extra_study_property.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from optuna.distributions import FloatDistribution
from optuna.exceptions import ExperimentalWarning
from optuna.trial import TrialState
from optuna_dashboard._cached_extra_study_property import _CachedExtraStudyProperty
from optuna_dashboard._inmemory_cache import _CachedExtraStudyProperty


class _CachedExtraStudyPropertySearchSpaceTestCase(TestCase):
Expand Down
11 changes: 1 addition & 10 deletions python_tests/wsgi_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,13 @@
import typing

from bottle import Bottle
from optuna_dashboard._storage import trials_cache
from optuna_dashboard._storage import trials_cache_lock
from optuna_dashboard._storage import trials_last_fetched_at


if typing.TYPE_CHECKING:
from _typeshed import OptExcInfo
from _typeshed.wsgi import WSGIEnvironment


def clear_inmemory_cache() -> None:
with trials_cache_lock:
trials_cache.clear()
trials_last_fetched_at.clear()


def create_wsgi_env(
path: str,
method: str,
Expand Down Expand Up @@ -77,7 +68,7 @@ def start_response(
queries = queries or {}
env = create_wsgi_env(path, method, content_type, bytes_body, queries, headers)

clear_inmemory_cache()
app._inmemory_cache.clear()
response_body = b""
iterable_body = app(env, start_response)
for b in iterable_body:
Expand Down

0 comments on commit 7a92876

Please sign in to comment.