Skip to content

Commit

Permalink
refactor(framework) Update huggingface template for flwr new (#4169)
Browse files Browse the repository at this point in the history
Co-authored-by: Chong Shen Ng <[email protected]>
Co-authored-by: Daniel J. Beutel <[email protected]>
  • Loading branch information
3 people authored Sep 11, 2024
1 parent c85417a commit 5dffa05
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 46 deletions.
48 changes: 19 additions & 29 deletions src/py/flwr/cli/new/templates/app/code/client.huggingface.py.tpl
Original file line number Diff line number Diff line change
@@ -1,18 +1,11 @@
"""$project_name: A Flower / $framework_str app."""

import torch
from flwr.client import ClientApp, NumPyClient
from flwr.common import Context
from transformers import AutoModelForSequenceClassification

from $import_name.task import (
get_weights,
load_data,
set_weights,
train,
test,
CHECKPOINT,
DEVICE,
)
from $import_name.task import get_weights, load_data, set_weights, test, train


# Flower client
Expand All @@ -22,37 +15,34 @@ class FlowerClient(NumPyClient):
self.trainloader = trainloader
self.testloader = testloader
self.local_epochs = local_epochs

def get_parameters(self, config):
return get_weights(self.net)

def set_parameters(self, parameters):
set_weights(self.net, parameters)
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.net.to(self.device)

def fit(self, parameters, config):
self.set_parameters(parameters)
train(
self.net,
self.trainloader,
epochs=self.local_epochs,
)
return self.get_parameters(config={}), len(self.trainloader), {}
set_weights(self.net, parameters)
train(self.net, self.trainloader, epochs=self.local_epochs, device=self.device)
return get_weights(self.net), len(self.trainloader), {}

def evaluate(self, parameters, config):
self.set_parameters(parameters)
loss, accuracy = test(self.net, self.testloader)
set_weights(self.net, parameters)
loss, accuracy = test(self.net, self.testloader, self.device)
return float(loss), len(self.testloader), {"accuracy": accuracy}


def client_fn(context: Context):
# Load model and data
net = AutoModelForSequenceClassification.from_pretrained(
CHECKPOINT, num_labels=2
).to(DEVICE)

# Get this client's dataset partition
partition_id = context.node_config["partition-id"]
num_partitions = context.node_config["num-partitions"]
trainloader, valloader = load_data(partition_id, num_partitions)
model_name = context.run_config["model-name"]
trainloader, valloader = load_data(partition_id, num_partitions, model_name)

# Load model
num_labels = context.run_config["num-labels"]
net = AutoModelForSequenceClassification.from_pretrained(
model_name, num_labels=num_labels
)

local_epochs = context.run_config["local-epochs"]

# Return Client instance
Expand Down
21 changes: 18 additions & 3 deletions src/py/flwr/cli/new/templates/app/code/server.huggingface.py.tpl
Original file line number Diff line number Diff line change
@@ -1,18 +1,33 @@
"""$project_name: A Flower / $framework_str app."""

from flwr.common import Context
from flwr.server.strategy import FedAvg
from flwr.common import Context, ndarrays_to_parameters
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
from flwr.server.strategy import FedAvg
from transformers import AutoModelForSequenceClassification

from $import_name.task import get_weights


def server_fn(context: Context):
# Read from config
num_rounds = context.run_config["num-server-rounds"]
fraction_fit = context.run_config["fraction-fit"]

# Initialize global model
model_name = context.run_config["model-name"]
num_labels = context.run_config["num-labels"]
net = AutoModelForSequenceClassification.from_pretrained(
model_name, num_labels=num_labels
)

weights = get_weights(net)
initial_parameters = ndarrays_to_parameters(weights)

# Define strategy
strategy = FedAvg(
fraction_fit=1.0,
fraction_fit=fraction_fit,
fraction_evaluate=1.0,
initial_parameters=initial_parameters,
)
config = ServerConfig(num_rounds=num_rounds)

Expand Down
29 changes: 16 additions & 13 deletions src/py/flwr/cli/new/templates/app/code/task.huggingface.py.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,25 @@ import warnings
from collections import OrderedDict

import torch
import transformers
from datasets.utils.logging import disable_progress_bar
from evaluate import load as load_metric
from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import IidPartitioner
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, DataCollatorWithPadding

from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import IidPartitioner


warnings.filterwarnings("ignore", category=UserWarning)
DEVICE = torch.device("cpu")
CHECKPOINT = "distilbert-base-uncased" # transformer model checkpoint
warnings.filterwarnings("ignore", category=FutureWarning)
disable_progress_bar()
transformers.logging.set_verbosity_error()


fds = None # Cache FederatedDataset


def load_data(partition_id: int, num_partitions: int):
def load_data(partition_id: int, num_partitions: int, model_name: str):
"""Load IMDB data (training and eval)"""
# Only initialize `FederatedDataset` once
global fds
Expand All @@ -35,10 +36,12 @@ def load_data(partition_id: int, num_partitions: int):
# Divide data: 80% train, 20% test
partition_train_test = partition.train_test_split(test_size=0.2, seed=42)

tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT)
tokenizer = AutoTokenizer.from_pretrained(model_name)

def tokenize_function(examples):
return tokenizer(examples["text"], truncation=True)
return tokenizer(
examples["text"], truncation=True, add_special_tokens=True, max_length=512
)

partition_train_test = partition_train_test.map(tokenize_function, batched=True)
partition_train_test = partition_train_test.remove_columns("text")
Expand All @@ -59,25 +62,25 @@ def load_data(partition_id: int, num_partitions: int):
return trainloader, testloader


def train(net, trainloader, epochs):
def train(net, trainloader, epochs, device):
optimizer = AdamW(net.parameters(), lr=5e-5)
net.train()
for _ in range(epochs):
for batch in trainloader:
batch = {k: v.to(DEVICE) for k, v in batch.items()}
batch = {k: v.to(device) for k, v in batch.items()}
outputs = net(**batch)
loss = outputs.loss
loss.backward()
optimizer.step()
optimizer.zero_grad()


def test(net, testloader):
def test(net, testloader, device):
metric = load_metric("accuracy")
loss = 0
net.eval()
for batch in testloader:
batch = {k: v.to(DEVICE) for k, v in batch.items()}
batch = {k: v.to(device) for k, v in batch.items()}
with torch.no_grad():
outputs = net(**batch)
logits = outputs.logits
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ version = "1.0.0"
description = ""
license = "Apache-2.0"
dependencies = [
"flwr[simulation]>=1.10.0",
"flwr[simulation]>=1.11.0",
"flwr-datasets>=0.3.0",
"torch==2.2.1",
"transformers>=4.30.0,<5.0",
Expand All @@ -29,10 +29,18 @@ clientapp = "$import_name.client_app:app"
[tool.flwr.app.config]
num-server-rounds = 3
fraction-fit = 0.5
local-epochs = 1
model-name = "prajjwal1/bert-tiny" # Set a larger model if you have access to more GPU resources
num-labels = 2
[tool.flwr.federations]
default = "localhost"
[tool.flwr.federations.localhost]
options.num-supernodes = 10
[tool.flwr.federations.localhost-gpu]
options.num-supernodes = 10
options.backend.client-resources.num-cpus = 4 # each ClientApp assumes to use 4CPUs
options.backend.client-resources.num-gpus = 0.25 # at most 4 ClientApps will run in a given GPU

0 comments on commit 5dffa05

Please sign in to comment.