Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the structure of cache to avoid unexpected errors #1022

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions e2e_tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
from optuna_dashboard import wsgi
import pytest

from .utils import clear_inmemory_cache


def get_free_port() -> int:
tcp = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
Expand All @@ -30,7 +28,7 @@ def make_test_server(
thread.start()

def stop_server() -> None:
clear_inmemory_cache()
app._inmemory_cache.clear()
httpd.shutdown()
httpd.server_close()
thread.join()
Expand Down
13 changes: 0 additions & 13 deletions e2e_tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,6 @@
from optuna_dashboard._cached_extra_study_property import cached_extra_study_property_cache
from optuna_dashboard._cached_extra_study_property import cached_extra_study_property_cache_lock
from optuna_dashboard._storage import trials_cache
from optuna_dashboard._storage import trials_cache_lock
from optuna_dashboard._storage import trials_last_fetched_at
from playwright.sync_api import Page


def clear_inmemory_cache() -> None:
with trials_cache_lock:
trials_cache.clear()
trials_last_fetched_at.clear()
with cached_extra_study_property_cache_lock:
cached_extra_study_property_cache.clear()


def count_components(page: Page, component_name: str):
component_count = page.evaluate(
f"""() => {{
Expand Down
18 changes: 13 additions & 5 deletions optuna_dashboard/_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@
from . import _note as note
from ._bottle_util import BottleViewReturn
from ._bottle_util import json_api_view
from ._cached_extra_study_property import get_cached_extra_study_property
from ._custom_plot_data import get_plotly_graph_objects
from ._importance import get_param_importance_from_trials_cache
from ._inmemory_cache import get_cached_extra_study_property
from ._inmemory_cache import InMemoryCache
from ._pareto_front import get_pareto_front_trials
from ._preference_setting import _register_preference_feedback_component
from ._preferential_history import NewHistory
Expand Down Expand Up @@ -80,6 +81,7 @@ def create_app(
debug: bool = False,
) -> Bottle:
app = Bottle()
app._inmemory_cache = InMemoryCache()

@app.hook("before_request")
def remove_trailing_slashes_hook() -> None:
Expand Down Expand Up @@ -214,7 +216,7 @@ def get_study_detail(study_id: int) -> dict[str, Any]:
if study is None:
response.status = 404 # Not found
return {"reason": f"study_id={study_id} is not found"}
trials = get_trials(storage, study_id)
trials = get_trials(app._inmemory_cache, storage, study_id)

system_attrs = getattr(study, "system_attrs", {})
is_preferential = system_attrs.get(_SYSTEM_ATTR_PREFERENTIAL_STUDY, False)
Expand All @@ -235,7 +237,7 @@ def get_study_detail(study_id: int) -> dict[str, Any]:
union,
union_user_attrs,
has_intermediate_values,
) = get_cached_extra_study_property(study_id, trials)
) = get_cached_extra_study_property(app._inmemory_cache, study_id, trials)

plotly_graph_objects = get_plotly_graph_objects(system_attrs)
skipped_trial_ids = get_skipped_trial_ids(system_attrs)
Expand All @@ -261,10 +263,16 @@ def get_param_importances(study_id: int) -> dict[str, Any]:
response.status = 404 # Study is not found
return {"reason": f"study_id={study_id} is not found"}

trials = get_trials(storage, study_id)
trials = get_trials(app._inmemory_cache, storage, study_id)
try:
importances = [
get_param_importance_from_trials_cache(storage, study_id, objective_id, trials)
get_param_importance_from_trials_cache(
app._inmemory_cache,
storage,
study_id,
objective_id,
trials,
)
for objective_id in range(n_directions)
]
return {"param_importances": importances}
Expand Down
2 changes: 1 addition & 1 deletion optuna_dashboard/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def run_debug_server(app: Bottle, host: str, port: int, quiet: bool) -> None:


def auto_select_server(
server_arg: Literal["auto", "gunicorn", "wsgiref"]
server_arg: Literal["auto", "gunicorn", "wsgiref"],
) -> Literal["gunicorn", "wsgiref"]:
if server_arg != "auto":
return server_arg
Expand Down
13 changes: 10 additions & 3 deletions optuna_dashboard/_importance.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from optuna.study import Study
from optuna.trial import FrozenTrial
from optuna.trial import TrialState
from optuna_dashboard._cached_extra_study_property import get_cached_extra_study_property
from optuna_dashboard._inmemory_cache import get_cached_extra_study_property
from optuna_dashboard._inmemory_cache import InMemoryCache


_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -97,7 +98,11 @@ def _get_param_importances(


def get_param_importance_from_trials_cache(
storage: BaseStorage, study_id: int, objective_id: int, trials: list[FrozenTrial]
inmemory_cache: InMemoryCache,
storage: BaseStorage,
study_id: int,
objective_id: int,
trials: list[FrozenTrial],
) -> list[ImportanceType]:
completed_trials = [t for t in trials if t.state == TrialState.COMPLETE]
n_completed_trials = len(completed_trials)
Expand All @@ -118,7 +123,9 @@ def get_param_importance_from_trials_cache(
except RuntimeError:
# RuntimeError("Encountered zero total variance in all trees.") may be raised
# when all objective values are same.
_, union_search_space, _, _ = get_cached_extra_study_property(study_id, trials)
_, union_search_space, _, _ = get_cached_extra_study_property(
inmemory_cache, study_id, trials
)
importance_value = 1 / len(union_search_space)
importance = {
param_name: importance_value for param_name, distribution in union_search_space
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import copy
from datetime import datetime
import numbers
import threading
from typing import List
Expand All @@ -14,25 +15,22 @@
from optuna.trial import TrialState


# In-memory cache
cached_extra_study_property_cache_lock = threading.Lock()
cached_extra_study_property_cache: dict[int, "_CachedExtraStudyProperty"] = {}


if TYPE_CHECKING:
SearchSpaceSetT = Set[Tuple[str, BaseDistribution]]
SearchSpaceListT = List[Tuple[str, BaseDistribution]]


def get_cached_extra_study_property(
study_id: int, trials: list[FrozenTrial]
in_memory_cache: InMemoryCache, study_id: int, trials: list[FrozenTrial]
) -> tuple[SearchSpaceListT, SearchSpaceListT, list[tuple[str, bool]], bool]:
with cached_extra_study_property_cache_lock:
cached_extra_study_property = cached_extra_study_property_cache.get(study_id, None)
with in_memory_cache._cached_extra_study_property_cache_lock:
cached_extra_study_property = in_memory_cache._cached_extra_study_property_cache.get(
study_id, None
)
if cached_extra_study_property is None:
cached_extra_study_property = _CachedExtraStudyProperty()
cached_extra_study_property.update(trials)
cached_extra_study_property_cache[study_id] = cached_extra_study_property
in_memory_cache._cached_extra_study_property_cache[study_id] = cached_extra_study_property
return (
cached_extra_study_property.intersection_search_space,
cached_extra_study_property.union_search_space,
Expand All @@ -41,6 +39,20 @@ def get_cached_extra_study_property(
)


class InMemoryCache:
def __init__(self) -> None:
self._cached_extra_study_property_cache: dict[int, "_CachedExtraStudyProperty"] = {}
self._cached_extra_study_property_cache_lock = threading.Lock()
self._trials_cache: dict[int, list[FrozenTrial]] = {}
self._trials_cache_lock = threading.Lock()
self._trials_last_fetched_at: dict[int, datetime] = {}

def clear(self) -> None:
self._cached_extra_study_property_cache.clear()
self._trials_cache.clear()
fusawa-yugo marked this conversation as resolved.
Show resolved Hide resolved
self._trials_last_fetched_at.clear()
Comment on lines +51 to +53
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need the lock here as well?

Suggested change
self._cached_extra_study_property_cache.clear()
self._trials_cache.clear()
self._trials_last_fetched_at.clear()
with self._cached_extra_study_property_cache_lock:
self._cached_extra_study_property_cache.clear()
with self._trials_cache_lock:
self._trials_cache.clear()
self._trials_last_fetched_at.clear()



class _CachedExtraStudyProperty:
def __init__(self) -> None:
self._cursor: int = -1
Expand Down
2 changes: 1 addition & 1 deletion optuna_dashboard/_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def serialize_distribution(distribution: BaseDistribution) -> DistributionJSON:


def serialize_search_space(
search_space: list[tuple[str, BaseDistribution]]
search_space: list[tuple[str, BaseDistribution]],
) -> list[dict[str, Any]]:
serialized = []
for param_name, distribution in search_space:
Expand Down
23 changes: 10 additions & 13 deletions optuna_dashboard/_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,21 @@

from datetime import datetime
from datetime import timedelta
import threading

from optuna.storages import BaseStorage
from optuna.storages import RDBStorage
from optuna.study import StudyDirection
from optuna.study._frozen import FrozenStudy
from optuna.trial import FrozenTrial

from ._inmemory_cache import InMemoryCache

# In-memory trials cache
trials_cache_lock = threading.Lock()
trials_cache: dict[int, list[FrozenTrial]] = {}
trials_last_fetched_at: dict[int, datetime] = {}


def get_trials(storage: BaseStorage, study_id: int) -> list[FrozenTrial]:
with trials_cache_lock:
trials = trials_cache.get(study_id, None)
def get_trials(
in_memory_cache: InMemoryCache, storage: BaseStorage, study_id: int
) -> list[FrozenTrial]:
with in_memory_cache._trials_cache_lock:
trials = in_memory_cache._trials_cache.get(study_id, None)

# Not a big fan of the heuristic, but I can't think of anything better.
if trials is None or len(trials) < 100:
Expand All @@ -29,7 +26,7 @@ def get_trials(storage: BaseStorage, study_id: int) -> list[FrozenTrial]:
else:
ttl_seconds = 10

last_fetched_at = trials_last_fetched_at.get(study_id, None)
last_fetched_at = in_memory_cache._trials_last_fetched_at.get(study_id, None)
if (
trials is not None
and last_fetched_at is not None
Expand All @@ -38,9 +35,9 @@ def get_trials(storage: BaseStorage, study_id: int) -> list[FrozenTrial]:
return trials
trials = storage.get_all_trials(study_id, deepcopy=False)

with trials_cache_lock:
trials_last_fetched_at[study_id] = datetime.now()
trials_cache[study_id] = trials
with in_memory_cache._trials_cache_lock:
in_memory_cache._trials_last_fetched_at[study_id] = datetime.now()
in_memory_cache._trials_cache[study_id] = trials
return trials


Expand Down
2 changes: 1 addition & 1 deletion python_tests/test_cached_extra_study_property.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from optuna.distributions import FloatDistribution
from optuna.exceptions import ExperimentalWarning
from optuna.trial import TrialState
from optuna_dashboard._cached_extra_study_property import _CachedExtraStudyProperty
from optuna_dashboard._inmemory_cache import _CachedExtraStudyProperty


class _CachedExtraStudyPropertySearchSpaceTestCase(TestCase):
Expand Down
11 changes: 1 addition & 10 deletions python_tests/wsgi_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,13 @@
import typing

from bottle import Bottle
from optuna_dashboard._storage import trials_cache
from optuna_dashboard._storage import trials_cache_lock
from optuna_dashboard._storage import trials_last_fetched_at


if typing.TYPE_CHECKING:
from _typeshed import OptExcInfo
from _typeshed.wsgi import WSGIEnvironment


def clear_inmemory_cache() -> None:
with trials_cache_lock:
trials_cache.clear()
trials_last_fetched_at.clear()


def create_wsgi_env(
path: str,
method: str,
Expand Down Expand Up @@ -77,7 +68,7 @@ def start_response(
queries = queries or {}
env = create_wsgi_env(path, method, content_type, bytes_body, queries, headers)

clear_inmemory_cache()
app._inmemory_cache.clear()
response_body = b""
iterable_body = app(env, start_response)
for b in iterable_body:
Expand Down
Loading