Skip to content

Commit

Permalink
Add test of correct result for SEBO (#2943)
Browse files Browse the repository at this point in the history
Summary:

Adds a unit test  that SEBO produces the correct result in a
controlled test case. Changes the existing test case data to match the desired
test case for this test, which is a linear function with two unused variables.
We verify that we maximize the used variable, while setting the unused
variables to their target point values.

Depends on D60089765

Reviewed By: dme65

Differential Revision: D64720328
  • Loading branch information
bletham authored and facebook-github-bot committed Oct 24, 2024
1 parent 322ea96 commit be80e88
Showing 1 changed file with 52 additions and 18 deletions.
70 changes: 52 additions & 18 deletions ax/models/torch/tests/test_sebo.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,18 +53,34 @@ def setUp(self) -> None:
tkwargs: dict[str, Any] = {"dtype": torch.double}
self.botorch_model_class = SingleTaskGP
self.surrogates = Surrogate(botorch_model_class=self.botorch_model_class)
self.X = torch.tensor([[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]], **tkwargs)
self.target_point = torch.tensor([1.0, 1.0, 1.0], **tkwargs)
self.Y = torch.tensor([[3.0], [4.0]], **tkwargs)
self.Yvar = torch.tensor([[0.0], [2.0]], **tkwargs)
# Function is f(x) = x_1 on [0, 1]^3, target point of [0.5, 0.5, 0.5].
# Optimal soln is [1.0, 0.5, 0.5]
self.X = torch.tensor(
[
[0.0, 0.0, 0.0],
[0.6, 0.0, 0.0],
[0.9, 0.0, 0.0],
[0.0, 0.5, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 0.5],
[0.0, 0.0, 1.0],
[0.9, 1.0, 1.0],
[0.9, 0.0, 1.0],
[0.9, 1.0, 0.0],
],
**tkwargs,
)
self.target_point = torch.tensor([0.5, 0.5, 0.5], **tkwargs)
self.Y = self.X[:, 0].unsqueeze(-1)
self.Yvar = 0.1 * torch.ones_like(self.Y)
self.training_data = [
SupervisedDataset(
X=self.X, Y=self.Y, feature_names=["a", "b", "c"], outcome_names=["m1"]
)
]
self.search_space_digest = SearchSpaceDigest(
feature_names=["a", "b", "c"],
bounds=[(0.0, 10.0), (0.0, 10.0), (0.0, 10.0)],
bounds=[(0.0, 1.0), (0.0, 1.0), (0.0, 1.0)],
target_values={2: 1.0},
)
self.surrogates.fit(
Expand All @@ -79,20 +95,20 @@ def setUp(self) -> None:
self.objective_thresholds = torch.tensor([1.0], **tkwargs)
self.objective_thresholds_sebo = torch.tensor([1.0, 3.0], **tkwargs)

self.pending_observations = [torch.tensor([[1.0, 3.0, 4.0]], **tkwargs)]
self.pending_observations = [torch.tensor([[0.5, 0.5, 0.5]], **tkwargs)]
self.outcome_constraints = (
torch.tensor([[1.0]], **tkwargs),
torch.tensor([[0.5]], **tkwargs),
torch.tensor([[2.0]], **tkwargs),
)
self.outcome_constraints_sebo = (
torch.tensor([[1.0, 0.0]], **tkwargs),
torch.tensor([[0.5]], **tkwargs),
torch.tensor([[2.0]], **tkwargs),
)
self.linear_constraints = None
self.fixed_features = {1: 2.0}
self.fixed_features = {1: 1.0}
self.options = {"best_f": 0.0, "target_point": self.target_point}
self.inequality_constraints = [
(torch.tensor([0, 1], **tkwargs), torch.tensor([-1.0, 1.0], **tkwargs), 1)
(torch.tensor([0, 1], **tkwargs), torch.tensor([1.0, -1.0], **tkwargs), 0)
]
self.rounding_func = lambda x: x
self.optimizer_options = {Keys.NUM_RESTARTS: 40, Keys.RAW_SAMPLES: 1024}
Expand Down Expand Up @@ -206,7 +222,7 @@ def test_init(self) -> None:
@mock.patch(f"{ACQUISITION_PATH}.optimize_acqf")
def test_optimize_l1(self, mock_optimize_acqf: Mock) -> None:
mock_optimize_acqf.return_value = (
torch.tensor([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]], dtype=torch.double),
torch.tensor([[0.1, 0.1, 0.1], [0.2, 0.2, 0.2]], dtype=torch.double),
torch.tensor([1.0, 2.0], dtype=torch.double),
)
acquisition = self.get_acquisition_function(
Expand All @@ -233,7 +249,7 @@ def test_optimize_l1(self, mock_optimize_acqf: Mock) -> None:
@mock.patch(f"{SEBOACQUISITION_PATH}.optimize_acqf_homotopy")
def test_optimize_l0(self, mock_optimize_acqf_homotopy: Mock) -> None:
mock_optimize_acqf_homotopy.return_value = (
torch.tensor([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]], dtype=torch.double),
torch.tensor([[0.1, 0.1, 0.1], [0.2, 0.2, 0.2]], dtype=torch.double),
torch.tensor([1.0, 2.0], dtype=torch.double),
)
acquisition = self.get_acquisition_function(
Expand Down Expand Up @@ -296,6 +312,25 @@ def test_optimize_l0(self, mock_optimize_acqf_homotopy: Mock) -> None:
optimizer_options=self.optimizer_options,
)

def test_optimization_result(self) -> None:
acquisition = self.get_acquisition_function(
options={"penalty": "L0_norm", "target_point": self.target_point},
)
# Function is f(x) = x_1 on [0, 1]^3, target point of [0.5, 0.5, 0.5].
# Optimal soln is [1.0, 0.5, 0.5]
optimizer_options = {Keys.NUM_RESTARTS: 1, Keys.RAW_SAMPLES: 1000}
xopt, _, _ = acquisition.optimize(
n=1,
search_space_digest=self.search_space_digest,
optimizer_options=optimizer_options, # pyre-ignore
)
self.assertTrue(
torch.allclose(xopt[0, 1:], torch.tensor([0.5, 0.5], **self.tkwargs))
)
self.assertTrue(
torch.isclose(xopt[0, 0], torch.tensor(1.0, **self.tkwargs), atol=0.3)
)

def test_clamp_to_target(self) -> None:
X = torch.tensor(
[[0.5, 0.01, 0.5], [0.05, 0.5, 0.95], [0.1, 0.02, 0.06]], **self.tkwargs
Expand All @@ -318,11 +353,11 @@ def test_get_batch_initial_conditions(
self, mock_get_batch_initial_conditions: Mock, mock_optimize_acqf_homotopy: Mock
) -> None:
mock_optimize_acqf_homotopy.return_value = (
torch.tensor([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]], dtype=torch.double),
torch.tensor([[0.1, 0.1, 0.1], [0.2, 0.2, 0.2]], dtype=torch.double),
torch.tensor([1.0, 2.0], dtype=torch.double),
)
acquisition = self.get_acquisition_function(
fixed_features=self.fixed_features,
fixed_features={1: 0.5},
options={"target_point": self.target_point},
torch_opt_config=self.torch_opt_config_2,
)
Expand All @@ -337,7 +372,7 @@ def test_get_batch_initial_conditions(
self.assertTrue(
torch.equal(
call_args["X_pareto"],
torch.tensor([[1.0, 2.0, 3.0]], dtype=torch.double),
torch.tensor([[0.5, 0.5, 0.5]], dtype=torch.double),
)
)
self.assertTrue(torch.equal(call_args["target_point"], self.target_point))
Expand All @@ -348,6 +383,5 @@ def test_get_batch_initial_conditions(
"batch_initial_conditions"
]
self.assertEqual(batch_initial_conditions.shape, torch.Size([3, 1, 3]))
self.assertTrue(torch.all(batch_initial_conditions[:1] != 1.0))
self.assertTrue(torch.all(batch_initial_conditions[1:, :, 0] == 1.0))
self.assertTrue(torch.all(batch_initial_conditions[1:, :, 1:] != 1.0))
self.assertTrue(torch.all(batch_initial_conditions[:1] != 0.5))
self.assertTrue(torch.all(batch_initial_conditions[1:, :, 1] == 0.5))

0 comments on commit be80e88

Please sign in to comment.