diff --git a/botorch/models/model.py b/botorch/models/model.py index 67f10fc9dc..8b20ff0edd 100644 --- a/botorch/models/model.py +++ b/botorch/models/model.py @@ -576,14 +576,13 @@ def fantasize( The constructed fantasy model. """ if evaluation_mask is not None: - if ( - evaluation_mask.ndim != 2 - and evaluation_mask.shape[0] != X.shape[-2] - and evaluation_mask.shape[1] != self.num_outputs + if evaluation_mask.ndim != 2 or evaluation_mask.shape != torch.Size( + [X.shape[-2], self.num_outputs] ): raise BotorchTensorDimensionError( - f"Expected evaluation_mask of shape {X.shape[0]} " - f"x {self.num_outputs}, but got {evaluation_mask.shape}." + f"Expected evaluation_mask of shape `{X.shape[0]} " + f"x {self.num_outputs}`, but got `" + f"{' x '.join(str(i) for i in evaluation_mask.shape)}`." ) if not isinstance(sampler, ListSampler): raise ValueError("Decoupled fantasization requires a list of samplers.") diff --git a/test/models/test_model_list_gp_regression.py b/test/models/test_model_list_gp_regression.py index 1547f2a212..09b058a521 100644 --- a/test/models/test_model_list_gp_regression.py +++ b/test/models/test_model_list_gp_regression.py @@ -555,7 +555,32 @@ def test_fantasize_with_outcome_transform_fixed_noise(self) -> None: FixedNoiseGP(X, Y, yvar, outcome_transform=Standardize(m=1)), FixedNoiseGP(X, Y2, yvar2, outcome_transform=Standardize(m=1)), ) + # test exceptions + eval_mask = torch.zeros( + 3, 2, 2, dtype=torch.bool, device=tkwargs["device"] + ) + msg = ( + f"Expected evaluation_mask of shape `{X.shape[0]} x " + f"{model.num_outputs}`, but got `" + f"{' x '.join(str(i) for i in eval_mask.shape)}`." + ) + with self.assertRaisesRegex(BotorchTensorDimensionError, msg): + model.fantasize( + X, + evaluation_mask=eval_mask, + sampler=ListSampler( + IIDNormalSampler(n_fants, seed=0), + IIDNormalSampler(n_fants, seed=0), + ), + ) + msg = "Decoupled fantasization requires a list of samplers." + with self.assertRaisesRegex(ValueError, msg): + model.fantasize( + X, + evaluation_mask=eval_mask[0], + sampler=IIDNormalSampler(n_fants, seed=0), + ) model.posterior(torch.zeros((1, 1), **tkwargs)) for decoupled in (False, True): if decoupled: