Skip to content

Simple, minimal implementation of the Mamba SSM in one pytorch file. More efficient than using for loops, but probably less efficient than using associative scans

License

Notifications You must be signed in to change notification settings

NolanBrad/mamba-tiny

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

35 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

mamba-tiny

Tiny implementation of Mamba in PyTorch.

Featuring:

  • Equivalent numerical output as official implementation for both forward and backward pass
  • Simplified, readable, annotated code
  • An alternative to using parallel scan (not available in pytorch as of current) via cumsum, inspired by heisen_sequence

Does NOT include:

  • Recurrent mode of the network intended for inference. The demo code (sentence generation) effectively runs the network as if it were the forward pass during training, which is much slower than the recurrent mode.
  • Kernel fusion. This repo does not make any attempt to perform kernel fusion of the selective scan operations with the other dense operations. So all the internal states of the model would be explicitly materialized, so memory usage may be a bottleneck.
  • Proper parameter initialization (though this could be added without sacrificing readability)

Currently, the supposedly more stable logcumsumexp scan mode (heisen sequence) only works on the gpu for sentence generation (demo.ipynb) but somehow diverges on the cpu.

Demo

See demo.ipynb for examples of prompt completions.

from model import Mamba
from transformers import AutoTokenizer

model = Mamba.from_pretrained('state-spaces/mamba-370m')
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')

generate(model, tokenizer, 'Mamba is the')

Mamba is the world's longest venomous snake with an estimated length of over 150 m. With such a large size and a venomous bite, Mamba kills by stabbing the victim (which is more painful and less effective than a single stab of the bite)

150 meters... 🫢 scary!

References

The Mamba architecture was introduced in Mamba: Linear-Time Sequence Modeling with Selective State Spaces by Albert Gu and Tri Dao.

The official implementation is here: https://github.com/state-spaces/mamba/tree/main

About

Simple, minimal implementation of the Mamba SSM in one pytorch file. More efficient than using for loops, but probably less efficient than using associative scans

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 65.5%
  • Jupyter Notebook 34.5%