diff --git a/botorch/acquisition/fixed_feature.py b/botorch/acquisition/fixed_feature.py index 0f3b85faa7..31a746b954 100644 --- a/botorch/acquisition/fixed_feature.py +++ b/botorch/acquisition/fixed_feature.py @@ -20,6 +20,35 @@ from torch.nn import Module +def get_dtype_of_sequence(values: Sequence[Union[Tensor, float]]) -> torch.dtype: + """ + Return torch.float32 if everything is single-precision and torch.float64 + otherwise. + + Numbers (non-tensors) are double-precision. + """ + + def _is_single(value: Union[Tensor, float]) -> bool: + return isinstance(value, Tensor) and value.dtype == torch.float32 + + all_single_precision = all(_is_single(value) for value in values) + return torch.float32 if all_single_precision else torch.float64 + + +def get_device_of_sequence(values: Sequence[Union[Tensor, float]]) -> torch.dtype: + """ + CPU if everything is on the CPU; Cuda otherwise. + + Numbers (non-tensors) are considered to be on the CPU. + """ + + def _is_cuda(value: Union[Tensor, float]) -> bool: + return hasattr(value, "device") and value.device == torch.device("cuda") + + any_cuda = any(_is_cuda(value) for value in values) + return torch.device("cuda") if any_cuda else torch.device("cpu") + + class FixedFeatureAcquisitionFunction(AcquisitionFunction): """A wrapper around AquisitionFunctions to fix a subset of features. @@ -58,27 +87,25 @@ def __init__( """ Module.__init__(self) self.acq_func = acq_function - dtype = torch.float - device = torch.device("cpu") self.d = d + if isinstance(values, Tensor): new_values = values.detach().clone() else: + + dtype = get_dtype_of_sequence(values) + device = get_device_of_sequence(values) + new_values = [] for value in values: if isinstance(value, Number): - new_values.append(torch.tensor([float(value)])) + value = torch.tensor([value], dtype=dtype) else: - # if any value uses double, use double for all values - # likewise if any value uses cuda, use cuda for all values - dtype = value.dtype if value.dtype == torch.double else dtype - device = value.device if value.device.type == "cuda" else device if value.ndim == 0: # since we can't broadcast with zero-d tensors value = value.unsqueeze(0) - new_values.append(value.detach().clone()) - # move all values to same device - for i, val in enumerate(new_values): - new_values[i] = val.to(dtype=dtype, device=device) + value = value.detach().clone() + + new_values.append(value.to(dtype=dtype, device=device)) # There are 3 cases for when `values` is a `Sequence`. # 1) `values` == list of floats as earlier. diff --git a/test/acquisition/test_fixed_feature.py b/test/acquisition/test_fixed_feature.py index 8dcc02f1df..1f5218d228 100644 --- a/test/acquisition/test_fixed_feature.py +++ b/test/acquisition/test_fixed_feature.py @@ -6,14 +6,18 @@ import torch from botorch.acquisition.analytic import ExpectedImprovement -from botorch.acquisition.fixed_feature import FixedFeatureAcquisitionFunction +from botorch.acquisition.fixed_feature import ( + FixedFeatureAcquisitionFunction, + get_device_of_sequence, + get_dtype_of_sequence, +) from botorch.acquisition.monte_carlo import qExpectedImprovement from botorch.models import SingleTaskGP -from botorch.utils.testing import BotorchTestCase +from botorch.utils.testing import BotorchTestCase, MockAcquisitionFunction class TestFixedFeatureAcquisitionFunction(BotorchTestCase): - def test_fixed_features(self): + def test_fixed_features(self) -> None: train_X = torch.rand(5, 3, device=self.device) train_Y = train_X.norm(dim=-1, keepdim=True) model = SingleTaskGP(train_X, train_Y).to(device=self.device).eval() @@ -132,3 +136,63 @@ def test_fixed_features(self): ) with self.assertRaises(ValueError): EI_ff.X_pending + + def test_values_dtypes(self) -> None: + acqf = MockAcquisitionFunction() + + for input, d, expected_dtype in [ + (torch.tensor([0.0], dtype=torch.float32), 1, torch.float32), + (torch.tensor([0.0], dtype=torch.float64), 1, torch.float64), + ( + [ + torch.tensor([0.0], dtype=torch.float32), + torch.tensor([0.0], dtype=torch.float64), + ], + 2, + torch.float64, + ), + ([0.0], 1, torch.float64), + ([torch.tensor(0.0, dtype=torch.float32), 0.0], 2, torch.float64), + ]: + with self.subTest(input=input, d=d, expected_dtype=expected_dtype): + self.assertEqual(get_dtype_of_sequence(input), expected_dtype) + ff = FixedFeatureAcquisitionFunction( + acqf, d=d, columns=[2], values=input + ) + self.assertEqual(ff.values.dtype, expected_dtype) + + def test_values_devices(self) -> None: + + acqf = MockAcquisitionFunction() + cpu = torch.device("cpu") + cuda = torch.device("cuda") + + test_cases = [ + (torch.tensor([0.0], device=cpu), 1, cpu), + ([0.0], 1, cpu), + ([0.0, torch.tensor([0.0], device=cpu)], 2, cpu), + ] + + # Can only properly test this when running CUDA tests + if self.device == torch.cuda: + test_cases = test_cases + [ + (torch.tensor([0.0], device=cuda), 1, cuda), + ( + [ + torch.tensor([0.0], dtype=cpu), + torch.tensor([0.0], dtype=cuda), + ], + 2, + cuda, + ), + ([0.0], 1, cpu), + ([torch.tensor(0.0, dtype=cuda), 0.0], 2, cuda), + ] + + for input, d, expected_device in test_cases: + with self.subTest(input=input, d=d, expected_device=expected_device): + self.assertEqual(get_device_of_sequence(input), expected_device) + ff = FixedFeatureAcquisitionFunction( + acqf, d=d, columns=[2], values=input + ) + self.assertEqual(ff.values.device, expected_device) diff --git a/test/optim/test_optimize.py b/test/optim/test_optimize.py index d29b63e05c..a17a28c60a 100644 --- a/test/optim/test_optimize.py +++ b/test/optim/test_optimize.py @@ -1763,3 +1763,23 @@ def test_optimize_acqf_discrete_local_search(self): ) self.assertEqual(len(X), 20) self.assertAllClose(torch.unique(X, dim=0), X) + + def test_no_precision_loss_with_fixed_features(self) -> None: + + acqf = SquaredAcquisitionFunction() + + val = 1e-1 + fixed_features_list = [{0: val}] + + bounds = torch.stack( + [torch.zeros(2, dtype=torch.float64), torch.ones(2, dtype=torch.float64)] + ) + candidate, _ = optimize_acqf_mixed( + acqf, + bounds=bounds, + q=1, + num_restarts=1, + raw_samples=1, + fixed_features_list=fixed_features_list, + ) + self.assertEqual(candidate[0, 0].item(), val)