diff --git a/botorch/optim/optimize_acqf_mixed.py b/botorch/optim/optimize_acqf_mixed.py new file mode 100644 index 0000000000..ac153b9e23 --- /dev/null +++ b/botorch/optim/optimize_acqf_mixed.py @@ -0,0 +1,716 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import dataclasses +import warnings +from typing import Any, Callable + +import torch +from botorch.acquisition import AcquisitionFunction +from botorch.exceptions.errors import CandidateGenerationError, UnsupportedError +from botorch.exceptions.warnings import OptimizationWarning +from botorch.generation.gen import gen_candidates_scipy +from botorch.optim.initializers import initialize_q_batch +from botorch.optim.optimize import ( + _optimize_acqf, + _validate_sequential_inputs, + OptimizeAcqfInputs, +) +from botorch.optim.utils.acquisition_utils import fix_features, get_X_baseline +from botorch.utils.sampling import ( + draw_sobol_samples, + HitAndRunPolytopeSampler, + sparse_to_dense_constraints, +) +from botorch.utils.transforms import unnormalize +from pyre_extensions import assert_is_instance, none_throws +from torch import Tensor +from torch.quasirandom import SobolEngine + +# Default values. +# NOTE: When changing a default, update the corresponding value in the docstrings. +STD_CONT_PERTURBATION = 0.1 +RAW_SAMPLES = 1024 # Number of candidates from which to select starting points. +NUM_RESTARTS = 20 # Number of restarts of optimizer with different starting points. +MAX_BATCH_SIZE = 2048 # Maximum batch size. +MAX_ITER_ALTER = 64 # Maximum number of alternating iterations. +MAX_ITER_DISCRETE = 4 # Maximum number of discrete iterations. +MAX_ITER_CONT = 8 # Maximum number of continuous iterations. +# Maximum number of iterations for optimizing the continuous relaxation +# during initialization +MAX_ITER_INIT = 100 +CONVERGENCE_TOL = 1e-8 # Optimizer convergence tolerance. +DUPLICATE_TOL = 1e-6 # Tolerance for deduplicating initial candidates. + +SUPPORTED_OPTIONS = { + "initialization_strategy", + "tol", + "maxiter_alternating", + "maxiter_discrete", + "maxiter_continuous", + "maxiter_init", + "num_spray_points", + "std_cont_perturbation", + "batch_limit", + "init_batch_limit", +} +SUPPORTED_INITIALIZATION = {"continuous_relaxation", "equally_spaced", "random"} + + +def _filter_infeasible( + X: Tensor, inequality_constraints: list[tuple[Tensor, Tensor, float]] | None +) -> Tensor: + r"""Filters infeasible points from a set of points. + + NOTE: This function only supports intra-point constraints. This is validated + in `optimize_acqf_mixed_alternating`, so we do not repeat the + validation in here. + + Args: + X: A tensor of points of shape `n x d`. + inequality_constraints: A list of tuples (indices, coefficients, rhs), + with each tuple encoding an inequality constraint of the form + `\sum_i (X[indices[i]] * coefficients[i]) >= rhs`. `indices` and + `coefficients` should be torch tensors. See the docstring of + `make_scipy_linear_constraints` for an example. + + Returns: + The tensor `X` with infeasible points removed. + """ + if inequality_constraints is None: + return X + is_feasible = torch.ones(X.shape[:-1], device=X.device, dtype=torch.bool) + for idx, coef, rhs in inequality_constraints: + is_feasible &= (X[..., idx] * coef).sum(dim=-1) >= rhs + return X[is_feasible] + + +def get_nearest_neighbors( + current_x: Tensor, + bounds: Tensor, + discrete_dims: Tensor, +) -> Tensor: + r"""Generate all 1-Manhattan distance neighbors of a given input. The neighbors + are generated for the discrete dimensions only. + + NOTE: This assumes that `current_x` is detached and uses in-place operations, + which are known to be incompatible with autograd. + + Args: + current_x: The design to find the neighbors of. A tensor of shape `d`. + bounds: A `2 x d` tensor of lower and upper bounds for each column of `X`. + discrete_dims: A tensor of indices corresponding to binary and + integer parameters. + + Returns: + A tensor of shape `num_neighbors x d`, denoting all unique 1-Manhattan + distance neighbors. + """ + num_discrete = len(discrete_dims) + diag_ones = torch.eye(num_discrete, dtype=current_x.dtype, device=current_x.device) + # Neighbors obtained by increasing a discrete dimension by one. + plus_neighbors = current_x.repeat(num_discrete, 1) + plus_neighbors[:, discrete_dims] += diag_ones + plus_neighbors.clamp_(max=bounds[1]) + # Neighbors obtained by decreasing a discrete dimension by one. + minus_neighbors = current_x.repeat(num_discrete, 1) + minus_neighbors[:, discrete_dims] -= diag_ones + minus_neighbors.clamp_(min=bounds[0]) + unique_neighbors = torch.cat([minus_neighbors, plus_neighbors], dim=0).unique(dim=0) + # Also remove current_x if it is in unique_neighbors. + unique_neighbors = unique_neighbors[~(unique_neighbors == current_x).all(dim=-1)] + return unique_neighbors + + +def get_spray_points( + X_baseline: Tensor, + cont_dims: Tensor, + discrete_dims: Tensor, + bounds: Tensor, + num_spray_points: int, + std_cont_perturbation: float = STD_CONT_PERTURBATION, +) -> Tensor: + r"""Generate spray points by perturbing the Pareto optimal points. + + Given the points on the Pareto frontier, we create perturbations (spray points) + by adding Gaussian perturbation to the continuous parameters and 1-Manhattan + distance neighbors of the discrete (binary and integer) parameters. + + Args: + X_baseline: Tensor of best acquired points across BO run. + cont_dims: Indices of continuous parameters/input dimensions. + discrete_dims: Indices of binary/integer parameters/input dimensions. + bounds: A `2 x d` tensor of lower and upper bounds for each column of `X`. + num_spray_points: Number of spray points to return. + std_cont_perturbation: standard deviation of Normal perturbations of + continuous dimensions. Default is STD_CONT_PERTURBATION = 0.2. + + Returns: + A (num_spray_points x d)-dim tensor of perturbed points. + """ + dim = bounds.shape[-1] + device, dtype = X_baseline.device, X_baseline.dtype + perturb_nbors = torch.zeros(0, dim, device=device, dtype=dtype) + for x in X_baseline: + discrete_perturbs = get_nearest_neighbors( + current_x=x, bounds=bounds, discrete_dims=discrete_dims + ) + discrete_perturbs = discrete_perturbs[ + torch.randint(len(discrete_perturbs), (num_spray_points,), device=device) + ] + cont_perturbs = x[cont_dims] + std_cont_perturbation * torch.randn( + num_spray_points, len(cont_dims), device=device, dtype=dtype + ) + cont_perturbs = cont_perturbs.clamp_( + min=bounds[0, cont_dims], max=bounds[1, cont_dims] + ) + nbds = torch.zeros(num_spray_points, dim, device=device, dtype=dtype) + nbds[..., discrete_dims] = discrete_perturbs[..., discrete_dims] + nbds[..., cont_dims] = cont_perturbs + perturb_nbors = torch.cat([perturb_nbors, nbds], dim=0) + return perturb_nbors + + +def sample_feasible_points( + opt_inputs: OptimizeAcqfInputs, + discrete_dims: Tensor, + num_points: int, +) -> Tensor: + r"""Sample feasible points from the optimization domain. + + Feasibility is determined according to the discrete dimensions taking + integer values and the inequality constraints being satisfied. + + If there are no inequality constraints, Sobol is used to generate the base points. + Otherwise, we use the polytope sampler to generate the base points. The base points + are then rounded to the nearest integer values for the discrete dimensions, and + the infeasible points are filtered out (in case rounding leads to infeasibility). + + This method will do 10 attempts to generate `num_points` feasible points, and + return the points generated so far. If no points are generated, it will error out. + + Args: + opt_inputs: Common set of arguments for acquisition optimization. + discrete_dims: A tensor of indices corresponding to binary and + integer parameters. + num_points: The number of points to sample. + + Returns: + A tensor of shape `num_points x d` containing the sampled points. + """ + bounds = opt_inputs.bounds + all_points = torch.empty( + 0, bounds.shape[-1], device=bounds.device, dtype=bounds.dtype + ) + constraints = opt_inputs.inequality_constraints + if constraints is None: + # Generate base points using Sobol. + sobol_engine = SobolEngine(dimension=bounds.shape[-1], scramble=True) + + def generator(n: int) -> Tensor: + samples = sobol_engine.draw(n=n, dtype=bounds.dtype).to(bounds.device) + return unnormalize(X=samples, bounds=bounds) + + else: + # Generate base points using polytope sampler. + # Since we may generate many times, we initialize the sampler with burn-in + # to reduce the start-up cost for subsequent calls. + A, b = sparse_to_dense_constraints(d=bounds.shape[-1], constraints=constraints) + polytope_sampler = HitAndRunPolytopeSampler( + bounds=bounds, inequality_constraints=(-A, -b) + ) + + def generator(n: int) -> Tensor: + return polytope_sampler.draw(n=n) + + for _ in range(10): + num_remaining = num_points - len(all_points) + if num_remaining <= 0: + break + # Generate twice as many, since we're likely to filter out some points. + base_points = generator(n=num_remaining * 2) + # Round the discrete dimensions to the nearest integer. + base_points[:, discrete_dims] = base_points[:, discrete_dims].round() + # Fix the fixed features. + base_points = fix_features( + X=base_points, fixed_features=opt_inputs.fixed_features + ) + # Filter out infeasible points. + feasible_points = _filter_infeasible( + X=base_points, inequality_constraints=constraints + ) + all_points = torch.cat([all_points, feasible_points], dim=0) + + if len(all_points) == 0: + raise CandidateGenerationError( + "Could not generate any feasible starting points for mixed optimizer." + ) + return all_points[:num_points] + + +def generate_starting_points( + opt_inputs: OptimizeAcqfInputs, + discrete_dims: Tensor, + cont_dims: Tensor, +) -> tuple[Tensor, Tensor]: + """Generate initial starting points for the alternating optimization. + + This method attempts to generate the initial points using the specified + options and completes any missing points using `sample_feasible_points`. + + Args: + opt_inputs: Common set of arguments for acquisition optimization. + This function utilizes `acq_function`, `bounds`, `num_restarts`, + `raw_samples`, `options`, `fixed_features` and constraints + from `opt_inputs`. + discrete_dims: A tensor of indices corresponding to integer and + binary parameters. + cont_dims: A tensor of indices corresponding to continuous parameters. + + Returns: + A tuple of two tensors: a (num_restarts x d)-dim tensor of starting points + and a (num_restarts)-dim tensor of their respective acquisition values. + In rare cases, this method may return fewer than `num_restarts` points. + """ + bounds = opt_inputs.bounds + binary_dims = [] + for dim in discrete_dims: + if bounds[0, dim] == 0 and bounds[1, dim] == 1: + binary_dims.append(dim) + num_binary = len(binary_dims) + num_integer = len(discrete_dims) - num_binary + num_restarts = opt_inputs.num_restarts + raw_samples = none_throws(opt_inputs.raw_samples) + + options = opt_inputs.options or {} + initialization_strategy = options.get( + "initialization_strategy", + ( + "equally_spaced" + if num_integer == 0 and num_binary >= 2 + else "continuous_relaxation" + ), + ) + if initialization_strategy not in SUPPORTED_INITIALIZATION: + raise UnsupportedError( # pragma: no cover + f"Unsupported initialization strategy: {initialization_strategy}." + f"Supported strategies are: {SUPPORTED_INITIALIZATION}." + ) + + # Initialize `x_init_candts` here so that it's always defined as a tensor. + x_init_candts = torch.empty( + 0, bounds.shape[-1], device=bounds.device, dtype=bounds.dtype + ) + if initialization_strategy == "continuous_relaxation": + try: + # Optimize the acquisition function with continuous relaxation. + updated_opt_inputs = dataclasses.replace( + opt_inputs, + q=1, + return_best_only=False, + options={ + "maxiter": options.get("maxiter_init", MAX_ITER_INIT), + "batch_limit": options.get("batch_limit", MAX_BATCH_SIZE), + "init_batch_limit": options.get("init_batch_limit", MAX_BATCH_SIZE), + }, + ) + x_init_candts, _ = _optimize_acqf(opt_inputs=updated_opt_inputs) + x_init_candts = x_init_candts.squeeze(-2).detach() + except Exception as e: + warnings.warn( + "Failed to initialize using continuous relaxation. Using " + "`sample_feasible_points` for initialization. Original error " + f"message: {e}", + OptimizationWarning, + stacklevel=2, + ) + + if len(x_init_candts) == 0: + # Generate Sobol points as a fallback for `continuous_relaxation` and for + # further refinement in `equally_spaced` strategy. + x_init_candts = draw_sobol_samples(bounds=bounds, n=raw_samples, q=1) + x_init_candts = x_init_candts.squeeze(-2) + + if initialization_strategy == "equally_spaced": + if num_integer > 0: + raise ValueError( # pragma: no cover + "Equally spaced initialization is not supported with non-binary " + "discrete variables." + ) + # Picking initial points by equally spaced number of features/binary inputs. + k = torch.randint( + low=0, + high=num_binary, + size=(raw_samples,), + dtype=torch.int64, + device=bounds.device, + ) + x_init_candts[:, binary_dims] = 0 + binary_dims_t = torch.as_tensor(binary_dims, device=bounds.device) + for i, xi in enumerate(x_init_candts): + rand_binary_dims = binary_dims_t[ + torch.randperm(num_binary, device=xi.device)[: k[i]] + ] + x_init_candts[i, rand_binary_dims] = 1 + + num_spray_points = assert_is_instance( + options.get("num_spray_points", 20 if num_integer == 0 else 0), int + ) + if ( + num_spray_points > 0 + and (X_baseline := get_X_baseline(acq_function=opt_inputs.acq_function)) + is not None + ): + perturb_nbors = get_spray_points( + X_baseline=X_baseline, + cont_dims=cont_dims, + discrete_dims=discrete_dims, + bounds=bounds, + num_spray_points=num_spray_points, + std_cont_perturbation=assert_is_instance( + options.get("std_cont_perturbation", STD_CONT_PERTURBATION), float + ), + ) + x_init_candts = torch.cat([x_init_candts, perturb_nbors], dim=0) + + # Process the candidates to make sure they are all feasible. + x_init_candts[..., discrete_dims] = x_init_candts[..., discrete_dims].round() + x_init_candts = fix_features( + X=x_init_candts, fixed_features=opt_inputs.fixed_features + ) + x_init_candts = _filter_infeasible( + X=x_init_candts, inequality_constraints=opt_inputs.inequality_constraints + ) + + # If there are fewer than `num_restarts` feasible points, attempt to generate more. + if len(x_init_candts) < num_restarts: + new_x_init = sample_feasible_points( + opt_inputs=opt_inputs, + discrete_dims=discrete_dims, + num_points=num_restarts - len(x_init_candts), + ) + x_init_candts = torch.cat([x_init_candts, new_x_init], dim=0) + + with torch.no_grad(): + acq_vals = torch.cat( + [ + opt_inputs.acq_function(X_.unsqueeze(-2)) + for X_ in x_init_candts.split( + options.get("init_batch_limit", MAX_BATCH_SIZE) + ) + ] + ) + if len(x_init_candts) > num_restarts: + # If there are more than `num_restarts` feasible points, select a diverse + # set of initializers using Boltzmann sampling. + x_init_candts, acq_vals = initialize_q_batch( + X=x_init_candts, acq_vals=acq_vals, n=num_restarts + ) + return x_init_candts, acq_vals + + +def discrete_step( + opt_inputs: OptimizeAcqfInputs, + discrete_dims: Tensor, + current_x: Tensor, +) -> tuple[Tensor, Tensor]: + """Discrete nearest neighbour search. + + Args: + opt_inputs: Common set of arguments for acquisition optimization. + This function utilizes `acq_function`, `bounds`, `options` + and constraints from `opt_inputs`. + discrete_dims: A tensor of indices corresponding to binary and + integer parameters. + current_x: Starting point. A tensor of shape `d`. + + Returns: + A tuple of two tensors: a (d)-dim tensor of optimized point + and a scalar tensor of correspondins acquisition value. + """ + with torch.no_grad(): + current_acqval = opt_inputs.acq_function(current_x.unsqueeze(0)) + options = opt_inputs.options or {} + for _ in range( + assert_is_instance(options.get("maxiter_discrete", MAX_ITER_DISCRETE), int) + ): + x_neighbors = get_nearest_neighbors( + current_x=current_x.detach(), + bounds=opt_inputs.bounds, + discrete_dims=discrete_dims, + ) + x_neighbors = _filter_infeasible( + X=x_neighbors, inequality_constraints=opt_inputs.inequality_constraints + ) + if x_neighbors.numel() == 0: + # Exit gracefully with last point if there are no feasible neighbors. + break + with torch.no_grad(): + acq_vals = torch.cat( + [ + opt_inputs.acq_function(X_.unsqueeze(-2)) + for X_ in x_neighbors.split( + options.get("init_batch_limit", MAX_BATCH_SIZE) + ) + ] + ) + argmax = acq_vals.argmax() + improvement = acq_vals[argmax] - current_acqval + if improvement > 0: + current_acqval, current_x = acq_vals[argmax], x_neighbors[argmax] + if improvement <= options.get("tol", CONVERGENCE_TOL): + break + return current_x, current_acqval + + +def continuous_step( + opt_inputs: OptimizeAcqfInputs, + discrete_dims: Tensor, + current_x: Tensor, +) -> tuple[Tensor, Tensor]: + """Continuous search using L-BFGS-B through optimize_acqf. + + Args: + opt_inputs: Common set of arguments for acquisition optimization. + This function utilizes `acq_function`, `bounds`, `options`, + `fixed_features` and constraints from `opt_inputs`. + discrete_dims: A tensor of indices corresponding to binary and + integer parameters. + current_x: Starting point. A tensor of shape `d`. + + Returns: + A tuple of two tensors: a (1 x d)-dim tensor of optimized points + and a (1)-dim tensor of acquisition values. + """ + bounds = opt_inputs.bounds + options = opt_inputs.options or {} + if (current_x < bounds[0]).any() or (current_x > bounds[1]).any(): + raise ValueError("continuous_step requires current_x to be within bounds.") + if len(discrete_dims) == len(current_x): # nothing continuous to optimize + with torch.no_grad(): + return current_x, opt_inputs.acq_function(current_x.unsqueeze(0)) + + updated_opt_inputs = dataclasses.replace( + opt_inputs, + q=1, + num_restarts=1, + batch_initial_conditions=current_x.unsqueeze(0), + fixed_features={ + **dict(zip(discrete_dims.tolist(), current_x[discrete_dims])), + **(opt_inputs.fixed_features or {}), + }, + options={ + "maxiter": options.get("maxiter_continuous", MAX_ITER_CONT), + "tol": options.get("tol", CONVERGENCE_TOL), + "batch_limit": options.get("batch_limit", MAX_BATCH_SIZE), + }, + ) + return _optimize_acqf(opt_inputs=updated_opt_inputs) + + +def optimize_acqf_mixed_alternating( + acq_function: AcquisitionFunction, + bounds: Tensor, + discrete_dims: list[int], + options: dict[str, Any] | None = None, + q: int = 1, + raw_samples: int = RAW_SAMPLES, + num_restarts: int = NUM_RESTARTS, + post_processing_func: Callable[[Tensor], Tensor] | None = None, + sequential: bool = True, + fixed_features: dict[int, float] | None = None, + inequality_constraints: list[tuple[Tensor, Tensor, float]] | None = None, +) -> tuple[Tensor, Tensor]: + r""" + Optimizes acquisition function over mixed binary and continuous input spaces. + Multiple random restarting starting points are picked by evaluating a large set + of initial candidates. From each starting point, alternating discrete local search + and continuous optimization via (L-BFGS) is performed for a fixed number of + iterations. + + NOTE: This method assumes that all discrete variables are integer valued. + + # TODO: Support categorical variables. + + Args: + acq_function: BoTorch Acquisition function. + bounds: A `2 x d` tensor of lower and upper bounds for each column of `X`. + discrete_dims: A list of indices corresponding to integer and binary parameters. + options: Dictionary specifying optimization options. Supports the following: + - "initialization_strategy": Strategy used to generate the initial candidates. + "random", "continuous_relaxation" or "equally_spaced" (linspace style). + - "tol": The algorithm terminates if the absolute improvement in acquisition + value of one iteration is smaller than this number. + - "maxiter_alternating": Number of alternating steps. Defaults to 64. + - "maxiter_discrete": Maximum number of iterations in each discrete step. + Defaults to 4. + - "maxiter_continuous": Maximum number of iterations in each continuous step. + Defaults to 8. + - "num_spray_points": Number of spray points (around `X_baseline`) to add to + the points generated by the initialization strategy. Defaults to 20 if + all discrete variables are binary and to 0 otherwise. + - "std_cont_perturbation": Standard deviation of the normal perturbations of + the continuous variables used to generate the spray points. + Defaults to 0.1. + - "batch_limit": The maximum batch size for jointly evaluating candidates + during optimization. + - "init_batch_limit": The maximum batch size for jointly evaluating candidates + during initialization. During initialization, candidates are evaluated + in a `no_grad` context, which reduces memory usage. As a result, + `init_batch_limit` can be set to a larger value than `batch_limit`. + Defaults to `batch_limit`, if given. + q: Number of candidates. + raw_samples: Number of initial candidates used to select starting points from. + Defaults to 1024. + num_restarts: Number of random restarts. Defaults to 20. + post_processing_func: A function that post-processes an optimization result + appropriately (i.e., according to `round-trip` transformations). + sequential: Whether to use joint or sequential optimization across q-batch. + This currently only supports sequential optimization. + fixed_features: A map `{feature_index: value}` for features that + should be fixed to a particular value during generation. + inequality_constraints: A list of tuples (indices, coefficients, rhs), + with each tuple encoding an inequality constraint of the form + `\sum_i (X[indices[i]] * coefficients[i]) >= rhs`. `indices` and + `coefficients` should be torch tensors. See the docstring of + `make_scipy_linear_constraints` for an example. + + Returns: + A tuple of two tensors: a (q x d)-dim tensor of optimized points + and a (q)-dim tensor of their respective acquisition values. + """ + if sequential is False: # pragma: no cover + raise NotImplementedError( + "`optimize_acqf_mixed_alternating` only supports " + "sequential optimization." + ) + + fixed_features = fixed_features or {} + options = options or {} + options.setdefault("batch_limit", MAX_BATCH_SIZE) + options.setdefault("init_batch_limit", options["batch_limit"]) + if not (keys := set(options.keys())).issubset(SUPPORTED_OPTIONS): + unsupported_keys = keys.difference(SUPPORTED_OPTIONS) + raise UnsupportedError( + f"Received an unsupported option {unsupported_keys}. {SUPPORTED_OPTIONS=}." + ) + + opt_inputs = OptimizeAcqfInputs( + acq_function=acq_function, + bounds=bounds, + q=q, + num_restarts=num_restarts, + raw_samples=raw_samples, + options=options, + inequality_constraints=inequality_constraints, + equality_constraints=None, + nonlinear_inequality_constraints=None, + fixed_features=fixed_features, + post_processing_func=post_processing_func, + batch_initial_conditions=None, + return_best_only=True, + gen_candidates=gen_candidates_scipy, + sequential=sequential, + ) + _validate_sequential_inputs(opt_inputs=opt_inputs) + + base_X_pending = acq_function.X_pending if q > 1 else None + dim = bounds.shape[-1] + tkwargs: dict[str, Any] = {"device": bounds.device, "dtype": bounds.dtype} + # Remove fixed features from dims, so they don't get optimized. + discrete_dims = [dim for dim in discrete_dims if dim not in fixed_features] + if len(discrete_dims) == 0: + raise ValueError("There must be at least one discrete parameter.") + if not ( + isinstance(discrete_dims, list) + and len(set(discrete_dims)) == len(discrete_dims) + and min(discrete_dims) >= 0 + and max(discrete_dims) <= dim - 1 + ): + raise ValueError( + "`discrete_dims` must be a list with unique integers " + "between 0 and num_dims - 1." + ) + discrete_dims_t = torch.tensor( + discrete_dims, dtype=torch.long, device=tkwargs["device"] + ) + cont_dims = complement_indices_like(indices=discrete_dims_t, d=dim) + # Fixed features are all in cont_dims. Remove them, so they don't get optimized. + ff_idcs = torch.tensor( + list(fixed_features.keys()), dtype=torch.long, device=tkwargs["device"] + ) + cont_dims = cont_dims[(cont_dims.unsqueeze(-1) != ff_idcs).all(dim=-1)] + candidates = torch.empty(0, dim, **tkwargs) + for _q in range(q): + # Generate starting points. + best_X, best_acq_val = generate_starting_points( + opt_inputs=opt_inputs, + discrete_dims=discrete_dims_t, + cont_dims=cont_dims, + ) + + # TODO: Eliminate this for loop. Tensors being unequal sizes could potentially + # be handled by concatenating them rather than stacking, and keeping a list + # of indices. + for i in range(num_restarts): + alternate_steps = 0 + while alternate_steps < options.get("maxiter_alternating", MAX_ITER_ALTER): + starting_acq_val = best_acq_val[i].clone() + alternate_steps += 1 + for step in (discrete_step, continuous_step): + best_X[i], best_acq_val[i] = step( + opt_inputs=opt_inputs, + discrete_dims=discrete_dims_t, + current_x=best_X[i], + ) + + improvement = best_acq_val[i] - starting_acq_val + if improvement < options.get("tol", CONVERGENCE_TOL): + # Check for convergence + break + + new_candidate = best_X[torch.argmax(best_acq_val)].unsqueeze(0) + candidates = torch.cat([candidates, new_candidate], dim=-2) + # Update pending points to include the new candidate. + if q > 1: + acq_function.set_X_pending( + torch.cat([base_X_pending, candidates], dim=-2) + if base_X_pending is not None + else candidates + ) + if q > 1: + acq_function.set_X_pending(base_X_pending) + + if post_processing_func is not None: + candidates = post_processing_func(candidates) + + with torch.no_grad(): + acq_value = acq_function(candidates) # compute joint acquisition value + return candidates, acq_value + + +def complement_indices_like(indices: Tensor, d: int) -> Tensor: + r"""Computes a tensor of complement indices: {range(d) \\ indices}. + Same as complement_indices but returns an integer tensor like indices. + """ + return torch.tensor( + complement_indices(indices.tolist(), d), + device=indices.device, + dtype=indices.dtype, + ) + + +def complement_indices(indices: list[int], d: int) -> list[int]: + r"""Computes a list of complement indices: {range(d) \\ indices}. + + Args: + indices: a list of integers. + d: an integer dimension in which to compute the complement. + + Returns: + A list of integer indices. + """ + return sorted(set(range(d)).difference(indices)) diff --git a/sphinx/source/optim.rst b/sphinx/source/optim.rst index 581207a623..8401ce3b3b 100644 --- a/sphinx/source/optim.rst +++ b/sphinx/source/optim.rst @@ -41,6 +41,11 @@ Acquisition Function Optimization with Homotopy .. automodule:: botorch.optim.optimize_homotopy :members: +Acquisition Function Optimization with Mixed Integer Variables +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: botorch.optim.optimize_acqf_mixed + :members: + Closures ------------------------------------------- diff --git a/test/optim/test_optimize_acqf_mixed.py b/test/optim/test_optimize_acqf_mixed.py new file mode 100644 index 0000000000..a9b18435c2 --- /dev/null +++ b/test/optim/test_optimize_acqf_mixed.py @@ -0,0 +1,719 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from itertools import product +from typing import Any, Callable +from unittest import mock + +import torch +from botorch.acquisition.acquisition import AcquisitionFunction +from botorch.acquisition.analytic import ExpectedImprovement +from botorch.acquisition.logei import qLogNoisyExpectedImprovement +from botorch.exceptions.errors import CandidateGenerationError, UnsupportedError +from botorch.exceptions.warnings import OptimizationWarning +from botorch.generation.gen import gen_candidates_scipy +from botorch.models.deterministic import DeterministicModel +from botorch.models.gp_regression import SingleTaskGP +from botorch.optim.optimize import _optimize_acqf, OptimizeAcqfInputs +from botorch.optim.optimize_acqf_mixed import ( + complement_indices, + continuous_step, + discrete_step, + generate_starting_points, + get_nearest_neighbors, + get_spray_points, + optimize_acqf_mixed_alternating, + sample_feasible_points, +) +from botorch.utils.testing import BotorchTestCase, MockAcquisitionFunction +from pyre_extensions import assert_is_instance +from torch import Tensor + +OPT_MODULE = f"{optimize_acqf_mixed_alternating.__module__}" + + +def _make_opt_inputs( + acq_function: AcquisitionFunction, + bounds: Tensor, + q: int = 1, + num_restarts: int = 20, + raw_samples: int | None = 1024, + options: dict[str, bool | float | int | str] | None = None, + inequality_constraints: list[tuple[Tensor, Tensor, float]] | None = None, + equality_constraints: list[tuple[Tensor, Tensor, float]] | None = None, + nonlinear_inequality_constraints: list[tuple[Callable, bool]] | None = None, + fixed_features: dict[int, float] | None = None, +) -> OptimizeAcqfInputs: + r"""Helper to construct `OptimizeAcqfInputs` from limited inputs.""" + return OptimizeAcqfInputs( + acq_function=acq_function, + bounds=bounds, + q=q, + num_restarts=num_restarts, + raw_samples=raw_samples, + options=options or {}, + inequality_constraints=inequality_constraints, + equality_constraints=equality_constraints, + nonlinear_inequality_constraints=nonlinear_inequality_constraints, + fixed_features=fixed_features or {}, + post_processing_func=None, + batch_initial_conditions=None, + return_best_only=True, + gen_candidates=gen_candidates_scipy, + sequential=True, + ) + + +def get_hamming_neighbors(x_discrete: Tensor) -> Tensor: + r"""Generate all 1-Hamming distance neighbors of a binary input.""" + aye = torch.eye( + x_discrete.shape[-1], dtype=x_discrete.dtype, device=x_discrete.device + ) + X_loc = (x_discrete - aye).abs() + return X_loc + + +class QuadraticDeterministicModel(DeterministicModel): + """A simple quadratic model for testing.""" + + def __init__(self, root: Tensor): + """Initialize the model with the given root.""" + super().__init__() + self.register_buffer("root", root) + self._num_outputs = 1 + + def forward(self, X: Tensor): + # `keepdim`` is necessary for optimize_acqf to work correctly. + return -(X - self.root).square().sum(dim=-1, keepdim=True) + + +class TestOptimizeAcqfMixed(BotorchTestCase): + def setUp(self): + super().setUp() + self.single_bound = torch.tensor([[0.0], [1.0]], device=self.device) + self.tkwargs: dict[str, Any] = {"device": self.device, "dtype": torch.double} + + def _get_random_binary(self, d: int, k: int) -> Tensor: + """d: dimensionality of vector, k: number of ones.""" + X = torch.zeros(d, device=self.device) + X[:k] = 1 + return X[torch.randperm(d, device=self.device)] + + def _get_data(self) -> tuple[Tensor, Tensor, list[int], list[int]]: + with torch.random.fork_rng(): + torch.manual_seed(0) + binary_dims, cont_dims, dim = [0, 3, 4], [1, 2], 5 + train_X = torch.rand(3, dim, **self.tkwargs) + train_X[:, binary_dims] = train_X[:, binary_dims].round() + train_Y = train_X.sin().sum(dim=-1).unsqueeze(-1) + return train_X, train_Y, binary_dims, cont_dims + + def test_get_nearest_neighbors(self) -> None: + # For binary inputs, this should be equivalent to get_hamming_neighbors, + # with potentially different ordering of the outputs. + current_x = self._get_random_binary(16, 7) + bounds = self.single_bound.repeat(1, 16) + discrete_dims = torch.arange(16, dtype=torch.long, device=self.device) + self.assertTrue( + torch.equal( + get_nearest_neighbors( + current_x=current_x, bounds=bounds, discrete_dims=discrete_dims + ) + .sort(dim=0) + .values, + get_hamming_neighbors(current_x).sort(dim=0).values, + ) + ) + # Test with integer and continuous inputs. + current_x = torch.tensor([1.0, 0.0, 0.5], device=self.device) + bounds = torch.tensor([[0.0, 0.0, 0.0], [3.0, 2.0, 1.0]], device=self.device) + discrete_dims = torch.tensor([0, 1], device=self.device) + expected_neighbors = torch.tensor( + [[0.0, 0.0, 0.5], [2.0, 0.0, 0.5], [1.0, 1.0, 0.5]], device=self.device + ) + self.assertTrue( + torch.equal( + expected_neighbors.sort(dim=0).values, + get_nearest_neighbors( + current_x=current_x, bounds=bounds, discrete_dims=discrete_dims + ) + .sort(dim=0) + .values, + ) + ) + + def test_sample_feasible_points(self, with_constraints: bool = False) -> None: + bounds = torch.tensor([[0.0, 2.0, 0.0], [1.0, 5.0, 1.0]], **self.tkwargs) + opt_inputs = _make_opt_inputs( + acq_function=MockAcquisitionFunction(), + bounds=bounds, + fixed_features={0: 0.5}, + inequality_constraints=( + [ + ( # X[1] >= 4.0 + torch.tensor([1], device=self.device), + torch.tensor([1.0], **self.tkwargs), + 4.0, + ) + ] + if with_constraints + else None + ), + ) + # Check for error if feasible points cannot be found. + with self.assertRaisesRegex( + CandidateGenerationError, "Could not generate" + ), mock.patch( + f"{OPT_MODULE}._filter_infeasible", + return_value=torch.empty(0, 3, **self.tkwargs), + ): + sample_feasible_points( + opt_inputs=opt_inputs, + discrete_dims=torch.tensor([0, 2], device=self.device), + num_points=10, + ) + # Generate a number of points. + X = sample_feasible_points( + opt_inputs=opt_inputs, + discrete_dims=torch.tensor([1], device=self.device), + num_points=10, + ) + self.assertEqual(X.shape, torch.Size([10, 3])) + self.assertTrue(torch.all(X[..., 0] == 0.5)) + if with_constraints: + self.assertTrue(torch.all(X[..., 1] >= 4.0)) + self.assertAllClose(X[..., 1], X[..., 1].round()) + + def test_sample_feasible_points_with_constraints(self) -> None: + self.test_sample_feasible_points(with_constraints=True) + + def test_discrete_step(self): + d = 16 + bounds = self.single_bound.repeat(1, d) + root = torch.zeros(d, device=self.device) + model = QuadraticDeterministicModel(root) + k = 7 # number of ones + X = self._get_random_binary(d, k) + best_f = model(X) + ei = ExpectedImprovement(model, best_f=best_f) + + # this just tests that the quadratic model + ei works correctly + ei_x_none = ei(X[None]) + self.assertAllClose(ei_x_none, torch.zeros_like(ei_x_none), atol=1e-3) + self.assertGreaterEqual(ei_x_none.min(), 0.0) + ei_root_none = ei(root[None]) + self.assertAllClose(ei_root_none, torch.full_like(ei_root_none, k)) + self.assertGreaterEqual(ei_root_none.min(), 0.0) + + # each discrete step should reduce the best_f value by exactly 1 + binary_dims = torch.arange(d) + for i in range(k): + X, ei_val = discrete_step( + opt_inputs=_make_opt_inputs( + acq_function=ei, + bounds=bounds, + options={"maxiter_discrete": 1, "tol": 0, "init_batch_limit": 32}, + ), + discrete_dims=binary_dims, + current_x=X, + ) + ei_x_none = ei(X[None]) + self.assertAllClose(ei_x_none, torch.full_like(ei_x_none, i + 1)) + self.assertGreaterEqual(ei_x_none.min(), 0.0) + + self.assertAllClose(X, root) + + # Test with integer variables. + bounds[1, :2] = 2.0 + X = self._get_random_binary(d, k) + for i in range(k): + X, ei_val = discrete_step( + opt_inputs=_make_opt_inputs( + acq_function=ei, + bounds=bounds, + options={"maxiter_discrete": 1, "tol": 0, "init_batch_limit": 2}, + ), + discrete_dims=binary_dims, + current_x=X, + ) + ei_x_none = ei(X[None]) + self.assertAllClose(ei_x_none, torch.full_like(ei_x_none, i + 1)) + + self.assertAllClose(X, root) + + # Testing that convergence_tol exits early. + X = self._get_random_binary(d, k) + X_clone = X.clone() + # Setting convergence_tol to above one should ensure that we only take one step. + mock_acqf = mock.MagicMock(wraps=ei) + discrete_step( + opt_inputs=_make_opt_inputs( + acq_function=mock_acqf, + bounds=bounds, + options={"maxiter_discrete": 1, "tol": 1.5}, + ), + discrete_dims=binary_dims, + current_x=X_clone, + ) + # One call when entering, one call in the loop. + self.assertEqual(mock_acqf.call_count, 2) + + # Test that no steps are taken if there's no improvement. + mock_acqf = mock.MagicMock( + side_effect=lambda x: torch.zeros( + x.shape[:-1], device=x.device, dtype=x.dtype + ) + ) + X_clone, _ = discrete_step( + opt_inputs=_make_opt_inputs( + acq_function=mock_acqf, + bounds=bounds, + options={"maxiter_discrete": 1, "tol": 1.5, "init_batch_limit": 2}, + ), + discrete_dims=binary_dims, + current_x=X_clone, + ) + self.assertAllClose(X_clone, X) + + # test with fixed continuous dimensions + X = self._get_random_binary(d, k) + X[:2] = 1.0 # To satisfy the constraint. + k = int(X.sum().item()) + X_cont = torch.rand(3, device=self.device) + X = torch.cat((X, X_cont)) # appended continuous dimensions + + root = torch.zeros(d + 3, device=self.device) + bounds = self.single_bound.repeat(1, d + 3) + model = QuadraticDeterministicModel(root) + best_f = model(X) + ei = ExpectedImprovement(model, best_f=best_f) + for i in range(k - 2): + X, ei_val = discrete_step( + opt_inputs=_make_opt_inputs( + acq_function=ei, + bounds=bounds, + options={"maxiter_discrete": 1, "tol": 0, "init_batch_limit": 2}, + inequality_constraints=[ + ( # X[..., 0] + X[..., 1] >= 2. + torch.arange(2, dtype=torch.long, device=self.device), + torch.ones(2, device=self.device), + 2.0, + ) + ], + ), + discrete_dims=binary_dims, + current_x=X, + ) + self.assertAllClose(ei_val, torch.full_like(ei_val, i + 1)) + self.assertAllClose( + X[:2], torch.ones(2, device=self.device) + ) # satisfies constraints. + self.assertAllClose(X[2:d], root[2:d]) # binary optimized + self.assertAllClose(X[d:], X_cont) # continuous unchanged + + # Test with super-tight constraint. + X = torch.ones(d + 3, device=self.device) + X_new, _ = discrete_step( + opt_inputs=_make_opt_inputs( + acq_function=ei, + bounds=bounds, + inequality_constraints=[ # sum(X) >= d + 3 + ( + torch.arange(d + 3, dtype=torch.long, device=self.device), + torch.ones(d + 3, device=self.device), + d + 3, + ) + ], + ), + discrete_dims=binary_dims, + current_x=X, + ) + # No feasible neighbors, so we should get the same point back. + self.assertAllClose(X_new, X) + + def test_continuous_step(self): + d_cont = 16 + d_bin = 5 + d = d_cont + d_bin + bounds = self.single_bound.repeat(1, d) + + root = torch.rand(d, device=self.device) + model = QuadraticDeterministicModel(root) + + indices = torch.randperm(d, device=self.device) + binary_dims = indices[:d_bin] + cont_dims = indices[d_bin:] + + X = torch.zeros(d, device=self.device) + k = 7 # number of ones in binary vector + X[binary_dims] = self._get_random_binary(d_bin, k) + X[cont_dims] = torch.rand(d_cont, device=self.device) + + best_f = model(X) + ei = ExpectedImprovement(model, best_f=best_f) + X_clone, ei_val = continuous_step( + opt_inputs=_make_opt_inputs( + acq_function=ei, + bounds=bounds, + options={"maxiter_continuous": 32}, + ), + discrete_dims=binary_dims, + current_x=X.clone(), + ) + self.assertAllClose(X_clone[cont_dims], root[cont_dims]) + self.assertAllClose(X_clone[binary_dims], X[binary_dims]) + + # Test with fixed features and constraints. + fixed_binary = int(binary_dims[0]) + fixed_cont = int(cont_dims[0]) + X_ = X.clone() + X_[:2] = 1.0 # To satisfy the constraint. + X_clone, ei_val = continuous_step( + opt_inputs=_make_opt_inputs( + acq_function=ei, + bounds=bounds, + options={"maxiter_continuous": 32}, + fixed_features={fixed_binary: 1, fixed_cont: 0.5}, + inequality_constraints=[ + ( # X[..., 0] + X[..., 1] >= 2. + torch.tensor([0, 1], device=self.device), + torch.ones(2, device=self.device), + 2.0, + ) + ], + ), + discrete_dims=binary_dims, + current_x=X_, + ) + self.assertTrue( + torch.equal( + X_clone[[fixed_binary, fixed_cont]], + torch.tensor([1.0, 0.5], device=self.device), + ) + ) + self.assertAllClose(X_clone[:2], X_[:2]) + + # test edge case when all parameters are binary + root = torch.rand(d_bin) + model = QuadraticDeterministicModel(root) + ei = ExpectedImprovement(model, best_f=best_f) + X = self._get_random_binary(d_bin, k) + bounds = self.single_bound.repeat(1, d_bin) + binary_dims = torch.arange(d_bin) + X_out, ei_val = continuous_step( + opt_inputs=_make_opt_inputs( + acq_function=ei, + bounds=bounds, + options={"maxiter_continuous": 32}, + ), + discrete_dims=binary_dims, + current_x=X, + ) + self.assertTrue(X is X_out) # testing pointer equality for due to short cut + self.assertAllClose(ei_val, ei(X[None])) + + # Input outside of bounds raises error. + invalid_X = X.clone() + invalid_X[2] = 2 + with self.assertRaisesRegex( + ValueError, + "continuous_step requires current_x to be", + ): + X_clone, ei_val = continuous_step( + opt_inputs=_make_opt_inputs( + acq_function=ei, + bounds=bounds, + options={"maxiter_continuous": 32}, + ), + discrete_dims=binary_dims, + current_x=invalid_X, + ) + + def test_optimize_acqf_mixed_binary_only(self) -> None: + train_X, train_Y, binary_dims, cont_dims = self._get_data() + dim = len(binary_dims) + len(cont_dims) + bounds = self.single_bound.repeat(1, dim) + torch.manual_seed(0) + model = SingleTaskGP(train_X=train_X, train_Y=train_Y) + acqf = ExpectedImprovement(model=model, best_f=torch.max(train_Y)) + options = { + "initialization_strategy": "random", + "maxiter_alternating": 2, + "maxiter_discrete": 8, + "maxiter_continuous": 32, + "num_spray_points": 32, + "std_cont_perturbation": 1e-2, + } + X_baseline = train_X[torch.argmax(train_Y)].unsqueeze(0) + + # testing spray points + perturb_nbors = get_spray_points( + X_baseline=X_baseline, + discrete_dims=binary_dims, + cont_dims=cont_dims, + bounds=bounds, + num_spray_points=assert_is_instance(options["num_spray_points"], int), + ) + self.assertEqual(perturb_nbors.shape, (options["num_spray_points"], dim)) + # get single candidate + candidates, _ = optimize_acqf_mixed_alternating( + acq_function=acqf, + bounds=bounds, + discrete_dims=binary_dims, + options=options, + q=1, + raw_samples=32, + num_restarts=2, + ) + self.assertEqual(candidates.shape[-1], dim) + c_binary = candidates[:, binary_dims] + self.assertTrue(((c_binary == 0) | (c_binary == 1)).all()) + + # testing that continuous perturbations lead to lower acquisition values + std_pert = 1e-2 + perturbed_candidates = candidates.clone() + perturbed_candidates[..., cont_dims] += std_pert * torch.randn_like( + perturbed_candidates[..., cont_dims], device=self.device + ) + perturbed_candidates.clamp_(0, 1) + # Needs a loose tolerance to avoid flakiness + self.assertLess((acqf(perturbed_candidates) - acqf(candidates)).max(), 0.0) + + # testing that a discrete perturbation leads to a lower acquisition values + for i in binary_dims: + perturbed_candidates = candidates.clone() + perturbed_candidates[..., i] = 0 if perturbed_candidates[..., i] == 1 else 1 + self.assertLess((acqf(perturbed_candidates) - acqf(candidates)).max(), 0.0) + + # get multiple candidates + root = torch.zeros(dim, device=self.device) + model = QuadraticDeterministicModel(root) + acqf = qLogNoisyExpectedImprovement(model=model, X_baseline=train_X) + options["initialization_strategy"] = "equally_spaced" + candidates, _ = optimize_acqf_mixed_alternating( + acq_function=acqf, + bounds=bounds, + discrete_dims=binary_dims, + options=options, + q=3, + raw_samples=32, + num_restarts=2, + ) + self.assertEqual(candidates.shape, torch.Size([3, dim])) + c_binary = candidates[:, binary_dims] + self.assertTrue(((c_binary == 0) | (c_binary == 1)).all()) + + # testing that continuous perturbations lead to lower acquisition values + perturbed_candidates = candidates.clone() + perturbed_candidates[..., cont_dims] += std_pert * torch.randn_like( + perturbed_candidates[..., cont_dims], device=self.device + ) + # need to project continuous variables into [0, 1] for test to work + # since binaries are in [0, 1] too, we can clamp the entire tensor + perturbed_candidates.clamp_(0, 1) + self.assertLess((acqf(perturbed_candidates) - acqf(candidates)).max(), 0.0) + + # testing that any bit flip leads to a lower acquisition values + for i in binary_dims: + perturbed_candidates = candidates.clone() + perturbed_candidates[..., i] = torch.where( + perturbed_candidates[..., i] == 1, 0, 1 + ) + self.assertLess((acqf(perturbed_candidates) - acqf(candidates)).max(), 0.0) + + # Test only using one continuous variable + cont_dims = [1] + binary_dims = complement_indices(cont_dims, dim) + X_baseline[:, binary_dims] = X_baseline[:, binary_dims].round() + candidates, _ = optimize_acqf_mixed_alternating( + acq_function=acqf, + bounds=bounds, + discrete_dims=binary_dims, + options=options, + q=1, + raw_samples=20, + num_restarts=2, + post_processing_func=lambda x: x, + ) + self.assertEqual(candidates.shape[-1], dim) + c_binary = candidates[:, binary_dims + [2]] + self.assertTrue(((c_binary == 0) | (c_binary == 1)).all()) + # Only continuous parameters will raise an error. + with self.assertRaisesRegex( + ValueError, + "There must be at least one discrete parameter", + ): + optimize_acqf_mixed_alternating( + acq_function=acqf, + bounds=bounds, + discrete_dims=[], + options=options, + q=1, + raw_samples=20, + num_restarts=20, + ) + # Only discrete works fine. + candidates, _ = optimize_acqf_mixed_alternating( + acq_function=acqf, + bounds=bounds, + discrete_dims=list(range(dim)), + options=options, + q=1, + raw_samples=20, + num_restarts=20, + ) + self.assertTrue(((candidates == 0) | (candidates == 1)).all()) + # Invalid indices will raise an error. + with self.assertRaisesRegex( + ValueError, + "with unique integers between 0 and num_dims - 1", + ): + optimize_acqf_mixed_alternating( + acq_function=acqf, + bounds=bounds, + discrete_dims=[-1], + options=options, + q=1, + raw_samples=20, + num_restarts=2, + ) + + def test_optimize_acqf_mixed_integer(self) -> None: + # Testing with integer variables. + train_X, train_Y, binary_dims, cont_dims = self._get_data() + dim = len(binary_dims) + len(cont_dims) + # Update the data to introduce integer dimensions. + binary_dims = [0] + integer_dims = [3, 4] + discrete_dims = binary_dims + integer_dims + bounds = self.single_bound.repeat(1, dim) + bounds[1, 3:5] = 4.0 + # Update the model to have a different optimizer. + root = torch.tensor([0.0, 0.0, 0.0, 4.0, 4.0], device=self.device) + model = QuadraticDeterministicModel(root) + acqf = qLogNoisyExpectedImprovement(model=model, X_baseline=train_X) + with mock.patch( + f"{OPT_MODULE}._optimize_acqf", wraps=_optimize_acqf + ) as wrapped_optimize: + candidates, _ = optimize_acqf_mixed_alternating( + acq_function=acqf, + bounds=bounds, + discrete_dims=discrete_dims, + q=3, + raw_samples=32, + num_restarts=4, + options={ + "batch_limit": 5, + "init_batch_limit": 20, + "maxiter_alternating": 1, + }, + ) + self.assertEqual(candidates.shape, torch.Size([3, dim])) + self.assertEqual(candidates.shape[-1], dim) + c_binary = candidates[:, binary_dims] + self.assertTrue(((c_binary == 0) | (c_binary == 1)).all()) + c_integer = candidates[:, integer_dims] + self.assertTrue(torch.equal(c_integer, c_integer.round())) + self.assertTrue((c_integer == 4.0).any()) + # Check that we used continuous relaxation for initialization. + first_call_options = ( + wrapped_optimize.call_args_list[0].kwargs["opt_inputs"].options + ) + self.assertEqual( + first_call_options, + {"maxiter": 100, "batch_limit": 5, "init_batch_limit": 20}, + ) + + # Testing that continuous perturbations lead to lower acquisition values. + perturbed_candidates = candidates.clone() + perturbed_candidates[..., cont_dims] += 1e-2 * torch.randn_like( + perturbed_candidates[..., cont_dims], device=self.device + ) + perturbed_candidates[..., cont_dims].clamp_(0, 1) + self.assertLess((acqf(perturbed_candidates) - acqf(candidates)).max(), 1e-12) + # Testing that integer value change leads to a lower acquisition values. + for i, j in product(integer_dims, range(3)): + perturbed_candidates = candidates.repeat(2, 1, 1) + perturbed_candidates[0, j, i] += 1.0 + perturbed_candidates[1, j, i] -= 1.0 + perturbed_candidates.clamp_(bounds[0], bounds[1]) + self.assertLess( + (acqf(perturbed_candidates) - acqf(candidates)).max(), 1e-12 + ) + + # Test gracious fallback when continuous relaxation fails. + with mock.patch( + f"{OPT_MODULE}._optimize_acqf", + side_effect=RuntimeError, + ), self.assertWarnsRegex(OptimizationWarning, "Failed to initialize"): + candidates, _ = generate_starting_points( + opt_inputs=_make_opt_inputs( + acq_function=acqf, + bounds=bounds, + raw_samples=32, + num_restarts=4, + options={"batch_limit": 2, "init_batch_limit": 2}, + ), + discrete_dims=torch.tensor(discrete_dims, device=self.device), + cont_dims=torch.tensor(cont_dims, device=self.device), + ) + self.assertEqual(candidates.shape, torch.Size([4, dim])) + + # Test unsupported options. + with self.assertRaisesRegex(UnsupportedError, "unsupported option"): + optimize_acqf_mixed_alternating( + acq_function=acqf, + bounds=bounds, + discrete_dims=discrete_dims, + options={"invalid": 5, "init_batch_limit": 20}, + ) + + # Test with fixed features and constraints. Using both discrete and continuous. + constraint = ( # X[..., 0] + X[..., 1] >= 1. + torch.tensor([0, 1], device=self.device), + torch.ones(2, device=self.device), + 1.0, + ) + candidates, _ = optimize_acqf_mixed_alternating( + acq_function=acqf, + bounds=bounds, + discrete_dims=integer_dims, + q=3, + raw_samples=32, + num_restarts=4, + options={"batch_limit": 5, "init_batch_limit": 20}, + fixed_features={1: 0.5, 3: 2}, + inequality_constraints=[constraint], + ) + self.assertAllClose( + candidates[:, [0, 1, 3]], + torch.tensor([0.5, 0.5, 2.0], device=self.device).repeat(3, 1), + ) + + # Test fallback when initializer cannot generate enough feasible points. + with mock.patch( + f"{OPT_MODULE}._optimize_acqf", + return_value=( + torch.zeros(4, 1, dim, **self.tkwargs), + torch.zeros(4, **self.tkwargs), + ), + ), mock.patch( + f"{OPT_MODULE}.sample_feasible_points", wraps=sample_feasible_points + ) as wrapped_sample_feasible: + generate_starting_points( + opt_inputs=_make_opt_inputs( + acq_function=acqf, + bounds=bounds, + raw_samples=32, + num_restarts=4, + inequality_constraints=[constraint], + ), + discrete_dims=torch.tensor(discrete_dims, device=self.device), + cont_dims=torch.tensor(cont_dims, device=self.device), + ) + wrapped_sample_feasible.assert_called_once() + # Should request 4 candidates, since all 4 are infeasible. + self.assertEqual(wrapped_sample_feasible.call_args.kwargs["num_points"], 4)