diff --git a/gpytorch/models/exact_prediction_strategies.py b/gpytorch/models/exact_prediction_strategies.py index 2b716d73f..fb45aae46 100644 --- a/gpytorch/models/exact_prediction_strategies.py +++ b/gpytorch/models/exact_prediction_strategies.py @@ -339,7 +339,7 @@ def exact_predictive_mean(self, test_mean: Tensor, test_train_covar: LinearOpera # see https://github.com/cornellius-gp/gpytorch/pull/2317#discussion_r1157994719 mean_cache = self.mean_cache if len(mean_cache.shape) == 4: - mean_cache = mean_cache.squeeze(1) + mean_cache = mean_cache.permute(2, 0, 1, 3)[:, 0, 0:1, :] # Handle NaNs nan_policy = settings.observation_nan_policy.value()