We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Hi, I design a plug-in version of rotation position embeddings:
def apply_rotary(x, sinusoidal_pos): sin, cos = sinusoidal_pos x1, x2 = x[..., 0::2], x[..., 1::2] # 如果是旋转query key的话,下面这个直接cat就行,因为要进行矩阵乘法,最终会在这个维度求和。(只要保持query和key的最后一个dim的每一个位置对应上就可以) # torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1) # 如果是旋转value的话,下面这个stack后再flatten才可以,因为训练好的模型最后一个dim是两两之间交替的。 return torch.stack([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1).flatten(-2, -1) class PositionalEncoding(nn.Module): def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000): super().__init__() self.dropout = nn.Dropout(p=dropout) position = torch.arange(max_len).unsqueeze(1) div_term = torch.exp( torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model) ) pe = torch.zeros(max_len, 1, d_model) self.sinusoidal_pos = [torch.sin(position * div_term), torch.cos(position * div_term)] pe[:, 0, 0::2] = torch.sin(position * div_term) pe[:, 0, 1::2] = torch.cos(position * div_term) self.register_buffer("pe", pe) def forward(self, x: Tensor) -> Tensor: """ Args: x: Tensor, shape [seq_len, batch_size, embedding_dim] """ x = x + apply_rotary(x, self.sinusoidal_pos) return self.dropout(x)
I did not meet any bugs, is it a correct implementation? Thanks.
The text was updated successfully, but these errors were encountered:
No branches or pull requests
Hi, I design a plug-in version of rotation position embeddings:
I did not meet any bugs, is it a correct implementation? Thanks.
The text was updated successfully, but these errors were encountered: