diff --git a/openfold/model/primitives.py b/openfold/model/primitives.py index f58c7f33..0c5102b2 100644 --- a/openfold/model/primitives.py +++ b/openfold/model/primitives.py @@ -386,12 +386,12 @@ def _prep_qkv(self, k = self.linear_k(kv_x) v = self.linear_v(kv_x) - # [*, Q/K, H, C_hidden] + # [*, Q/K/V, H, C_hidden] q = q.view(q.shape[:-1] + (self.no_heads, -1)) k = k.view(k.shape[:-1] + (self.no_heads, -1)) v = v.view(v.shape[:-1] + (self.no_heads, -1)) - # [*, H, Q/K, C_hidden] + # [*, H, Q/K/V, C_hidden] q = q.transpose(-2, -3) k = k.transpose(-2, -3) v = v.transpose(-2, -3) diff --git a/openfold/model/triangular_attention.py b/openfold/model/triangular_attention.py index 9f96032b..e0522750 100644 --- a/openfold/model/triangular_attention.py +++ b/openfold/model/triangular_attention.py @@ -37,7 +37,7 @@ def __init__( c_in: Input channel dimension c_hidden: - Overall hidden channel dimension (not per-head) + Per-head hidden channel dimension no_heads: Number of attention heads """