Skip to content

Commit

Permalink
sign fix for gradient blocks
Browse files Browse the repository at this point in the history
  • Loading branch information
m-julian committed Oct 16, 2024
1 parent d501c28 commit 7061ce2
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 9 deletions.
5 changes: 2 additions & 3 deletions gpytorch/kernels/matern52_kernel_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,15 +94,14 @@ def forward(self, x1, x2, diag=False, **params):

# 2) First gradient block, cov(f^m, omega^n_d)
outer1 = outer.view(*batch_shape, n1, n2 * d)
K[..., :n1, n2:] = outer1 * (-five_thirds * (1 + sqrt5 * distance_matrix) * exp_neg_sqrt5r).repeat(
K[..., :n1, n2:] = outer1 * (five_thirds * (1 + sqrt5 * distance_matrix) * exp_neg_sqrt5r).repeat(
[*([1] * (n_batch_dims + 1)), d]
)

# 3) Second gradient block, cov(omega^m_d, f^n)
outer2 = outer.transpose(-1, -3).reshape(*batch_shape, n2, n1 * d)
outer2 = outer2.transpose(-1, -2)
# the - signs on -outer2 and -five_thirds cancel out
K[..., n1:, :n2] = outer2 * (five_thirds * (1 + sqrt5 * distance_matrix) * exp_neg_sqrt5r).repeat(
K[..., n1:, :n2] = -outer2 * (five_thirds * (1 + sqrt5 * distance_matrix) * exp_neg_sqrt5r).repeat(
[*([1] * n_batch_dims), d, 1]
)

Expand Down
12 changes: 6 additions & 6 deletions test/kernels/test_matern52_kernel_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ def test_kernel(self, cuda=False):

actual = torch.tensor(
[
[0.3056225, -0.0000000, 0.5822443, 0.0188260, -0.0209871, 0.0419742],
[0.0000000, 0.5822443, 0.0000000, 0.0209871, -0.0056045, 0.0531832],
[-0.5822443, 0.0000000, -0.8515886, -0.0419742, 0.0531832, -0.0853792],
[0.1304891, -0.2014212, -0.2014212, 0.0336440, -0.0815567, -0.0000000],
[0.2014212, -0.1754366, -0.3768578, 0.0815567, -0.1870145, -0.0000000],
[0.2014212, -0.3768578, -0.1754366, 0.0000000, -0.0000000, 0.0407784],
[0.3056225, 0.0000000, -0.5822443, 0.0188260, 0.0209871, -0.0419742],
[-0.0000000, 0.5822443, 0.0000000, -0.0209871, -0.0056045, 0.0531832],
[0.5822443, 0.0000000, -0.8515886, 0.0419742, 0.0531832, -0.0853792],
[0.1304891, 0.2014212, 0.2014212, 0.0336440, 0.0815567, 0.0000000],
[-0.2014212, -0.1754366, -0.3768578, -0.0815567, -0.1870145, -0.0000000],
[-0.2014212, -0.3768578, -0.1754366, -0.0000000, -0.0000000, 0.0407784],
]
)

Expand Down

0 comments on commit 7061ce2

Please sign in to comment.