From 7cb65d70d5f9ac8caed9d9ee95096dbe9a33b044 Mon Sep 17 00:00:00 2001 From: Javier Date: Wed, 11 Sep 2024 11:58:47 +0200 Subject: [PATCH] refactor(framework) Update `huggingface` template for `flwr new` (#4169) Co-authored-by: Chong Shen Ng Co-authored-by: Daniel J. Beutel --- .../app/code/client.huggingface.py.tpl | 48 ++++++++----------- .../app/code/server.huggingface.py.tpl | 21 ++++++-- .../app/code/task.huggingface.py.tpl | 29 ++++++----- .../app/pyproject.huggingface.toml.tpl | 10 +++- 4 files changed, 62 insertions(+), 46 deletions(-) diff --git a/src/py/flwr/cli/new/templates/app/code/client.huggingface.py.tpl b/src/py/flwr/cli/new/templates/app/code/client.huggingface.py.tpl index 3041a69e3aaa..840f938b4ecc 100644 --- a/src/py/flwr/cli/new/templates/app/code/client.huggingface.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/client.huggingface.py.tpl @@ -1,18 +1,11 @@ """$project_name: A Flower / $framework_str app.""" +import torch from flwr.client import ClientApp, NumPyClient from flwr.common import Context from transformers import AutoModelForSequenceClassification -from $import_name.task import ( - get_weights, - load_data, - set_weights, - train, - test, - CHECKPOINT, - DEVICE, -) +from $import_name.task import get_weights, load_data, set_weights, test, train # Flower client @@ -22,37 +15,34 @@ class FlowerClient(NumPyClient): self.trainloader = trainloader self.testloader = testloader self.local_epochs = local_epochs - - def get_parameters(self, config): - return get_weights(self.net) - - def set_parameters(self, parameters): - set_weights(self.net, parameters) + self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + self.net.to(self.device) def fit(self, parameters, config): - self.set_parameters(parameters) - train( - self.net, - self.trainloader, - epochs=self.local_epochs, - ) - return self.get_parameters(config={}), len(self.trainloader), {} + set_weights(self.net, parameters) + train(self.net, self.trainloader, epochs=self.local_epochs, device=self.device) + return get_weights(self.net), len(self.trainloader), {} def evaluate(self, parameters, config): - self.set_parameters(parameters) - loss, accuracy = test(self.net, self.testloader) + set_weights(self.net, parameters) + loss, accuracy = test(self.net, self.testloader, self.device) return float(loss), len(self.testloader), {"accuracy": accuracy} def client_fn(context: Context): - # Load model and data - net = AutoModelForSequenceClassification.from_pretrained( - CHECKPOINT, num_labels=2 - ).to(DEVICE) + # Get this client's dataset partition partition_id = context.node_config["partition-id"] num_partitions = context.node_config["num-partitions"] - trainloader, valloader = load_data(partition_id, num_partitions) + model_name = context.run_config["model-name"] + trainloader, valloader = load_data(partition_id, num_partitions, model_name) + + # Load model + num_labels = context.run_config["num-labels"] + net = AutoModelForSequenceClassification.from_pretrained( + model_name, num_labels=num_labels + ) + local_epochs = context.run_config["local-epochs"] # Return Client instance diff --git a/src/py/flwr/cli/new/templates/app/code/server.huggingface.py.tpl b/src/py/flwr/cli/new/templates/app/code/server.huggingface.py.tpl index 5491f6616160..16f94f0a64e9 100644 --- a/src/py/flwr/cli/new/templates/app/code/server.huggingface.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/server.huggingface.py.tpl @@ -1,18 +1,33 @@ """$project_name: A Flower / $framework_str app.""" -from flwr.common import Context -from flwr.server.strategy import FedAvg +from flwr.common import Context, ndarrays_to_parameters from flwr.server import ServerApp, ServerAppComponents, ServerConfig +from flwr.server.strategy import FedAvg +from transformers import AutoModelForSequenceClassification + +from $import_name.task import get_weights def server_fn(context: Context): # Read from config num_rounds = context.run_config["num-server-rounds"] + fraction_fit = context.run_config["fraction-fit"] + + # Initialize global model + model_name = context.run_config["model-name"] + num_labels = context.run_config["num-labels"] + net = AutoModelForSequenceClassification.from_pretrained( + model_name, num_labels=num_labels + ) + + weights = get_weights(net) + initial_parameters = ndarrays_to_parameters(weights) # Define strategy strategy = FedAvg( - fraction_fit=1.0, + fraction_fit=fraction_fit, fraction_evaluate=1.0, + initial_parameters=initial_parameters, ) config = ServerConfig(num_rounds=num_rounds) diff --git a/src/py/flwr/cli/new/templates/app/code/task.huggingface.py.tpl b/src/py/flwr/cli/new/templates/app/code/task.huggingface.py.tpl index ad52e2c3fe21..1c50e85d7103 100644 --- a/src/py/flwr/cli/new/templates/app/code/task.huggingface.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/task.huggingface.py.tpl @@ -4,24 +4,25 @@ import warnings from collections import OrderedDict import torch +import transformers +from datasets.utils.logging import disable_progress_bar from evaluate import load as load_metric +from flwr_datasets import FederatedDataset +from flwr_datasets.partitioner import IidPartitioner from torch.optim import AdamW from torch.utils.data import DataLoader from transformers import AutoTokenizer, DataCollatorWithPadding -from flwr_datasets import FederatedDataset -from flwr_datasets.partitioner import IidPartitioner - - warnings.filterwarnings("ignore", category=UserWarning) -DEVICE = torch.device("cpu") -CHECKPOINT = "distilbert-base-uncased" # transformer model checkpoint +warnings.filterwarnings("ignore", category=FutureWarning) +disable_progress_bar() +transformers.logging.set_verbosity_error() fds = None # Cache FederatedDataset -def load_data(partition_id: int, num_partitions: int): +def load_data(partition_id: int, num_partitions: int, model_name: str): """Load IMDB data (training and eval)""" # Only initialize `FederatedDataset` once global fds @@ -35,10 +36,12 @@ def load_data(partition_id: int, num_partitions: int): # Divide data: 80% train, 20% test partition_train_test = partition.train_test_split(test_size=0.2, seed=42) - tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT) + tokenizer = AutoTokenizer.from_pretrained(model_name) def tokenize_function(examples): - return tokenizer(examples["text"], truncation=True) + return tokenizer( + examples["text"], truncation=True, add_special_tokens=True, max_length=512 + ) partition_train_test = partition_train_test.map(tokenize_function, batched=True) partition_train_test = partition_train_test.remove_columns("text") @@ -59,12 +62,12 @@ def load_data(partition_id: int, num_partitions: int): return trainloader, testloader -def train(net, trainloader, epochs): +def train(net, trainloader, epochs, device): optimizer = AdamW(net.parameters(), lr=5e-5) net.train() for _ in range(epochs): for batch in trainloader: - batch = {k: v.to(DEVICE) for k, v in batch.items()} + batch = {k: v.to(device) for k, v in batch.items()} outputs = net(**batch) loss = outputs.loss loss.backward() @@ -72,12 +75,12 @@ def train(net, trainloader, epochs): optimizer.zero_grad() -def test(net, testloader): +def test(net, testloader, device): metric = load_metric("accuracy") loss = 0 net.eval() for batch in testloader: - batch = {k: v.to(DEVICE) for k, v in batch.items()} + batch = {k: v.to(device) for k, v in batch.items()} with torch.no_grad(): outputs = net(**batch) logits = outputs.logits diff --git a/src/py/flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl index 15dc2af87a3f..af1e4d005114 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl @@ -8,7 +8,7 @@ version = "1.0.0" description = "" license = "Apache-2.0" dependencies = [ - "flwr[simulation]>=1.10.0", + "flwr[simulation]>=1.11.0", "flwr-datasets>=0.3.0", "torch==2.2.1", "transformers>=4.30.0,<5.0", @@ -29,10 +29,18 @@ clientapp = "$import_name.client_app:app" [tool.flwr.app.config] num-server-rounds = 3 +fraction-fit = 0.5 local-epochs = 1 +model-name = "prajjwal1/bert-tiny" # Set a larger model if you have access to more GPU resources +num-labels = 2 [tool.flwr.federations] default = "localhost" [tool.flwr.federations.localhost] options.num-supernodes = 10 + +[tool.flwr.federations.localhost-gpu] +options.num-supernodes = 10 +options.backend.client-resources.num-cpus = 4 # each ClientApp assumes to use 4CPUs +options.backend.client-resources.num-gpus = 0.25 # at most 4 ClientApps will run in a given GPU