Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fedrag example with embedding training #2915

Merged
merged 15 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions examples/advanced/rag/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Federated Retrieval-Augmented Generation (RAG)
The examples in this directory illustrate how to use [NVIDIA FLARE](https://nvidia.github.io/NVFlare) for RAG tasks, including:
- federated embedding model training
- retrieval-augmented generation with federated context retrieval
94 changes: 94 additions & 0 deletions examples/advanced/rag/embedding/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Embedding Model Tuning via SentenceTransformers Trainer
This example shows how to use [NVIDIA FLARE](https://nvidia.github.io/NVFlare) for embedding tuning tasks, a critical component of Retrieval-Augmented Generation (RAG).

It illustrates how to adapt a local training script with [SentenceTransformers](https://github.com/UKPLab/sentence-transformers) trainer to NVFlare.

## Introduction
[SentenceTransformers](https://sbert.net/) is a widely used framework for computing dense vector representations for texts.
The models are based on transformer, achieving state-of-the-art performance in various tasks.

One major application is to embed the text in vector space for later clustering and/or retrieval using similarity metrics.

This example illustrates a supervised fine-tuning (SFT) scheme for an embedding model with various training datasets.

## Setup
Please make sure you set up virtual environment following [example root readme](../../../README.md).
Install additional requirements (if you already have a specific version of nvflare installed in your environment, you may want to remove nvflare in the requirements to avoid reinstalling nvflare):
```
python3 -m pip install -r requirements.txt
```
Models and data will be loaded directly from Huggingface, so no need to download them manually.

## Centralized Training
### Single-session training
Centralized trainings, as the baseline for comparison with FL results, are done with the following command:
```
bash train_single_session.sh
```

### Adaptation Step 1: iterative training
To adapt the centralized training script to federated application, under `launch_once = true` setting, we first need to "break" the single call to `trainer.train()` into iterative calls, one for each round of training.
For this purpose, we provided `utils/train_iterative.py` as an example, which is a modified version of `utils/train_single_session.py`.

In the iterative training script, the `trainer.train()` call is replaced by a `for` loop, and the training epochs are split into six rounds, `unit_train_epochs = 0.25` epoch per round, in total `0.25 * 6 = 1.5` epochs, same as single session setting.

The first round is trained with `trainer.train()`, then from the second round,
we call `trainer.train(resume_from_checkpoint=True)` with `args.num_train_epochs` incremented by `unit_train_epochs` to continue training from the last checkpoint.

To run iterative training, we use the following command:
```
bash train_iterative.sh
```

The training loss curves are shown below, single session and iterative scripts align with each other.

![iter_single](./figs/iter_single.png)

### Adaptation Step 2: federated with NVFlare
Once we have the iterative training script ready with "starting model" loading capability, it can be easily adapted to a NVFlare trainer by using [Client API](../../../hello-world/ml-to-fl/pt/README.md).

The major code modifications are for receiving the global model, set it as the starting point for each round's training, and returning the trained model after each local training round.

## Job for NVFlare FL Training
With the local training script ready, we can go ahead to generate the NVFlare job configs by reusing the job templates.

Let's set the job template path with the following command.
```bash
nvflare config -jt ./job_template/
```
Then we can check the available templates with the following command.
```bash
nvflare job list_templates
```
We can see the "sag_pt_deploy_map" template is available, with which we further generate job configs for embedding model training as:
```
nvflare job create -force \
-j "/tmp/embed/nvflare/job" -w "sag_pt_deploy_map" -sd "code" \
-f meta.conf min_clients=3 \
-f app_1/config_fed_client.conf app_script="train_fl.py" app_config="--dataset_name nli" \
-f app_2/config_fed_client.conf app_script="train_fl.py" app_config="--dataset_name squad" \
-f app_3/config_fed_client.conf app_script="train_fl.py" app_config="--dataset_name quora" \
-f app_server/config_fed_server.conf model_class_path="st_model.SenTransModel" components[0].args.model.args.model_name="microsoft/mpnet-base" min_clients=3 num_rounds=7 key_metric="eval_loss" negate_key_metric=True
```
ZiyueXu77 marked this conversation as resolved.
Show resolved Hide resolved


For both client and server configs, we only set the necessary task-related parameters tasks, and leave the rest to the default values.

## Federated Training
With the produced job, we run the federated training on a single client using NVFlare Simulator.
```
nvflare simulator -w /tmp/embed/nvflare/workspace -n 3 -t 3 /tmp/embed/nvflare/job
```

## Results
The evaluation on two test datasets - [stsb](https://huggingface.co/datasets/sentence-transformers/stsb) with embedding similarity evaluation, and [NLI](https://huggingface.co/datasets/sentence-transformers/all-nli) with triplet accuracy evaluation, are shown below.

TrainData | STSB_pearson_cos | STSB_spearman_euc | NLI_cos_acc | NLI_euc_acc
--- |------------------|-------------------|-------------| ---
NLI | 0.7586 | 0.7895 | 0.8033 | 0.8045
Squad | 0.8206 | 0.8154 | 0.8051 | 0.8042
Quora | 0.8161 | 0.8121 | 0.7891 | 0.7854
All | 0.8497 | 0.8523 | 0.8426 | 0.8384
Federated | 0.8444 | 0.8368 | 0.8269 | 0.8246

As shown, the federated training results are better than individual site's, and can be close to the centralized training results, demonstrating the effectiveness of NVFlare in embedding model tuning tasks.
ZiyueXu77 marked this conversation as resolved.
Show resolved Hide resolved
26 changes: 26 additions & 0 deletions examples/advanced/rag/embedding/code/st_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from sentence_transformers import SentenceTransformer


class SenTransModel(torch.nn.Module):
def __init__(self, model_name):
super(SenTransModel, self).__init__()
self.model = SentenceTransformer(model_name)

def forward(self, input_id):
output = self.model(input_ids=input_id, return_dict=False)
return output
158 changes: 158 additions & 0 deletions examples/advanced/rag/embedding/code/train_fl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import copy

from datasets import load_dataset
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, SentenceTransformerTrainingArguments
from sentence_transformers.losses import MultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers
from transformers import trainer_utils

import nvflare.client as flare


def main():
# argparse
parser = argparse.ArgumentParser(description="Train a model on a dataset")
parser.add_argument(
"--model_name",
type=str,
default="microsoft/mpnet-base",
)
parser.add_argument(
"--dataset_name",
type=str,
default="nli",
)
args = parser.parse_args()
model_name = args.model_name
dataset_name = args.dataset_name

# Load a model to finetune with
model = SentenceTransformer(model_name)

# Load training datasets
if dataset_name == "nli":
# (anchor, positive, negative)
dataset_train = load_dataset("sentence-transformers/all-nli", "triplet", split="train[:16000]")
dataset_val = load_dataset("sentence-transformers/all-nli", "triplet", split="dev")
elif dataset_name == "squad":
# (question, answer)
dataset_train = load_dataset("sentence-transformers/squad", split="train[:16000]")
dataset_val = load_dataset("sentence-transformers/squad", split="train[16000:18000]")
elif dataset_name == "quora":
# (anchor, positive)
dataset_train = load_dataset("sentence-transformers/quora-duplicates", "pair", split="train[:16000]")
dataset_val = load_dataset("sentence-transformers/quora-duplicates", "pair", split="train[16000:18000]")
else:
raise ValueError(f"Unknown dataset name: {dataset_name}")

# Load loss function
loss = MultipleNegativesRankingLoss(model)

base_model_name = model_name.split("/")[-1]
output_dir = f"./models/{base_model_name}-{dataset_name}"
unit_train_epochs = 0.25
# Specify training arguments
args = SentenceTransformerTrainingArguments(
# Required parameter:
output_dir=output_dir,
# Optional training parameters:
num_train_epochs=unit_train_epochs,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
learning_rate=1e-6,
lr_scheduler_type="constant",
bf16=True,
batch_sampler=BatchSamplers.NO_DUPLICATES,
# Optional tracking/debugging parameters:
eval_strategy="steps",
eval_steps=50,
save_strategy="steps",
save_steps=50,
save_total_limit=1,
# logging parameters:
logging_dir=f"{output_dir}/logs",
logging_strategy="steps",
logging_steps=50,
report_to="tensorboard",
)

# Define trainer
trainer = SentenceTransformerTrainer(
model=model,
args=args,
train_dataset=dataset_train,
eval_dataset=dataset_val,
loss=loss,
)

# initializes NVFlare client API
flare.init()

while flare.is_running():
# receives FLModel from NVFlare
input_model = flare.receive()
curr_round = input_model.current_round
print(f"current_round={curr_round}")

# Update the key name received from global model if using model def file
global_model = copy.deepcopy(input_model.params)
for key in list(global_model.keys()):
global_model[key.replace("model.", "", 1)] = global_model.pop(key)

# evaluate on received global model
trainer.model.load_state_dict(global_model)
eval_loss_dict = trainer.evaluate()
eval_loss = float(eval_loss_dict["eval_loss"])
print(f"Evaluation loss: {eval_loss}")
# Save the global model
model.save_pretrained(f"{output_dir}/global")

# Train the model
if curr_round == 0:
# First round: start from scratch
trainer.train()
else:
# Subsequent rounds: start from the previous model
# Since we perform iterative training by using "resume" functionality
# we need to replace the resume weights with global weights every round
resume_from_checkpoint_folder = trainer_utils.get_last_checkpoint(trainer.args.output_dir)
# update local record with global model weights
trainer.model.save_pretrained(resume_from_checkpoint_folder)
# increment the number of training epochs so that the trainer will continue training
args.num_train_epochs += unit_train_epochs
# continue training
trainer.train(resume_from_checkpoint=True)

# update the key name sent to global model
out_param = trainer.model.state_dict()
for key in list(out_param.keys()):
out_param["model." + key] = out_param.pop(key).cpu()
num_steps = trainer.train_dataset.num_rows * unit_train_epochs

# construct trained FL model
output_model = flare.FLModel(
params=out_param,
metrics={"eval_loss": eval_loss},
meta={"NUM_STEPS_CURRENT_ROUND": num_steps},
)
# send model back to NVFlare
flare.send(output_model)


if __name__ == "__main__":
main()
7 changes: 7 additions & 0 deletions examples/advanced/rag/embedding/eval_all.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
for dataset_name in nli squad quora all
do
echo "Evaluation on ${dataset_name} with model"
python utils/eval_model.py --model_path /tmp/embed/cen/models_single/mpnet-base-${dataset_name}/final
done

python utils/eval_model.py --model_path /tmp/embed/nvflare/workspace/site-1/simulate_job/app_site-1/models/mpnet-base-nli/global
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
{
ZiyueXu77 marked this conversation as resolved.
Show resolved Hide resolved
# version of the configuration
format_version = 2

# This is the application script which will be invoked. Client can replace this script with user's own training script.
app_script = "cifar10.py"

# Additional arguments needed by the training code. For example, in lightning, these can be --trainer.batch_size=xxx.
app_config = ""

# Client Computing Executors.
executors = [
{
# tasks the executors are defined to handle
tasks = ["train"]

# This particular executor
executor {

# This is an executor for Client API. The underline data exchange is using Pipe.
path = "nvflare.app_opt.pt.client_api_launcher_executor.PTClientAPILauncherExecutor"

args {
# launcher_id is used to locate the Launcher object in "components"
launcher_id = "launcher"

# pipe_id is used to locate the Pipe object in "components"
pipe_id = "pipe"

# Timeout in seconds for waiting for a heartbeat from the training script. Defaults to 30 seconds.
# Please refer to the class docstring for all available arguments
heartbeat_timeout = 60

# format of the exchange parameters
params_exchange_format = "pytorch"

# if the transfer_type is FULL, then it will be sent directly
# if the transfer_type is DIFF, then we will calculate the
# difference VS received parameters and send the difference
params_transfer_type = "DIFF"

# if train_with_evaluation is true, the executor will expect
# the custom code need to send back both the trained parameters and the evaluation metric
# otherwise only trained parameters are expected
train_with_evaluation = true
}
}
}
],

# this defined an array of task data filters. If provided, it will control the data from server controller to client executor
task_data_filters = []

# this defined an array of task result filters. If provided, it will control the result from client executor to server controller
task_result_filters = []

components = [
{
# This "launcher" component
id = "launcher"

# the class path of the component
path = "nvflare.app_common.launchers.subprocess_launcher.SubprocessLauncher"

args {
# the launcher will invoke the script
script = "python3 -u custom/{app_script} {app_config} "
# if launch_once is true, the SubprocessLauncher will launch once for the whole job
# if launch_once is false, the SubprocessLauncher will launch a process for each task it receives from server
launch_once = true
}
}
{
id = "pipe"

path = "nvflare.fuel.utils.pipe.file_pipe.FilePipe"

args {
# Mode of the endpoint. A pipe has two endpoints.
# An endpoint can be either the one that initiates communication or the one listening.
# PASSIVE is the one listening.
mode = "PASSIVE"

# root_path: is the directory location of the data exchange.
# If empty string, it will be set to the app_dir of the running job.
# You can also set it to an absolute path in your system.
root_path = "{WORKSPACE}/{JOB_ID}/{SITE_NAME}"
}
}
]
}
Loading
Loading