Skip to content

Is mamba slower than transformer? #657

Open
@Lynnzake

Description

@Lynnzake

GPU: A100
Mamba config: using the default MambaConfig except vocab_size set to 108192
CUDA: 12.1
Pytorch:2.3.1
python:3.11

I trained a two tower Bert with about 230m parameter in total, with 1.5B data, trainging was completed in 3 days or so. But I train a mamba model from scratch, resize the vocab_size to 108192, and other parameters setting to be compared to my Bert Model. Both way trained with huggingface Trainer.

But Mamba, firstly, the batch_size has to be far too small than bert, which is 4096 downgraded to 512 to deal with the CUDA OOM error. Second, the training time is about five times of bert training.

I have pip install mamba-ssm[causal-conv1d], what's the possible problem of my setting, or the possible cause of mamba.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions