Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix periodicity batching bug for supervised_chi_loss in loss.py #382

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions openfold/utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,15 +304,15 @@ def supervised_chi_loss(
angles_sin_cos:
[*, N, 7, 2] predicted angles
unnormalized_angles_sin_cos:
The same angles, but unnormalized
[*, N, 4, 2] predicted angles, but unnormalized
aatype:
[*, N] residue indices
seq_mask:
[*, N] sequence mask
chi_mask:
[*, N, 7] angle mask
[*, N, 4] angle mask
chi_angles_sin_cos:
[*, N, 7, 2] ground truth angles
[*, N, 4, 2] ground truth angles
chi_weight:
Weight for the angle component of the loss
angle_norm_weight:
Expand All @@ -326,7 +326,7 @@ def supervised_chi_loss(
residue_constants.restype_num + 1,
)
chi_pi_periodic = torch.einsum(
"...ij,jk->ik",
"...ij,jk->...ik",
residue_type_one_hot.type(angles_sin_cos.dtype),
angles_sin_cos.new_tensor(residue_constants.chi_pi_periodic),
)
Expand All @@ -353,7 +353,7 @@ def supervised_chi_loss(
loss = chi_weight * sq_chi_loss

angle_norm = torch.sqrt(
torch.sum(unnormalized_angles_sin_cos ** 2, dim=-1) + eps
torch.sum(unnormalized_angles_sin_cos[None] ** 2, dim=-1) + eps
)
norm_error = torch.abs(angle_norm - 1.0)
norm_error = norm_error.permute(
Expand Down