Skip to content

Commit

Permalink
Fix flaky test_get_best_f_mc test (#1969)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1969

This test started to become flaky recently, presumably b/c to some minor changes in the numerics on GPU computations. Replacing equality with closeness check fixes this.

Reviewed By: saitcakmak

Differential Revision: D48033053

fbshipit-source-id: 5261ec92ca67aef82567c994aeb6ca802a1e07a4
  • Loading branch information
Balandat authored and facebook-github-bot committed Aug 3, 2023
1 parent f20697a commit ffb7bd9
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions test/acquisition/test_input_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,22 +140,22 @@ def test_get_best_f_mc(self):
best_f = get_best_f_mc(training_data=self.blockX_blockY)
self.assertEqual(best_f, get_best_f_mc(self.blockX_blockY[0]))

best_f_expected = self.blockX_blockY[0].Y().squeeze().max()
self.assertEqual(best_f, best_f_expected)
best_f_expected = self.blockX_blockY[0].Y().max(dim=0).values
self.assertAllClose(best_f, best_f_expected)
with self.assertRaisesRegex(UnsupportedError, "require an objective"):
get_best_f_mc(training_data=self.blockX_multiY)
obj = LinearMCObjective(weights=torch.rand(2))
best_f = get_best_f_mc(training_data=self.blockX_multiY, objective=obj)

multi_Y = torch.cat([d.Y() for d in self.blockX_multiY.values()], dim=-1)
best_f_expected = (multi_Y @ obj.weights).amax(dim=-1, keepdim=True)
self.assertEqual(best_f, best_f_expected)
self.assertAllClose(best_f, best_f_expected)
post_tf = ScalarizedPosteriorTransform(weights=torch.ones(2))
best_f = get_best_f_mc(
training_data=self.blockX_multiY, posterior_transform=post_tf
)
best_f_expected = (multi_Y.sum(dim=-1)).amax(dim=-1, keepdim=True)
self.assertEqual(best_f, best_f_expected)
self.assertAllClose(best_f, best_f_expected)

@mock.patch("botorch.acquisition.input_constructors.optimize_acqf")
def test_optimize_objective(self, mock_optimize_acqf):
Expand Down

0 comments on commit ffb7bd9

Please sign in to comment.