Skip to content

Commit

Permalink
test new cache
Browse files Browse the repository at this point in the history
  • Loading branch information
fusawayugo committed Jan 29, 2025
1 parent 765735c commit f086a20
Show file tree
Hide file tree
Showing 9 changed files with 287 additions and 174 deletions.
4 changes: 1 addition & 3 deletions e2e_tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
from optuna_dashboard import wsgi
import pytest

from .utils import clear_inmemory_cache


def get_free_port() -> int:
tcp = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
Expand All @@ -30,7 +28,7 @@ def make_test_server(
thread.start()

def stop_server() -> None:
clear_inmemory_cache()
app._inmemory_cache.clear()
httpd.shutdown()
httpd.server_close()
thread.join()
Expand Down
13 changes: 0 additions & 13 deletions e2e_tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,6 @@
from optuna_dashboard._cached_extra_study_property import cached_extra_study_property_cache
from optuna_dashboard._cached_extra_study_property import cached_extra_study_property_cache_lock
from optuna_dashboard._storage import trials_cache
from optuna_dashboard._storage import trials_cache_lock
from optuna_dashboard._storage import trials_last_fetched_at
from playwright.sync_api import Page


def clear_inmemory_cache() -> None:
with trials_cache_lock:
trials_cache.clear()
trials_last_fetched_at.clear()
with cached_extra_study_property_cache_lock:
cached_extra_study_property_cache.clear()


def count_components(page: Page, component_name: str):
component_count = page.evaluate(
f"""() => {{
Expand Down
16 changes: 10 additions & 6 deletions optuna_dashboard/_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@
from . import _note as note
from ._bottle_util import BottleViewReturn
from ._bottle_util import json_api_view
from ._cached_extra_study_property import get_cached_extra_study_property
from ._custom_plot_data import get_plotly_graph_objects
from ._importance import get_param_importance_from_trials_cache
from ._inmemory_cache import get_cached_extra_study_property
from ._inmemory_cache import get_trials
from ._inmemory_cache import InMemoryCache
from ._pareto_front import get_pareto_front_trials
from ._preference_setting import _register_preference_feedback_component
from ._preferential_history import NewHistory
Expand All @@ -43,7 +45,6 @@
from ._storage import create_new_study
from ._storage import get_studies
from ._storage import get_study
from ._storage import get_trials
from ._storage_url import get_storage
from .artifact._backend import delete_all_artifacts
from .artifact._backend import register_artifact_route
Expand Down Expand Up @@ -80,6 +81,7 @@ def create_app(
debug: bool = False,
) -> Bottle:
app = Bottle()
app._inmemory_cache = InMemoryCache()

@app.hook("before_request")
def remove_trailing_slashes_hook() -> None:
Expand Down Expand Up @@ -214,7 +216,7 @@ def get_study_detail(study_id: int) -> dict[str, Any]:
if study is None:
response.status = 404 # Not found
return {"reason": f"study_id={study_id} is not found"}
trials = get_trials(storage, study_id)
trials = get_trials(app._inmemory_cache, storage, study_id)

system_attrs = getattr(study, "system_attrs", {})
is_preferential = system_attrs.get(_SYSTEM_ATTR_PREFERENTIAL_STUDY, False)
Expand All @@ -235,7 +237,7 @@ def get_study_detail(study_id: int) -> dict[str, Any]:
union,
union_user_attrs,
has_intermediate_values,
) = get_cached_extra_study_property(study_id, trials)
) = get_cached_extra_study_property(app._inmemory_cache, study_id, trials)

plotly_graph_objects = get_plotly_graph_objects(system_attrs)
skipped_trial_ids = get_skipped_trial_ids(system_attrs)
Expand All @@ -261,10 +263,12 @@ def get_param_importances(study_id: int) -> dict[str, Any]:
response.status = 404 # Study is not found
return {"reason": f"study_id={study_id} is not found"}

trials = get_trials(storage, study_id)
trials = get_trials(app._inmemory_cache, storage, study_id)
try:
importances = [
get_param_importance_from_trials_cache(storage, study_id, objective_id, trials)
get_param_importance_from_trials_cache(
storage, study_id, objective_id, trials, app._inmemory_cache
)
for objective_id in range(n_directions)
]
return {"param_importances": importances}
Expand Down
222 changes: 111 additions & 111 deletions optuna_dashboard/_cached_extra_study_property.py
Original file line number Diff line number Diff line change
@@ -1,111 +1,111 @@
from __future__ import annotations

import copy
import numbers
import threading
from typing import List
from typing import Optional
from typing import Set
from typing import Tuple
from typing import TYPE_CHECKING

from optuna.distributions import BaseDistribution
from optuna.trial import FrozenTrial
from optuna.trial import TrialState


# In-memory cache
cached_extra_study_property_cache_lock = threading.Lock()
cached_extra_study_property_cache: dict[int, "_CachedExtraStudyProperty"] = {}


if TYPE_CHECKING:
SearchSpaceSetT = Set[Tuple[str, BaseDistribution]]
SearchSpaceListT = List[Tuple[str, BaseDistribution]]


def get_cached_extra_study_property(
study_id: int, trials: list[FrozenTrial]
) -> tuple[SearchSpaceListT, SearchSpaceListT, list[tuple[str, bool]], bool]:
with cached_extra_study_property_cache_lock:
cached_extra_study_property = cached_extra_study_property_cache.get(study_id, None)
if cached_extra_study_property is None:
cached_extra_study_property = _CachedExtraStudyProperty()
cached_extra_study_property.update(trials)
cached_extra_study_property_cache[study_id] = cached_extra_study_property
return (
cached_extra_study_property.intersection_search_space,
cached_extra_study_property.union_search_space,
cached_extra_study_property.union_user_attrs,
cached_extra_study_property.has_intermediate_values,
)


class _CachedExtraStudyProperty:
def __init__(self) -> None:
self._cursor: int = -1
self._intersection_search_space: Optional[SearchSpaceSetT] = None
self._union_search_space: SearchSpaceSetT = set()
self._union_user_attrs: dict[str, bool] = {} # attr_name: is_sortable (= is_number)
self.has_intermediate_values: bool = False

@property
def intersection_search_space(self) -> SearchSpaceListT:
if self._intersection_search_space is None:
return []
intersection = list(self._intersection_search_space)
intersection.sort(key=lambda x: x[0])
return intersection

@property
def union_search_space(self) -> SearchSpaceListT:
union = list(self._union_search_space)
union.sort(key=lambda x: x[0])
return union

@property
def union_user_attrs(self) -> list[tuple[str, bool]]:
union = [(name, is_sortable) for name, is_sortable in self._union_user_attrs.items()]
sorted(union, key=lambda x: x[0])
return union

def update(self, trials: list[FrozenTrial]) -> None:
next_cursor = self._cursor
for trial in reversed(trials):
if self._cursor > trial.number:
break

if not trial.state.is_finished():
next_cursor = trial.number

self._update_user_attrs(trial)
if trial.state != TrialState.FAIL:
self._update_intermediate_values(trial)
self._update_search_space(trial)

self._cursor = next_cursor

def _update_user_attrs(self, trial: FrozenTrial) -> None:
current_user_attrs = {
k: not isinstance(v, bool) and isinstance(v, numbers.Real)
for k, v in trial.user_attrs.items()
}
for attr_name, current_is_sortable in current_user_attrs.items():
is_sortable = self._union_user_attrs.get(attr_name)
if is_sortable is None:
self._union_user_attrs[attr_name] = current_is_sortable
elif is_sortable and not current_is_sortable:
self._union_user_attrs[attr_name] = False

def _update_intermediate_values(self, trial: FrozenTrial) -> None:
if not self.has_intermediate_values and len(trial.intermediate_values) > 0:
self.has_intermediate_values = True

def _update_search_space(self, trial: FrozenTrial) -> None:
current = set([(n, d) for n, d in trial.distributions.items()])
self._union_search_space = self._union_search_space.union(current)

if self._intersection_search_space is None:
self._intersection_search_space = copy.copy(current)
else:
self._intersection_search_space = self._intersection_search_space.intersection(current)
# from __future__ import annotations

# import copy
# import numbers
# import threading
# from typing import List
# from typing import Optional
# from typing import Set
# from typing import Tuple
# from typing import TYPE_CHECKING

# from optuna.distributions import BaseDistribution
# from optuna.trial import FrozenTrial
# from optuna.trial import TrialState


# # In-memory cache
# cached_extra_study_property_cache_lock = threading.Lock()
# cached_extra_study_property_cache: dict[int, "_CachedExtraStudyProperty"] = {}


# if TYPE_CHECKING:
# SearchSpaceSetT = Set[Tuple[str, BaseDistribution]]
# SearchSpaceListT = List[Tuple[str, BaseDistribution]]


# def get_cached_extra_study_property(
# study_id: int, trials: list[FrozenTrial]
# ) -> tuple[SearchSpaceListT, SearchSpaceListT, list[tuple[str, bool]], bool]:
# with cached_extra_study_property_cache_lock:
# cached_extra_study_property = cached_extra_study_property_cache.get(study_id, None)
# if cached_extra_study_property is None:
# cached_extra_study_property = _CachedExtraStudyProperty()
# cached_extra_study_property.update(trials)
# cached_extra_study_property_cache[study_id] = cached_extra_study_property
# return (
# cached_extra_study_property.intersection_search_space,
# cached_extra_study_property.union_search_space,
# cached_extra_study_property.union_user_attrs,
# cached_extra_study_property.has_intermediate_values,
# )


# class _CachedExtraStudyProperty:
# def __init__(self) -> None:
# self._cursor: int = -1
# self._intersection_search_space: Optional[SearchSpaceSetT] = None
# self._union_search_space: SearchSpaceSetT = set()
# self._union_user_attrs: dict[str, bool] = {} # attr_name: is_sortable (= is_number)
# self.has_intermediate_values: bool = False

# @property
# def intersection_search_space(self) -> SearchSpaceListT:
# if self._intersection_search_space is None:
# return []
# intersection = list(self._intersection_search_space)
# intersection.sort(key=lambda x: x[0])
# return intersection

# @property
# def union_search_space(self) -> SearchSpaceListT:
# union = list(self._union_search_space)
# union.sort(key=lambda x: x[0])
# return union

# @property
# def union_user_attrs(self) -> list[tuple[str, bool]]:
# union = [(name, is_sortable) for name, is_sortable in self._union_user_attrs.items()]
# sorted(union, key=lambda x: x[0])
# return union

# def update(self, trials: list[FrozenTrial]) -> None:
# next_cursor = self._cursor
# for trial in reversed(trials):
# if self._cursor > trial.number:
# break

# if not trial.state.is_finished():
# next_cursor = trial.number

# self._update_user_attrs(trial)
# if trial.state != TrialState.FAIL:
# self._update_intermediate_values(trial)
# self._update_search_space(trial)

# self._cursor = next_cursor

# def _update_user_attrs(self, trial: FrozenTrial) -> None:
# current_user_attrs = {
# k: not isinstance(v, bool) and isinstance(v, numbers.Real)
# for k, v in trial.user_attrs.items()
# }
# for attr_name, current_is_sortable in current_user_attrs.items():
# is_sortable = self._union_user_attrs.get(attr_name)
# if is_sortable is None:
# self._union_user_attrs[attr_name] = current_is_sortable
# elif is_sortable and not current_is_sortable:
# self._union_user_attrs[attr_name] = False

# def _update_intermediate_values(self, trial: FrozenTrial) -> None:
# if not self.has_intermediate_values and len(trial.intermediate_values) > 0:
# self.has_intermediate_values = True

# def _update_search_space(self, trial: FrozenTrial) -> None:
# current = set([(n, d) for n, d in trial.distributions.items()])
# self._union_search_space = self._union_search_space.union(current)

# if self._intersection_search_space is None:
# self._intersection_search_space = copy.copy(current)
# else:
# self._intersection_search_space = self._intersection_search_space.intersection(current)
13 changes: 10 additions & 3 deletions optuna_dashboard/_importance.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from optuna.study import Study
from optuna.trial import FrozenTrial
from optuna.trial import TrialState
from optuna_dashboard._cached_extra_study_property import get_cached_extra_study_property
from optuna_dashboard._inmemory_cache import get_cached_extra_study_property
from optuna_dashboard._inmemory_cache import InMemoryCache


_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -97,7 +98,11 @@ def _get_param_importances(


def get_param_importance_from_trials_cache(
storage: BaseStorage, study_id: int, objective_id: int, trials: list[FrozenTrial]
storage: BaseStorage,
study_id: int,
objective_id: int,
trials: list[FrozenTrial],
inmemory_cache: InMemoryCache,
) -> list[ImportanceType]:
completed_trials = [t for t in trials if t.state == TrialState.COMPLETE]
n_completed_trials = len(completed_trials)
Expand All @@ -118,7 +123,9 @@ def get_param_importance_from_trials_cache(
except RuntimeError:
# RuntimeError("Encountered zero total variance in all trees.") may be raised
# when all objective values are same.
_, union_search_space, _, _ = get_cached_extra_study_property(study_id, trials)
_, union_search_space, _, _ = get_cached_extra_study_property(
inmemory_cache, study_id, trials
)
importance_value = 1 / len(union_search_space)
importance = {
param_name: importance_value for param_name, distribution in union_search_space
Expand Down
Loading

0 comments on commit f086a20

Please sign in to comment.