Skip to content

Commit

Permalink
[FEAT][LinearAttention]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Jan 13, 2024
1 parent 4708d72 commit b490ca8
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 17 deletions.
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,14 @@ model = MambaTransformer(
ff_mult=4, # Multiplier for the feed-forward layer dimension
return_embeddings=False, # Whether to return the embeddings,
transformer_depth=2, # Number of transformer blocks
mamba_depth=10, # Number of Mamba blocks
mamba_depth=10, # Number of Mamba blocks,
use_linear_attn=True, # Whether to use linear attention
)

# Pass the input tensor through the model and print the output shape
print(model(x).shape)
out = model(x)

print(out.shape)


# to train
Expand Down
7 changes: 5 additions & 2 deletions example.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@
ff_mult=4, # Multiplier for the feed-forward layer dimension
return_embeddings=False, # Whether to return the embeddings,
transformer_depth=2, # Number of transformer blocks
mamba_depth=10, # Number of Mamba blocks
mamba_depth=10, # Number of Mamba blocks,
use_linear_attn=True, # Whether to use linear attention
)

# Pass the input tensor through the model and print the output shape
print(model(x).shape)
out = model(x)

print(out.shape)
7 changes: 5 additions & 2 deletions mamba_transformer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from mamba_transformer.blocks import LinearAttention

from mamba_transformer.model import (
RMSNorm,
MultiQueryTransformerBlock,
TransformerBlock,
MambaTransformerblock,
MambaTransformer,
)

__all__ = [
"LinearAttention",
"RMSNorm",
"MultiQueryTransformerBlock",
"TransformerBlock",
"MambaTransformerblock",
"MambaTransformer",
]
39 changes: 39 additions & 0 deletions mamba_transformer/blocks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from torch import nn, einsum

from einops import rearrange

from zeta.utils import exists

# linear attention


class LinearAttention(nn.Module):
def __init__(self, dim, *, heads=4, dim_head=64, dropout=0.0):
super().__init__()
inner_dim = heads * dim_head
self.heads = heads
self.scale = dim_head**-0.5

self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim), nn.Dropout(dropout)
)

def forward(self, x, mask=None):
h = self.heads
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(
lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h),
(q, k, v),
)

q = q * self.scale
q, k = q.softmax(dim=-1), k.softmax(dim=-2)

if exists(mask):
k.masked_fill_(mask, 0.0)

context = einsum("b n d, b n e -> b d e", q, k)
out = einsum("b d e, b n d -> b n e", context, v)
out = rearrange(out, " (b h) n d -> b n (h d)", h=h)
return self.to_out(out)
43 changes: 32 additions & 11 deletions mamba_transformer/model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import torch
from torch import nn, Tensor
from zeta.nn import MambaBlock, FeedForward, MultiQueryAttention
from zeta.nn import (
MambaBlock,
FeedForward,
MultiQueryAttention,
)
import torch.nn.functional as F
from mamba_transformer.blocks import LinearAttention


class RMSNorm(nn.Module):
Expand All @@ -14,9 +19,9 @@ def forward(self, x: Tensor):
return F.normalize(x, dim=-1) * self.scale * self.g


class MultiQueryTransformerBlock(nn.Module):
class TransformerBlock(nn.Module):
"""
MultiQueryTransformerBlock is a module that represents a single block of the Multi-Query Transformer.
TransformerBlock is a module that represents a single block of the Multi-Query Transformer.
It consists of a multi-query attention layer, a feed-forward network, and layer normalization.
Args:
Expand All @@ -38,7 +43,7 @@ class MultiQueryTransformerBlock(nn.Module):
Methods:
forward(x: Tensor) -> Tensor:
Performs a forward pass of the MultiQueryTransformerBlock.
Performs a forward pass of the TransformerBlock.
"""

Expand All @@ -49,6 +54,7 @@ def __init__(
dim_head: int,
dropout: float = 0.1,
ff_mult: int = 4,
use_linear_attn: bool = False,
*args,
**kwargs,
):
Expand All @@ -58,17 +64,23 @@ def __init__(
self.dim_head = dim_head
self.dropout = dropout
self.ff_mult = ff_mult
self.use_linear_attn = use_linear_attn

self.attn = MultiQueryAttention(dim, heads, *args, **kwargs)

# Linear Attention
self.linear_attn = LinearAttention(
dim=dim, heads=heads, dim_head=dim_head, dropout=dropout
)

self.ffn = FeedForward(dim, dim, ff_mult, *args, **kwargs)

# Normalization
self.norm = nn.LayerNorm(dim)

def forward(self, x: Tensor) -> Tensor:
"""
Performs a forward pass of the MultiQueryTransformerBlock.
Performs a forward pass of the TransformerBlock.
Args:
x (Tensor): The input tensor.
Expand All @@ -77,9 +89,14 @@ def forward(self, x: Tensor) -> Tensor:
Tensor: The output tensor.
"""
x, _, _ = self.attn(x)
x = self.norm(x)
x = self.ffn(x)
if self.use_linear_attn:
x = self.linear_attn(x)
x = self.norm(x)
x = self.ffn(x)
else:
x, _, _ = self.attn(x)
x = self.norm(x)
x = self.ffn(x)
return x


Expand All @@ -106,7 +123,7 @@ class MambaTransformerblock(nn.Module):
dropout (float): The dropout rate.
ff_mult (int): The multiplier for the feed-forward network dimension.
mamba_blocks (nn.ModuleList): List of MambaBlock instances.
transformer_blocks (nn.ModuleList): List of MultiQueryTransformerBlock instances.
transformer_blocks (nn.ModuleList): List of TransformerBlock instances.
ffn_blocks (nn.ModuleList): List of FeedForward instances.
norm (nn.LayerNorm): Layer normalization module.
Expand Down Expand Up @@ -140,6 +157,7 @@ def __init__(
d_state: int = None,
transformer_depth: int = 1,
mamba_depth: int = 1,
use_linear_attn: bool = False,
*args,
**kwargs,
):
Expand Down Expand Up @@ -167,15 +185,16 @@ def __init__(
self.ffn_blocks.append(
FeedForward(dim, dim, ff_mult, *args, **kwargs)
)

for _ in range(transformer_depth):
self.transformer_blocks.append(
MultiQueryTransformerBlock(
TransformerBlock(
dim,
heads,
dim_head,
dropout,
ff_mult,
use_linear_attn,
*args,
**kwargs,
)
Expand Down Expand Up @@ -247,6 +266,7 @@ def __init__(
return_embeddings: bool = False,
transformer_depth: int = 1,
mamba_depth: int = 1,
use_linear_attn=False,
*args,
**kwargs,
):
Expand Down Expand Up @@ -274,6 +294,7 @@ def __init__(
return_embeddings,
transformer_depth,
mamba_depth,
use_linear_attn,
*args,
**kwargs,
)
Expand Down

0 comments on commit b490ca8

Please sign in to comment.