From ac2e64d6b0fe6fd6894f3e5a0e5a5f40235650e4 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Thu, 11 Jul 2024 15:37:17 +0200 Subject: [PATCH 1/4] Update client_fn args in e2e tests --- e2e/bare-client-auth/client.py | 4 +++- e2e/bare-https/client.py | 3 ++- e2e/bare/client.py | 3 ++- e2e/docker/client.py | 3 ++- e2e/framework-fastai/client.py | 3 ++- e2e/framework-jax/client.py | 4 ++-- e2e/framework-opacus/client.py | 3 ++- e2e/framework-pandas/client.py | 4 ++-- e2e/framework-pytorch-lightning/client.py | 3 ++- e2e/framework-pytorch/client.py | 3 ++- e2e/framework-scikit-learn/client.py | 3 ++- e2e/framework-tensorflow/client.py | 3 ++- e2e/strategies/client.py | 3 ++- e2e/strategies/test.py | 4 ++-- 14 files changed, 29 insertions(+), 17 deletions(-) diff --git a/e2e/bare-client-auth/client.py b/e2e/bare-client-auth/client.py index e82f17088bd9..85ace32aa3cc 100644 --- a/e2e/bare-client-auth/client.py +++ b/e2e/bare-client-auth/client.py @@ -1,3 +1,5 @@ +from typing import Optional + import numpy as np import flwr as fl @@ -23,7 +25,7 @@ def evaluate(self, parameters, config): return loss, 1, {"accuracy": accuracy} -def client_fn(cid): +def client_fn(node_id: int, partition_id: Optional[int]): return FlowerClient().to_client() diff --git a/e2e/bare-https/client.py b/e2e/bare-https/client.py index 8f5c1412fd01..a073e9134a20 100644 --- a/e2e/bare-https/client.py +++ b/e2e/bare-https/client.py @@ -1,4 +1,5 @@ from pathlib import Path +from typing import Optional import numpy as np @@ -25,7 +26,7 @@ def evaluate(self, parameters, config): return loss, 1, {"accuracy": accuracy} -def client_fn(cid): +def client_fn(node_id: int, partition_id: Optional[int]): return FlowerClient().to_client() diff --git a/e2e/bare/client.py b/e2e/bare/client.py index 402d775ac3a9..50246b88175d 100644 --- a/e2e/bare/client.py +++ b/e2e/bare/client.py @@ -1,4 +1,5 @@ from datetime import datetime +from typing import Optional import numpy as np @@ -51,7 +52,7 @@ def evaluate(self, parameters, config): ) -def client_fn(cid): +def client_fn(node_id: int, partition_id: Optional[int]): return FlowerClient().to_client() diff --git a/e2e/docker/client.py b/e2e/docker/client.py index 8451b810416b..35ee6d69a014 100644 --- a/e2e/docker/client.py +++ b/e2e/docker/client.py @@ -1,5 +1,6 @@ import warnings from collections import OrderedDict +from typing import Optional import torch import torch.nn as nn @@ -122,7 +123,7 @@ def evaluate(self, parameters, config): return loss, len(testloader.dataset), {"accuracy": accuracy} -def client_fn(cid: str): +def client_fn(node_id: int, partition_id: Optional[int]): """Create and return an instance of Flower `Client`.""" return FlowerClient().to_client() diff --git a/e2e/framework-fastai/client.py b/e2e/framework-fastai/client.py index 1d98a1134941..d0c00cf17eeb 100644 --- a/e2e/framework-fastai/client.py +++ b/e2e/framework-fastai/client.py @@ -1,5 +1,6 @@ import warnings from collections import OrderedDict +from typing import Optional import numpy as np import torch @@ -49,7 +50,7 @@ def evaluate(self, parameters, config): return loss, len(dls.valid), {"accuracy": 1 - error_rate} -def client_fn(cid): +def client_fn(node_id: int, partition_id: Optional[int]): return FlowerClient().to_client() diff --git a/e2e/framework-jax/client.py b/e2e/framework-jax/client.py index 347a005d923a..877cd6d1be66 100644 --- a/e2e/framework-jax/client.py +++ b/e2e/framework-jax/client.py @@ -1,6 +1,6 @@ """Flower client example using JAX for linear regression.""" -from typing import Dict, List, Tuple +from typing import Dict, List, Optional, Tuple import jax import jax_training @@ -48,7 +48,7 @@ def evaluate( return float(loss), num_examples, {"loss": float(loss)} -def client_fn(cid): +def client_fn(node_id: int, partition_id: Optional[int]): return FlowerClient().to_client() diff --git a/e2e/framework-opacus/client.py b/e2e/framework-opacus/client.py index c9ebe319063a..fd64c43f7594 100644 --- a/e2e/framework-opacus/client.py +++ b/e2e/framework-opacus/client.py @@ -1,5 +1,6 @@ import math from collections import OrderedDict +from typing import Optional import torch import torch.nn as nn @@ -139,7 +140,7 @@ def evaluate(self, parameters, config): return float(loss), len(testloader), {"accuracy": float(accuracy)} -def client_fn(cid): +def client_fn(node_id: int, partition_id: Optional[int]): model = Net() return FlowerClient(model).to_client() diff --git a/e2e/framework-pandas/client.py b/e2e/framework-pandas/client.py index 19e15f5a3b11..5368379d069c 100644 --- a/e2e/framework-pandas/client.py +++ b/e2e/framework-pandas/client.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Tuple +from typing import Dict, List, Optional, Tuple import numpy as np import pandas as pd @@ -32,7 +32,7 @@ def fit( ) -def client_fn(cid): +def client_fn(node_id: int, partition_id: Optional[int]): return FlowerClient().to_client() diff --git a/e2e/framework-pytorch-lightning/client.py b/e2e/framework-pytorch-lightning/client.py index fdd55b3dc344..59d37e19a21e 100644 --- a/e2e/framework-pytorch-lightning/client.py +++ b/e2e/framework-pytorch-lightning/client.py @@ -1,4 +1,5 @@ from collections import OrderedDict +from typing import Optional import mnist import pytorch_lightning as pl @@ -51,7 +52,7 @@ def _set_parameters(model, parameters): model.load_state_dict(state_dict, strict=True) -def client_fn(cid): +def client_fn(node_id: int, partition_id: Optional[int]): model = mnist.LitAutoEncoder() train_loader, val_loader, test_loader = mnist.load_data() diff --git a/e2e/framework-pytorch/client.py b/e2e/framework-pytorch/client.py index dbfbfed1ffa7..c1718b7e5b72 100644 --- a/e2e/framework-pytorch/client.py +++ b/e2e/framework-pytorch/client.py @@ -1,6 +1,7 @@ import warnings from collections import OrderedDict from datetime import datetime +from typing import Optional import torch import torch.nn as nn @@ -136,7 +137,7 @@ def set_parameters(model, parameters): return -def client_fn(cid): +def client_fn(node_id: int, partition_id: Optional[int]): return FlowerClient().to_client() diff --git a/e2e/framework-scikit-learn/client.py b/e2e/framework-scikit-learn/client.py index b0691e75a79d..ec10bab03afe 100644 --- a/e2e/framework-scikit-learn/client.py +++ b/e2e/framework-scikit-learn/client.py @@ -1,4 +1,5 @@ import warnings +from typing import Optional import numpy as np import utils @@ -45,7 +46,7 @@ def evaluate(self, parameters, config): # type: ignore return loss, len(X_test), {"accuracy": accuracy} -def client_fn(cid): +def client_fn(node_id: int, partition_id: Optional[int]): return FlowerClient().to_client() diff --git a/e2e/framework-tensorflow/client.py b/e2e/framework-tensorflow/client.py index 779be0c3746d..790c71ce3b23 100644 --- a/e2e/framework-tensorflow/client.py +++ b/e2e/framework-tensorflow/client.py @@ -1,4 +1,5 @@ import os +from typing import Optional import tensorflow as tf @@ -33,7 +34,7 @@ def evaluate(self, parameters, config): return loss, len(x_test), {"accuracy": accuracy} -def client_fn(cid): +def client_fn(node_id: int, partition_id: Optional[int]): return FlowerClient().to_client() diff --git a/e2e/strategies/client.py b/e2e/strategies/client.py index 505340e013a5..9ab3b0d2f1b9 100644 --- a/e2e/strategies/client.py +++ b/e2e/strategies/client.py @@ -1,4 +1,5 @@ import os +from typing import Optional import tensorflow as tf @@ -48,7 +49,7 @@ def evaluate(self, parameters, config): return loss, len(x_test), {"accuracy": accuracy} -def client_fn(cid): +def client_fn(node_id: int, partition_id: Optional[int]): return FlowerClient().to_client() diff --git a/e2e/strategies/test.py b/e2e/strategies/test.py index abf9cdb5a5c7..524b4578e353 100644 --- a/e2e/strategies/test.py +++ b/e2e/strategies/test.py @@ -1,4 +1,5 @@ from sys import argv +from typing import Optional import tensorflow as tf from client import SUBSET_SIZE, FlowerClient, get_model @@ -42,8 +43,7 @@ def get_strat(name): init_model = get_model() -def client_fn(cid): - _ = cid +def client_fn(node_id: int, partition_id: Optional[int]): return FlowerClient() From 09b7170c2a40a9218d408eac137152b908dba893 Mon Sep 17 00:00:00 2001 From: "Daniel J. Beutel" Date: Sat, 13 Jul 2024 15:51:21 +0200 Subject: [PATCH 2/4] Update E2E tests to new client_fn signature --- e2e/bare-client-auth/client.py | 3 ++- e2e/bare-https/client.py | 3 ++- e2e/bare/client.py | 4 ++-- e2e/docker/client.py | 3 ++- e2e/framework-fastai/client.py | 3 ++- e2e/framework-jax/client.py | 3 ++- e2e/framework-opacus/client.py | 3 ++- e2e/framework-pandas/client.py | 3 ++- e2e/framework-pytorch-lightning/client.py | 3 ++- e2e/framework-pytorch/client.py | 4 ++-- e2e/framework-scikit-learn/client.py | 3 ++- e2e/framework-tensorflow/client.py | 3 ++- e2e/strategies/client.py | 3 ++- e2e/strategies/test.py | 4 ++-- 14 files changed, 28 insertions(+), 17 deletions(-) diff --git a/e2e/bare-client-auth/client.py b/e2e/bare-client-auth/client.py index 85ace32aa3cc..4ac8071cc0b4 100644 --- a/e2e/bare-client-auth/client.py +++ b/e2e/bare-client-auth/client.py @@ -3,6 +3,7 @@ import numpy as np import flwr as fl +from flwr.common import Context model_params = np.array([1]) objective = 5 @@ -25,7 +26,7 @@ def evaluate(self, parameters, config): return loss, 1, {"accuracy": accuracy} -def client_fn(node_id: int, partition_id: Optional[int]): +def client_fn(context: Context): return FlowerClient().to_client() diff --git a/e2e/bare-https/client.py b/e2e/bare-https/client.py index a073e9134a20..0bee51a5b26a 100644 --- a/e2e/bare-https/client.py +++ b/e2e/bare-https/client.py @@ -4,6 +4,7 @@ import numpy as np import flwr as fl +from flwr.common import Context model_params = np.array([1]) objective = 5 @@ -26,7 +27,7 @@ def evaluate(self, parameters, config): return loss, 1, {"accuracy": accuracy} -def client_fn(node_id: int, partition_id: Optional[int]): +def client_fn(context: Context): return FlowerClient().to_client() diff --git a/e2e/bare/client.py b/e2e/bare/client.py index 50246b88175d..940069f8c46c 100644 --- a/e2e/bare/client.py +++ b/e2e/bare/client.py @@ -4,7 +4,7 @@ import numpy as np import flwr as fl -from flwr.common import ConfigsRecord +from flwr.common import ConfigsRecord, Context SUBSET_SIZE = 1000 STATE_VAR = "timestamp" @@ -52,7 +52,7 @@ def evaluate(self, parameters, config): ) -def client_fn(node_id: int, partition_id: Optional[int]): +def client_fn(context: Context): return FlowerClient().to_client() diff --git a/e2e/docker/client.py b/e2e/docker/client.py index 35ee6d69a014..e54b27417b90 100644 --- a/e2e/docker/client.py +++ b/e2e/docker/client.py @@ -10,6 +10,7 @@ from torchvision.transforms import Compose, Normalize, ToTensor from flwr.client import ClientApp, NumPyClient +from flwr.common import Context # ############################################################################# # 1. Regular PyTorch pipeline: nn.Module, train, test, and DataLoader @@ -123,7 +124,7 @@ def evaluate(self, parameters, config): return loss, len(testloader.dataset), {"accuracy": accuracy} -def client_fn(node_id: int, partition_id: Optional[int]): +def client_fn(context: Context): """Create and return an instance of Flower `Client`.""" return FlowerClient().to_client() diff --git a/e2e/framework-fastai/client.py b/e2e/framework-fastai/client.py index d0c00cf17eeb..5ab7e6ce4ac6 100644 --- a/e2e/framework-fastai/client.py +++ b/e2e/framework-fastai/client.py @@ -7,6 +7,7 @@ from fastai.vision.all import * import flwr as fl +from flwr.common import Context warnings.filterwarnings("ignore", category=UserWarning) @@ -50,7 +51,7 @@ def evaluate(self, parameters, config): return loss, len(dls.valid), {"accuracy": 1 - error_rate} -def client_fn(node_id: int, partition_id: Optional[int]): +def client_fn(context: Context): return FlowerClient().to_client() diff --git a/e2e/framework-jax/client.py b/e2e/framework-jax/client.py index 877cd6d1be66..c909f281a08d 100644 --- a/e2e/framework-jax/client.py +++ b/e2e/framework-jax/client.py @@ -7,6 +7,7 @@ import numpy as np import flwr as fl +from flwr.common import Context # Load data and determine model shape train_x, train_y, test_x, test_y = jax_training.load_data() @@ -48,7 +49,7 @@ def evaluate( return float(loss), num_examples, {"loss": float(loss)} -def client_fn(node_id: int, partition_id: Optional[int]): +def client_fn(context: Context): return FlowerClient().to_client() diff --git a/e2e/framework-opacus/client.py b/e2e/framework-opacus/client.py index fd64c43f7594..18d845893400 100644 --- a/e2e/framework-opacus/client.py +++ b/e2e/framework-opacus/client.py @@ -11,6 +11,7 @@ from torchvision.datasets import CIFAR10 import flwr as fl +from flwr.common import Context # Define parameters. PARAMS = { @@ -140,7 +141,7 @@ def evaluate(self, parameters, config): return float(loss), len(testloader), {"accuracy": float(accuracy)} -def client_fn(node_id: int, partition_id: Optional[int]): +def client_fn(context: Context): model = Net() return FlowerClient(model).to_client() diff --git a/e2e/framework-pandas/client.py b/e2e/framework-pandas/client.py index 5368379d069c..404f76513c64 100644 --- a/e2e/framework-pandas/client.py +++ b/e2e/framework-pandas/client.py @@ -4,6 +4,7 @@ import pandas as pd import flwr as fl +from flwr.common import Context df = pd.read_csv("./data/client.csv") @@ -32,7 +33,7 @@ def fit( ) -def client_fn(node_id: int, partition_id: Optional[int]): +def client_fn(context: Context): return FlowerClient().to_client() diff --git a/e2e/framework-pytorch-lightning/client.py b/e2e/framework-pytorch-lightning/client.py index 59d37e19a21e..40f4da869bee 100644 --- a/e2e/framework-pytorch-lightning/client.py +++ b/e2e/framework-pytorch-lightning/client.py @@ -6,6 +6,7 @@ import torch import flwr as fl +from flwr.common import Context class FlowerClient(fl.client.NumPyClient): @@ -52,7 +53,7 @@ def _set_parameters(model, parameters): model.load_state_dict(state_dict, strict=True) -def client_fn(node_id: int, partition_id: Optional[int]): +def client_fn(context: Context): model = mnist.LitAutoEncoder() train_loader, val_loader, test_loader = mnist.load_data() diff --git a/e2e/framework-pytorch/client.py b/e2e/framework-pytorch/client.py index c1718b7e5b72..577511a30103 100644 --- a/e2e/framework-pytorch/client.py +++ b/e2e/framework-pytorch/client.py @@ -12,7 +12,7 @@ from tqdm import tqdm import flwr as fl -from flwr.common import ConfigsRecord +from flwr.common import ConfigsRecord, Context # ############################################################################# # 1. Regular PyTorch pipeline: nn.Module, train, test, and DataLoader @@ -137,7 +137,7 @@ def set_parameters(model, parameters): return -def client_fn(node_id: int, partition_id: Optional[int]): +def client_fn(context: Context): return FlowerClient().to_client() diff --git a/e2e/framework-scikit-learn/client.py b/e2e/framework-scikit-learn/client.py index ec10bab03afe..8626cd36741a 100644 --- a/e2e/framework-scikit-learn/client.py +++ b/e2e/framework-scikit-learn/client.py @@ -7,6 +7,7 @@ from sklearn.metrics import log_loss import flwr as fl +from flwr.common import Context # Load MNIST dataset from https://www.openml.org/d/554 (X_train, y_train), (X_test, y_test) = utils.load_mnist() @@ -46,7 +47,7 @@ def evaluate(self, parameters, config): # type: ignore return loss, len(X_test), {"accuracy": accuracy} -def client_fn(node_id: int, partition_id: Optional[int]): +def client_fn(context: Context): return FlowerClient().to_client() diff --git a/e2e/framework-tensorflow/client.py b/e2e/framework-tensorflow/client.py index 790c71ce3b23..c7cbeb51aa19 100644 --- a/e2e/framework-tensorflow/client.py +++ b/e2e/framework-tensorflow/client.py @@ -4,6 +4,7 @@ import tensorflow as tf import flwr as fl +from flwr.common import Context SUBSET_SIZE = 1000 @@ -34,7 +35,7 @@ def evaluate(self, parameters, config): return loss, len(x_test), {"accuracy": accuracy} -def client_fn(node_id: int, partition_id: Optional[int]): +def client_fn(context: Context): return FlowerClient().to_client() diff --git a/e2e/strategies/client.py b/e2e/strategies/client.py index 9ab3b0d2f1b9..668ffdc238e4 100644 --- a/e2e/strategies/client.py +++ b/e2e/strategies/client.py @@ -4,6 +4,7 @@ import tensorflow as tf import flwr as fl +from flwr.common import Context SUBSET_SIZE = 1000 @@ -49,7 +50,7 @@ def evaluate(self, parameters, config): return loss, len(x_test), {"accuracy": accuracy} -def client_fn(node_id: int, partition_id: Optional[int]): +def client_fn(context: Context): return FlowerClient().to_client() diff --git a/e2e/strategies/test.py b/e2e/strategies/test.py index 524b4578e353..2d5d04b08125 100644 --- a/e2e/strategies/test.py +++ b/e2e/strategies/test.py @@ -5,7 +5,7 @@ from client import SUBSET_SIZE, FlowerClient, get_model import flwr as fl -from flwr.common import ndarrays_to_parameters +from flwr.common import Context, ndarrays_to_parameters from flwr.server.strategy import ( FaultTolerantFedAvg, FedAdagrad, @@ -43,7 +43,7 @@ def get_strat(name): init_model = get_model() -def client_fn(node_id: int, partition_id: Optional[int]): +def client_fn(context: Context): return FlowerClient() From 129ad8311580fa5c8730ad58c9c6433620862e3b Mon Sep 17 00:00:00 2001 From: "Daniel J. Beutel" Date: Sat, 13 Jul 2024 16:10:44 +0200 Subject: [PATCH 3/4] Update E2E tests to follow recent conventions --- e2e/bare-client-auth/client.py | 8 +++----- e2e/bare-https/client.py | 9 ++++----- e2e/bare/client.py | 11 ++++------- e2e/docker/client.py | 1 - e2e/framework-fastai/client.py | 9 ++++----- e2e/framework-jax/client.py | 12 +++++------- e2e/framework-opacus/client.py | 9 ++++----- e2e/framework-pandas/client.py | 10 +++++----- e2e/framework-pytorch-lightning/client.py | 9 ++++----- e2e/framework-pytorch/client.py | 9 ++++----- e2e/framework-scikit-learn/client.py | 10 ++++------ e2e/framework-tensorflow/client.py | 11 ++++------- e2e/strategies/client.py | 11 ++++------- e2e/strategies/test.py | 8 ++++---- 14 files changed, 53 insertions(+), 74 deletions(-) diff --git a/e2e/bare-client-auth/client.py b/e2e/bare-client-auth/client.py index 4ac8071cc0b4..c7b0d59b8ea5 100644 --- a/e2e/bare-client-auth/client.py +++ b/e2e/bare-client-auth/client.py @@ -1,8 +1,6 @@ -from typing import Optional - import numpy as np -import flwr as fl +from flwr.client import ClientApp, NumPyClient from flwr.common import Context model_params = np.array([1]) @@ -10,7 +8,7 @@ # Define Flower client -class FlowerClient(fl.client.NumPyClient): +class FlowerClient(NumPyClient): def get_parameters(self, config): return model_params @@ -30,6 +28,6 @@ def client_fn(context: Context): return FlowerClient().to_client() -app = fl.client.ClientApp( +app = ClientApp( client_fn=client_fn, ) diff --git a/e2e/bare-https/client.py b/e2e/bare-https/client.py index 0bee51a5b26a..4a682af3aec3 100644 --- a/e2e/bare-https/client.py +++ b/e2e/bare-https/client.py @@ -1,9 +1,8 @@ from pathlib import Path -from typing import Optional import numpy as np -import flwr as fl +from flwr.client import ClientApp, NumPyClient, start_client from flwr.common import Context model_params = np.array([1]) @@ -11,7 +10,7 @@ # Define Flower client -class FlowerClient(fl.client.NumPyClient): +class FlowerClient(NumPyClient): def get_parameters(self, config): return model_params @@ -31,13 +30,13 @@ def client_fn(context: Context): return FlowerClient().to_client() -app = fl.client.ClientApp( +app = ClientApp( client_fn=client_fn, ) if __name__ == "__main__": # Start Flower client - fl.client.start_client( + start_client( server_address="127.0.0.1:8080", client=FlowerClient().to_client(), root_certificates=Path("certificates/ca.crt").read_bytes(), diff --git a/e2e/bare/client.py b/e2e/bare/client.py index 940069f8c46c..943e60d5db9f 100644 --- a/e2e/bare/client.py +++ b/e2e/bare/client.py @@ -1,9 +1,8 @@ from datetime import datetime -from typing import Optional import numpy as np -import flwr as fl +from flwr.client import ClientApp, NumPyClient, start_client from flwr.common import ConfigsRecord, Context SUBSET_SIZE = 1000 @@ -15,7 +14,7 @@ # Define Flower client -class FlowerClient(fl.client.NumPyClient): +class FlowerClient(NumPyClient): def get_parameters(self, config): return model_params @@ -56,12 +55,10 @@ def client_fn(context: Context): return FlowerClient().to_client() -app = fl.client.ClientApp( +app = ClientApp( client_fn=client_fn, ) if __name__ == "__main__": # Start Flower client - fl.client.start_client( - server_address="127.0.0.1:8080", client=FlowerClient().to_client() - ) + start_client(server_address="127.0.0.1:8080", client=FlowerClient().to_client()) diff --git a/e2e/docker/client.py b/e2e/docker/client.py index e54b27417b90..44313c7c3af6 100644 --- a/e2e/docker/client.py +++ b/e2e/docker/client.py @@ -1,6 +1,5 @@ import warnings from collections import OrderedDict -from typing import Optional import torch import torch.nn as nn diff --git a/e2e/framework-fastai/client.py b/e2e/framework-fastai/client.py index 5ab7e6ce4ac6..161b27b5a548 100644 --- a/e2e/framework-fastai/client.py +++ b/e2e/framework-fastai/client.py @@ -1,12 +1,11 @@ import warnings from collections import OrderedDict -from typing import Optional import numpy as np import torch from fastai.vision.all import * -import flwr as fl +from flwr.client import ClientApp, NumPyClient, start_client from flwr.common import Context warnings.filterwarnings("ignore", category=UserWarning) @@ -31,7 +30,7 @@ # Define Flower client -class FlowerClient(fl.client.NumPyClient): +class FlowerClient(NumPyClient): def get_parameters(self, config): return [val.cpu().numpy() for _, val in learn.model.state_dict().items()] @@ -55,14 +54,14 @@ def client_fn(context: Context): return FlowerClient().to_client() -app = fl.client.ClientApp( +app = ClientApp( client_fn=client_fn, ) if __name__ == "__main__": # Start Flower client - fl.client.start_client( + start_client( server_address="127.0.0.1:8080", client=FlowerClient().to_client(), ) diff --git a/e2e/framework-jax/client.py b/e2e/framework-jax/client.py index c909f281a08d..c9ff67b3e38e 100644 --- a/e2e/framework-jax/client.py +++ b/e2e/framework-jax/client.py @@ -1,12 +1,12 @@ """Flower client example using JAX for linear regression.""" -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Tuple import jax import jax_training import numpy as np -import flwr as fl +from flwr.client import ClientApp, NumPyClient, start_client from flwr.common import Context # Load data and determine model shape @@ -15,7 +15,7 @@ model_shape = train_x.shape[1:] -class FlowerClient(fl.client.NumPyClient): +class FlowerClient(NumPyClient): def __init__(self): self.params = jax_training.load_model(model_shape) @@ -53,12 +53,10 @@ def client_fn(context: Context): return FlowerClient().to_client() -app = fl.client.ClientApp( +app = ClientApp( client_fn=client_fn, ) if __name__ == "__main__": # Start Flower client - fl.client.start_client( - server_address="127.0.0.1:8080", client=FlowerClient().to_client() - ) + start_client(server_address="127.0.0.1:8080", client=FlowerClient().to_client()) diff --git a/e2e/framework-opacus/client.py b/e2e/framework-opacus/client.py index 18d845893400..167fa4584e37 100644 --- a/e2e/framework-opacus/client.py +++ b/e2e/framework-opacus/client.py @@ -1,6 +1,5 @@ import math from collections import OrderedDict -from typing import Optional import torch import torch.nn as nn @@ -10,7 +9,7 @@ from torch.utils.data import DataLoader from torchvision.datasets import CIFAR10 -import flwr as fl +from flwr.client import ClientApp, NumPyClient, start_client from flwr.common import Context # Define parameters. @@ -97,7 +96,7 @@ def load_data(): # Define Flower client. -class FlowerClient(fl.client.NumPyClient): +class FlowerClient(NumPyClient): def __init__(self, model) -> None: super().__init__() # Create a privacy engine which will add DP and keep track of the privacy budget. @@ -146,11 +145,11 @@ def client_fn(context: Context): return FlowerClient(model).to_client() -app = fl.client.ClientApp( +app = ClientApp( client_fn=client_fn, ) if __name__ == "__main__": - fl.client.start_client( + start_client( server_address="127.0.0.1:8080", client=FlowerClient(model).to_client() ) diff --git a/e2e/framework-pandas/client.py b/e2e/framework-pandas/client.py index 404f76513c64..0c3300e1dd3f 100644 --- a/e2e/framework-pandas/client.py +++ b/e2e/framework-pandas/client.py @@ -1,9 +1,9 @@ -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Tuple import numpy as np import pandas as pd -import flwr as fl +from flwr.client import ClientApp, NumPyClient, start_client from flwr.common import Context df = pd.read_csv("./data/client.csv") @@ -17,7 +17,7 @@ def compute_hist(df: pd.DataFrame, col_name: str) -> np.ndarray: # Define Flower client -class FlowerClient(fl.client.NumPyClient): +class FlowerClient(NumPyClient): def fit( self, parameters: List[np.ndarray], config: Dict[str, str] ) -> Tuple[List[np.ndarray], int, Dict]: @@ -37,13 +37,13 @@ def client_fn(context: Context): return FlowerClient().to_client() -app = fl.client.ClientApp( +app = ClientApp( client_fn=client_fn, ) if __name__ == "__main__": # Start Flower client - fl.client.start_client( + start_client( server_address="127.0.0.1:8080", client=FlowerClient().to_client(), ) diff --git a/e2e/framework-pytorch-lightning/client.py b/e2e/framework-pytorch-lightning/client.py index 40f4da869bee..bf291a1ca2c5 100644 --- a/e2e/framework-pytorch-lightning/client.py +++ b/e2e/framework-pytorch-lightning/client.py @@ -1,15 +1,14 @@ from collections import OrderedDict -from typing import Optional import mnist import pytorch_lightning as pl import torch -import flwr as fl +from flwr.client import ClientApp, NumPyClient, start_client from flwr.common import Context -class FlowerClient(fl.client.NumPyClient): +class FlowerClient(NumPyClient): def __init__(self, model, train_loader, val_loader, test_loader): self.model = model self.train_loader = train_loader @@ -61,7 +60,7 @@ def client_fn(context: Context): return FlowerClient(model, train_loader, val_loader, test_loader).to_client() -app = fl.client.ClientApp( +app = ClientApp( client_fn=client_fn, ) @@ -73,7 +72,7 @@ def main() -> None: # Flower client client = FlowerClient(model, train_loader, val_loader, test_loader).to_client() - fl.client.start_client(server_address="127.0.0.1:8080", client=client) + start_client(server_address="127.0.0.1:8080", client=client) if __name__ == "__main__": diff --git a/e2e/framework-pytorch/client.py b/e2e/framework-pytorch/client.py index 577511a30103..ab4bc7b5c5b9 100644 --- a/e2e/framework-pytorch/client.py +++ b/e2e/framework-pytorch/client.py @@ -1,7 +1,6 @@ import warnings from collections import OrderedDict from datetime import datetime -from typing import Optional import torch import torch.nn as nn @@ -11,7 +10,7 @@ from torchvision.transforms import Compose, Normalize, ToTensor from tqdm import tqdm -import flwr as fl +from flwr.client import ClientApp, NumPyClient, start_client from flwr.common import ConfigsRecord, Context # ############################################################################# @@ -90,7 +89,7 @@ def load_data(): # Define Flower client -class FlowerClient(fl.client.NumPyClient): +class FlowerClient(NumPyClient): def get_parameters(self, config): return [val.cpu().numpy() for _, val in net.state_dict().items()] @@ -141,14 +140,14 @@ def client_fn(context: Context): return FlowerClient().to_client() -app = fl.client.ClientApp( +app = ClientApp( client_fn=client_fn, ) if __name__ == "__main__": # Start Flower client - fl.client.start_client( + start_client( server_address="127.0.0.1:8080", client=FlowerClient().to_client(), ) diff --git a/e2e/framework-scikit-learn/client.py b/e2e/framework-scikit-learn/client.py index 8626cd36741a..f8f99091f5e5 100644 --- a/e2e/framework-scikit-learn/client.py +++ b/e2e/framework-scikit-learn/client.py @@ -6,7 +6,7 @@ from sklearn.linear_model import LogisticRegression from sklearn.metrics import log_loss -import flwr as fl +from flwr.client import ClientApp, NumPyClient, start_client from flwr.common import Context # Load MNIST dataset from https://www.openml.org/d/554 @@ -28,7 +28,7 @@ # Define Flower client -class FlowerClient(fl.client.NumPyClient): +class FlowerClient(NumPyClient): def get_parameters(self, config): # type: ignore return utils.get_model_parameters(model) @@ -51,12 +51,10 @@ def client_fn(context: Context): return FlowerClient().to_client() -app = fl.client.ClientApp( +app = ClientApp( client_fn=client_fn, ) if __name__ == "__main__": # Start Flower client - fl.client.start_client( - server_address="0.0.0.0:8080", client=FlowerClient().to_client() - ) + start_client(server_address="0.0.0.0:8080", client=FlowerClient().to_client()) diff --git a/e2e/framework-tensorflow/client.py b/e2e/framework-tensorflow/client.py index c7cbeb51aa19..351f495a3acb 100644 --- a/e2e/framework-tensorflow/client.py +++ b/e2e/framework-tensorflow/client.py @@ -1,9 +1,8 @@ import os -from typing import Optional import tensorflow as tf -import flwr as fl +from flwr.client import ClientApp, NumPyClient, start_client from flwr.common import Context SUBSET_SIZE = 1000 @@ -20,7 +19,7 @@ # Define Flower client -class FlowerClient(fl.client.NumPyClient): +class FlowerClient(NumPyClient): def get_parameters(self, config): return model.get_weights() @@ -39,12 +38,10 @@ def client_fn(context: Context): return FlowerClient().to_client() -app = fl.client.ClientApp( +app = ClientApp( client_fn=client_fn, ) if __name__ == "__main__": # Start Flower client - fl.client.start_client( - server_address="127.0.0.1:8080", client=FlowerClient().to_client() - ) + start_client(server_address="127.0.0.1:8080", client=FlowerClient().to_client()) diff --git a/e2e/strategies/client.py b/e2e/strategies/client.py index 668ffdc238e4..0403416cc3b7 100644 --- a/e2e/strategies/client.py +++ b/e2e/strategies/client.py @@ -1,9 +1,8 @@ import os -from typing import Optional import tensorflow as tf -import flwr as fl +from flwr.client import ClientApp, NumPyClient, start_client from flwr.common import Context SUBSET_SIZE = 1000 @@ -35,7 +34,7 @@ def get_model(): # Define Flower client -class FlowerClient(fl.client.NumPyClient): +class FlowerClient(NumPyClient): def get_parameters(self, config): return model.get_weights() @@ -54,13 +53,11 @@ def client_fn(context: Context): return FlowerClient().to_client() -app = fl.client.ClientApp( +app = ClientApp( client_fn=client_fn, ) if __name__ == "__main__": # Start Flower client - fl.client.start_client( - server_address="127.0.0.1:8080", client=FlowerClient().to_client() - ) + start_client(server_address="127.0.0.1:8080", client=FlowerClient().to_client()) diff --git a/e2e/strategies/test.py b/e2e/strategies/test.py index 2d5d04b08125..c567f33b236b 100644 --- a/e2e/strategies/test.py +++ b/e2e/strategies/test.py @@ -1,11 +1,10 @@ from sys import argv -from typing import Optional import tensorflow as tf from client import SUBSET_SIZE, FlowerClient, get_model -import flwr as fl from flwr.common import Context, ndarrays_to_parameters +from flwr.server import ServerConfig from flwr.server.strategy import ( FaultTolerantFedAvg, FedAdagrad, @@ -16,6 +15,7 @@ FedYogi, QFedAvg, ) +from flwr.simulation import start_simulation STRATEGY_LIST = [ FedMedian, @@ -71,10 +71,10 @@ def evaluate(server_round, parameters, config): if start_idx >= OPT_IDX: strat_args["tau"] = 0.01 -hist = fl.simulation.start_simulation( +hist = start_simulation( client_fn=client_fn, num_clients=2, - config=fl.server.ServerConfig(num_rounds=3), + config=ServerConfig(num_rounds=3), strategy=strategy(**strat_args), ) From af6c72bc6598cca7f5d29dc86c3d53b30f330f47 Mon Sep 17 00:00:00 2001 From: "Daniel J. Beutel" Date: Sat, 13 Jul 2024 16:20:09 +0200 Subject: [PATCH 4/4] Remove unused import --- e2e/framework-scikit-learn/client.py | 1 - 1 file changed, 1 deletion(-) diff --git a/e2e/framework-scikit-learn/client.py b/e2e/framework-scikit-learn/client.py index f8f99091f5e5..24c6617c1289 100644 --- a/e2e/framework-scikit-learn/client.py +++ b/e2e/framework-scikit-learn/client.py @@ -1,5 +1,4 @@ import warnings -from typing import Optional import numpy as np import utils