Skip to content

Commit

Permalink
Merge branch 'records' of https://github.com/adap/flower into records
Browse files Browse the repository at this point in the history
  • Loading branch information
panh99 committed Jan 22, 2024
2 parents b61fbb4 + 1267d21 commit 5db284e
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 34 deletions.
6 changes: 3 additions & 3 deletions examples/quickstart-huggingface/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Federated HuggingFace Transformers using Flower and PyTorch

This introductory example to using [HuggingFace](https://huggingface.co) Transformers with Flower with PyTorch. This example has been extended from the [quickstart-pytorch](https://flower.dev/docs/examples/quickstart-pytorch.html) example. The training script closely follows the [HuggingFace course](https://huggingface.co/course/chapter3?fw=pt), so you are encouraged to check that out for detailed explaination for the transformer pipeline.
This introductory example to using [HuggingFace](https://huggingface.co) Transformers with Flower with PyTorch. This example has been extended from the [quickstart-pytorch](https://flower.dev/docs/examples/quickstart-pytorch.html) example. The training script closely follows the [HuggingFace course](https://huggingface.co/course/chapter3?fw=pt), so you are encouraged to check that out for a detailed explanation of the transformer pipeline.

Like `quickstart-pytorch`, running this example in itself is also meant to be quite easy.

Expand Down Expand Up @@ -62,13 +62,13 @@ Now you are ready to start the Flower clients which will participate in the lear
Start client 1 in the first terminal:

```shell
python3 client.py
python3 client.py --node-id 0
```

Start client 2 in the second terminal:

```shell
python3 client.py
python3 client.py --node-id 1
```

You will see that PyTorch is starting a federated training.
60 changes: 30 additions & 30 deletions examples/quickstart-huggingface/client.py
Original file line number Diff line number Diff line change
@@ -1,58 +1,48 @@
from collections import OrderedDict
import argparse
import warnings
from collections import OrderedDict

import flwr as fl
import torch
import numpy as np

import random
from torch.utils.data import DataLoader

from datasets import load_dataset
from evaluate import load as load_metric

from transformers import AutoTokenizer, DataCollatorWithPadding
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import AutoModelForSequenceClassification
from transformers import AdamW
from transformers import AutoTokenizer, DataCollatorWithPadding

from flwr_datasets import FederatedDataset

warnings.filterwarnings("ignore", category=UserWarning)
DEVICE = torch.device("cpu")
CHECKPOINT = "distilbert-base-uncased" # transformer model checkpoint


def load_data():
def load_data(node_id):
"""Load IMDB data (training and eval)"""
raw_datasets = load_dataset("imdb")
raw_datasets = raw_datasets.shuffle(seed=42)

# remove unnecessary data split
del raw_datasets["unsupervised"]
fds = FederatedDataset(dataset="imdb", partitioners={"train": 1_000})
partition = fds.load_partition(node_id)
# Divide data: 80% train, 20% test
partition_train_test = partition.train_test_split(test_size=0.2)

tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT)

def tokenize_function(examples):
return tokenizer(examples["text"], truncation=True)

# random 100 samples
population = random.sample(range(len(raw_datasets["train"])), 100)

tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
tokenized_datasets["train"] = tokenized_datasets["train"].select(population)
tokenized_datasets["test"] = tokenized_datasets["test"].select(population)

tokenized_datasets = tokenized_datasets.remove_columns("text")
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
partition_train_test = partition_train_test.map(tokenize_function, batched=True)
partition_train_test = partition_train_test.remove_columns("text")
partition_train_test = partition_train_test.rename_column("label", "labels")

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
trainloader = DataLoader(
tokenized_datasets["train"],
partition_train_test["train"],
shuffle=True,
batch_size=32,
collate_fn=data_collator,
)

testloader = DataLoader(
tokenized_datasets["test"], batch_size=32, collate_fn=data_collator
partition_train_test["test"], batch_size=32, collate_fn=data_collator
)

return trainloader, testloader
Expand Down Expand Up @@ -88,12 +78,12 @@ def test(net, testloader):
return loss, accuracy


def main():
def main(node_id):
net = AutoModelForSequenceClassification.from_pretrained(
CHECKPOINT, num_labels=2
).to(DEVICE)

trainloader, testloader = load_data()
trainloader, testloader = load_data(node_id)

# Flower client
class IMDBClient(fl.client.NumPyClient):
Expand Down Expand Up @@ -122,4 +112,14 @@ def evaluate(self, parameters, config):


if __name__ == "__main__":
main()
parser = argparse.ArgumentParser(description="Flower")
parser.add_argument(
"--node-id",
choices=list(range(1_000)),
required=True,
type=int,
help="Partition of the dataset divided into 1,000 iid partitions created "
"artificially.",
)
node_id = parser.parse_args().node_id
main(node_id)
1 change: 1 addition & 0 deletions examples/quickstart-huggingface/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ authors = [
[tool.poetry.dependencies]
python = ">=3.8,<3.11"
flwr = ">=1.0,<2.0"
flwr-datasets = ">=0.0.2,<1.0.0"
torch = ">=1.13.1,<2.0"
transformers = ">=4.30.0,<5.0"
evaluate = ">=0.4.0,<1.0"
Expand Down
1 change: 1 addition & 0 deletions examples/quickstart-huggingface/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
flwr>=1.0, <2.0
flwr-datasets>=0.0.2, <1.0.0
torch>=1.13.1, <2.0
transformers>=4.30.0, <5.0
evaluate>=0.4.0, <1.0
Expand Down
2 changes: 1 addition & 1 deletion examples/quickstart-huggingface/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ sleep 3 # Sleep for 3s to give the server enough time to start

for i in `seq 0 1`; do
echo "Starting client $i"
python client.py &
python client.py --node-id ${i}&
done

# This will allow you to use CTRL+C to stop all background processes
Expand Down

0 comments on commit 5db284e

Please sign in to comment.