Skip to content

Commit

Permalink
Split out some mostly cosmetic changes from PR #720
Browse files Browse the repository at this point in the history
  • Loading branch information
bpkroth committed Jan 9, 2025
1 parent 013005c commit e765331
Showing 1 changed file with 20 additions and 38 deletions.
58 changes: 20 additions & 38 deletions mlos_bench/mlos_bench/schedulers/base_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def __init__( # pylint: disable=too-many-arguments
self._do_teardown = bool(config.get("teardown", True))

self._experiment: Storage.Experiment | None = None
self.environment = environment
self._environment = environment
self._optimizer = optimizer
self._storage = storage
self._root_env_config = root_env_config
Expand Down Expand Up @@ -132,23 +132,9 @@ def experiment(self) -> Storage.Experiment | None:
return self._experiment

@property
def root_environment(self) -> Environment:
"""
Gets the root (prototypical) Environment from the first TrialRunner.
Notes
-----
All TrialRunners have the same Environment config and are made
unique by their use of the unique trial_runner_id assigned to each
TrialRunner's Environment's global_config.
"""
# Use the first TrialRunner's Environment as the root Environment.
return self._trial_runners[self._trial_runner_ids[0]].environment

@property
def environments(self) -> Iterable[Environment]:
"""Gets the Environment from the TrialRunners."""
return (trial_runner.environment for trial_runner in self._trial_runners.values())
def environment(self) -> Environment:
"""Gets the Experiment Storage."""
return self._environment

@property
def optimizer(self) -> Optimizer:
Expand Down Expand Up @@ -176,8 +162,8 @@ def __enter__(self) -> "Scheduler":
_LOG.debug("Scheduler START :: %s", self)
assert self.experiment is None
assert not self._in_context
self._environment.__enter__()
self._optimizer.__enter__()
self.environment.__enter__()
self.optimizer.__enter__()
# Start new or resume the existing experiment. Verify that the
# experiment configuration is compatible with the previous runs.
# If the `merge` config parameter is present, merge in the data
Expand All @@ -186,8 +172,8 @@ def __enter__(self) -> "Scheduler":
experiment_id=self._experiment_id,
trial_id=self._trial_id,
root_env_config=self._root_env_config,
description=self.root_environment.name,
tunables=self.root_environment.tunable_params,
description=self.environment.name,
tunables=self.environment.tunable_params,
opt_targets=self.optimizer.targets,
).__enter__()
self._in_context = True
Expand All @@ -206,10 +192,10 @@ def __exit__(
assert ex_type and ex_val
_LOG.warning("Scheduler END :: %s", self, exc_info=(ex_type, ex_val, ex_tb))
assert self._in_context
assert self._experiment is not None
self._experiment.__exit__(ex_type, ex_val, ex_tb)
self._optimizer.__exit__(ex_type, ex_val, ex_tb)
self._experiment.__exit__(ex_type, ex_val, ex_tb)
assert self.experiment is not None
self.experiment.__exit__(ex_type, ex_val, ex_tb)
self.optimizer.__exit__(ex_type, ex_val, ex_tb)
self.environment.__exit__(ex_type, ex_val, ex_tb)
self._experiment = None
self._in_context = False
return False # Do not suppress exceptions
Expand All @@ -220,46 +206,42 @@ def start(self) -> None:
assert self.experiment is not None
_LOG.info(
"START: Experiment: %s Env: %s Optimizer: %s",
self._experiment,
self.root_environment,
self.experiment,
self.environment,
self.optimizer,
)
if _LOG.isEnabledFor(logging.INFO):
_LOG.info("Root Environment:\n%s", self.root_environment.pprint())
_LOG.info("Root Environment:\n%s", self.environment.pprint())

if self._config_id > 0:
tunables = self.load_tunable_config(self._config_id)
self.schedule_trial(tunables)

def teardown(self) -> None:
"""
Tear down the TrialRunners/Environment(s).
Tear down the environment.
Call it after the completion of the `.start()` in the scheduler context.
"""
assert self.experiment is not None
if self._do_teardown:
for trial_runner in self._trial_runners.values():
assert not trial_runner.is_running
trial_runner.teardown()
self.environment.teardown()

def get_best_observation(self) -> tuple[dict[str, float] | None, TunableGroups | None]:
"""Get the best observation from the optimizer."""
(best_score, best_config) = self.optimizer.get_best_observation()
_LOG.info("Env: %s best score: %s", self.root_environment, best_score)
_LOG.info("Env: %s best score: %s", self.environment, best_score)
return (best_score, best_config)

def load_tunable_config(self, config_id: int) -> TunableGroups:
"""Load the existing tunable configuration from the storage."""
assert self.experiment is not None
tunable_values = self.experiment.load_tunable_config(config_id)
tunables = TunableGroups()
for environment in self.environments:
tunables = environment.tunable_params.assign(tunable_values)
tunables = self.environment.tunable_params.assign(tunable_values)
_LOG.info("Load config from storage: %d", config_id)
if _LOG.isEnabledFor(logging.DEBUG):
_LOG.debug("Config %d ::\n%s", config_id, json.dumps(tunable_values, indent=2))
return tunables.copy()
return tunables

def _schedule_new_optimizer_suggestions(self) -> bool:
"""
Expand Down

0 comments on commit e765331

Please sign in to comment.