Skip to content

Commit

Permalink
refactor(framework) Update flwr new templates with new client_fn
Browse files Browse the repository at this point in the history
…signature (#3795)

Co-authored-by: Daniel J. Beutel <[email protected]>
  • Loading branch information
jafermarq and danieljanes authored Jul 13, 2024
1 parent bdc7602 commit 7d5d6f3
Show file tree
Hide file tree
Showing 10 changed files with 34 additions and 17 deletions.
8 changes: 6 additions & 2 deletions src/py/flwr/cli/new/templates/app/code/client.hf.py.tpl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""$project_name: A Flower / HuggingFace Transformers app."""

from flwr.client import ClientApp, NumPyClient
from flwr.common import Context
from transformers import AutoModelForSequenceClassification

from $import_name.task import (
Expand Down Expand Up @@ -38,12 +39,15 @@ class FlowerClient(NumPyClient):
return float(loss), len(self.testloader), {"accuracy": accuracy}


def client_fn(cid):
def client_fn(context: Context):
# Load model and data
net = AutoModelForSequenceClassification.from_pretrained(
CHECKPOINT, num_labels=2
).to(DEVICE)
trainloader, valloader = load_data(int(cid), 2)

partition_id = int(context.node_config['partition-id'])
num_partitions = int(context.node_config['num-partitions])
trainloader, valloader = load_data(partition_id, num_partitions)

# Return Client instance
return FlowerClient(net, trainloader, valloader).to_client()
Expand Down
3 changes: 2 additions & 1 deletion src/py/flwr/cli/new/templates/app/code/client.jax.py.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import jax
from flwr.client import NumPyClient, ClientApp
from flwr.common import Context

from $import_name.task import (
evaluation,
Expand Down Expand Up @@ -44,7 +45,7 @@ class FlowerClient(NumPyClient):
)
return float(loss), num_examples, {"loss": float(loss)}

def client_fn(cid):
def client_fn(context: Context):
# Return Client instance
return FlowerClient().to_client()

Expand Down
7 changes: 5 additions & 2 deletions src/py/flwr/cli/new/templates/app/code/client.mlx.py.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from flwr.client import NumPyClient, ClientApp
from flwr.common import Context

from $import_name.task import (
batch_iterate,
Expand Down Expand Up @@ -57,8 +58,10 @@ class FlowerClient(NumPyClient):
return loss.item(), len(self.test_images), {"accuracy": accuracy.item()}


def client_fn(cid):
data = load_data(int(cid), 2)
def client_fn(context: Context):
partition_id = int(context.node_config["partition-id"])
num_partitions = int(context.node_config["num-partitions"])
data = load_data(partition_id, num_partitions)

# Return Client instance
return FlowerClient(data).to_client()
Expand Down
3 changes: 2 additions & 1 deletion src/py/flwr/cli/new/templates/app/code/client.numpy.py.tpl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""$project_name: A Flower / NumPy app."""

from flwr.client import NumPyClient, ClientApp
from flwr.common import Context
import numpy as np


Expand All @@ -15,7 +16,7 @@ class FlowerClient(NumPyClient):
return float(0.0), 1, {"accuracy": float(1.0)}


def client_fn(cid: str):
def client_fn(context: Context):
return FlowerClient().to_client()


Expand Down
7 changes: 5 additions & 2 deletions src/py/flwr/cli/new/templates/app/code/client.pytorch.py.tpl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""$project_name: A Flower / PyTorch app."""

from flwr.client import NumPyClient, ClientApp
from flwr.common import Context

from $import_name.task import (
Net,
Expand Down Expand Up @@ -31,10 +32,12 @@ class FlowerClient(NumPyClient):
return loss, len(self.valloader.dataset), {"accuracy": accuracy}


def client_fn(cid):
def client_fn(context: Context):
# Load model and data
net = Net().to(DEVICE)
trainloader, valloader = load_data(int(cid), 2)
partition_id = int(context.node_config["partition-id"])
num_partitions = int(context.node_config["num-partitions"])
trainloader, valloader = load_data(partition_id, num_partitions)

# Return Client instance
return FlowerClient(net, trainloader, valloader).to_client()
Expand Down
6 changes: 4 additions & 2 deletions src/py/flwr/cli/new/templates/app/code/client.sklearn.py.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import warnings

import numpy as np
from flwr.client import NumPyClient, ClientApp
from flwr.common import Context
from flwr_datasets import FederatedDataset
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import log_loss
Expand Down Expand Up @@ -68,8 +69,9 @@ class FlowerClient(NumPyClient):

fds = FederatedDataset(dataset="mnist", partitioners={"train": 2})

def client_fn(cid: str):
dataset = fds.load_partition(int(cid), "train").with_format("numpy")
def client_fn(context: Context):
partition_id = int(context.node_config["partition-id"])
dataset = fds.load_partition(partition_id, "train").with_format("numpy")

X, y = dataset["image"].reshape((len(dataset), -1)), dataset["label"]

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""$project_name: A Flower / TensorFlow app."""

from flwr.client import NumPyClient, ClientApp
from flwr.common import Context

from $import_name.task import load_data, load_model

Expand Down Expand Up @@ -28,10 +29,12 @@ class FlowerClient(NumPyClient):
return loss, len(self.x_test), {"accuracy": accuracy}


def client_fn(cid):
def client_fn(context: Context):
# Load model and data
net = load_model()
x_train, y_train, x_test, y_test = load_data(int(cid), 2)

partition_id = int(context.node_config["partition-id"])
x_train, y_train, x_test, y_test = load_data(partition_id, 2)

# Return Client instance
return FlowerClient(net, x_train, y_train, x_test, y_test).to_client()
Expand Down
4 changes: 2 additions & 2 deletions src/py/flwr/cli/new/templates/app/code/task.hf.py.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ DEVICE = torch.device("cpu")
CHECKPOINT = "distilbert-base-uncased" # transformer model checkpoint


def load_data(partition_id, num_clients):
def load_data(partition_id: int, num_partitions: int):
"""Load IMDB data (training and eval)"""
fds = FederatedDataset(dataset="imdb", partitioners={"train": num_clients})
fds = FederatedDataset(dataset="imdb", partitioners={"train": num_partitions})
partition = fds.load_partition(partition_id)
# Divide data: 80% train, 20% test
partition_train_test = partition.train_test_split(test_size=0.2, seed=42)
Expand Down
4 changes: 2 additions & 2 deletions src/py/flwr/cli/new/templates/app/code/task.mlx.py.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ def batch_iterate(batch_size, X, y):
yield X[ids], y[ids]


def load_data(partition_id, num_clients):
fds = FederatedDataset(dataset="mnist", partitioners={"train": num_clients})
def load_data(partition_id: int, num_partitions: int):
fds = FederatedDataset(dataset="mnist", partitioners={"train": num_partitions})
partition = fds.load_partition(partition_id)
partition_splits = partition.train_test_split(test_size=0.2, seed=42)

Expand Down
2 changes: 1 addition & 1 deletion src/py/flwr/cli/new/templates/app/code/task.pytorch.py.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class Net(nn.Module):
return self.fc3(x)


def load_data(partition_id, num_partitions):
def load_data(partition_id: int, num_partitions: int):
"""Load partition CIFAR10 data."""
fds = FederatedDataset(dataset="cifar10", partitioners={"train": num_partitions})
partition = fds.load_partition(partition_id)
Expand Down

0 comments on commit 7d5d6f3

Please sign in to comment.