Skip to content

Commit

Permalink
wip: some changes from other prs
Browse files Browse the repository at this point in the history
  • Loading branch information
bpkroth committed Jan 9, 2025
1 parent d7294fa commit 013005c
Showing 1 changed file with 76 additions and 32 deletions.
108 changes: 76 additions & 32 deletions mlos_bench/mlos_bench/schedulers/base_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,17 @@ def __init__( # pylint: disable=too-many-arguments
Parameters
----------
config : dict
The configuration for the scheduler.
The configuration for the Scheduler.
global_config : dict
he global configuration for the experiment.
The global configuration for the Experiment.
environment : Environment
The environment to benchmark/optimize.
The Environment to optimize.
optimizer : Optimizer
The optimizer to use.
The Optimizer to use.
storage : Storage
The storage to use.
The Storage to use.
root_env_config : str
Path to the root environment configuration.
Path to the root Environment configuration.
"""
self.global_config = global_config
config = merge_parameters(
Expand All @@ -66,6 +66,7 @@ def __init__( # pylint: disable=too-many-arguments
)
self._validate_json_config(config)

self._in_context = False
self._experiment_id = config["experiment_id"].strip()
self._trial_id = int(config["trial_id"])
self._config_id = int(config.get("config_id", -1))
Expand All @@ -80,10 +81,10 @@ def __init__( # pylint: disable=too-many-arguments

self._do_teardown = bool(config.get("teardown", True))

self.experiment: Storage.Experiment | None = None
self._experiment: Storage.Experiment | None = None
self.environment = environment
self.optimizer = optimizer
self.storage = storage
self._optimizer = optimizer
self._storage = storage
self._root_env_config = root_env_config
self._last_trial_id = -1
self._ran_trials: list[Storage.Trial] = []
Expand Down Expand Up @@ -125,6 +126,40 @@ def max_trials(self) -> int:
"""
return self._max_trials

@property
def experiment(self) -> Storage.Experiment | None:
"""Gets the Experiment Storage."""
return self._experiment

@property
def root_environment(self) -> Environment:
"""
Gets the root (prototypical) Environment from the first TrialRunner.
Notes
-----
All TrialRunners have the same Environment config and are made
unique by their use of the unique trial_runner_id assigned to each
TrialRunner's Environment's global_config.
"""
# Use the first TrialRunner's Environment as the root Environment.
return self._trial_runners[self._trial_runner_ids[0]].environment

@property
def environments(self) -> Iterable[Environment]:
"""Gets the Environment from the TrialRunners."""
return (trial_runner.environment for trial_runner in self._trial_runners.values())

@property
def optimizer(self) -> Optimizer:
"""Gets the Optimizer."""
return self._optimizer

@property
def storage(self) -> Storage:
"""Gets the Storage."""
return self._storage

def __repr__(self) -> str:
"""
Produce a human-readable version of the Scheduler (mostly for logging).
Expand All @@ -140,20 +175,22 @@ def __enter__(self) -> "Scheduler":
"""Enter the scheduler's context."""
_LOG.debug("Scheduler START :: %s", self)
assert self.experiment is None
self.environment.__enter__()
self.optimizer.__enter__()
assert not self._in_context
self._environment.__enter__()
self._optimizer.__enter__()
# Start new or resume the existing experiment. Verify that the
# experiment configuration is compatible with the previous runs.
# If the `merge` config parameter is present, merge in the data
# from other experiments and check for compatibility.
self.experiment = self.storage.experiment(
self._experiment = self.storage.experiment(
experiment_id=self._experiment_id,
trial_id=self._trial_id,
root_env_config=self._root_env_config,
description=self.environment.name,
tunables=self.environment.tunable_params,
description=self.root_environment.name,
tunables=self.root_environment.tunable_params,
opt_targets=self.optimizer.targets,
).__enter__()
self._in_context = True
return self

def __exit__(
Expand All @@ -168,55 +205,61 @@ def __exit__(
else:
assert ex_type and ex_val
_LOG.warning("Scheduler END :: %s", self, exc_info=(ex_type, ex_val, ex_tb))
assert self.experiment is not None
self.experiment.__exit__(ex_type, ex_val, ex_tb)
self.optimizer.__exit__(ex_type, ex_val, ex_tb)
self.environment.__exit__(ex_type, ex_val, ex_tb)
self.experiment = None
assert self._in_context
assert self._experiment is not None
self._experiment.__exit__(ex_type, ex_val, ex_tb)
self._optimizer.__exit__(ex_type, ex_val, ex_tb)
self._experiment.__exit__(ex_type, ex_val, ex_tb)
self._experiment = None
self._in_context = False
return False # Do not suppress exceptions

@abstractmethod
def start(self) -> None:
"""Start the optimization loop."""
"""Start the scheduling loop."""
assert self.experiment is not None
_LOG.info(
"START: Experiment: %s Env: %s Optimizer: %s",
self.experiment,
self.environment,
self._experiment,
self.root_environment,
self.optimizer,
)
if _LOG.isEnabledFor(logging.INFO):
_LOG.info("Root Environment:\n%s", self.environment.pprint())
_LOG.info("Root Environment:\n%s", self.root_environment.pprint())

if self._config_id > 0:
tunables = self.load_config(self._config_id)
tunables = self.load_tunable_config(self._config_id)
self.schedule_trial(tunables)

def teardown(self) -> None:
"""
Tear down the environment.
Tear down the TrialRunners/Environment(s).
Call it after the completion of the `.start()` in the scheduler context.
"""
assert self.experiment is not None
if self._do_teardown:
self.environment.teardown()
for trial_runner in self._trial_runners.values():
assert not trial_runner.is_running
trial_runner.teardown()

def get_best_observation(self) -> tuple[dict[str, float] | None, TunableGroups | None]:
"""Get the best observation from the optimizer."""
(best_score, best_config) = self.optimizer.get_best_observation()
_LOG.info("Env: %s best score: %s", self.environment, best_score)
_LOG.info("Env: %s best score: %s", self.root_environment, best_score)
return (best_score, best_config)

def load_config(self, config_id: int) -> TunableGroups:
def load_tunable_config(self, config_id: int) -> TunableGroups:
"""Load the existing tunable configuration from the storage."""
assert self.experiment is not None
tunable_values = self.experiment.load_tunable_config(config_id)
tunables = self.environment.tunable_params.assign(tunable_values)
tunables = TunableGroups()
for environment in self.environments:
tunables = environment.tunable_params.assign(tunable_values)
_LOG.info("Load config from storage: %d", config_id)
if _LOG.isEnabledFor(logging.DEBUG):
_LOG.debug("Config %d ::\n%s", config_id, json.dumps(tunable_values, indent=2))
return tunables
return tunables.copy()

def _schedule_new_optimizer_suggestions(self) -> bool:
"""
Expand Down Expand Up @@ -270,13 +313,13 @@ def _add_trial_to_queue(
config: dict[str, Any] | None = None,
) -> None:
"""
Add a configuration to the queue of trials.
Add a configuration to the queue of trials in the Storage backend.
A wrapper for the `Experiment.new_trial` method.
"""
assert self.experiment is not None
trial = self.experiment.new_trial(tunables, ts_start, config)
_LOG.info("QUEUE: Add new trial: %s", trial)
_LOG.info("QUEUE: Added new trial: %s", trial)

def _run_schedule(self, running: bool = False) -> None:
"""
Expand Down Expand Up @@ -305,6 +348,7 @@ def run_trial(self, trial: Storage.Trial) -> None:
Save the results in the storage.
"""
assert self._in_context
assert self.experiment is not None
self._trial_count += 1
self._ran_trials.append(trial)
Expand Down

0 comments on commit 013005c

Please sign in to comment.