Tiny implementation of Mamba in PyTorch.
Featuring:
- Equivalent numerical output as official implementation for both forward and backward pass
- Simplified, readable, annotated code
- An alternative to using parallel scan (not available in pytorch as of current) via cumsum, inspired by heisen_sequence
Does NOT include:
- Recurrent mode of the network intended for inference. The demo code (sentence generation) effectively runs the network as if it were the forward pass during training, which is much slower than the recurrent mode.
- Kernel fusion. This repo does not make any attempt to perform kernel fusion of the selective scan operations with the other dense operations. So all the internal states of the model would be explicitly materialized, so memory usage may be a bottleneck.
- Proper parameter initialization (though this could be added without sacrificing readability)
Currently, the supposedly more stable logcumsumexp
scan mode (heisen sequence) only works on the gpu for sentence generation (demo.ipynb
) but somehow diverges on the cpu.
See demo.ipynb for examples of prompt completions.
from model import Mamba
from transformers import AutoTokenizer
model = Mamba.from_pretrained('state-spaces/mamba-370m')
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')
generate(model, tokenizer, 'Mamba is the')
Mamba is the world's longest venomous snake with an estimated length of over 150 m. With such a large size and a venomous bite, Mamba kills by stabbing the victim (which is more painful and less effective than a single stab of the bite)
150 meters... 🫢 scary!
The Mamba architecture was introduced in Mamba: Linear-Time Sequence Modeling with Selective State Spaces by Albert Gu and Tri Dao.
The official implementation is here: https://github.com/state-spaces/mamba/tree/main