Skip to content

Commit

Permalink
Dtype fix (floats should be double-precision), cleanup, units
Browse files Browse the repository at this point in the history
  • Loading branch information
esantorella committed Jul 24, 2023
1 parent 81f7c98 commit a5946e3
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 14 deletions.
49 changes: 38 additions & 11 deletions botorch/acquisition/fixed_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,35 @@
from torch.nn import Module


def get_dtype_of_sequence(values: Sequence[Union[Tensor, float]]) -> torch.dtype:
"""
Return torch.float32 if everything is single-precision and torch.float64
otherwise.
Numbers (non-tensors) are double-precision.
"""

def _is_single(value: Union[Tensor, float]) -> bool:
return isinstance(value, Tensor) and value.dtype == torch.float32

all_single_precision = all(_is_single(value) for value in values)
return torch.float32 if all_single_precision else torch.float64


def get_device_of_sequence(values: Sequence[Union[Tensor, float]]) -> torch.dtype:
"""
CPU if everything is on the CPU; Cuda otherwise.
Numbers (non-tensors) are considered to be on the CPU.
"""

def _is_cuda(value: Union[Tensor, float]) -> bool:
return hasattr(value, "device") and value.device == torch.device("cuda")

any_cuda = any(_is_cuda(value) for value in values)
return torch.device("cuda") if any_cuda else torch.device("cpu")


class FixedFeatureAcquisitionFunction(AcquisitionFunction):
"""A wrapper around AquisitionFunctions to fix a subset of features.
Expand Down Expand Up @@ -58,27 +87,25 @@ def __init__(
"""
Module.__init__(self)
self.acq_func = acq_function
dtype = torch.float
device = torch.device("cpu")
self.d = d

if isinstance(values, Tensor):
new_values = values.detach().clone()
else:

dtype = get_dtype_of_sequence(values)
device = get_device_of_sequence(values)

new_values = []
for value in values:
if isinstance(value, Number):
new_values.append(torch.tensor([float(value)]))
value = torch.tensor([value], dtype=dtype)
else:
# if any value uses double, use double for all values
# likewise if any value uses cuda, use cuda for all values
dtype = value.dtype if value.dtype == torch.double else dtype
device = value.device if value.device.type == "cuda" else device
if value.ndim == 0: # since we can't broadcast with zero-d tensors
value = value.unsqueeze(0)
new_values.append(value.detach().clone())
# move all values to same device
for i, val in enumerate(new_values):
new_values[i] = val.to(dtype=dtype, device=device)
value = value.detach().clone()

new_values.append(value.to(dtype=dtype, device=device))

# There are 3 cases for when `values` is a `Sequence`.
# 1) `values` == list of floats as earlier.
Expand Down
70 changes: 67 additions & 3 deletions test/acquisition/test_fixed_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,18 @@

import torch
from botorch.acquisition.analytic import ExpectedImprovement
from botorch.acquisition.fixed_feature import FixedFeatureAcquisitionFunction
from botorch.acquisition.fixed_feature import (
FixedFeatureAcquisitionFunction,
get_device_of_sequence,
get_dtype_of_sequence,
)
from botorch.acquisition.monte_carlo import qExpectedImprovement
from botorch.models import SingleTaskGP
from botorch.utils.testing import BotorchTestCase
from botorch.utils.testing import BotorchTestCase, MockAcquisitionFunction


class TestFixedFeatureAcquisitionFunction(BotorchTestCase):
def test_fixed_features(self):
def test_fixed_features(self) -> None:
train_X = torch.rand(5, 3, device=self.device)
train_Y = train_X.norm(dim=-1, keepdim=True)
model = SingleTaskGP(train_X, train_Y).to(device=self.device).eval()
Expand Down Expand Up @@ -132,3 +136,63 @@ def test_fixed_features(self):
)
with self.assertRaises(ValueError):
EI_ff.X_pending

def test_values_dtypes(self) -> None:
acqf = MockAcquisitionFunction()

for input, d, expected_dtype in [
(torch.tensor([0.0], dtype=torch.float32), 1, torch.float32),
(torch.tensor([0.0], dtype=torch.float64), 1, torch.float64),
(
[
torch.tensor([0.0], dtype=torch.float32),
torch.tensor([0.0], dtype=torch.float64),
],
2,
torch.float64,
),
([0.0], 1, torch.float64),
([torch.tensor(0.0, dtype=torch.float32), 0.0], 2, torch.float64),
]:
with self.subTest(input=input, d=d, expected_dtype=expected_dtype):
self.assertEqual(get_dtype_of_sequence(input), expected_dtype)
ff = FixedFeatureAcquisitionFunction(
acqf, d=d, columns=[2], values=input
)
self.assertEqual(ff.values.dtype, expected_dtype)

def test_values_devices(self) -> None:

acqf = MockAcquisitionFunction()
cpu = torch.device("cpu")
cuda = torch.device("cuda")

test_cases = [
(torch.tensor([0.0], device=cpu), 1, cpu),
([0.0], 1, cpu),
([0.0, torch.tensor([0.0], device=cpu)], 2, cpu),
]

# Can only properly test this when running CUDA tests
if self.device == torch.cuda:
test_cases = test_cases + [
(torch.tensor([0.0], device=cuda), 1, cuda),
(
[
torch.tensor([0.0], dtype=cpu),
torch.tensor([0.0], dtype=cuda),
],
2,
cuda,
),
([0.0], 1, cpu),
([torch.tensor(0.0, dtype=cuda), 0.0], 2, cuda),
]

for input, d, expected_device in test_cases:
with self.subTest(input=input, d=d, expected_device=expected_device):
self.assertEqual(get_device_of_sequence(input), expected_device)
ff = FixedFeatureAcquisitionFunction(
acqf, d=d, columns=[2], values=input
)
self.assertEqual(ff.values.device, expected_device)
20 changes: 20 additions & 0 deletions test/optim/test_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1763,3 +1763,23 @@ def test_optimize_acqf_discrete_local_search(self):
)
self.assertEqual(len(X), 20)
self.assertAllClose(torch.unique(X, dim=0), X)

def test_no_precision_loss_with_fixed_features(self) -> None:

acqf = SquaredAcquisitionFunction()

val = 1e-1
fixed_features_list = [{0: val}]

bounds = torch.stack(
[torch.zeros(2, dtype=torch.float64), torch.ones(2, dtype=torch.float64)]
)
candidate, _ = optimize_acqf_mixed(
acqf,
bounds=bounds,
q=1,
num_restarts=1,
raw_samples=1,
fixed_features_list=fixed_features_list,
)
self.assertEqual(candidate[0, 0].item(), val)

0 comments on commit a5946e3

Please sign in to comment.