Skip to content

Can we have a plug-in based roformer? #51

Open
@HelloWorldLTY

Description

@HelloWorldLTY

Hi, I design a plug-in version of rotation position embeddings:

def apply_rotary(x, sinusoidal_pos):
    sin, cos = sinusoidal_pos
    x1, x2 = x[..., 0::2], x[..., 1::2]
    # 如果是旋转query key的话,下面这个直接cat就行,因为要进行矩阵乘法,最终会在这个维度求和。(只要保持query和key的最后一个dim的每一个位置对应上就可以)
    # torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1)
    # 如果是旋转value的话,下面这个stack后再flatten才可以,因为训练好的模型最后一个dim是两两之间交替的。
    return torch.stack([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1).flatten(-2, -1)

class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
        )
        pe = torch.zeros(max_len, 1, d_model)

        self.sinusoidal_pos = [torch.sin(position * div_term), torch.cos(position * div_term)]
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe)

    def forward(self, x: Tensor) -> Tensor:
        """
        Args:
            x: Tensor, shape [seq_len, batch_size, embedding_dim]
        """
        x = x + apply_rotary(x, self.sinusoidal_pos)
        return self.dropout(x)

I did not meet any bugs, is it a correct implementation? Thanks.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions