Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question: initialization for the case of multi-head attention #8

Open
t-taniai opened this issue Jan 7, 2022 · 1 comment
Open

Comments

@t-taniai
Copy link

t-taniai commented Jan 7, 2022

Hi, thanks for the good paper and for releasing the code.
I'm reading both the paper and code, and I've got questions regarding the case of multi-head attention.

It seems that the code exactly follows the initialization procedures explained in Sec 3.2 (ie, initialize weights by Xavier and scale v by 0.67 *N**(-1/4) or (9*N)**(-1/4)), even for multi-head attention. But when v (v_proj.weight) is defined by a shape (embed_dim, embed_dim), then xavier_uniform_ initializes it by a uniform distribution U(-a, a), where a = sqrt(6/(fan_in + fan_out)) = sqrt(6/(embed_dim + embed_dim)). But in multi-head attention, v is actually used as multiple (num_heads) matrices with a shape (head_dim, embed_dim). In this case, shouldn't v be initialized by U(-a, a) where a = sqrt(6/(embed_dim + head_dim)) ....? When num_heads=8, this initialization increases the scale of v's weights by 4/3 (= sqrt(2/(1+1/num_heads)).

Other questions:

  • In the paper, I assumed that d and d' in eq 5 correspond to d_model and d_k, respectively, in the original Transformer's paper, and also correspond to embed_dim and head_dim in the code. Correct? If so, 1/sqrt(d) in eq 5 is perhaps 1/sqrt(d') ?
  • In the code, TransformerEncoderLayer.self_attn.v_proj.weight is scaled as (0.67 * (en_layers) ** (- 1. / 4.)) * (param * (2**0.5)) and (9 * de_layers) ** (- 1. / 4.) * (param * (2**0.5)) for encoder and decoder, respectively. I assumed that the extra scaling * (2**0.5) is to cancel the gain=1/math.sqrt(2) option used in xavier_uniform_ when initializing v_proj.weight. Correct?

Best,
Tatsunori

@petrgeiger-incieve
Copy link

+1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants