As described in our technical report, training this model proceeds in three steps:
-
Generate pseudo labels from a teacher model
meta-llama/Meta-Llama-3-8B-Instruct
. We provide the generated pseudo labels using the seed dataset of the UltraChat and UltraFeedback dataset here. Please download it and change thetrain_datasets_path
inllama3_0.25_mamba.yaml
andllama3_0.50_mamba.yaml
to the path of your downloadedllama3_ultrafeedback
andllama3_ultrachat
. -
Apply SFT to distilled model. We collected the SFT dataset from multiple sources and preprocessed those datasets using our style. The SFT dataset can be found here. The result is an SFT model like
JunxiongWang/llama3_mamba_0_5_sft
. -
Align the SFT model to AI feedback via DPO on a preprocessed version of the UltraFeedback dataset (link).
Following the Zephyr paper, we tested two hyperparameters:
- 1 epoch with
beta=0.01
, resulting in a DPO model here. - 3 epochs with
beta=0.1
, resulting in a DPO model here.
Here are detailed commands to reproduce those models. Make sure you are in the root folder of the project.
We start with meta-llama/Meta-Llama-3-8B-Instruct. First, we replace 25% of the attention layers with Mamba, and then replace another 25% of the attention layers with Mamba by running the following command.
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file multi_gpu.yaml train_mamba/train_hybrid.py mamba_llama/llama3_0.25_mamba.yaml
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file multi_gpu.yaml train_mamba/train_hybrid.py mamba_llama/llama3_0.50_mamba.yaml
This should rougly takes 10 hours in 8x80G A100.
Now, we have a distilled hybrid mamba model with 50% attention and 50% mamba. We will then want to align it with human feedback.
This model is available here.
We explore two ways for this,
Approach 1: SFT using CE loss of GPT-4 synthetic data
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file deepspeed_zero3.yaml train_mamba/train_sft.py mamba_llama/llama3_0.50_mamba_sft.yaml
This should rougly takes 4 days in 8x80G A100. This model is available here.
Approach 2: SFT using KL loss of a larger teacher model, for example Llama-70B-instruct
.
Please check train_mamba/train_distill.py
and the Mamba-Llama-3.1 for details. It should have better results comapred with SFT using CE loss of GPT-4 synthetic data.
If you don't do Layerwise Distillation phrase, you should set with_distill
to False and it will initialize using attention linear layers. If you already do Layerwise Distillation phrase, you should set with_distill
to True, and it loads the model trained after the first phrase.
Zephyr provides two hyperparameters. You can choose one config from those two.
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file deepspeed_zero3.yaml train_mamba/train_dpo.py mamba_llama/llama3_0.50_mamba_dpo_ep1.yaml
This model is available here.
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file deepspeed_zero3.yaml train_mamba/train_dpo.py mamba_llama/llama3_0.50_mamba_dpo_ep3.yaml
This model is available here.
This should rougly takes few hours in 8x80G A100.
We use the distilled SFT model from 50% attention to initialize this model.
We explore two ways for this,
Approach 1: SFT using CE loss of GPT-4 synthetic data
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file deepspeed_zero3.yaml train_mamba/train_sft.py mamba_llama/llama3_0.75_mamba_sft.yaml
This model is available here.
Approach 2: SFT using KL loss of a larger teacher model, for example Llama-70B-instruct
.
Please check train_mamba/train_distill.py
and the Mamba-Llama-3.1 for details. It should have better results comapred with SFT using CE loss of GPT-4 synthetic data.
If you don't do Layerwise Distillation phrase, you should set with_distill
to False and it will initialize using attention linear layers. If you already do Layerwise Distillation phrase, you should set with_distill
to True, and it loads the model trained after the first phrase.
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file deepspeed_zero3.yaml train_mamba/train_dpo.py mamba_llama/llama3_0.75_mamba_dpo_ep1.yaml
This model is available here.
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file deepspeed_zero3.yaml train_mamba/train_dpo.py mamba_llama/llama3_0.75_mamba_dpo_ep3.yaml
This model is available here.
We use the distilled SFT model from 25% attention to initialize this model.
Approach 1: SFT using CE loss of GPT-4 synthetic data
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file deepspeed_zero3.yaml train_mamba/train_sft.py mamba_llama/llama3_0.875_mamba_sft.yaml
This model is available here.
Approach 2: SFT using KL loss of a larger teacher model, for example Llama-70B-instruct
.
Please check train_mamba/train_distill.py
and the Mamba-Llama-3.1 for details. It should have better results comapred with SFT using CE loss of GPT-4 synthetic data.
If you don't do Layerwise Distillation phrase, you should set with_distill
to False and it will initialize using attention linear layers. If you already do Layerwise Distillation phrase, you should set with_distill
to True, and it loads the model trained after the first phrase.
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file deepspeed_zero3.yaml train_mamba/train_dpo.py mamba_llama/llama3_0.875_mamba_dpo_ep1.yaml
This model is available here.
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file deepspeed_zero3.yaml train_mamba/train_dpo.py mamba_llama/llama3_0.875_mamba_dpo_ep3.yaml
This model is available here.
Please follow the instructions here
For Mamba 2 models, you can just change the script path from train_mamba/*.py
to train_mamba2/*.py
.
After releasing our models, we found that using a larger teacher model and minimizing it with KL divergence loss leads to a better model in the distillation phase. Please check train_mamba/train_distill.py, train_mamba2/train_distill.py and our Mamba 3.2B for more details.