Skip to content

Commit

Permalink
Update initialize_q_batch methods to return both candidates and the…
Browse files Browse the repository at this point in the history
… corresponding acquisition values (#2571)

Summary:
As titled. This avoids the need to re-compute the acquisition values after sub-selecting in cases where they are needed.


Differential Revision: D64333367

Pulled By: saitcakmak
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Oct 14, 2024
1 parent 85d8996 commit c292356
Show file tree
Hide file tree
Showing 4 changed files with 421 additions and 366 deletions.
4 changes: 2 additions & 2 deletions botorch/acquisition/multi_step_lookahead.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,8 +656,8 @@ def mixin_tree(T: Tensor, bounds: Tensor, alpha: float) -> Tensor:
)

with torch.no_grad():
Y_full = acq_function(X_full)
X_init = initialize_q_batch(X=X_full, Y=Y_full, n=num_restarts, eta=1.0)
acq_vals = acq_function(X_full)
X_init, _ = initialize_q_batch(X=X_full, acq_vals=acq_vals, n=num_restarts, eta=1.0)
return X_init[:raw_samples]


Expand Down
100 changes: 58 additions & 42 deletions botorch/optim/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,8 @@ def gen_batch_initial_conditions(
],
dim=0,
)
X_rnd = fix_features(X_rnd, fixed_features=fixed_features)
# Keep X on CPU for consistency & to limit GPU memory usage.
X_rnd = fix_features(X_rnd, fixed_features=fixed_features).cpu()
if fixed_X_fantasies is not None:
if (d_f := fixed_X_fantasies.shape[-1]) != (d_r := X_rnd.shape[-1]):
raise BotorchTensorDimensionError(
Expand All @@ -415,16 +416,17 @@ def gen_batch_initial_conditions(
batch_limit = X_rnd.shape[0]
# Evaluate the acquisition function on `X_rnd` using `batch_limit`
# sized chunks.
Y_rnd = torch.cat(
acq_vals = torch.cat(
[
acq_function(x_.to(device=device)).cpu()
for x_ in X_rnd.split(split_size=batch_limit, dim=0)
],
dim=0,
)
batch_initial_conditions = init_func(
X=X_rnd, Y=Y_rnd, n=num_restarts, **init_kwargs
).to(device=device)
batch_initial_conditions, _ = init_func(
X=X_rnd, acq_vals=acq_vals, n=num_restarts, **init_kwargs
)
batch_initial_conditions = batch_initial_conditions.to(device=device)
if not any(issubclass(w.category, BadInitialCandidatesWarning) for w in ws):
return batch_initial_conditions
if factor < max_factor:
Expand Down Expand Up @@ -884,20 +886,24 @@ def gen_value_function_initial_conditions(

# evaluate the raw samples
with torch.no_grad():
Y_rnd = acq_function(X_rnd)
acq_vals = acq_function(X_rnd)

# select the restart points using the heuristic
return initialize_q_batch(
X=X_rnd, Y=Y_rnd, n=num_restarts, eta=options.get("eta", 2.0)
X_init, _ = initialize_q_batch(
X=X_rnd, acq_vals=acq_vals, n=num_restarts, eta=options.get("eta", 2.0)
)
return X_init


def initialize_q_batch(X: Tensor, Y: Tensor, n: int, eta: float = 1.0) -> Tensor:
def initialize_q_batch(
X: Tensor, acq_vals: Tensor, n: int, eta: float = 1.0
) -> tuple[Tensor, Tensor]:
r"""Heuristic for selecting initial conditions for candidate generation.
This heuristic selects points from `X` (without replacement) with probability
proportional to `exp(eta * Z)`, where `Z = (Y - mean(Y)) / std(Y)` and `eta`
is a temperature parameter.
proportional to `exp(eta * Z)`, where
`Z = (acq_vals - mean(acq_vals)) / std(ac_vals)`
and `eta`is a temperature parameter.
When using an acquisiton function that is non-negative and possibly zero
over large areas of the feature space (e.g. qEI), you should use
Expand All @@ -907,22 +913,23 @@ def initialize_q_batch(X: Tensor, Y: Tensor, n: int, eta: float = 1.0) -> Tensor
X: A `b x batch_shape x q x d` tensor of `b` - `batch_shape` samples of
`q`-batches from a d`-dim feature space. Typically, these are generated
using qMC sampling.
Y: A tensor of `b x batch_shape` outcomes associated with the samples.
acq_vals: A tensor of `b x batch_shape` outcomes associated with the samples.
Typically, this is the value of the batch acquisition function to be
maximized.
n: The number of initial condition to be generated. Must be less than `b`.
eta: Temperature parameter for weighting samples.
Returns:
A `n x batch_shape x q x d` tensor of `n` - `batch_shape` `q`-batch initial
conditions, where each batch of `n x q x d` samples is selected independently.
- An `n x batch_shape x q x d` tensor of `n` - `batch_shape` `q`-batch initial
conditions, where each batch of `n x q x d` samples is selected independently.
- An `n x batch_shape` tensor of the corresponding acquisition values.
Example:
>>> # To get `n=10` starting points of q-batch size `q=3`
>>> # for model with `d=6`:
>>> qUCB = qUpperConfidenceBound(model, beta=0.1)
>>> Xrnd = torch.rand(500, 3, 6)
>>> Xinit = initialize_q_batch(Xrnd, qUCB(Xrnd), 10)
>>> X_rnd = torch.rand(500, 3, 6)
>>> X_init, acq_init = initialize_q_batch(X=X_rnd, acq_vals=qUCB(X_rnd), n=10)
"""
n_samples = X.shape[0]
batch_shape = X.shape[1:-2] or torch.Size()
Expand All @@ -932,20 +939,21 @@ def initialize_q_batch(X: Tensor, Y: Tensor, n: int, eta: float = 1.0) -> Tensor
f"provided samples ({n_samples})"
)
elif n == n_samples:
return X
return X, acq_vals

Ystd = Y.std(dim=0)
Ystd = acq_vals.std(dim=0)
if torch.any(Ystd == 0):
warnings.warn(
"All acquisition values for raw samples points are the same for "
"at least one batch. Choosing initial conditions at random.",
BadInitialCandidatesWarning,
stacklevel=3,
)
return X[torch.randperm(n=n_samples, device=X.device)][:n]
idcs = torch.randperm(n=n_samples, device=X.device)[:n]
return X[idcs], acq_vals[idcs]

max_val, max_idx = torch.max(Y, dim=0)
Z = (Y - Y.mean(dim=0)) / Ystd
max_val, max_idx = torch.max(acq_vals, dim=0)
Z = (acq_vals - acq_vals.mean(dim=0)) / Ystd
etaZ = eta * Z
weights = torch.exp(etaZ)
while torch.isinf(weights).any():
Expand All @@ -961,28 +969,30 @@ def initialize_q_batch(X: Tensor, Y: Tensor, n: int, eta: float = 1.0) -> Tensor
if max_idx not in idcs:
idcs[-1] = max_idx
if batch_shape == torch.Size():
return X[idcs]
return X[idcs], acq_vals[idcs]
else:
return X.gather(
X_select = X.gather(
dim=0, index=idcs.view(*idcs.shape, 1, 1).expand(n, *X.shape[1:])
)
acq_select = acq_vals.gather(dim=0, index=idcs)
return X_select, acq_select


def initialize_q_batch_nonneg(
X: Tensor, Y: Tensor, n: int, eta: float = 1.0, alpha: float = 1e-4
) -> Tensor:
X: Tensor, acq_vals: Tensor, n: int, eta: float = 1.0, alpha: float = 1e-4
) -> tuple[Tensor, Tensor]:
r"""Heuristic for selecting initial conditions for non-neg. acquisition functions.
This function is similar to `initialize_q_batch`, but designed specifically
for acquisition functions that are non-negative and possibly zero over
large areas of the feature space (e.g. qEI). All samples for which
`Y < alpha * max(Y)` will be ignored (assuming that `Y` contains at least
one positive value).
`acq_vals < alpha * max(acq_vals)` will be ignored (assuming that `acq_vals`
contains at least one positive value).
Args:
X: A `b x q x d` tensor of `b` samples of `q`-batches from a `d`-dim.
feature space. Typically, these are generated using qMC.
Y: A tensor of `b` outcomes associated with the samples. Typically, this
acq_vals: A tensor of `b` outcomes associated with the samples. Typically, this
is the value of the batch acquisition function to be maximized.
n: The number of initial condition to be generated. Must be less than `b`.
eta: Temperature parameter for weighting samples.
Expand All @@ -991,54 +1001,60 @@ def initialize_q_batch_nonneg(
`Y < alpha * max(Y)` will be ignored.
Returns:
A `n x q x d` tensor of `n` `q`-batch initial conditions.
- An `n x q x d` tensor of `n` `q`-batch initial conditions.
- An `n` tensor of the corresponding acquisition values.
Example:
>>> # To get `n=10` starting points of q-batch size `q=3`
>>> # for model with `d=6`:
>>> qEI = qExpectedImprovement(model, best_f=0.2)
>>> Xrnd = torch.rand(500, 3, 6)
>>> Xinit = initialize_q_batch(Xrnd, qEI(Xrnd), 10)
>>> X_rnd = torch.rand(500, 3, 6)
>>> X_init, acq_init = initialize_q_batch_nonneg(
... X=X_rnd, acq_vals=qEI(X_rnd), n=10
... )
"""
n_samples = X.shape[0]
if n > n_samples:
raise RuntimeError("n cannot be larger than the number of provided samples")
elif n == n_samples:
return X
return X, acq_vals

max_val, max_idx = torch.max(Y, dim=0)
max_val, max_idx = torch.max(acq_vals, dim=0)
if torch.any(max_val <= 0):
warnings.warn(
"All acquisition values for raw sampled points are nonpositive, so "
"initial conditions are being selected randomly.",
BadInitialCandidatesWarning,
stacklevel=3,
)
return X[torch.randperm(n=n_samples, device=X.device)][:n]
idcs = torch.randperm(n=n_samples, device=X.device)[:n]
return X[idcs], acq_vals[idcs]

# make sure there are at least `n` points with positive acquisition values
pos = Y > 0
pos = acq_vals > 0
num_pos = pos.sum().item()
if num_pos < n:
# select all positive points and then fill remaining quota with randomly
# selected points
remaining_indices = (~pos).nonzero(as_tuple=False).view(-1)
rand_indices = torch.randperm(remaining_indices.shape[0], device=Y.device)
rand_indices = torch.randperm(
remaining_indices.shape[0], device=acq_vals.device
)
sampled_remaining_indices = remaining_indices[rand_indices[: n - num_pos]]
pos[sampled_remaining_indices] = 1
return X[pos]
return X[pos], acq_vals[pos]
# select points within alpha of max_val, iteratively decreasing alpha by a
# factor of 10 as necessary
alpha_pos = Y >= alpha * max_val
alpha_pos = acq_vals >= alpha * max_val
while alpha_pos.sum() < n:
alpha = 0.1 * alpha
alpha_pos = Y >= alpha * max_val
alpha_pos_idcs = torch.arange(len(Y), device=Y.device)[alpha_pos]
weights = torch.exp(eta * (Y[alpha_pos] / max_val - 1))
alpha_pos = acq_vals >= alpha * max_val
alpha_pos_idcs = torch.arange(len(acq_vals), device=acq_vals.device)[alpha_pos]
weights = torch.exp(eta * (acq_vals[alpha_pos] / max_val - 1))
idcs = alpha_pos_idcs[torch.multinomial(weights, n)]
if max_idx not in idcs:
idcs[-1] = max_idx
return X[idcs]
return X[idcs], acq_vals[idcs]


def sample_points_around_best(
Expand Down
82 changes: 44 additions & 38 deletions test/optim/test_initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,40 +89,42 @@ def test_initialize_q_batch_nonneg(self):
for dtype in (torch.float, torch.double):
# basic test
X = torch.rand(5, 3, 4, device=self.device, dtype=dtype)
Y = torch.rand(5, device=self.device, dtype=dtype)
ics = initialize_q_batch_nonneg(X=X, Y=Y, n=2)
self.assertEqual(ics.shape, torch.Size([2, 3, 4]))
self.assertEqual(ics.device, X.device)
self.assertEqual(ics.dtype, X.dtype)
acq_vals = torch.rand(5, device=self.device, dtype=dtype)
ics_X, ics_acq_vals = initialize_q_batch_nonneg(X=X, acq_vals=acq_vals, n=2)
self.assertEqual(ics_X.shape, torch.Size([2, 3, 4]))
self.assertEqual(ics_X.device, X.device)
self.assertEqual(ics_X.dtype, X.dtype)
self.assertEqual(ics_acq_vals.shape, torch.Size([2]))
self.assertEqual(ics_acq_vals.device, acq_vals.device)
self.assertEqual(ics_acq_vals.dtype, acq_vals.dtype)
# ensure nothing happens if we want all samples
ics = initialize_q_batch_nonneg(X=X, Y=Y, n=5)
self.assertTrue(torch.equal(X, ics))
ics_X, ics_acq_vals = initialize_q_batch_nonneg(X=X, acq_vals=acq_vals, n=5)
self.assertTrue(torch.equal(X, ics_X))
self.assertTrue(torch.equal(acq_vals, ics_acq_vals))
# make sure things work with constant inputs
Y = torch.ones(5, device=self.device, dtype=dtype)
ics = initialize_q_batch_nonneg(X=X, Y=Y, n=2)
acq_vals = torch.ones(5, device=self.device, dtype=dtype)
ics, _ = initialize_q_batch_nonneg(X=X, acq_vals=acq_vals, n=2)
self.assertEqual(ics.shape, torch.Size([2, 3, 4]))
self.assertEqual(ics.device, X.device)
self.assertEqual(ics.dtype, X.dtype)
# ensure raises correct warning
Y = torch.zeros(5, device=self.device, dtype=dtype)
acq_vals = torch.zeros(5, device=self.device, dtype=dtype)
with warnings.catch_warnings(record=True) as w, settings.debug(True):
ics = initialize_q_batch_nonneg(X=X, Y=Y, n=2)
self.assertEqual(len(w), 1)
self.assertTrue(issubclass(w[-1].category, BadInitialCandidatesWarning))
ics, _ = initialize_q_batch_nonneg(X=X, acq_vals=acq_vals, n=2)
self.assertEqual(len(w), 1)
self.assertTrue(issubclass(w[-1].category, BadInitialCandidatesWarning))
self.assertEqual(ics.shape, torch.Size([2, 3, 4]))
with self.assertRaises(RuntimeError):
initialize_q_batch_nonneg(X=X, Y=Y, n=10)
initialize_q_batch_nonneg(X=X, acq_vals=acq_vals, n=10)
# test less than `n` positive acquisition values
Y = torch.arange(5, device=self.device, dtype=dtype) - 3
ics = initialize_q_batch_nonneg(X=X, Y=Y, n=2)
self.assertEqual(ics.shape, torch.Size([2, 3, 4]))
self.assertEqual(ics.device, X.device)
self.assertEqual(ics.dtype, X.dtype)
acq_vals = torch.arange(5, device=self.device, dtype=dtype) - 3
ics_X, ics_acq_vals = initialize_q_batch_nonneg(X=X, acq_vals=acq_vals, n=2)
self.assertEqual(ics_X.shape, torch.Size([2, 3, 4]))
# check that we chose the point with the positive acquisition value
self.assertTrue(torch.equal(ics[0], X[-1]) or torch.equal(ics[1], X[-1]))
self.assertTrue((ics_acq_vals > 0).any())
# test less than `n` alpha_pos values
Y = torch.arange(5, device=self.device, dtype=dtype)
ics = initialize_q_batch_nonneg(X=X, Y=Y, n=2, alpha=1.0)
acq_vals = torch.arange(5, device=self.device, dtype=dtype)
ics, _ = initialize_q_batch_nonneg(X=X, acq_vals=acq_vals, n=2, alpha=1.0)
self.assertEqual(ics.shape, torch.Size([2, 3, 4]))
self.assertEqual(ics.device, X.device)
self.assertEqual(ics.dtype, X.dtype)
Expand All @@ -132,32 +134,36 @@ def test_initialize_q_batch(self):
for batch_shape in (torch.Size(), [3, 2], (2,), torch.Size([2, 3, 4]), []):
# basic test
X = torch.rand(5, *batch_shape, 3, 4, device=self.device, dtype=dtype)
Y = torch.rand(5, *batch_shape, device=self.device, dtype=dtype)
ics = initialize_q_batch(X=X, Y=Y, n=2)
self.assertEqual(ics.shape, torch.Size([2, *batch_shape, 3, 4]))
self.assertEqual(ics.device, X.device)
self.assertEqual(ics.dtype, X.dtype)
acq_vals = torch.rand(5, *batch_shape, device=self.device, dtype=dtype)
ics_X, ics_acq_vals = initialize_q_batch(X=X, acq_vals=acq_vals, n=2)
self.assertEqual(ics_X.shape, torch.Size([2, *batch_shape, 3, 4]))
self.assertEqual(ics_X.device, X.device)
self.assertEqual(ics_X.dtype, X.dtype)
self.assertEqual(ics_acq_vals.shape, torch.Size([2, *batch_shape]))
self.assertEqual(ics_acq_vals.device, acq_vals.device)
self.assertEqual(ics_acq_vals.dtype, acq_vals.dtype)
# ensure nothing happens if we want all samples
ics = initialize_q_batch(X=X, Y=Y, n=5)
self.assertTrue(torch.equal(X, ics))
ics_X, ics_acq_vals = initialize_q_batch(X=X, acq_vals=acq_vals, n=5)
self.assertTrue(torch.equal(X, ics_X))
self.assertTrue(torch.equal(acq_vals, ics_acq_vals))
# ensure raises correct warning
Y = torch.zeros(5, device=self.device, dtype=dtype)
acq_vals = torch.zeros(5, device=self.device, dtype=dtype)
with warnings.catch_warnings(record=True) as w, settings.debug(True):
ics = initialize_q_batch(X=X, Y=Y, n=2)
self.assertEqual(len(w), 1)
self.assertTrue(
issubclass(w[-1].category, BadInitialCandidatesWarning)
)
ics, _ = initialize_q_batch(X=X, acq_vals=acq_vals, n=2)
self.assertEqual(len(w), 1)
self.assertTrue(issubclass(w[-1].category, BadInitialCandidatesWarning))
self.assertEqual(ics.shape, torch.Size([2, *batch_shape, 3, 4]))
with self.assertRaises(RuntimeError):
initialize_q_batch(X=X, Y=Y, n=10)
initialize_q_batch(X=X, acq_vals=acq_vals, n=10)

def test_initialize_q_batch_largeZ(self):
for dtype in (torch.float, torch.double):
# testing large eta*Z
X = torch.rand(5, 3, 4, device=self.device, dtype=dtype)
Y = torch.tensor([-1e12, 0, 0, 0, 1e12], device=self.device, dtype=dtype)
ics = initialize_q_batch(X=X, Y=Y, n=2, eta=100)
acq_vals = torch.tensor(
[-1e12, 0, 0, 0, 1e12], device=self.device, dtype=dtype
)
ics, _ = initialize_q_batch(X=X, acq_vals=acq_vals, n=2, eta=100)
self.assertEqual(ics.shape[0], 2)


Expand Down
Loading

0 comments on commit c292356

Please sign in to comment.