diff --git a/src/py/flwr/cli/new/templates/app/code/client.hf.py.tpl b/src/py/flwr/cli/new/templates/app/code/client.hf.py.tpl index 314da2120c53..56bac8543c50 100644 --- a/src/py/flwr/cli/new/templates/app/code/client.hf.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/client.hf.py.tpl @@ -1,6 +1,7 @@ """$project_name: A Flower / HuggingFace Transformers app.""" from flwr.client import ClientApp, NumPyClient +from flwr.common import Context from transformers import AutoModelForSequenceClassification from $import_name.task import ( @@ -38,12 +39,15 @@ class FlowerClient(NumPyClient): return float(loss), len(self.testloader), {"accuracy": accuracy} -def client_fn(cid): +def client_fn(context: Context): # Load model and data net = AutoModelForSequenceClassification.from_pretrained( CHECKPOINT, num_labels=2 ).to(DEVICE) - trainloader, valloader = load_data(int(cid), 2) + + partition_id = int(context.node_config['partition-id']) + num_partitions = int(context.node_config['num-partitions]) + trainloader, valloader = load_data(partition_id, num_partitions) # Return Client instance return FlowerClient(net, trainloader, valloader).to_client() diff --git a/src/py/flwr/cli/new/templates/app/code/client.jax.py.tpl b/src/py/flwr/cli/new/templates/app/code/client.jax.py.tpl index 3c6d2f03637a..48b667665f3f 100644 --- a/src/py/flwr/cli/new/templates/app/code/client.jax.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/client.jax.py.tpl @@ -2,6 +2,7 @@ import jax from flwr.client import NumPyClient, ClientApp +from flwr.common import Context from $import_name.task import ( evaluation, @@ -44,7 +45,7 @@ class FlowerClient(NumPyClient): ) return float(loss), num_examples, {"loss": float(loss)} -def client_fn(cid): +def client_fn(context: Context): # Return Client instance return FlowerClient().to_client() diff --git a/src/py/flwr/cli/new/templates/app/code/client.mlx.py.tpl b/src/py/flwr/cli/new/templates/app/code/client.mlx.py.tpl index 1722561370a8..37207c940d83 100644 --- a/src/py/flwr/cli/new/templates/app/code/client.mlx.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/client.mlx.py.tpl @@ -4,6 +4,7 @@ import mlx.core as mx import mlx.nn as nn import mlx.optimizers as optim from flwr.client import NumPyClient, ClientApp +from flwr.common import Context from $import_name.task import ( batch_iterate, @@ -57,8 +58,10 @@ class FlowerClient(NumPyClient): return loss.item(), len(self.test_images), {"accuracy": accuracy.item()} -def client_fn(cid): - data = load_data(int(cid), 2) +def client_fn(context: Context): + partition_id = int(context.node_config["partition-id"]) + num_partitions = int(context.node_config["num-partitions"]) + data = load_data(partition_id, num_partitions) # Return Client instance return FlowerClient(data).to_client() diff --git a/src/py/flwr/cli/new/templates/app/code/client.numpy.py.tpl b/src/py/flwr/cli/new/templates/app/code/client.numpy.py.tpl index 232c305fc2a9..1dd83e108bb5 100644 --- a/src/py/flwr/cli/new/templates/app/code/client.numpy.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/client.numpy.py.tpl @@ -1,6 +1,7 @@ """$project_name: A Flower / NumPy app.""" from flwr.client import NumPyClient, ClientApp +from flwr.common import Context import numpy as np @@ -15,7 +16,7 @@ class FlowerClient(NumPyClient): return float(0.0), 1, {"accuracy": float(1.0)} -def client_fn(cid: str): +def client_fn(context: Context): return FlowerClient().to_client() diff --git a/src/py/flwr/cli/new/templates/app/code/client.pytorch.py.tpl b/src/py/flwr/cli/new/templates/app/code/client.pytorch.py.tpl index c68974efaadf..addc71023a09 100644 --- a/src/py/flwr/cli/new/templates/app/code/client.pytorch.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/client.pytorch.py.tpl @@ -1,6 +1,7 @@ """$project_name: A Flower / PyTorch app.""" from flwr.client import NumPyClient, ClientApp +from flwr.common import Context from $import_name.task import ( Net, @@ -31,10 +32,12 @@ class FlowerClient(NumPyClient): return loss, len(self.valloader.dataset), {"accuracy": accuracy} -def client_fn(cid): +def client_fn(context: Context): # Load model and data net = Net().to(DEVICE) - trainloader, valloader = load_data(int(cid), 2) + partition_id = int(context.node_config["partition-id"]) + num_partitions = int(context.node_config["num-partitions"]) + trainloader, valloader = load_data(partition_id, num_partitions) # Return Client instance return FlowerClient(net, trainloader, valloader).to_client() diff --git a/src/py/flwr/cli/new/templates/app/code/client.sklearn.py.tpl b/src/py/flwr/cli/new/templates/app/code/client.sklearn.py.tpl index 9181389cad1c..a1eefa034e7b 100644 --- a/src/py/flwr/cli/new/templates/app/code/client.sklearn.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/client.sklearn.py.tpl @@ -4,6 +4,7 @@ import warnings import numpy as np from flwr.client import NumPyClient, ClientApp +from flwr.common import Context from flwr_datasets import FederatedDataset from sklearn.linear_model import LogisticRegression from sklearn.metrics import log_loss @@ -68,8 +69,9 @@ class FlowerClient(NumPyClient): fds = FederatedDataset(dataset="mnist", partitioners={"train": 2}) -def client_fn(cid: str): - dataset = fds.load_partition(int(cid), "train").with_format("numpy") +def client_fn(context: Context): + partition_id = int(context.node_config["partition-id"]) + dataset = fds.load_partition(partition_id, "train").with_format("numpy") X, y = dataset["image"].reshape((len(dataset), -1)), dataset["label"] diff --git a/src/py/flwr/cli/new/templates/app/code/client.tensorflow.py.tpl b/src/py/flwr/cli/new/templates/app/code/client.tensorflow.py.tpl index dc55d4ca6569..0fe1c405a110 100644 --- a/src/py/flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/client.tensorflow.py.tpl @@ -1,6 +1,7 @@ """$project_name: A Flower / TensorFlow app.""" from flwr.client import NumPyClient, ClientApp +from flwr.common import Context from $import_name.task import load_data, load_model @@ -28,10 +29,12 @@ class FlowerClient(NumPyClient): return loss, len(self.x_test), {"accuracy": accuracy} -def client_fn(cid): +def client_fn(context: Context): # Load model and data net = load_model() - x_train, y_train, x_test, y_test = load_data(int(cid), 2) + + partition_id = int(context.node_config["partition-id"]) + x_train, y_train, x_test, y_test = load_data(partition_id, 2) # Return Client instance return FlowerClient(net, x_train, y_train, x_test, y_test).to_client() diff --git a/src/py/flwr/cli/new/templates/app/code/task.hf.py.tpl b/src/py/flwr/cli/new/templates/app/code/task.hf.py.tpl index 8e89add66835..eb43acfce976 100644 --- a/src/py/flwr/cli/new/templates/app/code/task.hf.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/task.hf.py.tpl @@ -16,9 +16,9 @@ DEVICE = torch.device("cpu") CHECKPOINT = "distilbert-base-uncased" # transformer model checkpoint -def load_data(partition_id, num_clients): +def load_data(partition_id: int, num_partitions: int): """Load IMDB data (training and eval)""" - fds = FederatedDataset(dataset="imdb", partitioners={"train": num_clients}) + fds = FederatedDataset(dataset="imdb", partitioners={"train": num_partitions}) partition = fds.load_partition(partition_id) # Divide data: 80% train, 20% test partition_train_test = partition.train_test_split(test_size=0.2, seed=42) diff --git a/src/py/flwr/cli/new/templates/app/code/task.mlx.py.tpl b/src/py/flwr/cli/new/templates/app/code/task.mlx.py.tpl index bcd4dde93310..88053b0cd590 100644 --- a/src/py/flwr/cli/new/templates/app/code/task.mlx.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/task.mlx.py.tpl @@ -43,8 +43,8 @@ def batch_iterate(batch_size, X, y): yield X[ids], y[ids] -def load_data(partition_id, num_clients): - fds = FederatedDataset(dataset="mnist", partitioners={"train": num_clients}) +def load_data(partition_id: int, num_partitions: int): + fds = FederatedDataset(dataset="mnist", partitioners={"train": num_partitions}) partition = fds.load_partition(partition_id) partition_splits = partition.train_test_split(test_size=0.2, seed=42) diff --git a/src/py/flwr/cli/new/templates/app/code/task.pytorch.py.tpl b/src/py/flwr/cli/new/templates/app/code/task.pytorch.py.tpl index b30c65a285b5..d5971ffb6ce5 100644 --- a/src/py/flwr/cli/new/templates/app/code/task.pytorch.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/task.pytorch.py.tpl @@ -34,7 +34,7 @@ class Net(nn.Module): return self.fc3(x) -def load_data(partition_id, num_partitions): +def load_data(partition_id: int, num_partitions: int): """Load partition CIFAR10 data.""" fds = FederatedDataset(dataset="cifar10", partitioners={"train": num_partitions}) partition = fds.load_partition(partition_id)