Skip to content

Commit

Permalink
Fix shape error in optimize_acqf_cyclic (pytorch#1648)
Browse files Browse the repository at this point in the history
Summary:
## Motivation

Fixes pytorch#873

In the past, `optimize_acqf` implicitly needed 3d inputs when there are equality constraints or inequality constraints and fixed_features don't provide the trivial solution, even though it worked with 2d inputs (no b-batches) in other cases. `optimize_acqf_cyclic` passed it 2d inputs, which would not generally work. I initially considered changing `optimize_acqf_cyclic` to pass 3d inputs, but since I found another place where 2d inputs were used, I decided to change `optimize_acqf` so it works with 2d inputs instead.

This was not caught because the only usage of `optimize_acqf_cyclic` was in a test that mocked `optimize_acqf`, so `optimize_acqf_cyclic` was never actually run end-to-end. I changed the test for `optimize_acqf_cyclic` to be more end-to-end, at the cost of worse testing of some intermediate properties. We could keep both versions though.

[x] Better docstring documentation on input shapes
[x] Add a singleton leading b-dimension where initial conditions are 2d

Pull Request resolved: pytorch#1648

Test Plan:
[x] More end-to-end test of `optimize_acqf_cyclic` that doesn't stub in `optimize_acqf` (see above)
[x] more input validation and  unit tests for input validation
[x] Ran cases that now raise errors without the new error handling, to make sure they were erroring before
[x] Make `_make_linear_constraints` work with 2d inputs so that `optimize_acqf` also does (previously, optimize_acqf only worked in some cases)

Reviewed By: Balandat

Differential Revision: D42875942

Pulled By: esantorella

fbshipit-source-id: e3c650683a6b8d7c9e36fe1f14558db2854bab56
  • Loading branch information
esantorella authored and facebook-github-bot committed Feb 8, 2023
1 parent ffcad4a commit 2f2b7e2
Show file tree
Hide file tree
Showing 5 changed files with 197 additions and 98 deletions.
5 changes: 3 additions & 2 deletions botorch/generation/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ def gen_candidates_scipy(
using `scipy.optimize.minimize` via a numpy converter.
Args:
initial_conditions: Starting points for optimization.
initial_conditions: Starting points for optimization, with shape
(b) x q x d.
acquisition_function: Acquisition function to be used.
lower_bounds: Minimum values for each column of initial_conditions.
upper_bounds: Maximum values for each column of initial_conditions.
Expand Down Expand Up @@ -162,7 +163,7 @@ def gen_candidates_scipy(
X=initial_conditions, lower_bounds=lower_bounds, upper_bounds=upper_bounds
)
constraints = make_scipy_linear_constraints(
shapeX=clamped_candidates.shape,
shapeX=shapeX,
inequality_constraints=inequality_constraints,
equality_constraints=equality_constraints,
)
Expand Down
19 changes: 17 additions & 2 deletions botorch/optim/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,9 @@ def optimize_acqf(
Returns:
A two-element tuple containing
- a `(num_restarts) x q x d`-dim tensor of generated candidates.
- A tensor of generated candidates. The shape is
-- `q x d` if `return_best_only` is True (default)
-- `num_restarts x q x d` if `return_best_only` is False
- a tensor of associated acquisition values. If `sequential=False`,
this is a `(num_restarts)`-dim tensor of joint acquisition values
(with explicit restart dimension if `return_best_only=False`). If
Expand Down Expand Up @@ -158,6 +160,19 @@ def optimize_acqf(
"initial conditions for the case of nonlinear inequality constraints."
)

d = bounds.shape[1]
if initial_conditions_provided:
if batch_initial_conditions.ndim not in (2, 3):
raise ValueError(
"batch_initial_conditions must be 2-dimensional or 3-dimensional. "
f"Its shape is {batch_initial_conditions.shape}."
)
if batch_initial_conditions.shape[-1] != d:
raise ValueError(
f"batch_initial_conditions.shape[-1] must be {d}. The "
f"shape is {batch_initial_conditions.shape}."
)

# Sets initial condition generator ic_gen if initial conditions not provided
if not initial_conditions_provided:
ic_gen = kwargs.pop("ic_generator", None)
Expand Down Expand Up @@ -298,7 +313,7 @@ def _optimize_batch_candidates(
logger.info(f"Generated candidate batch {i+1} of {len(batched_ics)}.")

batch_candidates = torch.cat(batch_candidates_list)
batch_acq_values = torch.cat(batch_acq_values_list)
batch_acq_values = torch.stack(batch_acq_values_list).flatten()
return batch_candidates, batch_acq_values, opt_warnings

batch_candidates, batch_acq_values, ws = _optimize_batch_candidates(timeout_sec)
Expand Down
31 changes: 28 additions & 3 deletions botorch/optim/parameter_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def make_scipy_linear_constraints(
r"""Generate scipy constraints from torch representation.
Args:
shapeX: The shape of the torch.Tensor to optimize over (i.e. `b x q x d`)
shapeX: The shape of the torch.Tensor to optimize over (i.e. `(b) x q x d`)
inequality constraints: A list of tuples (indices, coefficients, rhs),
with each tuple encoding an inequality constraint of the form
`\sum_i (X[indices[i]] * coefficients[i]) >= rhs`, where
Expand Down Expand Up @@ -219,10 +219,35 @@ def _make_linear_constraints(
version of the input tensor `X`, returning a scalar.
- "jac": A callable evaluating the constraint's Jacobian on `x`, a flattened
version of the input tensor `X`, returning a numpy array.
>>> shapeX = torch.Size([3, 5, 4])
>>> constraints = _make_linear_constraints(
... indices=torch.tensor([1., 2.]),
... coefficients=torch.tensor([-0.5, 1.3]),
... rhs=0.49,
... shapeX=shapeX,
... eq=True
... )
>>> len(constraints)
15
>>> constraints[0].keys()
dict_keys(['type', 'fun', 'jac'])
>>> x = np.arange(60).reshape(shapeX)
>>> constraints[0]["fun"](x)
1.61 # 1 * -0.5 + 2 * 1.3 - 0.49
>>> constraints[0]["jac"](x)
[0., -0.5, 1.3, 0., 0., ...]
>>> constraints[1]["fun"](x) #
4.81
"""
if len(shapeX) != 3:
raise UnsupportedError("`shapeX` must be `b x q x d`")
if len(shapeX) not in (2, 3):
raise UnsupportedError(
f"`shapeX` must be `(b) x q x d` (at least two-dimensional). It is "
f"{shapeX}."
)
q, d = shapeX[-2:]
if len(shapeX) == 2:
shapeX = torch.Size([1, q, d])
n = shapeX.numel()
constraints: List[ScipyConstraintDict] = []
coeffs = _arrayify(coefficients)
Expand Down
115 changes: 80 additions & 35 deletions test/optim/test_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,24 +334,64 @@ def test_optimize_acqf_sequential_notimplemented(self):
)

def test_optimize_acqf_runs_given_batch_initial_conditions(self):
num_restarts, raw_samples, dim = 1, 1, 1
num_restarts, raw_samples, dim = 1, 2, 3

opt_x = 2 / np.pi
# start near one (of many) optima
initial_conditions = (opt_x * 1.01) * torch.ones(
(num_restarts, raw_samples, dim)
)
# -x[i] * 1 >= -opt_x * 1.01 => x[i] <= opt_x * 1.01
inequality_constraints = [
(torch.tensor([i]), -torch.tensor([1]), -opt_x * 1.01) for i in range(dim)
] + [
# x[i] * 1 >= opt_x * .99
(torch.tensor([i]), torch.tensor([1]), opt_x * 0.99)
for i in range(dim)
]
q = 1

ic_shapes = [(1, 2, dim), (2, 1, dim), (1, dim)]

torch.manual_seed(0)
batch_candidates, acq_value_list = optimize_acqf(
acq_function=SinOneOverXAcqusitionFunction(),
bounds=torch.stack([-1 * torch.ones(dim), torch.ones(dim)]),
q=1,
num_restarts=num_restarts,
raw_samples=raw_samples,
batch_initial_conditions=initial_conditions,
)
self.assertAlmostEqual(batch_candidates.item(), opt_x, delta=1e-5)
self.assertAlmostEqual(acq_value_list.item(), 1)
for shape in ic_shapes:
with self.subTest(shape=shape):
# start near one (of many) optima
initial_conditions = (opt_x * 1.01) * torch.ones(shape)
batch_candidates, acq_value_list = optimize_acqf(
acq_function=SinOneOverXAcqusitionFunction(),
bounds=torch.stack([-1 * torch.ones(dim), torch.ones(dim)]),
q=q,
num_restarts=num_restarts,
raw_samples=raw_samples,
batch_initial_conditions=initial_conditions,
inequality_constraints=inequality_constraints,
)
self.assertAllClose(
batch_candidates,
opt_x * torch.ones_like(batch_candidates),
# must be at least 50% closer to the optimum than it started
atol=0.004,
rtol=0.005,
)
self.assertAlmostEqual(acq_value_list.item(), 1, places=3)

def test_optimize_acqf_wrong_ic_shape_inequality_constraints(self) -> None:
dim = 3
ic_shapes = [(1, 2, dim + 1), (1, 2, dim, 1), (1, dim + 1), (1, 1), (dim,)]

for shape in ic_shapes:
with self.subTest(shape=shape):
initial_conditions = torch.ones(shape)
expected_error = (
rf"batch_initial_conditions.shape\[-1\] must be {dim}\."
if len(shape) in (2, 3)
else r"batch_initial_conditions must be 2\-dimensional or "
)
with self.assertRaisesRegex(ValueError, expected_error):
optimize_acqf(
acq_function=MockAcquisitionFunction(),
bounds=torch.stack([-1 * torch.ones(dim), torch.ones(dim)]),
q=4,
batch_initial_conditions=initial_conditions,
num_restarts=1,
)

def test_optimize_acqf_warns_on_opt_failure(self):
"""
Expand Down Expand Up @@ -808,15 +848,20 @@ def test_optimize_acqf_cyclic(self, mock_optimize_acqf):
tkwargs = {"device": self.device}
bounds = torch.stack([torch.zeros(3), 4 * torch.ones(3)])
inequality_constraints = [
[torch.tensor([3]), torch.tensor([4]), torch.tensor(5)]
[torch.tensor([2], dtype=int), torch.tensor([4.0]), torch.tensor(5.0)]
]
mock_acq_function = MockAcquisitionFunction()
for q, dtype in itertools.product([1, 3], (torch.float, torch.double)):
inequality_constraints[0] = [
t.to(**tkwargs) for t in inequality_constraints[0]
tkwargs["dtype"] = dtype
inequality_constraints = [
(
# indices can't be floats or doubles
inequality_constraints[0][0],
inequality_constraints[0][1].to(**tkwargs),
inequality_constraints[0][2].to(**tkwargs),
)
]
mock_optimize_acqf.reset_mock()
tkwargs["dtype"] = dtype
bounds = bounds.to(**tkwargs)
candidate_rvs = []
acq_val_rvs = []
Expand Down Expand Up @@ -855,23 +900,23 @@ def test_optimize_acqf_cyclic(self, mock_optimize_acqf):
post_processing_func=rounding_func,
cyclic_options={"maxiter": num_cycles},
)
# check that X_pending is set correctly in cyclic optimization
if q > 1:
x_pending_call_args_list = mock_set_X_pending.call_args_list
idxr = torch.ones(q, dtype=torch.bool, device=self.device)
for i in range(len(x_pending_call_args_list) - 1):
idxr[i] = 0
self.assertTrue(
torch.equal(
x_pending_call_args_list[i][0][0], orig_candidates[idxr]
)
# check that X_pending is set correctly in cyclic optimization
if q > 1:
x_pending_call_args_list = mock_set_X_pending.call_args_list
idxr = torch.ones(q, dtype=torch.bool, device=self.device)
for i in range(len(x_pending_call_args_list) - 1):
idxr[i] = 0
self.assertTrue(
torch.equal(
x_pending_call_args_list[i][0][0], orig_candidates[idxr]
)
idxr[i] = 1
orig_candidates[i] = candidate_rvs[i + 1]
# check reset to base_X_pendingg
self.assertIsNone(x_pending_call_args_list[-1][0][0])
else:
mock_set_X_pending.assert_not_called()
)
idxr[i] = 1
orig_candidates[i] = candidate_rvs[i + 1]
# check reset to base_X_pendingg
self.assertIsNone(x_pending_call_args_list[-1][0][0])
else:
mock_set_X_pending.assert_not_called()
# check final candidates
expected_candidates = (
torch.cat(candidate_rvs[-q:], dim=0) if q > 1 else candidate_rvs[0]
Expand Down
Loading

0 comments on commit 2f2b7e2

Please sign in to comment.