Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is mamba slower than transformer? #657

Open
Lynnzake opened this issue Dec 27, 2024 · 4 comments
Open

Is mamba slower than transformer? #657

Lynnzake opened this issue Dec 27, 2024 · 4 comments

Comments

@Lynnzake
Copy link

GPU: A100
Mamba config: using the default MambaConfig except vocab_size set to 108192
CUDA: 12.1
Pytorch:2.3.1
python:3.11

I trained a two tower Bert with about 230m parameter in total, with 1.5B data, trainging was completed in 3 days or so. But I train a mamba model from scratch, resize the vocab_size to 108192, and other parameters setting to be compared to my Bert Model. Both way trained with huggingface Trainer.

But Mamba, firstly, the batch_size has to be far too small than bert, which is 4096 downgraded to 512 to deal with the CUDA OOM error. Second, the training time is about five times of bert training.

I have pip install mamba-ssm[causal-conv1d], what's the possible problem of my setting, or the possible cause of mamba.

@klae01
Copy link

klae01 commented Dec 29, 2024

Could you clarify the sequence length used during training? To my knowledge, Mamba and Transformer models demonstrate comparable speeds when the sequence length reaches 2048 or somewhat longer. However, for shorter sequences, Mamba may experience slower performance compared to Transformers. This could partially explain the increased training time you observed.

@Lynnzake
Copy link
Author

It's short , no more than 150 words, maybe that explain your insights. Does mamba2 solve this problem--this kind of performance degrade compare to transformer. I havn't tried mamba2 yet.

@klae01
Copy link

klae01 commented Jan 1, 2025

The transition from Mamba1 to Mamba2 does not show significant improvements for short sequence lengths. As seen in the attached image, Mamba2 still performs slower than Transformers for shorter sequences. This is primarily due to Mamba's architecture, which has a large constant overhead. Only when the sequence length becomes sufficiently long does this overhead become negligible.

Mamba2 demonstrates meaningful performance advantages over both Transformers and Mamba1 for sufficiently long sequences, particularly beyond 4k tokens. This aligns with its design goal of optimizing performance for long sequence processing.

For tasks involving short sequences, Transformers remain the faster and more practical choice. However, for applications requiring longer sequences, Mamba2 offers significant performance benefits.

transformer-mamba-mamba2-speed-comparisons

@Lynnzake
Copy link
Author

Lynnzake commented Jan 3, 2025

The transition from Mamba1 to Mamba2 does not show significant improvements for short sequence lengths. As seen in the attached image, Mamba2 still performs slower than Transformers for shorter sequences. This is primarily due to Mamba's architecture, which has a large constant overhead. Only when the sequence length becomes sufficiently long does this overhead become negligible.

Mamba2 demonstrates meaningful performance advantages over both Transformers and Mamba1 for sufficiently long sequences, particularly beyond 4k tokens. This aligns with its design goal of optimizing performance for long sequence processing.

For tasks involving short sequences, Transformers remain the faster and more practical choice. However, for applications requiring longer sequences, Mamba2 offers significant performance benefits.

transformer-mamba-mamba2-speed-comparisons

Appreciate for the answer, helps a lot.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants