Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding flash attention and xformer memory efficient through PT SDPA #97

Merged
merged 9 commits into from
Aug 9, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,21 @@ torchrun --nnodes 1 --nproc_per_node 4 llama_finetuning.py --enable_fsdp --use_

Here we use FSDP as discussed in the next section which can be used along with PEFT methods. To make use of PEFT methods with FSDP make sure to pass `use_peft` and `peft_method` args along with `enable_fsdp`. Here we are using `BF16` for training.

## Flash Attention and Xformer Memory Efficient Kernels

Setting `use_fast_kernels` will enable using of Flash Attention or Xformer memory-efficient kernels based on the hardware being used. This would speed up the fine-tuning job. This has been enabled in `optimum` library from HuggingFace as a one-liner API, please read more [here](https://pytorch.org/blog/out-of-the-box-acceleration/).

```bash
torchrun --nnodes 1 --nproc_per_node 4 llama_finetuning.py --enable_fsdp --use_peft --peft_method lora --model_name /patht_of_model_folder/7B --pure_bf16 --output_dir Path/to/save/PEFT/model --use_fast_kernels
```

### Fine-tuning using FSDP Only

If you are interested in running full parameter fine-tuning without making use of PEFT methods, please use the following command. Make sure to change the `nproc_per_node` to your available GPUs. This has been tested with `BF16` on 8xA100, 40GB GPUs.

```bash

torchrun --nnodes 1 --nproc_per_node 8 llama_finetuning.py --enable_fsdp --model_name /patht_of_model_folder/7B --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned
torchrun --nnodes 1 --nproc_per_node 8 llama_finetuning.py --enable_fsdp --model_name /patht_of_model_folder/7B --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned --use_fast_kernels

```

Expand Down
1 change: 1 addition & 0 deletions configs/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class train_config:
dist_checkpoint_root_folder: str="PATH/to/save/FSDP/model" # will be used if using FSDP
dist_checkpoint_folder: str="fine-tuned" # will be used if using FSDP
save_optimizer: bool=False # will be used if using FSDP
use_fast_kernels: bool = False, # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels



Expand Down
12 changes: 12 additions & 0 deletions docs/inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,18 @@ The inference folder also includes a chat completion example, that adds built-in
python inference/chat_completion.py --model_name "PATH/TO/MODEL/7B/" --prompt_file inference/chats.json --quantization --use_auditnlg
```

## Flash Attention and Xformer Memory Efficient Kernels

Setting `use_fast_kernels` will enable using of Flash Attention or Xformer memory-efficient kernels based on the hardware being used. This would speed up inference when used for batched inputs. This has been enabled in `optimum` library from HuggingFace as a one-liner API, please read more [here](https://pytorch.org/blog/out-of-the-box-acceleration/).

```bash
python inference/chat_completion.py --model_name "PATH/TO/MODEL/7B/" --prompt_file inference/chats.json --quantization --use_auditnlg --use_fast_kernels

python inference/inference.py --model_name <training_config.output_dir> --peft_model <training_config.output_dir> --prompt_file <test_prompt_file> --use_auditnlg --use_fast_kernels

```

## Loading back FSDP checkpoints

In case you have fine-tuned your model with pure FSDP and saved the checkpoints with "SHARDED_STATE_DICT" as shown [here](../configs/fsdp.py), you can use this converter script to convert the FSDP Sharded checkpoints into HuggingFace checkpoints. This enables you to use the inference script normally as mentioned above.
Expand Down
9 changes: 8 additions & 1 deletion docs/mutli_gpu.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,21 @@ The args used in the command above are:

We use `torchrun` here to spawn multiple processes for FSDP.

## Flash Attention and Xformer Memory Efficient Kernels

Setting `use_fast_kernels` will enable using of Flash Attention or Xformer memory-efficient kernels based on the hardware being used. This would speed up the fine-tuning job. This has been enabled in `optimum` library from HuggingFace as a one-liner API, please read more [here](https://pytorch.org/blog/out-of-the-box-acceleration/).

```bash
torchrun --nnodes 1 --nproc_per_node 4 ../llama_finetuning.py --enable_fsdp --model_name /patht_of_model_folder/7B --use_peft --peft_method lora --output_dir Path/to/save/PEFT/model --use_fast_kernels
```

### Fine-tuning using FSDP Only

If interested in running full parameter finetuning without making use of PEFT methods, please use the following command. Make sure to change the `nproc_per_node` to your available GPUs. This has been tested with `BF16` on 8xA100, 40GB GPUs.

```bash

torchrun --nnodes 1 --nproc_per_node 8 llama_finetuning.py --enable_fsdp --model_name /patht_of_model_folder/7B --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned --pure_bf16
torchrun --nnodes 1 --nproc_per_node 8 llama_finetuning.py --enable_fsdp --model_name /patht_of_model_folder/7B --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned --pure_bf16 --use_fast_kernels

```

Expand Down
14 changes: 14 additions & 0 deletions inference/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from peft import PeftModel, PeftConfig
from transformers import LlamaConfig, LlamaTokenizer, LlamaForCausalLM
from optimum.bettertransformer import BetterTransformer
HamidShojanazeri marked this conversation as resolved.
Show resolved Hide resolved
from safety_utils import get_safety_checker
from model_utils import load_model, load_peft_model
from chat_utils import read_dialogs_from_file, format_tokens
Expand All @@ -34,6 +35,7 @@ def main(
enable_azure_content_safety: bool=False, # Enable safety check with Azure content safety api
enable_sensitive_topics: bool=False, # Enable check for sensitive topics using AuditNLG APIs
enable_saleforce_content_safety: bool=True, # Enable safety check woth Saleforce safety flan t5
use_fast_kernels: bool = False, # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
**kwargs
):
if prompt_file is not None:
Expand All @@ -57,6 +59,18 @@ def main(
torch.cuda.manual_seed(seed)
torch.manual_seed(seed)
model = load_model(model_name, quantization)
if use_fast_kernels:
chauhang marked this conversation as resolved.
Show resolved Hide resolved
"""
Setting 'use_fast_kernels' will enable
using of Flash Attention or Xformer memory-efficient kernels
based on the hardware being used. This would speed up inference when used for batched inputs.
"""
try:
from optimum.bettertransformer import BetterTransformer
except ImportError:
print("Module 'optimum' not found. Please install 'optimum' it before proceeding.")

model = BetterTransformer.transform(model)
HamidShojanazeri marked this conversation as resolved.
Show resolved Hide resolved
if peft_model:
model = load_peft_model(model, peft_model)
tokenizer = LlamaTokenizer.from_pretrained(model_name)
Expand Down
13 changes: 13 additions & 0 deletions inference/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def main(
enable_azure_content_safety: bool=False, # Enable safety check with Azure content safety api
enable_sensitive_topics: bool=False, # Enable check for sensitive topics using AuditNLG APIs
enable_saleforce_content_safety: bool=True, # Enable safety check woth Saleforce safety flan t5
use_fast_kernels: bool = False, # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
**kwargs
):
if prompt_file is not None:
Expand All @@ -51,6 +52,18 @@ def main(
torch.manual_seed(seed)

model = load_model(model_name, quantization)
if use_fast_kernels:
"""
Setting 'use_fast_kernels' will enable
using of Flash Attention or Xformer memory-efficient kernels
based on the hardware being used. This would speed up inference when used for batched inputs.
"""
try:
from optimum.bettertransformer import BetterTransformer
except ImportError:
print("Module 'optimum' not found. Please install 'optimum' it before proceeding.")

model = BetterTransformer.transform(model)
HamidShojanazeri marked this conversation as resolved.
Show resolved Hide resolved
tokenizer = LlamaTokenizer.from_pretrained(model_name)
tokenizer.add_special_tokens(
{
Expand Down
13 changes: 12 additions & 1 deletion llama_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
BitsAndBytesConfig
)
import torch.distributed as dist
from optimum.bettertransformer import BetterTransformer
HamidShojanazeri marked this conversation as resolved.
Show resolved Hide resolved

# Unused imports removed
from utils.train_utils import (
Expand Down Expand Up @@ -95,7 +96,17 @@ def main(**kwargs):
load_in_8bit=True if train_config.quantization else None,
device_map="auto" if train_config.quantization else None,
)

if train_config.enable_fsdp and train_config.use_fast_kernels:
"""
For FSDP and FSDP+PEFT, setting 'use_fast_kernels' will enable
using of Flash Attention or Xformer memory-efficient kernels
based on the hardware being used. This would speed up fine-tuning.
"""
try:
from optimum.bettertransformer import BetterTransformer
except ImportError:
print("Module 'optimum' not found. Please install 'optimum' it before proceeding.")
model = BetterTransformer.transform(model)
HamidShojanazeri marked this conversation as resolved.
Show resolved Hide resolved
print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)

# Prepare the model for int8 training if quantization is enabled
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ transformers>=4.31.0
sentencepiece
py7zr
scipy

optimum
31 changes: 30 additions & 1 deletion scripts/spellcheck_conf/wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1089,4 +1089,33 @@ fragmentations
intra
nightlies
recenly
uncomment
uncomment
BFloat
DDP
LLM
Xformer
accuracies
activations
anyprecision
aplaca
assembels
boolean
checkpoining
defatults
gradinets
itermediate
recommond
scaler
sharding
slurm
summarization
theJfleg
xA
Jupyter
LLM
Xformer
dataset's
jupyter
mutli
summarization
xA
23 changes: 17 additions & 6 deletions utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import sys
from typing import List
import yaml
import time

import fire
import torch
Expand Down Expand Up @@ -73,9 +74,12 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
train_loss = []
val_prep = []
val_loss =[]
epoch_times = []
checkpoint_times = []
results = {}
best_val_loss = float("inf")
for epoch in range(train_config.num_epochs):
epoch_start_time = time.perf_counter()
with MemoryTrace() as memtrace: # track the memory usage
model.train()
total_loss = 0.0
Expand Down Expand Up @@ -106,7 +110,8 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
print(f"\n step {step} is completed and loss is {loss.detach().float()}")
else:
print(f"\n step {step} is completed and loss is {loss.detach().float()}")

epoch_end_time = time.perf_counter()-epoch_start_time
epoch_times.append(epoch_end_time)
# Reducing total_loss across all devices if there's more than one CUDA device
if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
Expand All @@ -117,6 +122,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche

train_prep.append(train_perplexity)
train_loss.append(train_epoch_loss)

if train_config.enable_fsdp:
if rank==0:
print(f"Max CUDA memory allocated was {memtrace.peak} GB")
Expand All @@ -136,6 +142,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche

if train_config.run_validation:
eval_ppl, eval_epoch_loss = evaluation(model, train_config, eval_dataloader, rank, tokenizer)
checkpoint_start_time = time.perf_counter()
if train_config.save_model and eval_epoch_loss < best_val_loss:
if train_config.enable_fsdp:
dist.barrier()
Expand Down Expand Up @@ -176,7 +183,8 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
print("=====================================================")
if train_config.enable_fsdp:
dist.barrier()

checkpoint_end_time = time.perf_counter() - checkpoint_start_time
checkpoint_times.append(checkpoint_end_time)
if eval_epoch_loss < best_val_loss:
best_val_loss = eval_epoch_loss
if train_config.enable_fsdp:
Expand All @@ -189,10 +197,11 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche

if train_config.enable_fsdp:
if rank==0:
print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}")
print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epcoh time {epoch_end_time}s")
else:
print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}")

print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epcoh time {epoch_end_time}s")
avg_epoch_time = sum(epoch_times)/ len(epoch_times)
avg_checkpoint_time = sum(checkpoint_times)/ len(checkpoint_times)
avg_train_prep = sum(train_prep)/len(train_prep)
avg_train_loss = sum(train_loss)/len(train_loss)
if train_config.run_validation:
Expand All @@ -204,7 +213,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
if train_config.run_validation:
results['avg_eval_prep'] = avg_eval_prep
results['avg_eval_loss'] = avg_eval_loss

results["avg_epoch_time"] = avg_epoch_time
results["avg_checkpoint_time"] = avg_checkpoint_time

#saving the training params including fsdp setting for reference.
if train_config.enable_fsdp and not train_config.use_peft:
save_train_params(train_config, fsdp_config, rank)
Expand Down
Loading