From b490ca832c9a42b0674fb902424a17e8c21c6888 Mon Sep 17 00:00:00 2001 From: Kye Date: Sat, 13 Jan 2024 13:16:14 -0500 Subject: [PATCH] [FEAT][LinearAttention] --- README.md | 7 ++++-- example.py | 7 ++++-- mamba_transformer/__init__.py | 7 ++++-- mamba_transformer/blocks.py | 39 +++++++++++++++++++++++++++++++ mamba_transformer/model.py | 43 ++++++++++++++++++++++++++--------- 5 files changed, 86 insertions(+), 17 deletions(-) create mode 100644 mamba_transformer/blocks.py diff --git a/README.md b/README.md index f2e7ca4..c84cc1e 100644 --- a/README.md +++ b/README.md @@ -37,11 +37,14 @@ model = MambaTransformer( ff_mult=4, # Multiplier for the feed-forward layer dimension return_embeddings=False, # Whether to return the embeddings, transformer_depth=2, # Number of transformer blocks - mamba_depth=10, # Number of Mamba blocks + mamba_depth=10, # Number of Mamba blocks, + use_linear_attn=True, # Whether to use linear attention ) # Pass the input tensor through the model and print the output shape -print(model(x).shape) +out = model(x) + +print(out.shape) # to train diff --git a/example.py b/example.py index 0380c6d..67bf66b 100644 --- a/example.py +++ b/example.py @@ -16,8 +16,11 @@ ff_mult=4, # Multiplier for the feed-forward layer dimension return_embeddings=False, # Whether to return the embeddings, transformer_depth=2, # Number of transformer blocks - mamba_depth=10, # Number of Mamba blocks + mamba_depth=10, # Number of Mamba blocks, + use_linear_attn=True, # Whether to use linear attention ) # Pass the input tensor through the model and print the output shape -print(model(x).shape) +out = model(x) + +print(out.shape) diff --git a/mamba_transformer/__init__.py b/mamba_transformer/__init__.py index 4207f69..1e05fac 100644 --- a/mamba_transformer/__init__.py +++ b/mamba_transformer/__init__.py @@ -1,13 +1,16 @@ +from mamba_transformer.blocks import LinearAttention + from mamba_transformer.model import ( RMSNorm, - MultiQueryTransformerBlock, + TransformerBlock, MambaTransformerblock, MambaTransformer, ) __all__ = [ + "LinearAttention", "RMSNorm", - "MultiQueryTransformerBlock", + "TransformerBlock", "MambaTransformerblock", "MambaTransformer", ] diff --git a/mamba_transformer/blocks.py b/mamba_transformer/blocks.py new file mode 100644 index 0000000..746230c --- /dev/null +++ b/mamba_transformer/blocks.py @@ -0,0 +1,39 @@ +from torch import nn, einsum + +from einops import rearrange + +from zeta.utils import exists + +# linear attention + + +class LinearAttention(nn.Module): + def __init__(self, dim, *, heads=4, dim_head=64, dropout=0.0): + super().__init__() + inner_dim = heads * dim_head + self.heads = heads + self.scale = dim_head**-0.5 + + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), nn.Dropout(dropout) + ) + + def forward(self, x, mask=None): + h = self.heads + q, k, v = self.to_qkv(x).chunk(3, dim=-1) + q, k, v = map( + lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), + (q, k, v), + ) + + q = q * self.scale + q, k = q.softmax(dim=-1), k.softmax(dim=-2) + + if exists(mask): + k.masked_fill_(mask, 0.0) + + context = einsum("b n d, b n e -> b d e", q, k) + out = einsum("b d e, b n d -> b n e", context, v) + out = rearrange(out, " (b h) n d -> b n (h d)", h=h) + return self.to_out(out) diff --git a/mamba_transformer/model.py b/mamba_transformer/model.py index 4eabaff..12e8100 100644 --- a/mamba_transformer/model.py +++ b/mamba_transformer/model.py @@ -1,7 +1,12 @@ import torch from torch import nn, Tensor -from zeta.nn import MambaBlock, FeedForward, MultiQueryAttention +from zeta.nn import ( + MambaBlock, + FeedForward, + MultiQueryAttention, +) import torch.nn.functional as F +from mamba_transformer.blocks import LinearAttention class RMSNorm(nn.Module): @@ -14,9 +19,9 @@ def forward(self, x: Tensor): return F.normalize(x, dim=-1) * self.scale * self.g -class MultiQueryTransformerBlock(nn.Module): +class TransformerBlock(nn.Module): """ - MultiQueryTransformerBlock is a module that represents a single block of the Multi-Query Transformer. + TransformerBlock is a module that represents a single block of the Multi-Query Transformer. It consists of a multi-query attention layer, a feed-forward network, and layer normalization. Args: @@ -38,7 +43,7 @@ class MultiQueryTransformerBlock(nn.Module): Methods: forward(x: Tensor) -> Tensor: - Performs a forward pass of the MultiQueryTransformerBlock. + Performs a forward pass of the TransformerBlock. """ @@ -49,6 +54,7 @@ def __init__( dim_head: int, dropout: float = 0.1, ff_mult: int = 4, + use_linear_attn: bool = False, *args, **kwargs, ): @@ -58,9 +64,15 @@ def __init__( self.dim_head = dim_head self.dropout = dropout self.ff_mult = ff_mult + self.use_linear_attn = use_linear_attn self.attn = MultiQueryAttention(dim, heads, *args, **kwargs) + # Linear Attention + self.linear_attn = LinearAttention( + dim=dim, heads=heads, dim_head=dim_head, dropout=dropout + ) + self.ffn = FeedForward(dim, dim, ff_mult, *args, **kwargs) # Normalization @@ -68,7 +80,7 @@ def __init__( def forward(self, x: Tensor) -> Tensor: """ - Performs a forward pass of the MultiQueryTransformerBlock. + Performs a forward pass of the TransformerBlock. Args: x (Tensor): The input tensor. @@ -77,9 +89,14 @@ def forward(self, x: Tensor) -> Tensor: Tensor: The output tensor. """ - x, _, _ = self.attn(x) - x = self.norm(x) - x = self.ffn(x) + if self.use_linear_attn: + x = self.linear_attn(x) + x = self.norm(x) + x = self.ffn(x) + else: + x, _, _ = self.attn(x) + x = self.norm(x) + x = self.ffn(x) return x @@ -106,7 +123,7 @@ class MambaTransformerblock(nn.Module): dropout (float): The dropout rate. ff_mult (int): The multiplier for the feed-forward network dimension. mamba_blocks (nn.ModuleList): List of MambaBlock instances. - transformer_blocks (nn.ModuleList): List of MultiQueryTransformerBlock instances. + transformer_blocks (nn.ModuleList): List of TransformerBlock instances. ffn_blocks (nn.ModuleList): List of FeedForward instances. norm (nn.LayerNorm): Layer normalization module. @@ -140,6 +157,7 @@ def __init__( d_state: int = None, transformer_depth: int = 1, mamba_depth: int = 1, + use_linear_attn: bool = False, *args, **kwargs, ): @@ -167,15 +185,16 @@ def __init__( self.ffn_blocks.append( FeedForward(dim, dim, ff_mult, *args, **kwargs) ) - + for _ in range(transformer_depth): self.transformer_blocks.append( - MultiQueryTransformerBlock( + TransformerBlock( dim, heads, dim_head, dropout, ff_mult, + use_linear_attn, *args, **kwargs, ) @@ -247,6 +266,7 @@ def __init__( return_embeddings: bool = False, transformer_depth: int = 1, mamba_depth: int = 1, + use_linear_attn=False, *args, **kwargs, ): @@ -274,6 +294,7 @@ def __init__( return_embeddings, transformer_depth, mamba_depth, + use_linear_attn, *args, **kwargs, )