From 495e054f7bd239fc7b195a2ea54e3bb85a320760 Mon Sep 17 00:00:00 2001 From: Alex Morehead Date: Thu, 7 Dec 2023 19:21:12 -0600 Subject: [PATCH] Fix periodicity batching bug for `supervised_chi_loss` in loss.py --- openfold/utils/loss.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/openfold/utils/loss.py b/openfold/utils/loss.py index ab422aa8..b36d12d0 100644 --- a/openfold/utils/loss.py +++ b/openfold/utils/loss.py @@ -304,15 +304,15 @@ def supervised_chi_loss( angles_sin_cos: [*, N, 7, 2] predicted angles unnormalized_angles_sin_cos: - The same angles, but unnormalized + [*, N, 4, 2] predicted angles, but unnormalized aatype: [*, N] residue indices seq_mask: [*, N] sequence mask chi_mask: - [*, N, 7] angle mask + [*, N, 4] angle mask chi_angles_sin_cos: - [*, N, 7, 2] ground truth angles + [*, N, 4, 2] ground truth angles chi_weight: Weight for the angle component of the loss angle_norm_weight: @@ -326,7 +326,7 @@ def supervised_chi_loss( residue_constants.restype_num + 1, ) chi_pi_periodic = torch.einsum( - "...ij,jk->ik", + "...ij,jk->...ik", residue_type_one_hot.type(angles_sin_cos.dtype), angles_sin_cos.new_tensor(residue_constants.chi_pi_periodic), ) @@ -353,7 +353,7 @@ def supervised_chi_loss( loss = chi_weight * sq_chi_loss angle_norm = torch.sqrt( - torch.sum(unnormalized_angles_sin_cos ** 2, dim=-1) + eps + torch.sum(unnormalized_angles_sin_cos[None] ** 2, dim=-1) + eps ) norm_error = torch.abs(angle_norm - 1.0) norm_error = norm_error.permute(