diff --git a/botorch/acquisition/utils.py b/botorch/acquisition/utils.py index 198228409a..ae4f054321 100644 --- a/botorch/acquisition/utils.py +++ b/botorch/acquisition/utils.py @@ -335,15 +335,16 @@ def prune_inferior_points( marginalize_dim=marginalize_dim, ) if infeas.any(): - # set infeasible points to worse than worst objective - # across all samples + # set infeasible points to worse than worst objective across all samples + # Use clone() here to avoid deprecated `index_put_` on an expanded tensor + obj_vals = obj_vals.clone() obj_vals[infeas] = obj_vals.min() - 1 is_best = torch.argmax(obj_vals, dim=-1) idcs, counts = torch.unique(is_best, return_counts=True) if len(idcs) > max_points: - counts, order_idcs = torch.sort(counts, descending=True) + counts, order_idcs = torch.sort(counts, stable=True, descending=True) idcs = order_idcs[:max_points] return X[idcs] diff --git a/test/acquisition/multi_objective/test_utils.py b/test/acquisition/multi_objective/test_utils.py index acdfddbc95..786c72ad9c 100644 --- a/test/acquisition/multi_objective/test_utils.py +++ b/test/acquisition/multi_objective/test_utils.py @@ -130,13 +130,14 @@ def test_prune_inferior_points_multi_objective(self): X_pruned = prune_inferior_points_multi_objective( model=mm, X=X, ref_point=ref_point, max_frac=2 / 3 ) - if self.device.type == "cuda": - # sorting has different order on cuda - self.assertTrue( - torch.equal(X_pruned, X[[2, 1]]) or torch.equal(X_pruned, X[[1, 2]]) + # sorting has different order on cuda + X_expected = X[1:3] if self.device.type == "cuda" else X[:2] + self.assertTrue( + torch.equal( + torch.sort(X_pruned, stable=True).values, + torch.sort(X_expected, stable=True).values, ) - else: - self.assertTrue(torch.equal(X_pruned, X[:2])) + ) # test that zero-probability is in fact pruned samples[2, 0, 0] = 10 with mock.patch.object(MockPosterior, "rsample", return_value=samples): diff --git a/test/acquisition/test_utils.py b/test/acquisition/test_utils.py index c8a6484cca..c9552da886 100644 --- a/test/acquisition/test_utils.py +++ b/test/acquisition/test_utils.py @@ -270,11 +270,14 @@ def test_prune_inferior_points(self): with mock.patch.object(MockPosterior, "rsample", return_value=samples): mm = MockModel(MockPosterior(samples=samples)) X_pruned = prune_inferior_points(model=mm, X=X, max_frac=2 / 3) - if self.device.type == "cuda": - # sorting has different order on cuda - self.assertTrue(torch.equal(X_pruned, torch.stack([X[2], X[1]], dim=0))) - else: - self.assertTrue(torch.equal(X_pruned, X[:2])) + # sorting has different order on cuda + X_expected = X[1:3] if self.device.type == "cuda" else X[:2] + self.assertTrue( + torch.equal( + torch.sort(X_pruned, stable=True).values, + torch.sort(X_expected, stable=True).values, + ) + ) # test that zero-probability is in fact pruned samples[2, 0, 0] = 10 with mock.patch.object(MockPosterior, "rsample", return_value=samples): @@ -289,11 +292,7 @@ def test_prune_inferior_points(self): device=self.device, dtype=dtype, ) - mm = MockModel( - MockPosterior( - samples=samples, - ) - ) + mm = MockModel(MockPosterior(samples=samples)) X_pruned = prune_inferior_points( model=mm, X=X,