Skip to content

Commit

Permalink
Fix AugmentedSingleTaskGP test
Browse files Browse the repository at this point in the history
  • Loading branch information
esantorella committed Sep 9, 2024
1 parent 29df54a commit a074d29
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions test_community/models/test_gp_regression_multisource.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ def _get_model_and_data(
None if train_Yvar else get_gaussian_likelihood_with_gamma_prior()
),
}
model = SingleTaskAugmentedGP(**model_kwargs, **extra_model_kwargs)
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=OptimizationWarning)
model = SingleTaskAugmentedGP(**model_kwargs, **extra_model_kwargs)
return model, model_kwargs

def test_data_init(self):
Expand Down Expand Up @@ -139,8 +141,8 @@ def test_get_reliable_observation(self):
self.assertListEqual(res.tolist(), true_res.tolist())

def test_gp(self):
bounds = torch.tensor([[-1.0], [1.0]])
d = 5
bounds = torch.stack((torch.full((d - 1,), -1), torch.ones(d - 1)))
for batch_shape, dtype, use_octf, use_intf, train_Yvar in itertools.product(
(torch.Size(), torch.Size([2])),
(torch.float, torch.double),
Expand All @@ -151,7 +153,7 @@ def test_gp(self):
tkwargs = {"device": self.device, "dtype": dtype}
octf = Standardize(m=1, batch_shape=torch.Size()) if use_octf else None
intf = (
Normalize(d=1, bounds=bounds.to(**tkwargs), transform_on_train=True)
Normalize(d=d - 1, bounds=bounds.to(**tkwargs), transform_on_train=True)
if use_intf
else None
)
Expand Down

0 comments on commit a074d29

Please sign in to comment.