From 90d734c90ccfc772c96ad1c886082d54006efd52 Mon Sep 17 00:00:00 2001 From: Alexey Y Date: Sun, 3 Dec 2023 17:22:27 -0800 Subject: [PATCH 1/6] Fixed data extraction for dKG --- gpytorch/models/exact_prediction_strategies.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gpytorch/models/exact_prediction_strategies.py b/gpytorch/models/exact_prediction_strategies.py index 2b716d73f..21beae64b 100644 --- a/gpytorch/models/exact_prediction_strategies.py +++ b/gpytorch/models/exact_prediction_strategies.py @@ -339,7 +339,7 @@ def exact_predictive_mean(self, test_mean: Tensor, test_train_covar: LinearOpera # see https://github.com/cornellius-gp/gpytorch/pull/2317#discussion_r1157994719 mean_cache = self.mean_cache if len(mean_cache.shape) == 4: - mean_cache = mean_cache.squeeze(1) + mean_cache = mean_cache.permute(2,0,1,3)[:,0:1,0,:] # Handle NaNs nan_policy = settings.observation_nan_policy.value() From 8b8759757300c88b5fd2f39e05fbdd785d13bd8f Mon Sep 17 00:00:00 2001 From: Alexey Y Date: Sun, 3 Dec 2023 17:30:23 -0800 Subject: [PATCH 2/6] Fixed linter --- gpytorch/models/exact_prediction_strategies.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gpytorch/models/exact_prediction_strategies.py b/gpytorch/models/exact_prediction_strategies.py index 21beae64b..12c9c3c21 100644 --- a/gpytorch/models/exact_prediction_strategies.py +++ b/gpytorch/models/exact_prediction_strategies.py @@ -339,7 +339,7 @@ def exact_predictive_mean(self, test_mean: Tensor, test_train_covar: LinearOpera # see https://github.com/cornellius-gp/gpytorch/pull/2317#discussion_r1157994719 mean_cache = self.mean_cache if len(mean_cache.shape) == 4: - mean_cache = mean_cache.permute(2,0,1,3)[:,0:1,0,:] + mean_cache = mean_cache.permute(2, 0, 1, 3)[:, 0:1, 0, :] # Handle NaNs nan_policy = settings.observation_nan_policy.value() From 547ed4ad5143e0e920b43a0ea0a1e85e50a2b49a Mon Sep 17 00:00:00 2001 From: Alexey Y Date: Sun, 3 Dec 2023 18:13:22 -0800 Subject: [PATCH 3/6] Re-run unittests --- gpytorch/models/exact_prediction_strategies.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gpytorch/models/exact_prediction_strategies.py b/gpytorch/models/exact_prediction_strategies.py index 12c9c3c21..fb45aae46 100644 --- a/gpytorch/models/exact_prediction_strategies.py +++ b/gpytorch/models/exact_prediction_strategies.py @@ -339,7 +339,7 @@ def exact_predictive_mean(self, test_mean: Tensor, test_train_covar: LinearOpera # see https://github.com/cornellius-gp/gpytorch/pull/2317#discussion_r1157994719 mean_cache = self.mean_cache if len(mean_cache.shape) == 4: - mean_cache = mean_cache.permute(2, 0, 1, 3)[:, 0:1, 0, :] + mean_cache = mean_cache.permute(2, 0, 1, 3)[:, 0, 0:1, :] # Handle NaNs nan_policy = settings.observation_nan_policy.value() From 45ee74d414a57ee4e476fe299e9fad2525d86d77 Mon Sep 17 00:00:00 2001 From: Alexey Y Date: Sun, 3 Dec 2023 18:23:41 -0800 Subject: [PATCH 4/6] Testing UT --- gpytorch/models/exact_prediction_strategies.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/gpytorch/models/exact_prediction_strategies.py b/gpytorch/models/exact_prediction_strategies.py index fb45aae46..bd3d5e36e 100644 --- a/gpytorch/models/exact_prediction_strategies.py +++ b/gpytorch/models/exact_prediction_strategies.py @@ -339,7 +339,8 @@ def exact_predictive_mean(self, test_mean: Tensor, test_train_covar: LinearOpera # see https://github.com/cornellius-gp/gpytorch/pull/2317#discussion_r1157994719 mean_cache = self.mean_cache if len(mean_cache.shape) == 4: - mean_cache = mean_cache.permute(2, 0, 1, 3)[:, 0, 0:1, :] + #mean_cache = mean_cache.permute(2, 0, 1, 3)[:, 0, 0:1, :] + mean_cache = mean_cache.squeeze(1) # Handle NaNs nan_policy = settings.observation_nan_policy.value() From 867255e779494372f773768b524fb1d06ae18068 Mon Sep 17 00:00:00 2001 From: Alexey Y Date: Sun, 3 Dec 2023 18:27:34 -0800 Subject: [PATCH 5/6] No code changed --- gpytorch/models/exact_prediction_strategies.py | 1 - 1 file changed, 1 deletion(-) diff --git a/gpytorch/models/exact_prediction_strategies.py b/gpytorch/models/exact_prediction_strategies.py index bd3d5e36e..2b716d73f 100644 --- a/gpytorch/models/exact_prediction_strategies.py +++ b/gpytorch/models/exact_prediction_strategies.py @@ -339,7 +339,6 @@ def exact_predictive_mean(self, test_mean: Tensor, test_train_covar: LinearOpera # see https://github.com/cornellius-gp/gpytorch/pull/2317#discussion_r1157994719 mean_cache = self.mean_cache if len(mean_cache.shape) == 4: - #mean_cache = mean_cache.permute(2, 0, 1, 3)[:, 0, 0:1, :] mean_cache = mean_cache.squeeze(1) # Handle NaNs From 6dd874778575b02ea8ab4d35b315bcc629b39d41 Mon Sep 17 00:00:00 2001 From: Alexey Y Date: Sun, 3 Dec 2023 18:32:38 -0800 Subject: [PATCH 6/6] Brought back my change --- gpytorch/models/exact_prediction_strategies.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gpytorch/models/exact_prediction_strategies.py b/gpytorch/models/exact_prediction_strategies.py index 2b716d73f..fb45aae46 100644 --- a/gpytorch/models/exact_prediction_strategies.py +++ b/gpytorch/models/exact_prediction_strategies.py @@ -339,7 +339,7 @@ def exact_predictive_mean(self, test_mean: Tensor, test_train_covar: LinearOpera # see https://github.com/cornellius-gp/gpytorch/pull/2317#discussion_r1157994719 mean_cache = self.mean_cache if len(mean_cache.shape) == 4: - mean_cache = mean_cache.squeeze(1) + mean_cache = mean_cache.permute(2, 0, 1, 3)[:, 0, 0:1, :] # Handle NaNs nan_policy = settings.observation_nan_policy.value()