From 02c0ce52b7eeb57e4352f3341a5ee7738edd7f4b Mon Sep 17 00:00:00 2001 From: David Eriksson Date: Wed, 23 Oct 2024 05:32:01 -0700 Subject: [PATCH] SEBO bug fixes (#2935) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/2935 This fixes a bunch of bugs in SEBO Reviewed By: bletham Differential Revision: D60089765 fbshipit-source-id: 231db9b67a0bce7ab9a99a2c03e8c617b488a408 --- ax/models/torch/botorch_modular/sebo.py | 226 ++++++++++-------------- ax/models/torch/tests/test_sebo.py | 160 +++++++++-------- 2 files changed, 187 insertions(+), 199 deletions(-) diff --git a/ax/models/torch/botorch_modular/sebo.py b/ax/models/torch/botorch_modular/sebo.py index 4c7d2ca31f8..0ad69be9afc 100644 --- a/ax/models/torch/botorch_modular/sebo.py +++ b/ax/models/torch/botorch_modular/sebo.py @@ -7,22 +7,26 @@ # pyre-strict import functools +import warnings from collections.abc import Callable from copy import deepcopy from functools import partial +from logging import Logger from typing import Any import torch from ax.core.search_space import SearchSpaceDigest +from ax.exceptions.core import AxWarning from ax.models.torch.botorch_modular.acquisition import Acquisition from ax.models.torch.botorch_modular.optimizer_argparse import optimizer_argparse from ax.models.torch.botorch_modular.surrogate import Surrogate from ax.models.torch_base import TorchOptConfig from ax.utils.common.constants import Keys +from ax.utils.common.logger import get_logger from ax.utils.common.typeutils import not_none from botorch.acquisition.acquisition import AcquisitionFunction -from botorch.acquisition.multi_objective.monte_carlo import ( - qExpectedHypervolumeImprovement, +from botorch.acquisition.multi_objective.logei import ( + qLogNoisyExpectedHypervolumeImprovement, ) from botorch.acquisition.penalized import L0Approximation from botorch.models.deterministic import GenericDeterministicModel @@ -34,11 +38,12 @@ optimize_acqf_homotopy, ) from botorch.utils.datasets import SupervisedDataset -from botorch.utils.multi_objective.pareto import is_non_dominated +from botorch.utils.transforms import unnormalize from torch import Tensor from torch.quasirandom import SobolEngine -CLAMP_TOL = 0.01 +CLAMP_TOL = 1e-2 +logger: Logger = get_logger(__name__) class SEBOAcquisition(Acquisition): @@ -76,53 +81,58 @@ def __init__( # construct determinsitic model for penalty term # pyre-fixme[4]: Attribute must be annotated. self.deterministic_model = self._construct_penalty() - surrogate_f = deepcopy(surrogate) + # we need to clamp the training data to the target point here as it may + # be slightly off due to numerical issues. + X_sparse = clamp_to_target( + X=surrogate_f.Xs[0].clone(), + target_point=self.target_point, + clamp_tol=CLAMP_TOL, + ) # update the training data in new surrogate not_none(surrogate_f._training_data).append( SupervisedDataset( - surrogate_f.Xs[0], - self.deterministic_model(surrogate_f.Xs[0]), - # append Yvar as zero for penalty term - Yvar=torch.zeros(surrogate_f.Xs[0].shape[0], 1, **tkwargs), + X=X_sparse, + Y=self.deterministic_model(X_sparse), + Yvar=torch.zeros(X_sparse.shape[0], 1, **tkwargs), # noiseless feature_names=surrogate_f.training_data[0].feature_names, outcome_names=[self.penalty_name], ) ) # update the model in new surrogate surrogate_f._model = ModelList(surrogate.model, self.deterministic_model) - self.det_metric_indx = -1 - - # update objective weights and thresholds in the torch config + # update objective weights and thresholds in the torch config torch_opt_config_sebo = self._transform_torch_config( - torch_opt_config, **tkwargs + torch_opt_config=torch_opt_config, **tkwargs ) - # instantiate botorch_acqf_class - if not issubclass(botorch_acqf_class, qExpectedHypervolumeImprovement): - raise ValueError("botorch_acqf_class must be qEHVI to use SEBO") + # Change some options (note: we do not want to do this in-place) + if options.get("cache_root", False): + warnings.warn( + "SEBO doesn't support `cache_root=True`. Changing it to `False`.", + AxWarning, + stacklevel=3, + ) + options = {**options, "cache_root": False} + + # Instantiate the `botorch_acqf_class`. We need to modify `a` before doing this + # (as it controls the L0 norm approximation) since the baseline will be pruned + # when the acquisition function is created. With a=1e-6 the deterministic model + # will be numerically close to the true L0 norm and we will select the + # baseline according to the last homotopy step. + if self.penalty_name == "L0_norm": + self.deterministic_model._f.a.fill_(1e-6) super().__init__( surrogates={"sebo": surrogate_f}, search_space_digest=search_space_digest, torch_opt_config=torch_opt_config_sebo, - botorch_acqf_class=botorch_acqf_class, + botorch_acqf_class=qLogNoisyExpectedHypervolumeImprovement, options=options, ) - if not isinstance(self.acqf, qExpectedHypervolumeImprovement): - raise ValueError("botorch_acqf_class must be qEHVI to use SEBO") # update objective threshold for deterministic model (penalty term) self.acqf.ref_point[-1] = self.sparsity_threshold * -1 - # pyre-ignore - self._objective_thresholds[-1] = self.sparsity_threshold - - Y_pareto = torch.cat( - [d.Y for d in self.surrogates["sebo"].training_data], - dim=-1, - ) - ind_pareto = is_non_dominated(Y_pareto * self._full_objective_weights) - # pyre-ignore - self.X_pareto = self.surrogates["sebo"].Xs[0][ind_pareto].clone() + self._objective_thresholds[-1] = self.sparsity_threshold # pyre-ignore def _construct_penalty(self) -> GenericDeterministicModel: """Construct a penalty term as deterministic model to be included in @@ -150,41 +160,29 @@ def _transform_torch_config( """Transform torch config to include penalty term (deterministic model) as an additional outcomes in BoTorch model. """ - # update objective weights by appending the weight (-1) of penalty term - # at the end - ow_sebo = torch.cat( - [torch_opt_config.objective_weights, torch.tensor([-1], **tkwargs)] + # update objective weights by appending the weight -1 for sparsity objective. + objective_weights_sebo = torch.cat( + [torch_opt_config.objective_weights, -torch.ones(1, **tkwargs)] ) if torch_opt_config.outcome_constraints is not None: # update the shape of A matrix in outcome_constraints - oc_sebo = ( - torch.cat( - [ - torch_opt_config.outcome_constraints[0], - torch.zeros( - # pyre-ignore - torch_opt_config.outcome_constraints[0].shape[0], - 1, - **tkwargs, - ), - ], - dim=1, - ), - torch_opt_config.outcome_constraints[1], + A, b = not_none(torch_opt_config.outcome_constraints) + outcome_constraints_sebo = ( + torch.cat([A, torch.zeros(A.shape[0], 1, **tkwargs)], dim=1), + b, ) else: - oc_sebo = None + outcome_constraints_sebo = None if torch_opt_config.objective_thresholds is not None: - # append the sparsity threshold at the end if objective_thresholds - # is not None - ot_sebo = torch.cat( + objective_thresholds_sebo = torch.cat( [ torch_opt_config.objective_thresholds, torch.tensor([self.sparsity_threshold], **tkwargs), ] ) else: - ot_sebo = None + # NOTE: The reference point will be inferred in the base class. + objective_thresholds_sebo = None # update pending observations (if not none) by appending an obs for # the new penalty outcome @@ -195,9 +193,9 @@ def _transform_torch_config( ] return TorchOptConfig( - objective_weights=ow_sebo, - outcome_constraints=oc_sebo, - objective_thresholds=ot_sebo, + objective_weights=objective_weights_sebo, + outcome_constraints=outcome_constraints_sebo, + objective_thresholds=objective_thresholds_sebo, linear_constraints=torch_opt_config.linear_constraints, fixed_features=torch_opt_config.fixed_features, pending_observations=pending_observations, @@ -266,12 +264,8 @@ def optimize( ) # similar, make sure if applies to sparse dimensions only - candidates = clamp_candidates( - X=candidates, - target_point=self.target_point, - clamp_tol=CLAMP_TOL, - device=self.device, - dtype=self.dtype, + candidates = clamp_to_target( + X=candidates, target_point=self.target_point, clamp_tol=CLAMP_TOL ) return candidates, expected_acquisition_value, weights @@ -288,8 +282,7 @@ def _optimize_with_homotopy( _tensorize = partial(torch.tensor, dtype=self.dtype, device=self.device) ssd = search_space_digest bounds = _tensorize(ssd.bounds).t() - - homotopy_schedule = LogLinearHomotopySchedule(start=0.1, end=1e-3, num_steps=30) + homotopy_schedule = LogLinearHomotopySchedule(start=0.2, end=1e-3, num_steps=30) # Prepare arguments for optimizer optimizer_options_with_defaults = optimizer_argparse( @@ -300,31 +293,6 @@ def _optimize_with_homotopy( optimizer="optimize_acqf_homotopy", ) - def callback() -> None: - if ( - self.acqf.cache_pending - ): # If true, pending points are concatenated with X_baseline - if self.acqf._max_iep != 0: - raise ValueError( - "The maximum number of pending points (max_iep) must be 0" - ) - X_baseline = self.acqf._X_baseline_and_pending.clone() - self.acqf.__init__( # pyre-ignore - X_baseline=X_baseline, - model=self.surrogates["sebo"].model, - ref_point=self.acqf.ref_point, - objective=self.acqf.objective, - ) - else: # We can directly get the pending points here - X_pending = self.acqf.X_pending - self.acqf.__init__( # pyre-ignore - X_baseline=self.X_observed, - model=self.surrogates["sebo"].model, - ref_point=self.acqf.ref_point, - objective=self.acqf.objective, - ) - self.acqf.set_X_pending(X_pending) - homotopy = Homotopy( homotopy_parameters=[ HomotopyParameter( @@ -332,16 +300,14 @@ def callback() -> None: schedule=homotopy_schedule, ) ], - callbacks=[callback], ) - # need to know sparse dimensions batch_initial_conditions = get_batch_initial_conditions( acq_function=self.acqf, raw_samples=optimizer_options_with_defaults["raw_samples"], - X_pareto=self.X_pareto, + X_pareto=self.acqf.X_baseline, target_point=self.target_point, + bounds=bounds, num_restarts=optimizer_options_with_defaults["num_restarts"], - **{"device": self.device, "dtype": self.dtype}, ) candidates, expected_acquisition_value = optimize_acqf_homotopy( q=n, @@ -354,11 +320,10 @@ def callback() -> None: fixed_features=fixed_features, batch_initial_conditions=batch_initial_conditions, ) - return ( candidates, expected_acquisition_value, - torch.ones(n, dtype=candidates.dtype), + torch.ones(n, device=candidates.device, dtype=candidates.dtype), ) @@ -370,15 +335,17 @@ def L1_norm_func(X: Tensor, init_point: Tensor) -> Tensor: return torch.linalg.norm((X - init_point), ord=1, dim=-1, keepdim=True) -def clamp_candidates( - X: Tensor, target_point: Tensor, clamp_tol: float, **tkwargs: Any -) -> Tensor: - """Clamp generated candidates within the given ranges to the target point.""" - clamp_mask = (X - target_point).abs() < clamp_tol - clamp_mask = clamp_mask - X[clamp_mask] = ( - target_point.clone().repeat(*X.shape[:-1], 1).to(**tkwargs)[clamp_mask] - ) +def clamp_to_target(X: Tensor, target_point: Tensor, clamp_tol: float) -> Tensor: + """Clamp generated candidates within the given ranges to the target point. + + Args: + X: A `batch_shape x n x d`-dim input tensor `X`. + target_point: A tensor of size `d` corresponding to the target point. + clamp_tol: The clamping tolerance. Any value within `clamp_tol` of the + `target_point` will be clamped to the `target_point`. + """ + clamp_mask = (X - target_point).abs() <= clamp_tol + X[clamp_mask] = target_point.clone().repeat(*X.shape[:-1], 1)[clamp_mask] return X @@ -387,36 +354,35 @@ def get_batch_initial_conditions( raw_samples: int, X_pareto: Tensor, target_point: Tensor, + bounds: Tensor, num_restarts: int = 20, - **tkwargs: Any, ) -> Tensor: """Generate starting points for the SEBO acquisition function optimization.""" + tkwargs: dict[str, Any] = {"device": X_pareto.device, "dtype": X_pareto.dtype} dim = X_pareto.shape[-1] # dimension - # (1) Global Sobol points - X_cand1 = SobolEngine(dimension=dim, scramble=True).draw(raw_samples).to(**tkwargs) - X_cand1 = X_cand1[ - acq_function(X_cand1.unsqueeze(1)).topk(num_restarts // 5).indices - ] - # (2) Global Sobol points with a Bernoulli mask - X_cand2 = SobolEngine(dimension=dim, scramble=True).draw(raw_samples).to(**tkwargs) - mask = torch.rand(X_cand2.shape, **tkwargs) < 0.5 - X_cand2[mask] = target_point.repeat(len(X_cand2), 1).to(**tkwargs)[mask] - X_cand2 = X_cand2[ - acq_function(X_cand2.unsqueeze(1)).topk(num_restarts // 5).indices - ] - # (3) Perturbations of points on the Pareto frontier (done by TuRBO and Spearmint) - X_cand3 = X_pareto.clone()[torch.randint(high=len(X_pareto), size=(raw_samples,))] - mask = X_cand3 != target_point - X_cand3[mask] += 0.2 * torch.randn(*X_cand3.shape, **tkwargs)[mask] - X_cand3 = torch.clamp(X_cand3, min=0.0, max=1.0) - X_cand3 = X_cand3[ - acq_function(X_cand3.unsqueeze(1)).topk(num_restarts // 5).indices + num_sobol, num_local = num_restarts // 2, num_restarts - num_restarts // 2 + # (1) Global sparse Sobol points + X_cand_sobol = ( + SobolEngine(dimension=dim, scramble=True) + .draw(raw_samples, dtype=tkwargs["dtype"]) + .to(**tkwargs) + ) + X_cand_sobol = unnormalize(X_cand_sobol, bounds=bounds) + acq_vals = acq_function(X_cand_sobol.unsqueeze(1)) + if len(X_pareto) == 0: + return X_cand_sobol[acq_vals.topk(num_restarts).indices] + + X_cand_sobol = X_cand_sobol[acq_vals.topk(num_sobol).indices] + # (2) Perturbations of points on the Pareto frontier (done by TuRBO/Spearmint) + X_cand_local = X_pareto.clone()[ + torch.randint(high=len(X_pareto), size=(raw_samples,)) ] - # (4) Apply a Bernoulli mask to points on the Pareto frontier - X_cand4 = X_pareto.clone()[torch.randint(high=len(X_pareto), size=(raw_samples,))] - mask = torch.rand(X_cand4.shape, **tkwargs) < 0.5 - X_cand4[mask] = target_point.repeat(len(X_cand4), 1).to(**tkwargs)[mask].clone() - X_cand4 = X_cand4[ - acq_function(X_cand4.unsqueeze(1)).topk(num_restarts // 5).indices + mask = X_cand_local != target_point + X_cand_local[mask] += ( + 0.2 * ((bounds[1] - bounds[0]) * torch.randn_like(X_cand_local))[mask] + ) + X_cand_local = torch.clamp(X_cand_local, min=bounds[0], max=bounds[1]) + X_cand_local = X_cand_local[ + acq_function(X_cand_local.unsqueeze(1)).topk(num_local).indices ] - return torch.cat((X_cand1, X_cand2, X_cand3, X_cand4), dim=0).unsqueeze(1) + return torch.cat((X_cand_sobol, X_cand_local), dim=0).unsqueeze(1) diff --git a/ax/models/torch/tests/test_sebo.py b/ax/models/torch/tests/test_sebo.py index 567222af670..2e3638ad041 100644 --- a/ax/models/torch/tests/test_sebo.py +++ b/ax/models/torch/tests/test_sebo.py @@ -10,6 +10,7 @@ import dataclasses import functools +import warnings from typing import Any from unittest import mock from unittest.mock import Mock @@ -17,15 +18,18 @@ import torch from ax.core.search_space import SearchSpaceDigest from ax.models.torch.botorch_modular.acquisition import Acquisition -from ax.models.torch.botorch_modular.sebo import L1_norm_func, SEBOAcquisition +from ax.models.torch.botorch_modular.sebo import ( + clamp_to_target, + get_batch_initial_conditions, + L1_norm_func, + SEBOAcquisition, +) from ax.models.torch.botorch_modular.surrogate import Surrogate from ax.models.torch_base import TorchOptConfig from ax.utils.common.constants import Keys from ax.utils.common.testutils import TestCase from ax.utils.common.typeutils import not_none from ax.utils.testing.mock import fast_botorch_optimize -from botorch.acquisition import PosteriorMean -from botorch.acquisition.monte_carlo import qNoisyExpectedImprovement from botorch.acquisition.multi_objective.monte_carlo import ( qNoisyExpectedHypervolumeImprovement, ) @@ -33,9 +37,7 @@ from botorch.models.deterministic import GenericDeterministicModel from botorch.models.gp_regression import SingleTaskGP from botorch.models.model import ModelList -from botorch.optim import Homotopy, HomotopyParameter, LinearHomotopySchedule from botorch.utils.datasets import SupervisedDataset -from torch.nn import Parameter SEBOACQUISITION_PATH: str = SEBOAcquisition.__module__ @@ -103,18 +105,25 @@ def setUp(self) -> None: linear_constraints=self.linear_constraints, fixed_features=self.fixed_features, ) + self.torch_opt_config_2 = TorchOptConfig( + objective_weights=self.objective_weights, + objective_thresholds=self.objective_thresholds, + pending_observations=self.pending_observations, + ) def get_acquisition_function( self, fixed_features: dict[int, float] | None = None, options: dict[str, str | float] | None = None, + torch_opt_config: TorchOptConfig | None = None, ) -> SEBOAcquisition: return SEBOAcquisition( botorch_acqf_class=qNoisyExpectedHypervolumeImprovement, surrogates={Keys.ONLY_SURROGATE: self.surrogates}, search_space_digest=self.search_space_digest, torch_opt_config=dataclasses.replace( - self.torch_opt_config, fixed_features=fixed_features or {} + torch_opt_config or self.torch_opt_config, + fixed_features=fixed_features or {}, ), options=options or self.options, ) @@ -129,11 +138,12 @@ def test_init(self) -> None: self.assertIsInstance(model_list, ModelList) self.assertIsInstance(model_list.models[0], SingleTaskGP) self.assertIsInstance(model_list.models[1], GenericDeterministicModel) - self.assertEqual(acquisition1.det_metric_indx, -1) # Check right penalty term is instantiated self.assertEqual(acquisition1.penalty_name, "L0_norm") self.assertIsInstance(model_list.models[1]._f, L0Approximation) + # `a` needs to be set to something small for the pruning to work as expected + self.assertEqual(model_list.models[-1]._f.a, 1e-6) # Check transformed objective threshold self.assertTrue( @@ -141,7 +151,7 @@ def test_init(self) -> None: # pyre-fixme[6]: For 2nd argument expected `Tensor` but got `int`. acquisition1.acqf.ref_point[-1], # pyre-fixme[6]: For 2nd argument expected `Tensor` but got `int`. - -1 * self.objective_thresholds_sebo[-1], + -self.objective_thresholds_sebo[-1], ) ) self.assertTrue( @@ -190,6 +200,24 @@ def test_init(self) -> None: with self.assertRaisesRegex(ValueError, "please provide target point."): self.get_acquisition_function(options={"penalty": "L1_norm"}) + # Cache root catches + with warnings.catch_warnings(record=True) as ws: + self.get_acquisition_function( + fixed_features=self.fixed_features, + options={"cache_root": True, "target_point": self.target_point}, + ) + self.assertEqual(len(ws), 1) + self.assertEqual( + "SEBO doesn't support `cache_root=True`. Changing it to `False`.", + str(ws[0].message), + ) + + # Test with no outcome constraints + self.get_acquisition_function( + options={"target_point": self.target_point}, + torch_opt_config=self.torch_opt_config_2, + ) + @mock.patch(f"{ACQUISITION_PATH}.optimize_acqf") def test_optimize_l1(self, mock_optimize_acqf: Mock) -> None: mock_optimize_acqf.return_value = ( @@ -217,52 +245,6 @@ def test_optimize_l1(self, mock_optimize_acqf: Mock) -> None: self.assertEqual(kwargs["num_restarts"], self.optimizer_options["num_restarts"]) self.assertEqual(kwargs["raw_samples"], self.optimizer_options["raw_samples"]) - @mock.patch( - f"{SEBOACQUISITION_PATH}.get_batch_initial_conditions", return_value=None - ) - @mock.patch(f"{SEBOACQUISITION_PATH}.Homotopy") - def test_optimize_l0_homotopy( - self, - mock_homotopy: Mock, - mock_get_batch_initial_conditions: Mock, - ) -> None: - tkwargs: dict[str, Any] = {"dtype": torch.double} - acquisition = self.get_acquisition_function( - fixed_features=self.fixed_features, - options={"penalty": "L0_norm", "target_point": self.target_point}, - ) - # overwrite acqf to validate homotopy - # pyre-fixme[61]: `p` is undefined, or not always defined. - # pyre-fixme[6]: For 1st argument expected `(Tensor) -> Tensor` but got `(x: - # Any) -> int`. - model = GenericDeterministicModel(f=lambda x: 5 - (x - p) ** 2) - acqf = PosteriorMean(model=model) - acquisition.acqf = acqf - - p = Parameter(-2 * torch.ones(1, **tkwargs)) - hp = HomotopyParameter( - parameter=p, - schedule=LinearHomotopySchedule(start=4, end=0, num_steps=5), - ) - mock_homotopy.return_value = Homotopy(homotopy_parameters=[hp]) - - search_space_digest = SearchSpaceDigest( - feature_names=["a"], - bounds=[(-10.0, 5.0)], - ) - candidate, acqf_val, weights = acquisition._optimize_with_homotopy( - n=1, - search_space_digest=search_space_digest, - optimizer_options={ - "num_restarts": 2, - "sequential": True, - "raw_samples": 16, - }, - ) - self.assertTrue(torch.allclose(candidate, torch.zeros(1, **tkwargs))) - self.assertTrue(torch.allclose(acqf_val, 5 * torch.ones(1, **tkwargs))) - self.assertEqual(weights, torch.ones(1, **tkwargs)) - @mock.patch(f"{SEBOACQUISITION_PATH}.optimize_acqf_homotopy") def test_optimize_l0(self, mock_optimize_acqf_homotopy: Mock) -> None: mock_optimize_acqf_homotopy.return_value = ( @@ -276,8 +258,6 @@ def test_optimize_l0(self, mock_optimize_acqf_homotopy: Mock) -> None: acquisition.optimize( n=2, search_space_digest=self.search_space_digest, - # does not support in homotopy now - # inequality_constraints=self.inequality_constraints, fixed_features=self.fixed_features, rounding_func=self.rounding_func, optimizer_options=self.optimizer_options, @@ -320,7 +300,7 @@ def test_optimize_l0(self, mock_optimize_acqf_homotopy: Mock) -> None: with self.assertRaisesRegex( NotImplementedError, "Homotopy does not support optimization with inequality " - + "constraints. Use L1 penalty norm instead.", + "constraints. Use L1 penalty norm instead.", ): acquisition.optimize( n=2, @@ -331,16 +311,58 @@ def test_optimize_l0(self, mock_optimize_acqf_homotopy: Mock) -> None: optimizer_options=self.optimizer_options, ) - # assert error when using a wrong botorch_acqf_class - with self.assertRaisesRegex( - ValueError, "botorch_acqf_class must be qEHVI to use SEBO" - ): - acquisition = SEBOAcquisition( - botorch_acqf_class=qNoisyExpectedImprovement, - surrogates={Keys.ONLY_SURROGATE: self.surrogates}, - search_space_digest=self.search_space_digest, - torch_opt_config=dataclasses.replace( - self.torch_opt_config, fixed_features=self.fixed_features - ), - options=self.options, + def test_clamp_to_target(self) -> None: + X = torch.tensor( + [[0.5, 0.01, 0.5], [0.05, 0.5, 0.95], [0.1, 0.02, 0.06]], **self.tkwargs + ) + X_true = torch.tensor( + [[0.5, 0, 0.5], [0, 0.5, 0.95], [0.1, 0, 0.06]], **self.tkwargs + ) + self.assertTrue( + torch.allclose( + clamp_to_target(X, torch.zeros(1, 3, **self.tkwargs), 0.05), X_true ) + ) + + @mock.patch(f"{SEBOACQUISITION_PATH}.optimize_acqf_homotopy") + @mock.patch( + f"{SEBOACQUISITION_PATH}.get_batch_initial_conditions", + wraps=get_batch_initial_conditions, + ) + def test_get_batch_initial_conditions( + self, mock_get_batch_initial_conditions: Mock, mock_optimize_acqf_homotopy: Mock + ) -> None: + mock_optimize_acqf_homotopy.return_value = ( + torch.tensor([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]], dtype=torch.double), + torch.tensor([1.0, 2.0], dtype=torch.double), + ) + acquisition = self.get_acquisition_function( + fixed_features=self.fixed_features, + options={"target_point": self.target_point}, + torch_opt_config=self.torch_opt_config_2, + ) + acquisition.optimize( + n=2, + search_space_digest=self.search_space_digest, + fixed_features=self.fixed_features, + rounding_func=self.rounding_func, + optimizer_options={Keys.NUM_RESTARTS: 3, Keys.RAW_SAMPLES: 32}, + ) + call_args = mock_get_batch_initial_conditions.call_args[1] + self.assertTrue( + torch.equal( + call_args["X_pareto"], + torch.tensor([[1.0, 2.0, 3.0]], dtype=torch.double), + ) + ) + self.assertTrue(torch.equal(call_args["target_point"], self.target_point)) + self.assertEqual(call_args["raw_samples"], 32) + self.assertEqual(call_args["num_restarts"], 3) + # Check the batch initial conditions + batch_initial_conditions = mock_optimize_acqf_homotopy.call_args[1][ + "batch_initial_conditions" + ] + self.assertEqual(batch_initial_conditions.shape, torch.Size([3, 1, 3])) + self.assertTrue(torch.all(batch_initial_conditions[:1] != 1.0)) + self.assertTrue(torch.all(batch_initial_conditions[1:, :, 0] == 1.0)) + self.assertTrue(torch.all(batch_initial_conditions[1:, :, 1:] != 1.0))