Skip to content

PenutChen/GaLore

 
 

Repository files navigation

GaLore

This repo contains the pre-release version of GaLore algorithm, proposed by GaLore: Memory-Efficient LLM Training by Gradient Low-Rank Projection.

Gradient Low-Rank Projection (GaLore) is a memory-efficient low-rank training strategy that allows full-parameter learning but is more memory-efficient than common low-rank adaptation methods, such as LoRA. As a gradient projection method, GaLore is independent of the choice of optimizers and can be easily plugged into existing ones with only two lines of code, as shown in Algorithm 1 below.

Image 2

News

Thanks everyone for the interest in GaLore!

We are working on the offical release of GaLore. In the meanwhile, please feel free to try the pre-release version and provide feedback to us. Currently, the pre-release version (e.g., GaLore optimizers) should provide a decent memory reduction and accurate simulation of GaLore algorithm.

The official release of GaLore will include:

  1. Per-layer weight updates for multi-GPU training (DDP and FSDP) (working with PyTorch).
  2. Memory-efficient low-rank gradient accumulation (working with PyTorch).
  3. Optimized GaLoreAdamW8bit (working with bitsandbytes).

We would like to express our gratitude to the community members who have been actively working on integrating GaLore into different platforms, including HuggingFace, LLaMA-Factory, and Axolotl. Join our Slack workspace GaLore-Social to engage in discussions with us.

Discussion (GaLore-Social)

We welcome any discussions, questions, and feedback on GaLore. Please join our Slack workspace GaLore-Social to discuss with us and the community.

Installation

Install GaLore optimizer

Install from pip:

pip install galore-torch

or if you want to install from source:

git clone [email protected]:jiaweizzhao/GaLore.git
cd GaLore
pip install -e .

Install experiment dependencies

pip install -r exp_requirements.txt

Usage

Save optimizer memory using GaLore optimizers

from galore_torch import GaLoreAdamW, GaLoreAdamW8bit, GaLoreAdafactor
# define param groups as galore_params and non_galore_params
param_groups = [{'params': non_galore_params}, 
                {'params': galore_params, 'rank': 128, 'update_proj_gap': 200, 'scale': 0.25, 'proj_type': 'std'}]
optimizer = GaLoreAdamW(param_groups, lr=0.01)

Save weight gradient memory using per-layer weight updates

We use register_post_accumulate_grad_hook provided by PyTorch to enable per-layer weight updates. An example is shown below:

# define an optimizer for each parameter p, and store them in optimizer_dict
for p in model.parameters():
    if p.requires_grad:
        optimizer_dict[p] = GaLoreAdamW([{'params': p, 'rank': 128, 'update_proj_gap': 200, 'scale': 0.25, 'proj_type': 'std'}], lr=0.01)

# define a hook function to update the parameter p during the backward pass
def optimizer_hook(p):
    if p.grad is None: 
        return
    optimizer_dict[p].step()
    optimizer_dict[p].zero_grad()

# Register the hook onto every parameter
for p in model.parameters():
    if p.requires_grad:
        p.register_post_accumulate_grad_hook(optimizer_hook)

More details can be found in torchrun_main.py.

Benchmark 1: Pre-Training LLaMA on C4 dataset

torchrun_main.py is the main script for training LLaMA models on C4 with GaLore. Our benchmark scripts for various sizes of models are in scripts/benchmark_c4 folder. For example, to train a 60m model on C4, do the following:

# LLaMA-60M, GaLore-Adam, 1 A100, 1 Node
torchrun --standalone --nproc_per_node 1 torchrun_main.py \
    --model_config configs/llama_60m.json \
    --lr 0.01 \
    --galore_scale 0.25 \
    --rank 128 \
    --update_proj_gap 200 \
    --batch_size 256 \
    --total_batch_size 512 \
    --num_training_steps 10000 \
    --warmup_steps 1000 \
    --weight_decay 0 \
    --dtype bfloat16 \
    --eval_every 1000 \
    --optimizer galore_adamw 

Train 7B model with a single GPU with 24GB memory

To train a 7B model with a single GPU such as NVIDIA RTX 4090, all you need to do is to specify --optimizer=galore_adamw8bit_per_layer, which enables GaLoreAdamW8bit with per-layer weight updates. With activation checkpointing, you can maintain a batch size of 16 tested on NVIDIA RTX 4090.

# LLaMA-7B, 8-bit GaLore-Adam, single GPU, activation checkpointing
# bsz=16, 22.8G, 
torchrun --standalone --nproc_per_node 1 torchrun_main.py \
    --model_config configs/llama_7b.json \
    --lr 0.005 \
    --galore_scale 0.25 \
    --rank 1024 \
    --update_proj_gap 500 \
    --batch_size 16 \
    --total_batch_size 512 \
    --activation_checkpointing \
    --num_training_steps 150000 \
    --warmup_steps 15000 \
    --weight_decay 0 \
    --grad_clipping 1.0 \
    --dtype bfloat16 \
    --eval_every 1000 \
    --single_gpu \
    --optimizer galore_adamw8bit_per_layer

Currently per-layer weight updates technique is only supported for single GPU training (--single_gpu) without using nn.parallel.DistributedDataParallel. We are working on supporting multi-GPU training with per-layer weight updates.

Benchmark 2: Fine-Tuning RoBERTa on GLUE tasks

run_glue.py is the main script for fine-tuning RoBERTa models on GLUE tasks with GaLore. An example script is shown below:

python run_glue.py \
    --model_name_or_path roberta-base \
    --task_name mrpc \
    --enable_galore \
    --lora_all_modules \
    --max_length 512 \
    --seed=1234 \
    --lora_r 4 \
    --galore_scale 4 \
    --per_device_train_batch_size 16 \
    --update_proj_gap 500 \
    --learning_rate 3e-5 \
    --num_train_epochs 30 \
    --output_dir results/ft/roberta_base/mrpc

Citation

@misc{zhao2024galore,
      title={GaLore: Memory-Efficient LLM Training by Gradient Low-Rank Projection}, 
      author={Jiawei Zhao and Zhenyu Zhang and Beidi Chen and Zhangyang Wang and Anima Anandkumar and Yuandong Tian},
      year={2024},
      eprint={2403.03507},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 97.5%
  • Shell 2.5%