From 4d4b47da8e72b287a0124dc83aa25f42ee57443d Mon Sep 17 00:00:00 2001 From: Sam Daulton Date: Wed, 8 Feb 2023 11:00:02 -0800 Subject: [PATCH] pass gen_candidates callable in optimize_acqf (#1655) Summary: Pull Request resolved: https://github.com/pytorch/botorch/pull/1655 see title. This will support using stochastic optimization Differential Revision: https://internalfb.com/D41629164 fbshipit-source-id: 6db9499cff12e54393968246c7344af0820e5a40 --- botorch/generation/gen.py | 4 +- botorch/optim/optimize.py | 43 +- test/acquisition/test_knowledge_gradient.py | 8 + test/optim/test_optimize.py | 534 ++++++++++++-------- 4 files changed, 358 insertions(+), 231 deletions(-) diff --git a/botorch/generation/gen.py b/botorch/generation/gen.py index 249ee8f8e6..a6670fe1ce 100644 --- a/botorch/generation/gen.py +++ b/botorch/generation/gen.py @@ -37,6 +37,8 @@ logger = _get_logger() +TGenCandidates = Callable[[Tensor, AcquisitionFunction, Any], Tuple[Tensor, Tensor]] + def gen_candidates_scipy( initial_conditions: Tensor, @@ -152,7 +154,6 @@ def gen_candidates_scipy( clamped_candidates ) return clamped_candidates, batch_acquisition - clamped_candidates = columnwise_clamp( X=initial_conditions, lower=lower_bounds, upper=upper_bounds ) @@ -360,7 +361,6 @@ def gen_candidates_torch( clamped_candidates ) return clamped_candidates, batch_acquisition - _clamp = partial(columnwise_clamp, lower=lower_bounds, upper=upper_bounds) clamped_candidates = _clamp(initial_conditions).requires_grad_(True) _optimizer = optimizer(params=[clamped_candidates], lr=options.get("lr", 0.025)) diff --git a/botorch/optim/optimize.py b/botorch/optim/optimize.py index b361d8d853..1498c110bd 100644 --- a/botorch/optim/optimize.py +++ b/botorch/optim/optimize.py @@ -23,13 +23,14 @@ from botorch.acquisition.knowledge_gradient import qKnowledgeGradient from botorch.exceptions import InputDataError, UnsupportedError from botorch.exceptions.warnings import OptimizationWarning -from botorch.generation.gen import gen_candidates_scipy +from botorch.generation.gen import gen_candidates_scipy, TGenCandidates from botorch.logging import logger from botorch.optim.initializers import ( gen_batch_initial_conditions, gen_one_shot_kg_initial_conditions, ) from botorch.optim.stopping import ExpMAStoppingCriterion +from botorch.optim.utils import _filter_kwargs from torch import Tensor INIT_OPTION_KEYS = { @@ -64,6 +65,7 @@ def optimize_acqf( post_processing_func: Optional[Callable[[Tensor], Tensor]] = None, batch_initial_conditions: Optional[Tensor] = None, return_best_only: bool = True, + gen_candidates: Optional[TGenCandidates] = None, sequential: bool = False, **kwargs: Any, ) -> Tuple[Tensor, Tensor]: @@ -103,6 +105,12 @@ def optimize_acqf( this if you do not want to use default initialization strategy. return_best_only: If False, outputs the solutions corresponding to all random restart initializations of the optimization. + gen_candidates: A callable for generating candidates (and their associated + acquisition values) given a tensor of initial conditions and an + acquisition function. Other common inputs include lower and upper bounds + and a dictionary of options, but refer to the documentation of specific + generation functions (e.g gen_candidates_scipy and gen_candidates_torch) + for method-specific inputs. Default: `gen_candidates_scipy` sequential: If False, uses joint optimization, otherwise uses sequential optimization. kwargs: Additonal keyword arguments. @@ -134,6 +142,9 @@ def optimize_acqf( """ start_time: float = time.monotonic() timeout_sec = kwargs.pop("timeout_sec", None) + # using a default of None simplifies unit testing + if gen_candidates is None: + gen_candidates = gen_candidates_scipy if inequality_constraints is None: if not (bounds.ndim == 2 and bounds.shape[0] == 2): @@ -229,6 +240,7 @@ def optimize_acqf( sequential=False, ic_generator=ic_gen, timeout_sec=timeout_sec, + gen_candidates=gen_candidates, ) candidate_list.append(candidate) @@ -277,6 +289,11 @@ def optimize_acqf( batch_limit: int = options.get( "batch_limit", num_restarts if not nonlinear_inequality_constraints else 1 ) + has_parameter_constraints = ( + inequality_constraints is not None + or equality_constraints is not None + or nonlinear_inequality_constraints is not None + ) def _optimize_batch_candidates( timeout_sec: Optional[float], @@ -288,24 +305,36 @@ def _optimize_batch_candidates( if timeout_sec is not None: timeout_sec = (timeout_sec - start_time) / len(batched_ics) - scipy_kws = { + gen_kwargs = { "acquisition_function": acq_function, "lower_bounds": None if bounds[0].isinf().all() else bounds[0], "upper_bounds": None if bounds[1].isinf().all() else bounds[1], "options": {k: v for k, v in options.items() if k not in INIT_OPTION_KEYS}, - "inequality_constraints": inequality_constraints, - "equality_constraints": equality_constraints, - "nonlinear_inequality_constraints": nonlinear_inequality_constraints, "fixed_features": fixed_features, "timeout_sec": timeout_sec, } + if has_parameter_constraints: + # only add parameter constraints to gen_kwargs if they are specified + # to avoid unnecessary warnings in _filter_kwargs + gen_kwargs.update( + { + "inequality_constraints": inequality_constraints, + "equality_constraints": equality_constraints, + # the line is too long + "nonlinear_inequality_constraints": ( + nonlinear_inequality_constraints + ), + } + ) + filtered_gen_kwargs = _filter_kwargs(gen_candidates, **gen_kwargs) + for i, batched_ics_ in enumerate(batched_ics): # optimize using random restart optimization with warnings.catch_warnings(record=True) as ws: warnings.simplefilter("always", category=OptimizationWarning) - batch_candidates_curr, batch_acq_values_curr = gen_candidates_scipy( - initial_conditions=batched_ics_, **scipy_kws + batch_candidates_curr, batch_acq_values_curr = gen_candidates( + initial_conditions=batched_ics_, **filtered_gen_kwargs ) opt_warnings += ws batch_candidates_list.append(batch_candidates_curr) diff --git a/test/acquisition/test_knowledge_gradient.py b/test/acquisition/test_knowledge_gradient.py index 18107a83ee..94efb87c65 100644 --- a/test/acquisition/test_knowledge_gradient.py +++ b/test/acquisition/test_knowledge_gradient.py @@ -25,8 +25,10 @@ ) from botorch.acquisition.utils import project_to_sample_points from botorch.exceptions.errors import UnsupportedError +from botorch.generation.gen import gen_candidates_scipy from botorch.models import SingleTaskGP from botorch.optim.optimize import optimize_acqf +from botorch.optim.utils import _filter_kwargs from botorch.posteriors.gpytorch import GPyTorchPosterior from botorch.sampling.normal import IIDNormalSampler, SobolQMCNormalSampler from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior @@ -593,7 +595,13 @@ def test_optimize_w_posterior_transform(self): torch.zeros(2, n_f + 1, 2, **tkwargs), torch.zeros(2, **tkwargs), ), + ), mock.patch( + f"{optimize_acqf.__module__}._filter_kwargs", + wraps=lambda f, **kwargs: _filter_kwargs( + function=gen_candidates_scipy, **kwargs + ), ): + candidate, value = optimize_acqf( acq_function=kg, bounds=bounds, diff --git a/test/optim/test_optimize.py b/test/optim/test_optimize.py index 2e849d2d10..3ca99acca5 100644 --- a/test/optim/test_optimize.py +++ b/test/optim/test_optimize.py @@ -6,6 +6,7 @@ import itertools import warnings +from inspect import signature from unittest import mock import numpy as np @@ -15,6 +16,7 @@ OneShotAcquisitionFunction, ) from botorch.exceptions import InputDataError, UnsupportedError +from botorch.generation.gen import gen_candidates_scipy, gen_candidates_torch from botorch.optim.optimize import ( _filter_infeasible, _filter_invalid, @@ -90,10 +92,16 @@ def rounding_func(X: Tensor) -> Tensor: class TestOptimizeAcqf(BotorchTestCase): + @mock.patch("botorch.generation.gen.gen_candidates_torch") @mock.patch("botorch.optim.optimize.gen_batch_initial_conditions") @mock.patch("botorch.optim.optimize.gen_candidates_scipy") + @mock.patch("botorch.optim.utils.common.signature") def test_optimize_acqf_joint( - self, mock_gen_candidates, mock_gen_batch_initial_conditions + self, + mock_signature, + mock_gen_candidates_scipy, + mock_gen_batch_initial_conditions, + mock_gen_candidates_torch, ): q = 3 num_restarts = 2 @@ -101,93 +109,114 @@ def test_optimize_acqf_joint( options = {} mock_acq_function = MockAcquisitionFunction() cnt = 0 + for dtype in (torch.float, torch.double): - mock_gen_batch_initial_conditions.return_value = torch.zeros( - num_restarts, q, 3, device=self.device, dtype=dtype - ) - base_cand = torch.arange(3, device=self.device, dtype=dtype).expand(1, q, 3) - mock_candidates = torch.cat( - [i * base_cand for i in range(num_restarts)], dim=0 - ) - mock_acq_values = num_restarts - torch.arange( - num_restarts, device=self.device, dtype=dtype - ) - mock_gen_candidates.return_value = (mock_candidates, mock_acq_values) - bounds = torch.stack( - [ - torch.zeros(3, device=self.device, dtype=dtype), - 4 * torch.ones(3, device=self.device, dtype=dtype), - ] - ) - candidates, acq_vals = optimize_acqf( - acq_function=mock_acq_function, - bounds=bounds, - q=q, - num_restarts=num_restarts, - raw_samples=raw_samples, - options=options, - ) - self.assertTrue(torch.equal(candidates, mock_candidates[0])) - self.assertTrue(torch.equal(acq_vals, mock_acq_values[0])) - cnt += 1 - self.assertEqual(mock_gen_batch_initial_conditions.call_count, cnt) + for mock_gen_candidates in ( + mock_gen_candidates_scipy, + mock_gen_candidates_torch, + ): + # Mocks don't have a __name__ attribute. + # Set the attribute, since it is needed for testing _filter_kwargs + if mock_gen_candidates == mock_gen_candidates_torch: + mock_signature.return_value = signature(gen_candidates_torch) + else: + mock_signature.return_value = signature(gen_candidates_scipy) + mock_gen_candidates.__name__ = "gen_candidates" - # test generation with provided initial conditions - candidates, acq_vals = optimize_acqf( - acq_function=mock_acq_function, - bounds=bounds, - q=q, - num_restarts=num_restarts, - raw_samples=raw_samples, - options=options, - return_best_only=False, - batch_initial_conditions=torch.zeros( + mock_gen_batch_initial_conditions.return_value = torch.zeros( num_restarts, q, 3, device=self.device, dtype=dtype - ), - ) - self.assertTrue(torch.equal(candidates, mock_candidates)) - self.assertTrue(torch.equal(acq_vals, mock_acq_values)) - self.assertEqual(mock_gen_batch_initial_conditions.call_count, cnt) - - # test fixed features - fixed_features = {0: 0.1} - mock_candidates[:, 0] = 0.1 - mock_gen_candidates.return_value = (mock_candidates, mock_acq_values) - candidates, acq_vals = optimize_acqf( - acq_function=mock_acq_function, - bounds=bounds, - q=q, - num_restarts=num_restarts, - raw_samples=raw_samples, - options=options, - fixed_features=fixed_features, - ) - self.assertEqual( - mock_gen_candidates.call_args[1]["fixed_features"], fixed_features - ) - self.assertTrue(torch.equal(candidates, mock_candidates[0])) - cnt += 1 - self.assertEqual(mock_gen_batch_initial_conditions.call_count, cnt) + ) + base_cand = torch.arange(3, device=self.device, dtype=dtype).expand( + 1, q, 3 + ) + mock_candidates = torch.cat( + [i * base_cand for i in range(num_restarts)], dim=0 + ) + mock_acq_values = num_restarts - torch.arange( + num_restarts, device=self.device, dtype=dtype + ) + mock_gen_candidates.return_value = (mock_candidates, mock_acq_values) + bounds = torch.stack( + [ + torch.zeros(3, device=self.device, dtype=dtype), + 4 * torch.ones(3, device=self.device, dtype=dtype), + ] + ) + mock_gen_candidates.reset_mock() + candidates, acq_vals = optimize_acqf( + acq_function=mock_acq_function, + bounds=bounds, + q=q, + num_restarts=num_restarts, + raw_samples=raw_samples, + options=options, + gen_candidates=mock_gen_candidates, + ) + mock_gen_candidates.assert_called_once() + self.assertTrue(torch.equal(candidates, mock_candidates[0])) + self.assertTrue(torch.equal(acq_vals, mock_acq_values[0])) + cnt += 1 + self.assertEqual(mock_gen_batch_initial_conditions.call_count, cnt) + + # test generation with provided initial conditions + candidates, acq_vals = optimize_acqf( + acq_function=mock_acq_function, + bounds=bounds, + q=q, + num_restarts=num_restarts, + raw_samples=raw_samples, + options=options, + return_best_only=False, + batch_initial_conditions=torch.zeros( + num_restarts, q, 3, device=self.device, dtype=dtype + ), + gen_candidates=mock_gen_candidates, + ) + self.assertTrue(torch.equal(candidates, mock_candidates)) + self.assertTrue(torch.equal(acq_vals, mock_acq_values)) + self.assertEqual(mock_gen_batch_initial_conditions.call_count, cnt) + + # test fixed features + fixed_features = {0: 0.1} + mock_candidates[:, 0] = 0.1 + mock_gen_candidates.return_value = (mock_candidates, mock_acq_values) + candidates, acq_vals = optimize_acqf( + acq_function=mock_acq_function, + bounds=bounds, + q=q, + num_restarts=num_restarts, + raw_samples=raw_samples, + options=options, + fixed_features=fixed_features, + gen_candidates=mock_gen_candidates, + ) + self.assertEqual( + mock_gen_candidates.call_args[1]["fixed_features"], fixed_features + ) + self.assertTrue(torch.equal(candidates, mock_candidates[0])) + cnt += 1 + self.assertEqual(mock_gen_batch_initial_conditions.call_count, cnt) - # test trivial case when all features are fixed - candidates, acq_vals = optimize_acqf( - acq_function=mock_acq_function, - bounds=bounds, - q=q, - num_restarts=num_restarts, - raw_samples=raw_samples, - options=options, - fixed_features={0: 0.1, 1: 0.2, 2: 0.3}, - ) - self.assertTrue( - torch.equal( - candidates, - torch.tensor( - [0.1, 0.2, 0.3], device=self.device, dtype=dtype - ).expand(3, 3), + # test trivial case when all features are fixed + candidates, acq_vals = optimize_acqf( + acq_function=mock_acq_function, + bounds=bounds, + q=q, + num_restarts=num_restarts, + raw_samples=raw_samples, + options=options, + fixed_features={0: 0.1, 1: 0.2, 2: 0.3}, + gen_candidates=mock_gen_candidates, ) - ) - self.assertEqual(mock_gen_batch_initial_conditions.call_count, cnt) + self.assertTrue( + torch.equal( + candidates, + torch.tensor( + [0.1, 0.2, 0.3], device=self.device, dtype=dtype + ).expand(3, 3), + ) + ) + self.assertEqual(mock_gen_batch_initial_conditions.call_count, cnt) # test OneShotAcquisitionFunction mock_acq_function = MockOneShotAcquisitionFunction() @@ -198,6 +227,7 @@ def test_optimize_acqf_joint( num_restarts=num_restarts, raw_samples=raw_samples, options=options, + gen_candidates=mock_gen_candidates, ) self.assertTrue( torch.equal( @@ -214,107 +244,126 @@ def test_optimize_acqf_joint( q=q, num_restarts=num_restarts, options=options, + gen_candidates=mock_gen_candidates, ) @mock.patch("botorch.optim.optimize.gen_batch_initial_conditions") @mock.patch("botorch.optim.optimize.gen_candidates_scipy") + @mock.patch("botorch.generation.gen.gen_candidates_torch") + @mock.patch("botorch.optim.utils.common.signature") def test_optimize_acqf_sequential( self, + mock_signature, + mock_gen_candidates_torch, mock_gen_candidates_scipy, mock_gen_batch_initial_conditions, timeout_sec=None, ): - q = 3 - num_restarts = 2 - raw_samples = 10 - options = {} - for dtype in (torch.float, torch.double): - mock_acq_function = MockAcquisitionFunction() - mock_gen_batch_initial_conditions.side_effect = [ - torch.zeros(num_restarts, device=self.device, dtype=dtype) - for _ in range(q) - ] - gcs_return_vals = [ - ( - torch.tensor([[[1.1, 2.1, 3.1]]], device=self.device, dtype=dtype), - torch.tensor([i], device=self.device, dtype=dtype), + for mock_gen_candidates in ( + mock_gen_candidates_scipy, + mock_gen_candidates_torch, + ): + if mock_gen_candidates == mock_gen_candidates_torch: + mock_signature.return_value = signature(gen_candidates_torch) + else: + mock_signature.return_value = signature(gen_candidates_scipy) + mock_gen_candidates.__name__ = "gen_candidates" + q = 3 + num_restarts = 2 + raw_samples = 10 + options = {} + for dtype in (torch.float, torch.double): + mock_acq_function = MockAcquisitionFunction() + mock_gen_batch_initial_conditions.side_effect = [ + torch.zeros(num_restarts, 1, 3, device=self.device, dtype=dtype) + for _ in range(q) + ] + gcs_return_vals = [ + ( + torch.tensor( + [[[1.1, 2.1, 3.1]]], device=self.device, dtype=dtype + ), + torch.tensor([i], device=self.device, dtype=dtype), + ) + for i in range(q) + ] + mock_gen_candidates.side_effect = gcs_return_vals + expected_candidates = torch.cat( + [cands[0] for cands, _ in gcs_return_vals], dim=-2 + ).round() + bounds = torch.stack( + [ + torch.zeros(3, device=self.device, dtype=dtype), + 4 * torch.ones(3, device=self.device, dtype=dtype), + ] ) - for i in range(q) - ] - mock_gen_candidates_scipy.side_effect = gcs_return_vals - expected_candidates = torch.cat( - [cands[0] for cands, _ in gcs_return_vals], dim=-2 - ).round() - bounds = torch.stack( - [ - torch.zeros(3, device=self.device, dtype=dtype), - 4 * torch.ones(3, device=self.device, dtype=dtype), + inequality_constraints = [ + (torch.tensor([2]), torch.tensor([4]), torch.tensor(5)) ] - ) - inequality_constraints = [ - (torch.tensor([2]), torch.tensor([4]), torch.tensor(5)) - ] - candidates, acq_value = optimize_acqf( - acq_function=mock_acq_function, - bounds=bounds, - q=q, - num_restarts=num_restarts, - raw_samples=raw_samples, - options=options, - inequality_constraints=inequality_constraints, - post_processing_func=rounding_func, - sequential=True, - timeout_sec=timeout_sec, - ) - self.assertTrue(torch.equal(candidates, expected_candidates)) - self.assertTrue( - torch.equal( - acq_value, torch.cat([acqval for _, acqval in gcs_return_vals]) + mock_gen_candidates.reset_mock() + candidates, acq_value = optimize_acqf( + acq_function=mock_acq_function, + bounds=bounds, + q=q, + num_restarts=num_restarts, + raw_samples=raw_samples, + options=options, + inequality_constraints=inequality_constraints, + post_processing_func=rounding_func, + sequential=True, + timeout_sec=timeout_sec, + gen_candidates=mock_gen_candidates, + ) + self.assertEqual(mock_gen_candidates.call_count, q) + self.assertTrue(torch.equal(candidates, expected_candidates)) + self.assertTrue( + torch.equal( + acq_value, torch.cat([acqval for _, acqval in gcs_return_vals]) + ) + ) + # verify error when using a OneShotAcquisitionFunction + with self.assertRaises(NotImplementedError): + optimize_acqf( + acq_function=mock.Mock(spec=OneShotAcquisitionFunction), + bounds=bounds, + q=q, + num_restarts=num_restarts, + raw_samples=raw_samples, + sequential=True, + ) + # Verify error for passing in incorrect bounds + with self.assertRaisesRegex( + ValueError, + "bounds should be a `2 x d` tensor", + ): + optimize_acqf( + acq_function=mock_acq_function, + bounds=bounds.T, + q=q, + num_restarts=num_restarts, + raw_samples=raw_samples, + sequential=True, ) - ) - # verify error when using a OneShotAcquisitionFunction - with self.assertRaises(NotImplementedError): - optimize_acqf( - acq_function=mock.Mock(spec=OneShotAcquisitionFunction), - bounds=bounds, - q=q, - num_restarts=num_restarts, - raw_samples=raw_samples, - sequential=True, - ) - # Verify error for passing in incorrect bounds - with self.assertRaisesRegex( - ValueError, - "bounds should be a `2 x d` tensor", - ): - optimize_acqf( - acq_function=mock_acq_function, - bounds=bounds.T, - q=q, - num_restarts=num_restarts, - raw_samples=raw_samples, - sequential=True, - ) - # Veryify error when using sequential=True in - # conjunction with user-supplied batch_initial_conditions - with self.assertRaisesRegex( - UnsupportedError, - "`batch_initial_conditions` is not supported for sequential " - "optimization. Either avoid specifying `batch_initial_conditions` " - "to use the custom initializer or use the `ic_generator` kwarg to " - "generate initial conditions for the case of " - "nonlinear inequality constraints.", - ): - optimize_acqf( - acq_function=mock_acq_function, - bounds=bounds, - q=q, - num_restarts=num_restarts, - raw_samples=raw_samples, - batch_initial_conditions=mock_gen_batch_initial_conditions, - sequential=True, - ) + # Verify error when using sequential=True in + # conjunction with user-supplied batch_initial_conditions + with self.assertRaisesRegex( + UnsupportedError, + "`batch_initial_conditions` is not supported for sequential " + "optimization. Either avoid specifying `batch_initial_conditions` " + "to use the custom initializer or use the `ic_generator` kwarg to " + "generate initial conditions for the case of " + "nonlinear inequality constraints.", + ): + optimize_acqf( + acq_function=mock_acq_function, + bounds=bounds, + q=q, + num_restarts=num_restarts, + raw_samples=raw_samples, + batch_initial_conditions=mock_gen_batch_initial_conditions, + sequential=True, + ) def test_optimize_acqf_sequential_timeout(self): self.test_optimize_acqf_sequential(timeout_sec=1e-4) @@ -734,10 +783,16 @@ def nlc4(x): raw_samples=16, ) + @mock.patch("botorch.generation.gen.gen_candidates_torch") @mock.patch("botorch.optim.optimize.gen_batch_initial_conditions") @mock.patch("botorch.optim.optimize.gen_candidates_scipy") + @mock.patch("botorch.optim.utils.common.signature") def test_optimize_acqf_non_linear_constraints_sequential( - self, mock_gen_candidates_scipy, mock_gen_batch_initial_conditions + self, + mock_signature, + mock_gen_candidates_scipy, + mock_gen_batch_initial_conditions, + mock_gen_candidates_torch, ): def nlc(x): return 4 * x[..., 2] - 5 @@ -746,60 +801,95 @@ def nlc(x): num_restarts = 2 raw_samples = 10 options = {} - for dtype in (torch.float, torch.double): - mock_acq_function = MockAcquisitionFunction() - mock_gen_batch_initial_conditions.side_effect = [ - torch.zeros(num_restarts, device=self.device, dtype=dtype) - for _ in range(q) - ] - gcs_return_vals = [ - ( - torch.tensor([[[1.0, 2.0, 3.0]]], device=self.device, dtype=dtype), - torch.tensor([i], device=self.device, dtype=dtype), - ) - # for nonlinear inequality constraints the batch_limit variable is - # currently set to 1 by default and hence gen_candidates_scipy is - # called num_restarts*q times - for i in range(num_restarts * q) - ] - mock_gen_candidates_scipy.side_effect = gcs_return_vals - expected_candidates = torch.cat( - [cands[0] for cands, _ in gcs_return_vals[::num_restarts]], dim=-2 - ) - bounds = torch.stack( - [ - torch.zeros(3, device=self.device, dtype=dtype), - 4 * torch.ones(3, device=self.device, dtype=dtype), - ] - ) + for mock_gen_candidates in ( + mock_gen_candidates_torch, + mock_gen_candidates_scipy, + ): + # Mocks don't have a __name__ attribute. + # Set the attribute, since it is needed for testing _filter_kwargs + if mock_gen_candidates == mock_gen_candidates_torch: + mock_signature.return_value = signature(gen_candidates_torch) + else: + mock_signature.return_value = signature(gen_candidates_scipy) + mock_gen_candidates.__name__ = "gen_candidates" + for dtype in (torch.float, torch.double): - candidates, acq_value = optimize_acqf( - acq_function=mock_acq_function, - bounds=bounds, - q=q, - num_restarts=num_restarts, - raw_samples=raw_samples, - options=options, - nonlinear_inequality_constraints=[nlc], - sequential=True, - ic_generator=mock_gen_batch_initial_conditions, - ) - self.assertTrue(torch.equal(candidates, expected_candidates)) - # Extract the relevant entries from gcs_return_vals to - # perform comparison with. - self.assertTrue( - torch.equal( - acq_value, - torch.cat( - [ - expected_acq_value - for _, expected_acq_value in gcs_return_vals[ - num_restarts - 1 :: num_restarts + mock_acq_function = MockAcquisitionFunction() + mock_gen_batch_initial_conditions.side_effect = [ + torch.zeros(num_restarts, 1, 3, device=self.device, dtype=dtype) + for _ in range(q) + ] + gcs_return_vals = [ + ( + torch.tensor( + [[[1.0, 2.0, 3.0]]], device=self.device, dtype=dtype + ), + torch.tensor([i], device=self.device, dtype=dtype), + ) + # for nonlinear inequality constraints the batch_limit variable is + # currently set to 1 by default and hence gen_candidates_scipy is + # called num_restarts*q times + for i in range(num_restarts * q) + ] + mock_gen_candidates.side_effect = gcs_return_vals + expected_candidates = torch.cat( + [cands[0] for cands, _ in gcs_return_vals[::num_restarts]], dim=-2 + ) + bounds = torch.stack( + [ + torch.zeros(3, device=self.device, dtype=dtype), + 4 * torch.ones(3, device=self.device, dtype=dtype), + ] + ) + with warnings.catch_warnings(record=True) as ws: + candidates, acq_value = optimize_acqf( + acq_function=mock_acq_function, + bounds=bounds, + q=q, + num_restarts=num_restarts, + raw_samples=raw_samples, + options=options, + nonlinear_inequality_constraints=[nlc], + sequential=True, + ic_generator=mock_gen_batch_initial_conditions, + gen_candidates=mock_gen_candidates, + ) + if mock_gen_candidates == mock_gen_candidates_torch: + self.assertEqual(len(ws), 3) + message = ( + "Keyword arguments ['nonlinear_inequality_constraints'," + " 'equality_constraints', 'inequality_constraints'] will" + " be ignored because they are not allowed parameters for" + " function gen_candidates. Allowed parameters are " + " ['initial_conditions', 'acquisition_function', " + "'lower_bounds', 'upper_bounds', 'optimizer', 'options'," + " 'callback', 'fixed_features', 'timeout_sec']." + ) + expected_warning_raised = ( + issubclass(w.category, UserWarning) + and message == str(w.message) + for w in ws + ) + self.assertTrue(expected_warning_raised) + # check message + else: + self.assertEqual(len(ws), 0) + self.assertTrue(torch.equal(candidates, expected_candidates)) + # Extract the relevant entries from gcs_return_vals to + # perform comparison with. + self.assertTrue( + torch.equal( + acq_value, + torch.cat( + [ + expected_acq_value + for _, expected_acq_value in gcs_return_vals[ + num_restarts - 1 :: num_restarts + ] ] - ] + ), ), - ), - ) + ) def test_constraint_caching(self): def nlc(x):