Open
Description
GPU: A100
Mamba config: using the default MambaConfig except vocab_size set to 108192
CUDA: 12.1
Pytorch:2.3.1
python:3.11
I trained a two tower Bert with about 230m parameter in total, with 1.5B data, trainging was completed in 3 days or so. But I train a mamba model from scratch, resize the vocab_size to 108192, and other parameters setting to be compared to my Bert Model. Both way trained with huggingface Trainer.
But Mamba, firstly, the batch_size has to be far too small than bert, which is 4096 downgraded to 512 to deal with the CUDA OOM error. Second, the training time is about five times of bert training.
I have pip install mamba-ssm[causal-conv1d], what's the possible problem of my setting, or the possible cause of mamba.
Metadata
Metadata
Assignees
Labels
No labels