-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Is mamba slower than transformer? #657
Comments
Could you clarify the |
It's short , no more than 150 words, maybe that explain your insights. Does mamba2 solve this problem--this kind of performance degrade compare to transformer. I havn't tried mamba2 yet. |
The transition from Mamba1 to Mamba2 does not show significant improvements for short sequence lengths. As seen in the attached image, Mamba2 still performs slower than Transformers for shorter sequences. This is primarily due to Mamba's architecture, which has a large constant overhead. Only when the sequence length becomes sufficiently long does this overhead become negligible. Mamba2 demonstrates meaningful performance advantages over both Transformers and Mamba1 for sufficiently long sequences, particularly beyond 4k tokens. This aligns with its design goal of optimizing performance for long sequence processing. For tasks involving short sequences, Transformers remain the faster and more practical choice. However, for applications requiring longer sequences, Mamba2 offers significant performance benefits. |
Appreciate for the answer, helps a lot. |
GPU: A100
Mamba config: using the default MambaConfig except vocab_size set to 108192
CUDA: 12.1
Pytorch:2.3.1
python:3.11
I trained a two tower Bert with about 230m parameter in total, with 1.5B data, trainging was completed in 3 days or so. But I train a mamba model from scratch, resize the vocab_size to 108192, and other parameters setting to be compared to my Bert Model. Both way trained with huggingface Trainer.
But Mamba, firstly, the batch_size has to be far too small than bert, which is 4096 downgraded to 512 to deal with the CUDA OOM error. Second, the training time is about five times of bert training.
I have pip install mamba-ssm[causal-conv1d], what's the possible problem of my setting, or the possible cause of mamba.
The text was updated successfully, but these errors were encountered: