Skip to content

Commit

Permalink
refactor(framework) Update FlowerTune LLM example (#4046)
Browse files Browse the repository at this point in the history
Co-authored-by: jafermarq <[email protected]>
  • Loading branch information
yan-gao-GY and jafermarq authored Aug 27, 2024
1 parent ebf0d3f commit 5b3784a
Show file tree
Hide file tree
Showing 17 changed files with 384 additions and 475 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ Other [examples](https://github.com/adap/flower/tree/main/examples):
- [PyTorch: From Centralized to Federated](https://github.com/adap/flower/tree/main/examples/pytorch-from-centralized-to-federated)
- [Vertical FL](https://github.com/adap/flower/tree/main/examples/vertical-fl)
- [Federated Finetuning of OpenAI's Whisper](https://github.com/adap/flower/tree/main/examples/whisper-federated-finetuning)
- [Federated Finetuning of Large Language Model](https://github.com/adap/flower/tree/main/examples/llm-flowertune)
- [Federated Finetuning of Large Language Model](https://github.com/adap/flower/tree/main/examples/flowertune-llm)
- [Federated Finetuning of a Vision Transformer](https://github.com/adap/flower/tree/main/examples/flowertune-vit)
- [Advanced Flower with TensorFlow/Keras](https://github.com/adap/flower/tree/main/examples/advanced-tensorflow)
- [Advanced Flower with PyTorch](https://github.com/adap/flower/tree/main/examples/advanced-pytorch)
Expand Down
1 change: 1 addition & 0 deletions examples/doc/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
"quickstart-mxnet": "index.html",
"mxnet-from-centralized-to-federated": "index.html",
"app-secure-aggregation": "flower-secure-aggregation.html",
"llm-flowertune": "flowertune-llm.html",
}


Expand Down
118 changes: 118 additions & 0 deletions examples/flowertune-llm/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
---
tags: [llm, nlp, LLama]
dataset: [Alpaca-GPT4]
framework: [PEFT, torch]
---

# FlowerTune LLM: Federated LLM Fine-tuning with Flower

Large language models (LLMs), which have been trained on vast amounts of publicly accessible data, have shown remarkable effectiveness in a wide range of areas.
However, despite the fact that more data typically leads to improved performance, there is a concerning prospect that the supply of high-quality public data will deplete within a few years.
Federated LLM training could unlock access to an endless pool of distributed private data by allowing multiple data owners to collaboratively train a shared model without the need to exchange raw data.

This introductory example conducts federated instruction tuning with pretrained [OpenLLaMA](https://huggingface.co/openlm-research) models on [Alpaca-GPT4](https://huggingface.co/datasets/vicgalle/alpaca-gpt4) dataset.
We implement FlowerTune LLM by integrating a bundle of techniques: 1) We use [Flower Datasets](https://flower.dev/docs/datasets/) to download, partition and preprocess the dataset. 2) The fine-tuning is done using the [🤗PEFT](https://huggingface.co/docs/peft/en/index) library. 3) We use Flower's Simulation Engine to simulate the LLM fine-tuning process in federated way,
which allows users to perform the training on a single GPU.

## Set up the project

Start by cloning the example project:

```shell
git clone --depth=1 https://github.com/adap/flower.git _tmp \
&& mv _tmp/examples/flowertune-llm . \
&& rm -rf _tmp \
&& cd flowertune-llm
```

This will create a new directory called `flowertune-llm` with the following structure:

```shell
flowertune-llm
├── flowertune_llm
│ ├── __init__.py
│ ├── client_app.py # Defines your ClientApp
│ ├── server_app.py # Defines your ServerApp
│ ├── dataset.py # Defines your dataset and tokenizer
│ └── models.py # Defines your models
├── pyproject.toml # Project metadata like dependencies and configs
├── test.py # Test pre-trained model
└── README.md
```

### Install dependencies and project

Install the dependencies defined in `pyproject.toml` as well as the `flowertune_llm` package.

```bash
pip install -e .
```

## Run the project

You can run your Flower project in both _simulation_ and _deployment_ mode without making changes to the code. If you are starting with Flower, we recommend you using the _simulation_ mode as it requires fewer components to be launched manually. By default, `flwr run` will make use of the Simulation Engine.

### Run with the Simulation Engine

```bash
flwr run .
```

This command will run FL simulations with a 4-bit [OpenLLaMA 3Bv2](https://huggingface.co/openlm-research/open_llama_3b_v2) model involving 2 clients per rounds for 100 FL rounds. You can override configuration parameters directly from the command line. Below are a few settings you might want to test:

```bash
# Use OpenLLaMA-7B instead of 3B and 8-bits quantization
flwr run . --run-config "model.name='openlm-research/open_llama_7b_v2' model.quantization=8"

# Run for 50 rounds but increasing the fraction of clients that participate per round to 25%
flwr run . --run-config "num-server-rounds=50 strategy.fraction-fit=0.25"
```

### Run with the Deployment Engine

> \[!NOTE\]
> An update to this example will show how to run this Flower application with the Deployment Engine and TLS certificates, or with Docker.
## Expected results

![](_static/train_loss_smooth.png)

As expected, OpenLLaMA-7B model works better than its 3B version with lower training loss. With the hyperparameters tested, the 8-bit model seems to deliver lower training loss for the smaller 3B model compared to its 4-bit version.

## VRAM consumption

| Models | 7-billion (8-bit) | 7-billion (4-bit) | 3-billion (8-bit) | 3-billion (4-bit) |
| :----: | :---------------: | :---------------: | :---------------: | :---------------: |
| VRAM | ~22.00 GB | ~16.50 GB | ~13.50 GB | ~10.60 GB |

We make use of the [bitsandbytes](https://huggingface.co/docs/bitsandbytes/main/en/index) library in conjunction with [PEFT](https://huggingface.co/docs/peft/en/index) to derive LLMs that can be fine-tuned efficiently.
The above table shows the VRAM consumption per client for the different models considered in this example.
You can adjust the CPU/GPU resources you assign to each of the clients based on your device.
For example, it is easy to train 2 concurrent clients on each GPU (24 GB VRAM) if you choose 3-billion (4-bit) model.
Assigning 50% of the GPU's VRAM to each client by setting `options.backend.clientapp-gpus = 0.5` under `[tool.flwr.federations.local-simulation]` in `pyproject.toml`.

## Test with your Questions

We provide a script to test your trained model by passing your specified questions. For example:

```bash
python test.py --peft-path=/path/to/trained-model-dir/ \
--question="What is the ideal 1-day plan in London?"
```

An answer generated from federated trained 7-billion (8-bit) OpenLLaMA model:

```
Great choice.
London has so much to offer, and you can really soak up all the sights and sounds in just a single day.
Here's a suggested itinerary for you.
Start your day off with a hearty breakfast at an authentic British diner.
Then head to the iconic Big Ben and the Houses of Parliament to learn about the history of the city.
Next, make your way to Westminster Abbey to see the many historical monuments and memorials.
From there, cross the river Thames to the Tower of London, which is home to the Crown Jewels of England and Scotland.
Finally, end your day with a relaxing visit to the London Eye, the tallest Ferris wheel in Europe, for a beautiful view of the city.
```

The [`Vicuna`](https://huggingface.co/lmsys/vicuna-13b-v1.1) template we used in this example is for a chat assistant.
The generated answer is expected to be a multi-turn conversations. Feel free to try more interesting questions!
1 change: 1 addition & 0 deletions examples/flowertune-llm/flowertune_llm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""flowertune_llm."""
Original file line number Diff line number Diff line change
@@ -1,21 +1,40 @@
"""flowertune-llm: A Flower / FlowerTune app."""

import os
import warnings
from typing import Dict, Tuple
from collections import OrderedDict
from typing import Callable, Dict, Tuple

import flwr as fl
import torch
from flwr.client import ClientApp, NumPyClient
from flwr.common import Context
from flwr.common.config import unflatten_dict
from flwr.common.typing import NDArrays, Scalar
from omegaconf import DictConfig

from peft import get_peft_model_state_dict, set_peft_model_state_dict
from transformers import TrainingArguments
from trl import SFTTrainer

from models import cosine_annealing, get_model
from flowertune_llm.dataset import (
get_tokenizer_and_data_collator_and_propt_formatting,
load_data,
replace_keys,
)
from flowertune_llm.models import (
cosine_annealing,
get_model,
)

# Avoid warnings
os.environ["TOKENIZERS_PARALLELISM"] = "true"
os.environ["RAY_DISABLE_DOCKER_CPU_WARNING"] = "1"
warnings.filterwarnings("ignore", category=UserWarning)


# pylint: disable=too-many-arguments
class FlowerClient(
fl.client.NumPyClient
): # pylint: disable=too-many-instance-attributes
# pylint: disable=too-many-instance-attributes
class FlowerClient(NumPyClient):
"""Standard Flower client for CNN training."""

def __init__(
Expand All @@ -26,27 +45,20 @@ def __init__(
tokenizer,
formatting_prompts_func,
data_collator,
save_path,
num_rounds,
): # pylint: disable=too-many-arguments
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.train_cfg = train_cfg
self.training_argumnets = TrainingArguments(**train_cfg.training_arguments)
self.tokenizer = tokenizer
self.formatting_prompts_func = formatting_prompts_func
self.data_collator = data_collator
self.save_path = save_path
self.num_rounds = num_rounds
self.trainset = trainset

# instantiate model
self.model = get_model(model_cfg)

self.trainset = trainset

def get_parameters(self, config: Dict[str, Scalar]) -> NDArrays:
"""Return the parameters of the current net."""

state_dict = get_peft_model_state_dict(self.model)
return [val.cpu().numpy() for _, val in state_dict.items()]

def fit(
self, parameters: NDArrays, config: Dict[str, Scalar]
) -> Tuple[NDArrays, int, Dict]:
Expand All @@ -55,13 +67,13 @@ def fit(

new_lr = cosine_annealing(
int(config["current_round"]),
self.train_cfg.num_rounds,
self.num_rounds,
self.train_cfg.learning_rate_max,
self.train_cfg.learning_rate_min,
)

self.training_argumnets.learning_rate = new_lr
self.training_argumnets.output_dir = self.save_path
self.training_argumnets.output_dir = config["save_path"]

# Construct trainer
trainer = SFTTrainer(
Expand All @@ -78,7 +90,7 @@ def fit(
results = trainer.train()

return (
self.get_parameters({}),
get_parameters(self.model),
len(self.trainset),
{"train_loss": results.training_loss},
)
Expand All @@ -92,38 +104,37 @@ def set_parameters(model, parameters: NDArrays) -> None:
set_peft_model_state_dict(model, state_dict)


def gen_client_fn(
fds,
tokenizer,
formatting_prompts_func,
data_collator,
model_cfg: DictConfig,
train_cfg: DictConfig,
save_path: str,
partition_id: int = 0,
api: bool = False,
) -> Callable[[str], FlowerClient]: # pylint: disable=too-many-arguments
"""Generate the client function that creates the Flower Clients."""

def client_fn(cid: str) -> FlowerClient:
"""Create a Flower client representing a single organization."""

# Let's get the partition corresponding to the i-th client
client_trainset = (
fds.load_partition(partition_id, "train")
if api
else fds.load_partition(int(cid), "train")
)
client_trainset = client_trainset.rename_column("output", "response")

return FlowerClient(
model_cfg,
train_cfg,
client_trainset,
tokenizer,
formatting_prompts_func,
data_collator,
save_path,
).to_client()

return client_fn
def get_parameters(model) -> NDArrays:
"""Return the parameters of the current net."""
state_dict = get_peft_model_state_dict(model)
return [val.cpu().numpy() for _, val in state_dict.items()]


def client_fn(context: Context) -> FlowerClient:
"""Create a Flower client representing a single organization."""
partition_id = context.node_config["partition-id"]
num_partitions = context.node_config["num-partitions"]
num_rounds = context.run_config["num-server-rounds"]
cfg = DictConfig(replace_keys(unflatten_dict(context.run_config)))

# Let's get the client partition
client_trainset = load_data(partition_id, num_partitions, cfg.dataset.name)
(
tokenizer,
data_collator,
formatting_prompts_func,
) = get_tokenizer_and_data_collator_and_propt_formatting(cfg.model.name)

return FlowerClient(
cfg.model,
cfg.train,
client_trainset,
tokenizer,
formatting_prompts_func,
data_collator,
num_rounds,
).to_client()


# Flower ClientApp
app = ClientApp(client_fn)
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from transformers import AutoTokenizer
from trl import DataCollatorForCompletionOnlyLM

from flwr_datasets.partitioner import IidPartitioner
from flwr_datasets import FederatedDataset

FDS = None # Cache FederatedDataset


def formatting_prompts_func(example):
output_texts = []
Expand All @@ -27,3 +32,31 @@ def get_tokenizer_and_data_collator_and_propt_formatting(model_name: str):
)

return tokenizer, data_collator, formatting_prompts_func


def load_data(partition_id: int, num_partitions: int, dataset_name: str):
"""Load partition data."""
# Only initialize `FederatedDataset` once
global FDS
if FDS is None:
partitioner = IidPartitioner(num_partitions=num_partitions)
FDS = FederatedDataset(
dataset=dataset_name,
partitioners={"train": partitioner},
)
client_trainset = FDS.load_partition(partition_id, "train")
client_trainset = client_trainset.rename_column("output", "response")

return client_trainset


def replace_keys(input_dict, match="-", target="_"):
"""Recursively replace match string with target string in dictionary keys."""
new_dict = {}
for key, value in input_dict.items():
new_key = key.replace(match, target)
if isinstance(value, dict):
new_dict[new_key] = replace_keys(value, match, target)
else:
new_dict[new_key] = value
return new_dict
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@

import torch
from omegaconf import DictConfig
from peft import LoraConfig, get_peft_model
from peft import (
LoraConfig,
get_peft_model,
)
from peft.utils import prepare_model_for_kbit_training
from transformers import AutoModelForCausalLM, BitsAndBytesConfig

Expand Down
Loading

0 comments on commit 5b3784a

Please sign in to comment.