This repository contains the codes used to train the models discussed in the article :
N. Leroux, P.-P. Manea, C. Sudarshan, J. Finkbeiner, S. Siegel, J. P. Strachan, and E. Neftci, Analog In-Memory Computing Attention Mechanism for Fast and Energy-Efficient Large Language Models, arXiv:2409.19315. https://arxiv.org/abs/2409.19315.
mkdir new_project
cd new_project
mkdir ./venv/
python -m venv ./venv/gpt2
git clone https://github.com/NathanLeroux-git/GainCellAttention.git
cd GainCellAttention/
source ../venv/gpt2/bin/activate
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
pip install -r requirements.txt
python datasets/texts/shakespeare/prepare.py
python datasets/texts/openwebtext/prepare.py
Training the models (total estimated time: ~4 days on 8 Nvidia H100 80 Gb GPUs, ~16 days on 8 Nvidia RTX 4090 24 Gb)
main_gpt.py automatically saved the results on wandb.ai (use --wandb_log=True and --wandb_offline=False). The trained models are checkpointed in "../saved_models/checkpoints".
python -m main_gpt ./configs/experiments/gpt-text/finetune_shakespeare.py --wandb_log=True --wandb_offline=False --batch_size=8
Distributed run with 4 GPUs on 3000 iterations to train the intermediate gain cells model fine-tuned from gpt2:
python -m torch.distributed.run --nproc_per_node 4 main_gpt.py ./configs/experiments/gpt-text/ft_lineardram_gpt2.py --wandb_log=True --wandb_offline=False --init_from='gpt2' --stop_saving_after=3000. --max_iters=3001 --out_dir='../saved_models/checkpoints/gpt2_LinearDRAMAttention'
After training the intermediate model, we need to change the saved model name to ../saved_models/gpt2-LinearDRAMAttention.pt. For instance:
cp ../saved_models/checkpoints/gpt2_LinearDRAMAttention/ckpt.pt ../saved_models/gpt2-LinearDRAMAttention.pt
Finally, we can fine-tune the intermediate model on the final gain cells model (the adaptation algorithm is operated by main_gpt.py):
python -m torch.distributed.run --nproc_per_node 4 main_gpt.py ./configs/experiments/gpt-text/ft_dram_gpt2.py --wandb_log=True --wandb_offline=False --init_from='gpt2-LinearDRAMAttention' --stop_saving_after=13000. --max_iters=13001 --out_dir='../saved_models/checkpoints/gpt2_DRAMAttention'