Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can we have a plug-in based roformer? #51

Open
HelloWorldLTY opened this issue Sep 15, 2023 · 0 comments
Open

Can we have a plug-in based roformer? #51

HelloWorldLTY opened this issue Sep 15, 2023 · 0 comments

Comments

@HelloWorldLTY
Copy link

Hi, I design a plug-in version of rotation position embeddings:

def apply_rotary(x, sinusoidal_pos):
    sin, cos = sinusoidal_pos
    x1, x2 = x[..., 0::2], x[..., 1::2]
    # 如果是旋转query key的话,下面这个直接cat就行,因为要进行矩阵乘法,最终会在这个维度求和。(只要保持query和key的最后一个dim的每一个位置对应上就可以)
    # torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1)
    # 如果是旋转value的话,下面这个stack后再flatten才可以,因为训练好的模型最后一个dim是两两之间交替的。
    return torch.stack([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1).flatten(-2, -1)

class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
        )
        pe = torch.zeros(max_len, 1, d_model)

        self.sinusoidal_pos = [torch.sin(position * div_term), torch.cos(position * div_term)]
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe)

    def forward(self, x: Tensor) -> Tensor:
        """
        Args:
            x: Tensor, shape [seq_len, batch_size, embedding_dim]
        """
        x = x + apply_rotary(x, self.sinusoidal_pos)
        return self.dropout(x)

I did not meet any bugs, is it a correct implementation? Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant