Skip to content

Commit

Permalink
Add PosteriorTransform to get_optimal_samples and optimize_posterior_…
Browse files Browse the repository at this point in the history
…samples (pytorch#2576)

Summary:

Added `posterior_transform` arg to get_optimal_samples to enable posterior sampling-based (xES, TestSet IG) acquisition functions with minimization problems. Intended use in one-shot settings.

Reviewed By: saitcakmak

Differential Revision: D64266499
  • Loading branch information
Carl Hvarfner authored and facebook-github-bot committed Oct 22, 2024
1 parent 24f659c commit 30e19a8
Show file tree
Hide file tree
Showing 7 changed files with 274 additions and 76 deletions.
3 changes: 0 additions & 3 deletions botorch/acquisition/input_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1800,7 +1800,6 @@ def construct_inputs_qJES(
model: Model,
bounds: list[tuple[float, float]],
num_optima: int = 64,
maximize: bool = True,
condition_noiseless: bool = True,
X_pending: Tensor | None = None,
estimation_type: str = "LB",
Expand All @@ -1811,15 +1810,13 @@ def construct_inputs_qJES(
model=model,
bounds=torch.as_tensor(bounds, dtype=dtype).T,
num_optima=num_optima,
maximize=maximize,
)

inputs = {
"model": model,
"optimal_inputs": optimal_inputs,
"optimal_outputs": optimal_outputs,
"condition_noiseless": condition_noiseless,
"maximize": maximize,
"X_pending": X_pending,
"estimation_type": estimation_type,
"num_samples": num_samples,
Expand Down
56 changes: 42 additions & 14 deletions botorch/acquisition/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
IdentityMCObjective,
MCAcquisitionObjective,
PosteriorTransform,
ScalarizedPosteriorTransform,
)
from botorch.exceptions.errors import (
BotorchTensorDimensionError,
Expand All @@ -28,10 +29,11 @@
from botorch.models.model import Model
from botorch.sampling.base import MCSampler
from botorch.sampling.get_sampler import get_sampler
from botorch.sampling.pathwise import draw_matheron_paths
from botorch.sampling.pathwise.posterior_samplers import get_matheron_path_model
from botorch.utils.objective import compute_feasibility_indicator
from botorch.utils.sampling import optimize_posterior_samples
from botorch.utils.transforms import is_ensemble, normalize_indices
from gpytorch.models import GP
from torch import Tensor


Expand Down Expand Up @@ -486,36 +488,62 @@ def project_to_sample_points(X: Tensor, sample_points: Tensor) -> Tensor:


def get_optimal_samples(
model: Model,
model: GP,
bounds: Tensor,
num_optima: int,
raw_samples: int = 1024,
num_restarts: int = 20,
maximize: bool = True,
posterior_transform: ScalarizedPosteriorTransform | None = None,
objective: MCAcquisitionObjective | None = None,
return_transformed: bool = False,
) -> tuple[Tensor, Tensor]:
"""Draws sample paths from the posterior and maximizes the samples using GD.
Args:
model (Model): The model from which samples are drawn.
bounds: (Tensor): Bounds of the search space. If the model inputs are
model: The model from which samples are drawn.
bounds: Bounds of the search space. If the model inputs are
normalized, the bounds should be normalized as well.
num_optima (int): The number of paths to be drawn and optimized.
raw_samples (int, optional): The number of candidates randomly sample.
num_optima: The number of paths to be drawn and optimized.
raw_samples: The number of candidates randomly sample.
Defaults to 1024.
num_restarts (int, optional): The number of candidates to do gradient-based
num_restarts: The number of candidates to do gradient-based
optimization on. Defaults to 20.
maximize: Whether to maximize or minimize the samples.
posterior_transform: A ScalarizedPosteriorTransform (may e.g. be used to
scalarize multi-output models or negate the objective).
objective: An MCAcquisitionObjective, used to negate the objective or otherwise
transform sample outputs. Cannot be combined with `posterior_transform`.
return_transformed: If True, return the transformed samples.
Returns:
Tuple[Tensor, Tensor]: The optimal input locations and corresponding
outputs, x* and f*.
The optimal input locations and corresponding outputs, x* and f*.
"""
paths = draw_matheron_paths(model, sample_shape=torch.Size([num_optima]))
if posterior_transform and not isinstance(
posterior_transform, ScalarizedPosteriorTransform
):
raise ValueError(
"Only the ScalarizedPosteriorTransform is supported for "
"get_optimal_samples."
)
if posterior_transform and objective:
raise ValueError(
"Only one of `posterior_transform` and `objective` can be specified."
)

if posterior_transform:
sample_transform = posterior_transform.evaluate
elif objective:
sample_transform = objective
else:
sample_transform = None

paths = get_matheron_path_model(model=model, sample_shape=torch.Size([num_optima]))
optimal_inputs, optimal_outputs = optimize_posterior_samples(
paths,
paths=paths,
bounds=bounds,
raw_samples=raw_samples,
num_restarts=num_restarts,
maximize=maximize,
sample_transform=sample_transform,
return_transformed=return_transformed,
)
return optimal_inputs, optimal_outputs
68 changes: 39 additions & 29 deletions botorch/utils/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
import warnings

from abc import ABC, abstractmethod
from collections.abc import Generator, Iterable
from collections.abc import Callable, Generator, Iterable
from contextlib import contextmanager
from typing import Any, TYPE_CHECKING
from typing import TYPE_CHECKING

import numpy as np
import numpy.typing as npt
Expand All @@ -37,7 +37,9 @@


if TYPE_CHECKING:
from botorch.sampling.pathwise.paths import SamplePath # pragma: no cover
from botorch.models.deterministic import ( # pragma: no cover
GenericDeterministicModel,
)


@contextmanager
Expand Down Expand Up @@ -989,68 +991,76 @@ def sparse_to_dense_constraints(


def optimize_posterior_samples(
paths: SamplePath,
paths: GenericDeterministicModel,
bounds: Tensor,
candidates: Tensor | None = None,
raw_samples: int | None = 1024,
raw_samples: int = 1024,
num_restarts: int = 20,
maximize: bool = True,
**kwargs: Any,
sample_transform: Callable[[Tensor], Tensor] | None = None,
return_transformed: bool = False,
) -> tuple[Tensor, Tensor]:
r"""Cheaply maximizes posterior samples by random querying followed by vanilla
gradient descent on the best num_restarts points.
r"""Cheaply maximizes posterior samples by random querying followed by
gradient-based optimization using SciPy's L-BFGS-B routine.
Args:
paths: Random Fourier Feature-based sample paths from the GP
bounds: The bounds on the search space.
candidates: A priori good candidates (typically previous design points)
which acts as extra initial guesses for the optimization routine.
raw_samples: The number of samples with which to query the samples initially.
num_restarts: The number of points selected for gradient-based optimization.
maximize: Boolean indicating whether to maimize or minimize
sample_transform: A callable transform of the sample outputs (e.g.
MCAcquisitionObjective or ScalarizedPosteriorTransform.evaluate) used to
negate the objective or otherwise transform the output.
return_transformed: A boolean indicating whether to return the transformed
or non-transformed samples.
Returns:
A two-element tuple containing:
- X_opt: A `num_optima x [batch_size] x d`-dim tensor of optimal inputs x*.
- f_opt: A `num_optima x [batch_size] x 1`-dim tensor of optimal outputs f*.
- f_opt: A `num_optima x [batch_size] x m`-dim, optionally
`num_optima x [batch_size] x 1`-dim, tensor of optimal outputs f*.
"""
if maximize:

def path_func(x):
return paths(x)

else:
def path_func(x) -> Tensor:
res = paths(x)
if sample_transform:
res = sample_transform(res)

def path_func(x):
return -paths(x)
return res.squeeze(-1)

candidate_set = unnormalize(
SobolEngine(dimension=bounds.shape[1], scramble=True).draw(raw_samples), bounds
SobolEngine(dimension=bounds.shape[1], scramble=True).draw(n=raw_samples),
bounds=bounds,
)

# queries all samples on all candidates - output shape
# raw_samples * num_optima * num_models
candidate_queries = path_func(candidate_set)
argtop_k = torch.topk(candidate_queries, num_restarts, dim=-1).indices
X_top_k = candidate_set[argtop_k, :]

# to avoid circular import, the import occurs here
from botorch.generation.gen import gen_candidates_torch
from botorch.generation.gen import gen_candidates_scipy

X_top_k, f_top_k = gen_candidates_torch(
X_top_k, path_func, lower_bounds=bounds[0], upper_bounds=bounds[1], **kwargs
X_top_k, f_top_k = gen_candidates_scipy(
X_top_k,
path_func,
lower_bounds=bounds[0],
upper_bounds=bounds[1],
)
f_opt, arg_opt = f_top_k.max(dim=-1, keepdim=True)

# For each sample (and possibly for every model in the batch of models), this
# retrieves the argmax. We flatten, pick out the indices and then reshape to
# the original batch shapes (so instead of pickig out the argmax of a
# (3, 7, num_restarts, D)) along the num_restarts dim, we pick it out of a
# (21 , num_restarts, D)
# (21, num_restarts, D)
final_shape = candidate_queries.shape[:-1]
X_opt = X_top_k.reshape(final_shape.numel(), num_restarts, -1)[
torch.arange(final_shape.numel()), arg_opt.flatten()
].reshape(*final_shape, -1)
if not maximize:
f_opt = -f_opt

# if we return transformed, we do not need to pass the samples through paths
# paths a second time but rather just return the transformed optimal values
if return_transformed:
return X_opt, f_opt

f_opt = paths(X_opt.unsqueeze(-2)).squeeze(-2)
return X_opt, f_opt
1 change: 0 additions & 1 deletion botorch_community/acquisition/input_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ def construct_inputs_SCoreBO(
model=model,
bounds=torch.as_tensor(bounds, dtype=dtype).T,
num_optima=num_optima,
maximize=maximize,
)

inputs = {
Expand Down
2 changes: 0 additions & 2 deletions test/acquisition/test_input_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1620,10 +1620,8 @@ def test_construct_inputs_jes(self) -> None:
training_data=self.blockX_blockY,
bounds=self.bounds,
num_optima=17,
maximize=False,
)

self.assertFalse(kwargs["maximize"])
self.assertEqual(self.blockX_blockY[0].X.dtype, kwargs["optimal_inputs"].dtype)
self.assertEqual(len(kwargs["optimal_inputs"]), 17)
self.assertEqual(len(kwargs["optimal_outputs"]), 17)
Expand Down
Loading

0 comments on commit 30e19a8

Please sign in to comment.