From 9b031478912b7dce93de1b292a054f2b64deec4b Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Mon, 27 Jan 2025 10:28:00 +0100 Subject: [PATCH] refactor: cleanups --- neps/optimizers/__init__.py | 21 +- neps/optimizers/algorithms.py | 29 +- neps/optimizers/bayesian_optimization.py | 1 + neps/optimizers/bracket_optimizer.py | 81 ++--- neps/optimizers/models/gp.py | 52 +-- neps/optimizers/optimizer.py | 15 +- neps/optimizers/priorband.py | 296 +++++++++--------- neps/space/encoding.py | 2 +- neps/state/neps_state.py | 39 ++- neps/utils/cli.py | 7 +- neps/utils/common.py | 10 + tests/test_runtime/test_stopping_criterion.py | 4 +- 12 files changed, 294 insertions(+), 263 deletions(-) diff --git a/neps/optimizers/__init__.py b/neps/optimizers/__init__.py index 436d4854..9b97790a 100644 --- a/neps/optimizers/__init__.py +++ b/neps/optimizers/__init__.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections.abc import Callable, Mapping -from typing import TYPE_CHECKING, Any, Concatenate, Literal, TypedDict +from typing import TYPE_CHECKING, Any, Concatenate, Literal from neps.optimizers.algorithms import ( CustomOptimizer, @@ -9,26 +9,13 @@ PredefinedOptimizers, determine_optimizer_automatically, ) -from neps.optimizers.optimizer import AskFunction # noqa: TC001 +from neps.optimizers.optimizer import AskFunction, OptimizerInfo from neps.utils.common import extract_keyword_defaults if TYPE_CHECKING: from neps.space import SearchSpace -class OptimizerInfo(TypedDict): - """Information about the optimizer.""" - - name: str - """The name of the optimizer.""" - - info: Mapping[str, Any] - """Additional information about the optimizer. - - Usually this will be the keyword arguments used to initialize the optimizer. - """ - - def _load_optimizer_from_string( optimizer: OptimizerChoice | Literal["auto"], space: SearchSpace, @@ -97,9 +84,9 @@ def load_optimizer( # Custom (already initialized) optimizer case CustomOptimizer(initialized=True): - _optimizer = optimizer.optimizer + preinit_opt = optimizer.optimizer info = OptimizerInfo(name=optimizer.name, info=optimizer.kwargs) - return _optimizer, info # type: ignore + return preinit_opt, info # type: ignore case _: raise ValueError( diff --git a/neps/optimizers/algorithms.py b/neps/optimizers/algorithms.py index ee960164..dd383187 100644 --- a/neps/optimizers/algorithms.py +++ b/neps/optimizers/algorithms.py @@ -31,7 +31,7 @@ from neps.optimizers.ifbo import IFBO from neps.optimizers.models.ftpfn import FTPFNSurrogate from neps.optimizers.optimizer import AskFunction # noqa: TC001 -from neps.optimizers.priorband import PriorBandArgs +from neps.optimizers.priorband import PriorBandSampler from neps.optimizers.random_search import RandomSearch from neps.sampling import Prior, Sampler, Uniform from neps.space.encoding import CategoricalToUnitNorm, ConfigEncoder @@ -123,7 +123,7 @@ def _bracket_optimizer( # noqa: C901, PLR0912, PLR0915 *, bracket_type: Literal["successive_halving", "hyperband", "asha", "async_hb"], eta: int, - sampler: Literal["uniform", "prior", "priorband"] | PriorBandArgs | Sampler, + sampler: Literal["uniform", "prior", "priorband"] | PriorBandSampler | Sampler, bayesian_optimization_kick_in_point: int | float | None, sample_prior_first: bool | Literal["highest_fidelity"], # NOTE: This is the only argument to get a default, since it @@ -212,6 +212,7 @@ def _bracket_optimizer( # noqa: C901, PLR0912, PLR0915 brackets.Sync.create_repeating, rung_sizes=rung_sizes, ) + case "hyperband": assert early_stopping_rate is None rung_to_fidelity, bracket_layouts = brackets.calculate_hb_bracket_layouts( @@ -222,6 +223,7 @@ def _bracket_optimizer( # noqa: C901, PLR0912, PLR0915 brackets.Hyperband.create_repeating, bracket_layouts=bracket_layouts, ) + case "asha": assert early_stopping_rate is not None rung_to_fidelity, _rung_sizes = brackets.calculate_sh_rungs( @@ -234,6 +236,7 @@ def _bracket_optimizer( # noqa: C901, PLR0912, PLR0915 rungs=list(rung_to_fidelity), eta=eta, ) + case "async_hb": assert early_stopping_rate is None rung_to_fidelity, bracket_layouts = brackets.calculate_hb_bracket_layouts( @@ -252,22 +255,31 @@ def _bracket_optimizer( # noqa: C901, PLR0912, PLR0915 encoder = ConfigEncoder.from_parameters(parameters) - _sampler: Sampler | PriorBandArgs + _sampler: Sampler | PriorBandSampler match sampler: case "uniform": _sampler = Sampler.uniform(ndim=encoder.ndim) case "prior": _sampler = Prior.from_parameters(parameters) case "priorband": - _sampler = PriorBandArgs(mutation_rate=0.5, mutation_std=0.25) - case PriorBandArgs() | Sampler(): + _sampler = PriorBandSampler( + parameters=parameters, + mutation_rate=0.5, + mutation_std=0.25, + encoder=encoder, + eta=eta, + early_stopping_rate=( + early_stopping_rate if early_stopping_rate is not None else 0 + ), + fid_bounds=(fidelity.lower, fidelity.upper), + ) + case PriorBandSampler() | Sampler(): _sampler = sampler case _: raise ValueError(f"Unknown sampler: {sampler}") # TODO: This should be lifted out of this function and have the caller # pass in a `GPSampler`. - # TODO: Better name and parametrization of this if not going with above gp_sampler: GPSampler | None if bayesian_optimization_kick_in_point is not None: if bayesian_optimization_kick_in_point <= 0: @@ -276,8 +288,7 @@ def _bracket_optimizer( # noqa: C901, PLR0912, PLR0915 ) # TODO: Parametrize? - two_stage_batch_sample_size = 10 - modelling_strategy = "joint" + two_stage_batch_sample_size = 100 gp_parameters = {**parameters, **pipeline_space.fidelities} gp_sampler = GPSampler( @@ -287,7 +298,6 @@ def _bracket_optimizer( # noqa: C901, PLR0912, PLR0915 threshold=bayesian_optimization_kick_in_point, fidelity_name=fidelity_name, fidelity_max=fidelity.upper, - modelling_strategy=modelling_strategy, two_stage_batch_sample_size=two_stage_batch_sample_size, device=device, ) @@ -669,6 +679,7 @@ def asha( sample_prior_first: Whether to sample the prior configuration first, and if so, should it be at the highest fidelity. """ + return _bracket_optimizer( pipeline_space=space, bracket_type="asha", diff --git a/neps/optimizers/bayesian_optimization.py b/neps/optimizers/bayesian_optimization.py index 29cd4930..ec556803 100644 --- a/neps/optimizers/bayesian_optimization.py +++ b/neps/optimizers/bayesian_optimization.py @@ -187,6 +187,7 @@ def __call__( costs=data.cost if self.cost_aware is not False else None, cost_percentage_used=cost_percent, costs_on_log_scale=self.cost_aware == "log", + hide_warnings=True, ) configs = encoder.decode(candidates) diff --git a/neps/optimizers/bracket_optimizer.py b/neps/optimizers/bracket_optimizer.py index 61b10bdd..950a811c 100644 --- a/neps/optimizers/bracket_optimizer.py +++ b/neps/optimizers/bracket_optimizer.py @@ -10,6 +10,7 @@ import torch from botorch.acquisition.multi_objective.parego import qLogNoisyExpectedImprovement from botorch.acquisition.objective import LinearMCObjective +from gpytorch.utils.warnings import NumericalWarning from neps.optimizers.models.gp import ( encode_trials_for_gp, @@ -17,9 +18,10 @@ make_default_single_obj_gp, ) from neps.optimizers.optimizer import SampledConfig -from neps.optimizers.priorband import PriorBandArgs, sample_with_priorband +from neps.optimizers.priorband import PriorBandSampler from neps.optimizers.utils.brackets import PromoteAction, SampleAction from neps.sampling.samplers import Sampler +from neps.utils.common import disable_warnings if TYPE_CHECKING: from gpytorch.models.approximate_gp import Any @@ -108,13 +110,6 @@ class GPSampler: fidelity_max: int | float """The maximum fidelity value.""" - modelling_strategy: Literal["joint", "separate"] - """The strategy for which training data to use for the GP model. - - If set to `"joint"`, the GP model will be trained on all data, - across all fidelities jointly, where the fidelity is considered as a dimension. - """ - device: torch.device | None """The device to use for the GP optimization.""" @@ -136,12 +131,10 @@ def sample_config( Please see parameter descriptions in the class docstring for more. """ assert budget_info is None, "cost-aware (using budget_info) not supported yet." - assert self.modelling_strategy == "joint", "Only joint strategy is supported now." # fit the GP model using all trials, using fidelity as a dimension. # Get to top 10 configurations for acquisition fixed at fidelity Z # Switch those configurations to be at fidelity z_max and take the best. # y_max for EI is taken to be the best value seen so far, across all fidelity - data, _ = encode_trials_for_gp( trials, self.parameters, @@ -149,16 +142,18 @@ def sample_config( device=self.device, ) gp = make_default_single_obj_gp(x=data.x, y=data.y, encoder=self.encoder) - acqf = qLogNoisyExpectedImprovement( - model=gp, - X_baseline=data.x, - # Unfortunatly, there's no option to indicate that we minimize - # the AcqFunction so we need to do some kind of transformation. - # https://github.com/pytorch/botorch/issues/2316#issuecomment-2085964607 - objective=LinearMCObjective(weights=torch.tensor([-1.0])), - X_pending=data.x_pending, - prune_baseline=True, - ) + + with disable_warnings(NumericalWarning): + acqf = qLogNoisyExpectedImprovement( + model=gp, + X_baseline=data.x, + # Unfortunatly, there's no option to indicate that we minimize + # the AcqFunction so we need to do some kind of transformation. + # https://github.com/pytorch/botorch/issues/2316#issuecomment-2085964607 + objective=LinearMCObjective(weights=torch.tensor([-1.0])), + X_pending=data.x_pending, + prune_baseline=True, + ) # When it's max fidelity, we can just sample the best configuration we find, # as we do not need to do the two step procedure. @@ -181,6 +176,7 @@ def sample_config( costs=None, cost_percentage_used=None, costs_on_log_scale=False, + hide_warnings=True, ) assert len(candidates) == N @@ -196,7 +192,6 @@ def sample_config( # Next, we set those N configurations to be at the max fidelity # Decode, set max fidelity, and encode again (TODO: Could do directly on tensors) configs = self.encoder.decode(candidates) - print("configs", configs) # noqa: T201 fid_max_configs = [{**c, self.fidelity_name: self.fidelity_max} for c in configs] encoded_fix_max_configs = self.encoder.encode(fid_max_configs) @@ -237,7 +232,7 @@ class BracketOptimizer: create_brackets: Callable[[pd.DataFrame], Sequence[Bracket] | Bracket] """A function that creates the brackets from the table of trials.""" - sampler: Sampler | PriorBandArgs + sampler: Sampler | PriorBandSampler """The sampler used to generate new trials.""" gp_sampler: GPSampler | None @@ -345,21 +340,24 @@ def __call__( # noqa: C901, PLR0912 ) # The bracket would like us to sample a new configuration for a rung - case SampleAction(rung=rung): + # and we have gp sampler which should kick in + case SampleAction(rung=rung) if ( + self.gp_sampler is not None and self.gp_sampler.threshold_reached(trials) + ): # If we should used BO to sample once a threshold has been reached, # do so. Otherwise we proceed to use the original sampler. - if self.gp_sampler is not None and self.gp_sampler.threshold_reached( - trials - ): - target_fidelity = self.rung_to_fid[rung] - config = self.gp_sampler.sample_config( - trials, - budget_info=None, # TODO: budget_info not supported yet - target_fidelity=target_fidelity, - ) - config.update(space.constants) - return SampledConfig(id=f"{nxt_id}_{rung}", config=config) + target_fidelity = self.rung_to_fid[rung] + config = self.gp_sampler.sample_config( + trials, + budget_info=None, # TODO: budget_info not supported yet + target_fidelity=target_fidelity, + ) + config.update(space.constants) + return SampledConfig(id=f"{nxt_id}_{rung}", config=config) + # We need to sample for a new rung, with either no gp or it has + # not yet kicked in. + case SampleAction(rung=rung): # Otherwise, we proceed with the original sampler match self.sampler: case Sampler(): @@ -371,19 +369,8 @@ def __call__( # noqa: C901, PLR0912 } return SampledConfig(id=f"{nxt_id}_{rung}", config=config) - case PriorBandArgs(): - config = sample_with_priorband( - table=table, - parameters=space.searchables, - rung_to_sample_for=rung, - fid_bounds=(self.fid_min, self.fid_max), - encoder=self.encoder, - inc_mutation_rate=self.sampler.mutation_rate, - inc_mutation_std=self.sampler.mutation_std, - eta=self.eta, - seed=None, # TODO - ) - + case PriorBandSampler(): + config = self.sampler.sample_config(table, rung=rung) config = { **config, **space.constants, diff --git a/neps/optimizers/models/gp.py b/neps/optimizers/models/gp.py index 63207793..d503f369 100644 --- a/neps/optimizers/models/gp.py +++ b/neps/optimizers/models/gp.py @@ -4,6 +4,7 @@ import logging from collections.abc import Mapping, Sequence +from contextlib import nullcontext from dataclasses import dataclass from functools import reduce from itertools import product @@ -19,9 +20,11 @@ from botorch.optim import optimize_acqf, optimize_acqf_mixed from gpytorch import ExactMarginalLogLikelihood from gpytorch.kernels import ScaleKernel +from gpytorch.utils.warnings import NumericalWarning from neps.optimizers.acquisition import cost_cooled_acq, pibo_acquisition from neps.space.encoding import CategoricalToIntegerTransformer, ConfigEncoder +from neps.utils.common import disable_warnings if TYPE_CHECKING: from botorch.acquisition import AcquisitionFunction @@ -130,8 +133,12 @@ def optimize_acq( acq_options: Mapping[str, Any] | None = None, fixed_features: dict[str, Any] | None = None, maximum_allowed_categorical_combinations: int = 30, + hide_warnings: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: """Optimize the acquisition function.""" + warning_context = ( + disable_warnings(NumericalWarning) if hide_warnings else nullcontext() + ) acq_options = acq_options or {} _fixed_features: dict[int, float] = {} @@ -163,15 +170,16 @@ def optimize_acq( if n_intial_start_points is None: n_intial_start_points = min(64 * len(bounds) ** 2, 4096) - return optimize_acqf( # type: ignore - acq_function=acq_fn, - bounds=bounds, - q=n_candidates_required, - num_restarts=num_restarts, - raw_samples=n_intial_start_points, - fixed_features=_fixed_features, - **acq_options, - ) + with warning_context: + return optimize_acqf( # type: ignore + acq_function=acq_fn, + bounds=bounds, + q=n_candidates_required, + num_restarts=num_restarts, + raw_samples=n_intial_start_points, + fixed_features=_fixed_features, + **acq_options, + ) # We need to generate the product of all possible combinations of categoricals, # first we do a sanity check @@ -215,17 +223,18 @@ def optimize_acq( if len(_fixed_features) > 0: fixed_cats = [{**cat, **_fixed_features} for cat in fixed_cats] - # TODO: we should deterministically shuffle the fixed_categoricals - # as the underlying function does not. - return optimize_acqf_mixed( # type: ignore - acq_function=acq_fn, - bounds=bounds, - num_restarts=min(num_restarts // n_combos, 2), - raw_samples=n_intial_start_points, - q=n_candidates_required, - fixed_features_list=fixed_cats, - **acq_options, - ) + with warning_context: + # TODO: we should deterministically shuffle the fixed_categoricals + # as the underlying function does not. + return optimize_acqf_mixed( # type: ignore + acq_function=acq_fn, + bounds=bounds, + num_restarts=min(num_restarts // n_combos, 2), + raw_samples=n_intial_start_points, + q=n_candidates_required, + fixed_features_list=fixed_cats, + **acq_options, + ) def encode_trials_for_gp( @@ -316,6 +325,7 @@ def fit_and_acquire_from_gp( maximum_allowed_categorical_combinations: int = 30, fixed_acq_features: dict[str, Any] | None = None, acq_options: Mapping[str, Any] | None = None, + hide_warnings: bool = False, ) -> torch.Tensor: """Acquire the next configuration to evaluate using a GP. @@ -360,6 +370,7 @@ def fit_and_acquire_from_gp( combinations to allow. If the number of combinations exceeds this, an error will be raised. acq_options: Additional options to pass to the botorch `optimizer_acqf` function. + hide_warnings: Whether to hide numerical warnings issued during GP routines. Returns: The encoded next configuration(s) to evaluate. Use the encoder you provided @@ -439,5 +450,6 @@ def fit_and_acquire_from_gp( fixed_features=fixed_acq_features, acq_options=acq_options, maximum_allowed_categorical_combinations=maximum_allowed_categorical_combinations, + hide_warnings=hide_warnings, ) return candidates diff --git a/neps/optimizers/optimizer.py b/neps/optimizers/optimizer.py index 1eef2d7e..88dc8fcd 100644 --- a/neps/optimizers/optimizer.py +++ b/neps/optimizers/optimizer.py @@ -27,13 +27,26 @@ def __call__( from abc import abstractmethod from collections.abc import Mapping from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Protocol +from typing import TYPE_CHECKING, Any, Protocol, TypedDict if TYPE_CHECKING: from neps.state.optimizer import BudgetInfo from neps.state.trial import Trial +class OptimizerInfo(TypedDict): + """Information about the optimizer, usually used for serialization.""" + + name: str + """The name of the optimizer.""" + + info: Mapping[str, Any] + """Additional information about the optimizer. + + Usually this will be the keyword arguments used to initialize the optimizer. + """ + + @dataclass class SampledConfig: id: str diff --git a/neps/optimizers/priorband.py b/neps/optimizers/priorband.py index 51b29ff1..9d6d23e4 100644 --- a/neps/optimizers/priorband.py +++ b/neps/optimizers/priorband.py @@ -4,7 +4,7 @@ from collections.abc import Mapping from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, ClassVar +from typing import TYPE_CHECKING, Any import numpy as np import torch @@ -20,20 +20,155 @@ @dataclass -class PriorBandArgs: - """Arguments for the PriorBand sampler. - - Args: - mutation_rate: The mutation rate for the PriorBand algorithm when sampling - from the incumbent. - mutation_std: The standard deviation for the mutation rate when sampling - from the incumbent. +class PriorBandSampler: + """A Sampler implementing the PriorBand algorithm for sampling. + + * https://openreview.net/forum?id=uoiwugtpCH¬eId=xECpK2WH6k """ - name: ClassVar = "priorband" + parameters: Mapping[str, Parameter] + """The parameters to consider.""" + + encoder: ConfigEncoder + """The encoder to use for encoding and decoding configurations into tensors.""" mutation_rate: float + """The mutation rate to use when sampling from the incumbent distribution.""" + mutation_std: float + """The mutation deviation to use when sampling from the incumbent distribution.""" + + eta: int + """The eta value to use for the SH bracket.""" + + early_stopping_rate: int + """The early stopping rate to use for the SH bracket.""" + + fid_bounds: tuple[int, int] | tuple[float, float] + """The fidelity bounds.""" + + def sample_config(self, table: pd.DataFrame, rung: int) -> dict[str, Any]: + """Samples a configuration using the PriorBand algorithm. + + Args: + table: The table of all the trials that have been run. + rung_to_sample_for: The rung to sample for. + + Returns: + The sampled configuration. + """ + rung_to_fid, rung_sizes = brackets.calculate_sh_rungs( + bounds=self.fid_bounds, + eta=self.eta, + early_stopping_rate=self.early_stopping_rate, + ) + max_rung = max(rung_sizes) + + prior_dist = Prior.from_parameters(self.parameters) + + # Below we will follow the "geomtric" spacing + w_random = 1 / (1 + self.eta**rung) + w_prior = 1 - w_random + + completed: pd.DataFrame = table[table["perf"].notna()] # type: ignore + + # To see if we activate incumbent sampling, we check: + # 1) We have at least one fully complete run + # 2) We have spent at least one full SH bracket worth of fidelity + # 3) There is at least one rung with eta evaluations to get the top 1/eta configs + completed_rungs = completed.index.get_level_values("rung") + one_complete_run_at_max_rung = (completed_rungs == max_rung).any() + + # For SH bracket cost, we include the fact we can continue runs, + # i.e. resources for rung 2 discounts the cost of evaluating to rung 1, + # only counting the difference in fidelity cost between rung 2 and rung 1. + cost_per_rung = { + i: rung_to_fid[i] - rung_to_fid.get(i - 1, 0) for i in rung_to_fid + } + + cost_of_one_sh_bracket = sum(rung_sizes[r] * cost_per_rung[r] for r in rung_sizes) + current_cost_used = sum(r * cost_per_rung[r] for r in completed_rungs) + spent_one_sh_bracket_worth_of_fidelity = ( + current_cost_used >= cost_of_one_sh_bracket + ) + + # Check that there is at least rung with `eta` evaluations + rung_counts = completed.groupby("rung").size() + any_rung_with_eta_evals = (rung_counts == self.eta).any() + + # If the conditions are not met, we sample from the prior or randomly depending on + # the geometrically distributed prior and uniform weights + if ( + one_complete_run_at_max_rung is False + or spent_one_sh_bracket_worth_of_fidelity is False + or any_rung_with_eta_evals is False + ): + policy = np.random.choice(["prior", "random"], p=[w_prior, w_random]) + match policy: + case "prior": + config = prior_dist.sample_config(to=self.encoder) + case "random": + _sampler = Sampler.uniform(ndim=self.encoder.ndim) + config = _sampler.sample_config(to=self.encoder) + + return config + + # Otherwise, we now further split the `prior` weight into `(prior, inc)` + + # 1. Select the top `1//eta` percent of configs at the highest rung supporting it + rungs_with_at_least_eta = rung_counts[rung_counts >= self.eta].index # type: ignore + rung_table: pd.DataFrame = completed[ # type: ignore + completed.index.get_level_values("rung") == rungs_with_at_least_eta.max() + ] + + K = len(rung_table) // self.eta + top_k_configs = rung_table.nsmallest(K, columns=["perf"])["config"].tolist() + + # 2. Get the global incumbent, and build a prior distribution around it + inc = completed.loc[completed["perf"].idxmin()]["config"] + inc_dist = Prior.from_parameters(self.parameters, center_values=inc) + + # 3. Calculate a ratio score of how likely each of the top K configs are under + # the prior and inc distribution, weighing them by their position in the top K + weights = torch.arange(K, 0, -1) + + top_k_pdf_inc = inc_dist.pdf_configs(top_k_configs, frm=self.encoder) # type: ignore + top_k_pdf_prior = prior_dist.pdf_configs(top_k_configs, frm=self.encoder) # type: ignore + + unnormalized_inc_score = (weights * top_k_pdf_inc).sum() + unnormalized_prior_score = (weights * top_k_pdf_prior).sum() + total_score = unnormalized_inc_score + unnormalized_prior_score + + inc_ratio = float(unnormalized_inc_score / total_score) + prior_ratio = float(unnormalized_prior_score / total_score) + + # 4. And finally, we distribute the original w_prior according to this ratio + w_inc = w_prior * inc_ratio + w_prior = w_prior * prior_ratio + assert np.isclose(w_prior + w_inc + w_random, 1.0) + + # Now we use these weights to choose which sampling distribution to sample from + policy = np.random.choice( + ["prior", "inc", "random"], + p=[w_prior, w_inc, w_random], + ) + match policy: + case "prior": + return prior_dist.sample_config(to=self.encoder) + case "random": + _sampler = Sampler.uniform(ndim=self.encoder.ndim) + return _sampler.sample_config(to=self.encoder) + case "inc": + assert inc is not None + return mutate_config( + inc, + parameters=self.parameters, + mutation_rate=self.mutation_rate, + std=self.mutation_std, + seed=None, + ) + + raise RuntimeError(f"Unknown policy: {policy}") def mutate_config( @@ -68,144 +203,3 @@ def mutate_config( key: mutant[key] if select_mutant else config[key] for key, select_mutant in zip(mutant.keys(), mutatant_selection, strict=False) } - - -def sample_with_priorband( - *, - table: pd.DataFrame, - rung_to_sample_for: int, - # Search Space - parameters: Mapping[str, Parameter], - encoder: ConfigEncoder, - # Inc sampling params - inc_mutation_rate: float, - inc_mutation_std: float, - # SH parameters to calculate the rungs - eta: int, - early_stopping_rate: int = 0, - fid_bounds: tuple[int, int] | tuple[float, float], - # Extra - seed: torch.Generator | None = None, -) -> dict[str, Any]: - """Samples a configuration using the PriorBand algorithm. - - Args: - table: The table of all the trials that have been run. - rung_to_sample_for: The rung to sample for. - space: The search space to sample from. - encoder: The encoder to use for the search space. - inc_mutation_rate: The mutation rate for the incumbent. - inc_mutation_std: The standard deviation for the incumbent mutation rate. - eta: The eta parameter for the Successive Halving algorithm. - early_stopping_rate: The early stopping rate for the Successive Halving algorithm. - fid_bounds: The bounds for the fidelity parameter. - seed: The seed to use for the random number generator. - - Returns: - The sampled configuration. - """ - rung_to_fid, rung_sizes = brackets.calculate_sh_rungs( - bounds=fid_bounds, - eta=eta, - early_stopping_rate=early_stopping_rate, - ) - max_rung = max(rung_sizes) - - prior_dist = Prior.from_parameters(parameters) - - # Below we will follow the "geomtric" spacing - w_random = 1 / (1 + eta**rung_to_sample_for) - w_prior = 1 - w_random - - completed: pd.DataFrame = table[table["perf"].notna()] # type: ignore - - # To see if we activate incumbent sampling, we check: - # 1) We have at least one fully complete run - # 2) We have spent at least one full SH bracket worth of fidelity - # 3) There is at least one rung with eta evaluations to get the top 1/eta configs of - completed_rungs = completed.index.get_level_values("rung") - one_complete_run_at_max_rung = (completed_rungs == max_rung).any() - - # For SH bracket cost, we include the fact we can continue runs, - # i.e. resources for rung 2 discounts the cost of evaluating to rung 1, - # only counting the difference in fidelity cost between rung 2 and rung 1. - cost_per_rung = {i: rung_to_fid[i] - rung_to_fid.get(i - 1, 0) for i in rung_to_fid} - - cost_of_one_sh_bracket = sum(rung_sizes[r] * cost_per_rung[r] for r in rung_sizes) - current_cost_used = sum(r * cost_per_rung[r] for r in completed_rungs) - spent_one_sh_bracket_worth_of_fidelity = current_cost_used >= cost_of_one_sh_bracket - - # Check that there is at least rung with `eta` evaluations - rung_counts = completed.groupby("rung").size() - any_rung_with_eta_evals = (rung_counts == eta).any() - - # If the conditions are not met, we sample from the prior or randomly depending on - # the geometrically distributed prior and uniform weights - if ( - one_complete_run_at_max_rung is False - or spent_one_sh_bracket_worth_of_fidelity is False - or any_rung_with_eta_evals is False - ): - policy = np.random.choice(["prior", "random"], p=[w_prior, w_random]) - match policy: - case "prior": - config = prior_dist.sample_config(to=encoder) - case "random": - _sampler = Sampler.uniform(ndim=encoder.ndim) - config = _sampler.sample_config(to=encoder) - - return config - - # Otherwise, we now further split the `prior` weight into `(prior, inc)` - - # 1. Select the top `1//eta` percent of configs at the highest rung that supports it - rungs_with_at_least_eta = rung_counts[rung_counts >= eta].index # type: ignore - rung_table: pd.DataFrame = completed[ # type: ignore - completed.index.get_level_values("rung") == rungs_with_at_least_eta.max() - ] - - K = len(rung_table) // eta - top_k_configs = rung_table.nsmallest(K, columns=["perf"])["config"].tolist() - - # 2. Get the global incumbent, and build a prior distribution around it - inc = completed.loc[completed["perf"].idxmin()]["config"] - inc_dist = Prior.from_parameters(parameters, center_values=inc) - - # 3. Calculate a ratio score of how likely each of the top K configs are under - # the prior and inc distribution, weighing them by their position in the top K - weights = torch.arange(K, 0, -1) - - top_k_pdf_inc = inc_dist.pdf_configs(top_k_configs, frm=encoder) # type: ignore - top_k_pdf_prior = prior_dist.pdf_configs(top_k_configs, frm=encoder) # type: ignore - - unnormalized_inc_score = (weights * top_k_pdf_inc).sum() - unnormalized_prior_score = (weights * top_k_pdf_prior).sum() - total_score = unnormalized_inc_score + unnormalized_prior_score - - inc_ratio = float(unnormalized_inc_score / total_score) - prior_ratio = float(unnormalized_prior_score / total_score) - - # 4. And finally, we distribute the original w_prior according to this ratio - w_inc = w_prior * inc_ratio - w_prior = w_prior * prior_ratio - assert np.isclose(w_prior + w_inc + w_random, 1.0) - - # Now we use these weights to choose which sampling distribution to sample from - policy = np.random.choice(["prior", "inc", "random"], p=[w_prior, w_inc, w_random]) - match policy: - case "prior": - return prior_dist.sample_config(to=encoder) - case "random": - _sampler = Sampler.uniform(ndim=encoder.ndim) - return _sampler.sample_config(to=encoder) - case "inc": - assert inc is not None - return mutate_config( - inc, - parameters=parameters, - mutation_rate=inc_mutation_rate, - std=inc_mutation_std, - seed=seed, - ) - - raise RuntimeError(f"Unknown policy: {policy}") diff --git a/neps/space/encoding.py b/neps/space/encoding.py index e4298611..d58c63dc 100644 --- a/neps/space/encoding.py +++ b/neps/space/encoding.py @@ -189,7 +189,7 @@ def decode(self, x: torch.Tensor) -> list[Any]: # TODO: Maybe add a shift argument, could be useful to have `0` as midpoint # and `-0.5` as lower bound with `0.5` as upper bound. @dataclass -class MinMaxNormalizer(TensorTransformer[V], Generic[V]): +class MinMaxNormalizer(TensorTransformer[float], Generic[V]): """A transformer that normalizes values to the unit interval.""" original_domain: Domain[V] diff --git a/neps/state/neps_state.py b/neps/state/neps_state.py index 5e01e82e..0e684a42 100644 --- a/neps/state/neps_state.py +++ b/neps/state/neps_state.py @@ -17,13 +17,7 @@ from collections.abc import Iterable from dataclasses import dataclass, field from pathlib import Path -from typing import ( - TYPE_CHECKING, - Literal, - TypeAlias, - TypeVar, - overload, -) +from typing import TYPE_CHECKING, Literal, TypeAlias, TypeVar, overload from neps.env import ( GLOBAL_ERR_FILELOCK_POLL, @@ -35,7 +29,6 @@ TRIAL_FILELOCK_TIMEOUT, ) from neps.exceptions import NePSError, TrialAlreadyExistsError, TrialNotFoundError -from neps.optimizers import OptimizerInfo from neps.state.err_dump import ErrDump from neps.state.filebased import ( FileLocker, @@ -48,6 +41,7 @@ from neps.utils.files import atomic_write, deserialize, serialize if TYPE_CHECKING: + from neps.optimizers import OptimizerInfo from neps.optimizers.optimizer import AskFunction logger = logging.getLogger(__name__) @@ -456,7 +450,7 @@ def lock_and_get_errors(self) -> ErrDump: def lock_and_get_optimizer_info(self) -> OptimizerInfo: """Get the optimizer information.""" with self._optimizer_lock.lock(): - return OptimizerInfo(**deserialize(self._optimizer_info_path)) + return _deserialize_optimizer_info(self._optimizer_info_path) def lock_and_get_optimizer_state(self) -> OptimizationState: """Get the optimizer state.""" @@ -602,7 +596,7 @@ def create_or_load( # check the optimizer info. If this assumption changes, then we would have # to first lock before we do this check if not is_new: - existing_info = OptimizerInfo(**deserialize(optimizer_info_path)) + existing_info = _deserialize_optimizer_info(optimizer_info_path) if not load_only and existing_info != optimizer_info: raise NePSError( "The optimizer info on disk does not match the one provided." @@ -651,3 +645,28 @@ def create_or_load( _shared_errors_path=shared_errors_path, _shared_errors=error_dump, ) + + +def _deserialize_optimizer_info(path: Path) -> OptimizerInfo: + from neps.optimizers import OptimizerInfo # Fighting circular import + + deserialized = deserialize(path) + if "name" not in deserialized or "info" not in deserialized: + raise NePSError( + f"Invalid optimizer info deserialized from" + f" {path}. Did not find" + " keys 'name' and 'info'." + ) + name = deserialized["name"] + info = deserialized["info"] + if not isinstance(name, str): + raise NePSError( + f"Invalid optimizer name '{name}' deserialized from {path}. Expected a `str`." + ) + + if not isinstance(info, dict | None): + raise NePSError( + f"Invalid optimizer info '{info}' deserialized from" + f" {path}. Expected a `dict` or `None`." + ) + return OptimizerInfo(name=name, info=info or {}) diff --git a/neps/utils/cli.py b/neps/utils/cli.py index ee36b9f0..cd58b461 100644 --- a/neps/utils/cli.py +++ b/neps/utils/cli.py @@ -23,10 +23,7 @@ from neps.state.trial import Trial from neps.exceptions import TrialNotFoundError from neps.optimizers import load_optimizer -from neps.state.optimizer import BudgetInfo, OptimizationState, OptimizerInfo - -# Suppress specific warnings -warnings.filterwarnings("ignore", category=UserWarning, module="torch._utils") +from neps.state.optimizer import BudgetInfo, OptimizationState def validate_directory(path: Path) -> bool: @@ -118,7 +115,7 @@ def init_config(args: argparse.Namespace) -> None: is_new = not directory.exists() _ = NePSState.create_or_load( path=directory, - optimizer_info=OptimizerInfo(optimizer_info), + optimizer_info=optimizer_info, optimizer_state=OptimizationState( seed_snapshot=SeedSnapshot.new_capture(), budget=( diff --git a/neps/utils/common.py b/neps/utils/common.py index 7abfe448..02b98865 100644 --- a/neps/utils/common.py +++ b/neps/utils/common.py @@ -7,6 +7,7 @@ import inspect import os import sys +import warnings from collections.abc import Callable, Iterator from contextlib import contextmanager from functools import partial @@ -259,6 +260,15 @@ def gc_disabled() -> Iterator[None]: gc.enable() +@contextmanager +def disable_warnings(*warning_types: type[Warning]) -> Iterator[None]: + """Disable certain warning categories for a specific block.""" + with warnings.catch_warnings(): + for warning_type in warning_types: + warnings.filterwarnings("ignore", category=warning_type) + yield + + def dynamic_load_object(path: str, object_name: str) -> object: """Dynamically loads an object from a given module file path. diff --git a/tests/test_runtime/test_stopping_criterion.py b/tests/test_runtime/test_stopping_criterion.py index 8a526b55..08fc3dbf 100644 --- a/tests/test_runtime/test_stopping_criterion.py +++ b/tests/test_runtime/test_stopping_criterion.py @@ -6,6 +6,7 @@ from pytest_cases import fixture from neps.optimizers.algorithms import random_search +from neps.optimizers.optimizer import OptimizerInfo from neps.runtime import DefaultWorker from neps.space import Float, SearchSpace from neps.state import ( @@ -13,7 +14,6 @@ NePSState, OnErrorPossibilities, OptimizationState, - OptimizerInfo, SeedSnapshot, Trial, WorkerSettings, @@ -24,7 +24,7 @@ def neps_state(tmp_path: Path) -> NePSState: return NePSState.create_or_load( path=tmp_path / "neps_state", - optimizer_info=OptimizerInfo(info={"nothing": "here"}), + optimizer_info=OptimizerInfo(name="blah", info={"nothing": "here"}), optimizer_state=OptimizationState( budget=None, seed_snapshot=SeedSnapshot.new_capture(),