Skip to content

Commit

Permalink
refactor: cleanups
Browse files Browse the repository at this point in the history
  • Loading branch information
eddiebergman committed Jan 27, 2025
1 parent eca62ce commit 9b03147
Show file tree
Hide file tree
Showing 12 changed files with 294 additions and 263 deletions.
21 changes: 4 additions & 17 deletions neps/optimizers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,21 @@
from __future__ import annotations

from collections.abc import Callable, Mapping
from typing import TYPE_CHECKING, Any, Concatenate, Literal, TypedDict
from typing import TYPE_CHECKING, Any, Concatenate, Literal

from neps.optimizers.algorithms import (
CustomOptimizer,
OptimizerChoice,
PredefinedOptimizers,
determine_optimizer_automatically,
)
from neps.optimizers.optimizer import AskFunction # noqa: TC001
from neps.optimizers.optimizer import AskFunction, OptimizerInfo
from neps.utils.common import extract_keyword_defaults

if TYPE_CHECKING:
from neps.space import SearchSpace


class OptimizerInfo(TypedDict):
"""Information about the optimizer."""

name: str
"""The name of the optimizer."""

info: Mapping[str, Any]
"""Additional information about the optimizer.
Usually this will be the keyword arguments used to initialize the optimizer.
"""


def _load_optimizer_from_string(
optimizer: OptimizerChoice | Literal["auto"],
space: SearchSpace,
Expand Down Expand Up @@ -97,9 +84,9 @@ def load_optimizer(

# Custom (already initialized) optimizer
case CustomOptimizer(initialized=True):
_optimizer = optimizer.optimizer
preinit_opt = optimizer.optimizer
info = OptimizerInfo(name=optimizer.name, info=optimizer.kwargs)
return _optimizer, info # type: ignore
return preinit_opt, info # type: ignore

case _:
raise ValueError(
Expand Down
29 changes: 20 additions & 9 deletions neps/optimizers/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from neps.optimizers.ifbo import IFBO
from neps.optimizers.models.ftpfn import FTPFNSurrogate
from neps.optimizers.optimizer import AskFunction # noqa: TC001
from neps.optimizers.priorband import PriorBandArgs
from neps.optimizers.priorband import PriorBandSampler
from neps.optimizers.random_search import RandomSearch
from neps.sampling import Prior, Sampler, Uniform
from neps.space.encoding import CategoricalToUnitNorm, ConfigEncoder
Expand Down Expand Up @@ -123,7 +123,7 @@ def _bracket_optimizer( # noqa: C901, PLR0912, PLR0915
*,
bracket_type: Literal["successive_halving", "hyperband", "asha", "async_hb"],
eta: int,
sampler: Literal["uniform", "prior", "priorband"] | PriorBandArgs | Sampler,
sampler: Literal["uniform", "prior", "priorband"] | PriorBandSampler | Sampler,
bayesian_optimization_kick_in_point: int | float | None,
sample_prior_first: bool | Literal["highest_fidelity"],
# NOTE: This is the only argument to get a default, since it
Expand Down Expand Up @@ -212,6 +212,7 @@ def _bracket_optimizer( # noqa: C901, PLR0912, PLR0915
brackets.Sync.create_repeating,
rung_sizes=rung_sizes,
)

case "hyperband":
assert early_stopping_rate is None
rung_to_fidelity, bracket_layouts = brackets.calculate_hb_bracket_layouts(
Expand All @@ -222,6 +223,7 @@ def _bracket_optimizer( # noqa: C901, PLR0912, PLR0915
brackets.Hyperband.create_repeating,
bracket_layouts=bracket_layouts,
)

case "asha":
assert early_stopping_rate is not None
rung_to_fidelity, _rung_sizes = brackets.calculate_sh_rungs(
Expand All @@ -234,6 +236,7 @@ def _bracket_optimizer( # noqa: C901, PLR0912, PLR0915
rungs=list(rung_to_fidelity),
eta=eta,
)

case "async_hb":
assert early_stopping_rate is None
rung_to_fidelity, bracket_layouts = brackets.calculate_hb_bracket_layouts(
Expand All @@ -252,22 +255,31 @@ def _bracket_optimizer( # noqa: C901, PLR0912, PLR0915

encoder = ConfigEncoder.from_parameters(parameters)

_sampler: Sampler | PriorBandArgs
_sampler: Sampler | PriorBandSampler
match sampler:
case "uniform":
_sampler = Sampler.uniform(ndim=encoder.ndim)
case "prior":
_sampler = Prior.from_parameters(parameters)
case "priorband":
_sampler = PriorBandArgs(mutation_rate=0.5, mutation_std=0.25)
case PriorBandArgs() | Sampler():
_sampler = PriorBandSampler(
parameters=parameters,
mutation_rate=0.5,
mutation_std=0.25,
encoder=encoder,
eta=eta,
early_stopping_rate=(
early_stopping_rate if early_stopping_rate is not None else 0
),
fid_bounds=(fidelity.lower, fidelity.upper),
)
case PriorBandSampler() | Sampler():
_sampler = sampler
case _:
raise ValueError(f"Unknown sampler: {sampler}")

# TODO: This should be lifted out of this function and have the caller
# pass in a `GPSampler`.
# TODO: Better name and parametrization of this if not going with above
gp_sampler: GPSampler | None
if bayesian_optimization_kick_in_point is not None:
if bayesian_optimization_kick_in_point <= 0:
Expand All @@ -276,8 +288,7 @@ def _bracket_optimizer( # noqa: C901, PLR0912, PLR0915
)

# TODO: Parametrize?
two_stage_batch_sample_size = 10
modelling_strategy = "joint"
two_stage_batch_sample_size = 100

gp_parameters = {**parameters, **pipeline_space.fidelities}
gp_sampler = GPSampler(
Expand All @@ -287,7 +298,6 @@ def _bracket_optimizer( # noqa: C901, PLR0912, PLR0915
threshold=bayesian_optimization_kick_in_point,
fidelity_name=fidelity_name,
fidelity_max=fidelity.upper,
modelling_strategy=modelling_strategy,
two_stage_batch_sample_size=two_stage_batch_sample_size,
device=device,
)
Expand Down Expand Up @@ -669,6 +679,7 @@ def asha(
sample_prior_first: Whether to sample the prior configuration first,
and if so, should it be at the highest fidelity.
"""

return _bracket_optimizer(
pipeline_space=space,
bracket_type="asha",
Expand Down
1 change: 1 addition & 0 deletions neps/optimizers/bayesian_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ def __call__(
costs=data.cost if self.cost_aware is not False else None,
cost_percentage_used=cost_percent,
costs_on_log_scale=self.cost_aware == "log",
hide_warnings=True,
)

configs = encoder.decode(candidates)
Expand Down
81 changes: 34 additions & 47 deletions neps/optimizers/bracket_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,18 @@
import torch
from botorch.acquisition.multi_objective.parego import qLogNoisyExpectedImprovement
from botorch.acquisition.objective import LinearMCObjective
from gpytorch.utils.warnings import NumericalWarning

from neps.optimizers.models.gp import (
encode_trials_for_gp,
fit_and_acquire_from_gp,
make_default_single_obj_gp,
)
from neps.optimizers.optimizer import SampledConfig
from neps.optimizers.priorband import PriorBandArgs, sample_with_priorband
from neps.optimizers.priorband import PriorBandSampler
from neps.optimizers.utils.brackets import PromoteAction, SampleAction
from neps.sampling.samplers import Sampler
from neps.utils.common import disable_warnings

if TYPE_CHECKING:
from gpytorch.models.approximate_gp import Any
Expand Down Expand Up @@ -108,13 +110,6 @@ class GPSampler:
fidelity_max: int | float
"""The maximum fidelity value."""

modelling_strategy: Literal["joint", "separate"]
"""The strategy for which training data to use for the GP model.
If set to `"joint"`, the GP model will be trained on all data,
across all fidelities jointly, where the fidelity is considered as a dimension.
"""

device: torch.device | None
"""The device to use for the GP optimization."""

Expand All @@ -136,29 +131,29 @@ def sample_config(
Please see parameter descriptions in the class docstring for more.
"""
assert budget_info is None, "cost-aware (using budget_info) not supported yet."
assert self.modelling_strategy == "joint", "Only joint strategy is supported now."
# fit the GP model using all trials, using fidelity as a dimension.
# Get to top 10 configurations for acquisition fixed at fidelity Z
# Switch those configurations to be at fidelity z_max and take the best.
# y_max for EI is taken to be the best value seen so far, across all fidelity

data, _ = encode_trials_for_gp(
trials,
self.parameters,
encoder=self.encoder,
device=self.device,
)
gp = make_default_single_obj_gp(x=data.x, y=data.y, encoder=self.encoder)
acqf = qLogNoisyExpectedImprovement(
model=gp,
X_baseline=data.x,
# Unfortunatly, there's no option to indicate that we minimize
# the AcqFunction so we need to do some kind of transformation.
# https://github.com/pytorch/botorch/issues/2316#issuecomment-2085964607
objective=LinearMCObjective(weights=torch.tensor([-1.0])),
X_pending=data.x_pending,
prune_baseline=True,
)

with disable_warnings(NumericalWarning):
acqf = qLogNoisyExpectedImprovement(
model=gp,
X_baseline=data.x,
# Unfortunatly, there's no option to indicate that we minimize
# the AcqFunction so we need to do some kind of transformation.
# https://github.com/pytorch/botorch/issues/2316#issuecomment-2085964607
objective=LinearMCObjective(weights=torch.tensor([-1.0])),
X_pending=data.x_pending,
prune_baseline=True,
)

# When it's max fidelity, we can just sample the best configuration we find,
# as we do not need to do the two step procedure.
Expand All @@ -181,6 +176,7 @@ def sample_config(
costs=None,
cost_percentage_used=None,
costs_on_log_scale=False,
hide_warnings=True,
)
assert len(candidates) == N

Expand All @@ -196,7 +192,6 @@ def sample_config(
# Next, we set those N configurations to be at the max fidelity
# Decode, set max fidelity, and encode again (TODO: Could do directly on tensors)
configs = self.encoder.decode(candidates)
print("configs", configs) # noqa: T201
fid_max_configs = [{**c, self.fidelity_name: self.fidelity_max} for c in configs]
encoded_fix_max_configs = self.encoder.encode(fid_max_configs)

Expand Down Expand Up @@ -237,7 +232,7 @@ class BracketOptimizer:
create_brackets: Callable[[pd.DataFrame], Sequence[Bracket] | Bracket]
"""A function that creates the brackets from the table of trials."""

sampler: Sampler | PriorBandArgs
sampler: Sampler | PriorBandSampler
"""The sampler used to generate new trials."""

gp_sampler: GPSampler | None
Expand Down Expand Up @@ -345,21 +340,24 @@ def __call__( # noqa: C901, PLR0912
)

# The bracket would like us to sample a new configuration for a rung
case SampleAction(rung=rung):
# and we have gp sampler which should kick in
case SampleAction(rung=rung) if (
self.gp_sampler is not None and self.gp_sampler.threshold_reached(trials)
):
# If we should used BO to sample once a threshold has been reached,
# do so. Otherwise we proceed to use the original sampler.
if self.gp_sampler is not None and self.gp_sampler.threshold_reached(
trials
):
target_fidelity = self.rung_to_fid[rung]
config = self.gp_sampler.sample_config(
trials,
budget_info=None, # TODO: budget_info not supported yet
target_fidelity=target_fidelity,
)
config.update(space.constants)
return SampledConfig(id=f"{nxt_id}_{rung}", config=config)
target_fidelity = self.rung_to_fid[rung]
config = self.gp_sampler.sample_config(
trials,
budget_info=None, # TODO: budget_info not supported yet
target_fidelity=target_fidelity,
)
config.update(space.constants)
return SampledConfig(id=f"{nxt_id}_{rung}", config=config)

# We need to sample for a new rung, with either no gp or it has
# not yet kicked in.
case SampleAction(rung=rung):
# Otherwise, we proceed with the original sampler
match self.sampler:
case Sampler():
Expand All @@ -371,19 +369,8 @@ def __call__( # noqa: C901, PLR0912
}
return SampledConfig(id=f"{nxt_id}_{rung}", config=config)

case PriorBandArgs():
config = sample_with_priorband(
table=table,
parameters=space.searchables,
rung_to_sample_for=rung,
fid_bounds=(self.fid_min, self.fid_max),
encoder=self.encoder,
inc_mutation_rate=self.sampler.mutation_rate,
inc_mutation_std=self.sampler.mutation_std,
eta=self.eta,
seed=None, # TODO
)

case PriorBandSampler():
config = self.sampler.sample_config(table, rung=rung)
config = {
**config,
**space.constants,
Expand Down
Loading

0 comments on commit 9b03147

Please sign in to comment.