Skip to content

Commit

Permalink
keep rag folder structure, remove the retrieveal placeholder
Browse files Browse the repository at this point in the history
  • Loading branch information
ZiyueXu77 committed Sep 20, 2024
1 parent 1096beb commit d05af23
Show file tree
Hide file tree
Showing 28 changed files with 1,640 additions and 9 deletions.
3 changes: 1 addition & 2 deletions examples/advanced/rag/README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# Federated Retrieval-Augmented Generation (RAG)
The examples in this directory illustrate how to use [NVIDIA FLARE](https://nvidia.github.io/NVFlare) for RAG tasks, including:
- federated embedding model training
- retrieval-augmented generation with federated context retrieval
- federated embedding model training
15 changes: 10 additions & 5 deletions examples/advanced/rag/embedding/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@ Once we have the iterative training script ready with "starting model" loading c

The major code modifications are for receiving the global model, set it as the starting point for each round's training, and returning the trained model after each local training round.

## Job for NVFlare FL Training
With the local training script ready, we can go ahead to generate the NVFlare job configs by reusing the job templates.
## Federated Training
### With Job Template
With the local training script ready, we can go ahead to generate the NVFlare job configs by using the job templates.

Let's set the job template path with the following command.
```bash
Expand All @@ -63,23 +64,27 @@ nvflare job list_templates
We can see the "sag_pt_deploy_map" template is available, with which we further generate job configs for embedding model training as:
```
nvflare job create -force \
-j "/tmp/embed/nvflare/job" -w "sag_pt_deploy_map" -sd "code" \
-j "/tmp/embed/nvflare/job" -w "sag_pt_deploy_map" -sd "src" \
-f meta.conf min_clients=3 \
-f app_1/config_fed_client.conf app_script="train_fl.py" app_config="--dataset_name nli" \
-f app_2/config_fed_client.conf app_script="train_fl.py" app_config="--dataset_name squad" \
-f app_3/config_fed_client.conf app_script="train_fl.py" app_config="--dataset_name quora" \
-f app_server/config_fed_server.conf model_class_path="st_model.SenTransModel" components[0].args.model.args.model_name="microsoft/mpnet-base" min_clients=3 num_rounds=7 key_metric="eval_loss" negate_key_metric=True
```


For both client and server configs, we only set the necessary task-related parameters tasks, and leave the rest to the default values.

## Federated Training
With the produced job, we run the federated training on a single client using NVFlare Simulator.
```
nvflare simulator -w /tmp/embed/nvflare/workspace -n 3 -t 3 /tmp/embed/nvflare/job
```

### With Python API
Alternatively, we can use the Python API to create and run the federated training job.
```
python3 train_fed.py
```

## Results
The evaluation on two test datasets - [stsb](https://huggingface.co/datasets/sentence-transformers/stsb) with embedding similarity evaluation, and [NLI](https://huggingface.co/datasets/sentence-transformers/all-nli) with triplet accuracy evaluation, are shown below.

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
{
format_version = 2
app_script = "train_fl.py"
app_config = "--dataset_name nli"
executors = [
{
tasks = [
"train"
]
executor {
path = "nvflare.app_opt.pt.client_api_launcher_executor.PTClientAPILauncherExecutor"
args {
launcher_id = "launcher"
pipe_id = "pipe"
heartbeat_timeout = 60
params_exchange_format = "pytorch"
params_transfer_type = "DIFF"
train_with_evaluation = true
}
}
}
]
task_data_filters = []
task_result_filters = []
components = [
{
id = "launcher"
path = "nvflare.app_common.launchers.subprocess_launcher.SubprocessLauncher"
args {
script = "python3 -u custom/{app_script} {app_config} "
launch_once = true
}
}
{
id = "pipe"
path = "nvflare.fuel.utils.pipe.file_pipe.FilePipe"
args {
mode = "PASSIVE"
root_path = "{WORKSPACE}/{JOB_ID}/{SITE_NAME}"
}
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
{
format_version = 2
app_script = "train_fl.py"
app_config = "--dataset_name squad"
executors = [
{
tasks = [
"train"
]
executor {
path = "nvflare.app_opt.pt.client_api_launcher_executor.PTClientAPILauncherExecutor"
args {
launcher_id = "launcher"
pipe_id = "pipe"
heartbeat_timeout = 60
params_exchange_format = "pytorch"
params_transfer_type = "DIFF"
train_with_evaluation = true
}
}
}
]
task_data_filters = []
task_result_filters = []
components = [
{
id = "launcher"
path = "nvflare.app_common.launchers.subprocess_launcher.SubprocessLauncher"
args {
script = "python3 -u custom/{app_script} {app_config} "
launch_once = true
}
}
{
id = "pipe"
path = "nvflare.fuel.utils.pipe.file_pipe.FilePipe"
args {
mode = "PASSIVE"
root_path = "{WORKSPACE}/{JOB_ID}/{SITE_NAME}"
}
}
]
}
26 changes: 26 additions & 0 deletions examples/advanced/rag/embedding/Temp/job/app_2/custom/st_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from sentence_transformers import SentenceTransformer


class SenTransModel(torch.nn.Module):
def __init__(self, model_name):
super(SenTransModel, self).__init__()
self.model = SentenceTransformer(model_name)

def forward(self, input_id):
output = self.model(input_ids=input_id, return_dict=False)
return output
158 changes: 158 additions & 0 deletions examples/advanced/rag/embedding/Temp/job/app_2/custom/train_fl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import copy

from datasets import load_dataset
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, SentenceTransformerTrainingArguments
from sentence_transformers.losses import MultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers
from transformers import trainer_utils

import nvflare.client as flare


def main():
# argparse
parser = argparse.ArgumentParser(description="Train a model on a dataset")
parser.add_argument(
"--model_name",
type=str,
default="microsoft/mpnet-base",
)
parser.add_argument(
"--dataset_name",
type=str,
default="nli",
)
args = parser.parse_args()
model_name = args.model_name
dataset_name = args.dataset_name

# Load a model to finetune with
model = SentenceTransformer(model_name)

# Load training datasets
if dataset_name == "nli":
# (anchor, positive, negative)
dataset_train = load_dataset("sentence-transformers/all-nli", "triplet", split="train[:16000]")
dataset_val = load_dataset("sentence-transformers/all-nli", "triplet", split="dev")
elif dataset_name == "squad":
# (question, answer)
dataset_train = load_dataset("sentence-transformers/squad", split="train[:16000]")
dataset_val = load_dataset("sentence-transformers/squad", split="train[16000:18000]")
elif dataset_name == "quora":
# (anchor, positive)
dataset_train = load_dataset("sentence-transformers/quora-duplicates", "pair", split="train[:16000]")
dataset_val = load_dataset("sentence-transformers/quora-duplicates", "pair", split="train[16000:18000]")
else:
raise ValueError(f"Unknown dataset name: {dataset_name}")

# Load loss function
loss = MultipleNegativesRankingLoss(model)

base_model_name = model_name.split("/")[-1]
output_dir = f"./models/{base_model_name}-{dataset_name}"
unit_train_epochs = 0.25
# Specify training arguments
args = SentenceTransformerTrainingArguments(
# Required parameter:
output_dir=output_dir,
# Optional training parameters:
num_train_epochs=unit_train_epochs,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
learning_rate=1e-6,
lr_scheduler_type="constant",
bf16=True,
batch_sampler=BatchSamplers.NO_DUPLICATES,
# Optional tracking/debugging parameters:
eval_strategy="steps",
eval_steps=50,
save_strategy="steps",
save_steps=50,
save_total_limit=1,
# logging parameters:
logging_dir=f"{output_dir}/logs",
logging_strategy="steps",
logging_steps=50,
report_to="tensorboard",
)

# Define trainer
trainer = SentenceTransformerTrainer(
model=model,
args=args,
train_dataset=dataset_train,
eval_dataset=dataset_val,
loss=loss,
)

# initializes NVFlare client API
flare.init()

while flare.is_running():
# receives FLModel from NVFlare
input_model = flare.receive()
curr_round = input_model.current_round
print(f"current_round={curr_round}")

# Update the key name received from global model if using model def file
global_model = copy.deepcopy(input_model.params)
for key in list(global_model.keys()):
global_model[key.replace("model.", "", 1)] = global_model.pop(key)

# evaluate on received global model
trainer.model.load_state_dict(global_model)
eval_loss_dict = trainer.evaluate()
eval_loss = float(eval_loss_dict["eval_loss"])
print(f"Evaluation loss: {eval_loss}")
# Save the global model
model.save_pretrained(f"{output_dir}/global")

# Train the model
if curr_round == 0:
# First round: start from scratch
trainer.train()
else:
# Subsequent rounds: start from the previous model
# Since we perform iterative training by using "resume" functionality
# we need to replace the resume weights with global weights every round
resume_from_checkpoint_folder = trainer_utils.get_last_checkpoint(trainer.args.output_dir)
# update local record with global model weights
trainer.model.save_pretrained(resume_from_checkpoint_folder)
# increment the number of training epochs so that the trainer will continue training
args.num_train_epochs += unit_train_epochs
# continue training
trainer.train(resume_from_checkpoint=True)

# update the key name sent to global model
out_param = trainer.model.state_dict()
for key in list(out_param.keys()):
out_param["model." + key] = out_param.pop(key).cpu()
num_steps = trainer.train_dataset.num_rows * unit_train_epochs

# construct trained FL model
output_model = flare.FLModel(
params=out_param,
metrics={"eval_loss": eval_loss},
meta={"NUM_STEPS_CURRENT_ROUND": num_steps},
)
# send model back to NVFlare
flare.send(output_model)


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
{
format_version = 2
app_script = "train_fl.py"
app_config = "--dataset_name quora"
executors = [
{
tasks = [
"train"
]
executor {
path = "nvflare.app_opt.pt.client_api_launcher_executor.PTClientAPILauncherExecutor"
args {
launcher_id = "launcher"
pipe_id = "pipe"
heartbeat_timeout = 60
params_exchange_format = "pytorch"
params_transfer_type = "DIFF"
train_with_evaluation = true
}
}
}
]
task_data_filters = []
task_result_filters = []
components = [
{
id = "launcher"
path = "nvflare.app_common.launchers.subprocess_launcher.SubprocessLauncher"
args {
script = "python3 -u custom/{app_script} {app_config} "
launch_once = true
}
}
{
id = "pipe"
path = "nvflare.fuel.utils.pipe.file_pipe.FilePipe"
args {
mode = "PASSIVE"
root_path = "{WORKSPACE}/{JOB_ID}/{SITE_NAME}"
}
}
]
}
26 changes: 26 additions & 0 deletions examples/advanced/rag/embedding/Temp/job/app_3/custom/st_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from sentence_transformers import SentenceTransformer


class SenTransModel(torch.nn.Module):
def __init__(self, model_name):
super(SenTransModel, self).__init__()
self.model = SentenceTransformer(model_name)

def forward(self, input_id):
output = self.model(input_ids=input_id, return_dict=False)
return output
Loading

0 comments on commit d05af23

Please sign in to comment.