From 6e4a18c99f6141c08a2b86726721f57019fb8e50 Mon Sep 17 00:00:00 2001 From: kfeng123 <446100240@qq.com> Date: Fri, 16 Jun 2023 09:54:53 +0800 Subject: [PATCH] Improve comments, especially for TriangleAttention: c_hidden is the per-head hidden dimension, not overall hidden dimension. --- openfold/model/primitives.py | 4 ++-- openfold/model/triangular_attention.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/openfold/model/primitives.py b/openfold/model/primitives.py index f58c7f33f..0c5102b2f 100644 --- a/openfold/model/primitives.py +++ b/openfold/model/primitives.py @@ -386,12 +386,12 @@ def _prep_qkv(self, k = self.linear_k(kv_x) v = self.linear_v(kv_x) - # [*, Q/K, H, C_hidden] + # [*, Q/K/V, H, C_hidden] q = q.view(q.shape[:-1] + (self.no_heads, -1)) k = k.view(k.shape[:-1] + (self.no_heads, -1)) v = v.view(v.shape[:-1] + (self.no_heads, -1)) - # [*, H, Q/K, C_hidden] + # [*, H, Q/K/V, C_hidden] q = q.transpose(-2, -3) k = k.transpose(-2, -3) v = v.transpose(-2, -3) diff --git a/openfold/model/triangular_attention.py b/openfold/model/triangular_attention.py index 9f96032b8..e05227509 100644 --- a/openfold/model/triangular_attention.py +++ b/openfold/model/triangular_attention.py @@ -37,7 +37,7 @@ def __init__( c_in: Input channel dimension c_hidden: - Overall hidden channel dimension (not per-head) + Per-head hidden channel dimension no_heads: Number of attention heads """