From 129ad8311580fa5c8730ad58c9c6433620862e3b Mon Sep 17 00:00:00 2001 From: "Daniel J. Beutel" Date: Sat, 13 Jul 2024 16:10:44 +0200 Subject: [PATCH] Update E2E tests to follow recent conventions --- e2e/bare-client-auth/client.py | 8 +++----- e2e/bare-https/client.py | 9 ++++----- e2e/bare/client.py | 11 ++++------- e2e/docker/client.py | 1 - e2e/framework-fastai/client.py | 9 ++++----- e2e/framework-jax/client.py | 12 +++++------- e2e/framework-opacus/client.py | 9 ++++----- e2e/framework-pandas/client.py | 10 +++++----- e2e/framework-pytorch-lightning/client.py | 9 ++++----- e2e/framework-pytorch/client.py | 9 ++++----- e2e/framework-scikit-learn/client.py | 10 ++++------ e2e/framework-tensorflow/client.py | 11 ++++------- e2e/strategies/client.py | 11 ++++------- e2e/strategies/test.py | 8 ++++---- 14 files changed, 53 insertions(+), 74 deletions(-) diff --git a/e2e/bare-client-auth/client.py b/e2e/bare-client-auth/client.py index 4ac8071cc0b4..c7b0d59b8ea5 100644 --- a/e2e/bare-client-auth/client.py +++ b/e2e/bare-client-auth/client.py @@ -1,8 +1,6 @@ -from typing import Optional - import numpy as np -import flwr as fl +from flwr.client import ClientApp, NumPyClient from flwr.common import Context model_params = np.array([1]) @@ -10,7 +8,7 @@ # Define Flower client -class FlowerClient(fl.client.NumPyClient): +class FlowerClient(NumPyClient): def get_parameters(self, config): return model_params @@ -30,6 +28,6 @@ def client_fn(context: Context): return FlowerClient().to_client() -app = fl.client.ClientApp( +app = ClientApp( client_fn=client_fn, ) diff --git a/e2e/bare-https/client.py b/e2e/bare-https/client.py index 0bee51a5b26a..4a682af3aec3 100644 --- a/e2e/bare-https/client.py +++ b/e2e/bare-https/client.py @@ -1,9 +1,8 @@ from pathlib import Path -from typing import Optional import numpy as np -import flwr as fl +from flwr.client import ClientApp, NumPyClient, start_client from flwr.common import Context model_params = np.array([1]) @@ -11,7 +10,7 @@ # Define Flower client -class FlowerClient(fl.client.NumPyClient): +class FlowerClient(NumPyClient): def get_parameters(self, config): return model_params @@ -31,13 +30,13 @@ def client_fn(context: Context): return FlowerClient().to_client() -app = fl.client.ClientApp( +app = ClientApp( client_fn=client_fn, ) if __name__ == "__main__": # Start Flower client - fl.client.start_client( + start_client( server_address="127.0.0.1:8080", client=FlowerClient().to_client(), root_certificates=Path("certificates/ca.crt").read_bytes(), diff --git a/e2e/bare/client.py b/e2e/bare/client.py index 940069f8c46c..943e60d5db9f 100644 --- a/e2e/bare/client.py +++ b/e2e/bare/client.py @@ -1,9 +1,8 @@ from datetime import datetime -from typing import Optional import numpy as np -import flwr as fl +from flwr.client import ClientApp, NumPyClient, start_client from flwr.common import ConfigsRecord, Context SUBSET_SIZE = 1000 @@ -15,7 +14,7 @@ # Define Flower client -class FlowerClient(fl.client.NumPyClient): +class FlowerClient(NumPyClient): def get_parameters(self, config): return model_params @@ -56,12 +55,10 @@ def client_fn(context: Context): return FlowerClient().to_client() -app = fl.client.ClientApp( +app = ClientApp( client_fn=client_fn, ) if __name__ == "__main__": # Start Flower client - fl.client.start_client( - server_address="127.0.0.1:8080", client=FlowerClient().to_client() - ) + start_client(server_address="127.0.0.1:8080", client=FlowerClient().to_client()) diff --git a/e2e/docker/client.py b/e2e/docker/client.py index e54b27417b90..44313c7c3af6 100644 --- a/e2e/docker/client.py +++ b/e2e/docker/client.py @@ -1,6 +1,5 @@ import warnings from collections import OrderedDict -from typing import Optional import torch import torch.nn as nn diff --git a/e2e/framework-fastai/client.py b/e2e/framework-fastai/client.py index 5ab7e6ce4ac6..161b27b5a548 100644 --- a/e2e/framework-fastai/client.py +++ b/e2e/framework-fastai/client.py @@ -1,12 +1,11 @@ import warnings from collections import OrderedDict -from typing import Optional import numpy as np import torch from fastai.vision.all import * -import flwr as fl +from flwr.client import ClientApp, NumPyClient, start_client from flwr.common import Context warnings.filterwarnings("ignore", category=UserWarning) @@ -31,7 +30,7 @@ # Define Flower client -class FlowerClient(fl.client.NumPyClient): +class FlowerClient(NumPyClient): def get_parameters(self, config): return [val.cpu().numpy() for _, val in learn.model.state_dict().items()] @@ -55,14 +54,14 @@ def client_fn(context: Context): return FlowerClient().to_client() -app = fl.client.ClientApp( +app = ClientApp( client_fn=client_fn, ) if __name__ == "__main__": # Start Flower client - fl.client.start_client( + start_client( server_address="127.0.0.1:8080", client=FlowerClient().to_client(), ) diff --git a/e2e/framework-jax/client.py b/e2e/framework-jax/client.py index c909f281a08d..c9ff67b3e38e 100644 --- a/e2e/framework-jax/client.py +++ b/e2e/framework-jax/client.py @@ -1,12 +1,12 @@ """Flower client example using JAX for linear regression.""" -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Tuple import jax import jax_training import numpy as np -import flwr as fl +from flwr.client import ClientApp, NumPyClient, start_client from flwr.common import Context # Load data and determine model shape @@ -15,7 +15,7 @@ model_shape = train_x.shape[1:] -class FlowerClient(fl.client.NumPyClient): +class FlowerClient(NumPyClient): def __init__(self): self.params = jax_training.load_model(model_shape) @@ -53,12 +53,10 @@ def client_fn(context: Context): return FlowerClient().to_client() -app = fl.client.ClientApp( +app = ClientApp( client_fn=client_fn, ) if __name__ == "__main__": # Start Flower client - fl.client.start_client( - server_address="127.0.0.1:8080", client=FlowerClient().to_client() - ) + start_client(server_address="127.0.0.1:8080", client=FlowerClient().to_client()) diff --git a/e2e/framework-opacus/client.py b/e2e/framework-opacus/client.py index 18d845893400..167fa4584e37 100644 --- a/e2e/framework-opacus/client.py +++ b/e2e/framework-opacus/client.py @@ -1,6 +1,5 @@ import math from collections import OrderedDict -from typing import Optional import torch import torch.nn as nn @@ -10,7 +9,7 @@ from torch.utils.data import DataLoader from torchvision.datasets import CIFAR10 -import flwr as fl +from flwr.client import ClientApp, NumPyClient, start_client from flwr.common import Context # Define parameters. @@ -97,7 +96,7 @@ def load_data(): # Define Flower client. -class FlowerClient(fl.client.NumPyClient): +class FlowerClient(NumPyClient): def __init__(self, model) -> None: super().__init__() # Create a privacy engine which will add DP and keep track of the privacy budget. @@ -146,11 +145,11 @@ def client_fn(context: Context): return FlowerClient(model).to_client() -app = fl.client.ClientApp( +app = ClientApp( client_fn=client_fn, ) if __name__ == "__main__": - fl.client.start_client( + start_client( server_address="127.0.0.1:8080", client=FlowerClient(model).to_client() ) diff --git a/e2e/framework-pandas/client.py b/e2e/framework-pandas/client.py index 404f76513c64..0c3300e1dd3f 100644 --- a/e2e/framework-pandas/client.py +++ b/e2e/framework-pandas/client.py @@ -1,9 +1,9 @@ -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Tuple import numpy as np import pandas as pd -import flwr as fl +from flwr.client import ClientApp, NumPyClient, start_client from flwr.common import Context df = pd.read_csv("./data/client.csv") @@ -17,7 +17,7 @@ def compute_hist(df: pd.DataFrame, col_name: str) -> np.ndarray: # Define Flower client -class FlowerClient(fl.client.NumPyClient): +class FlowerClient(NumPyClient): def fit( self, parameters: List[np.ndarray], config: Dict[str, str] ) -> Tuple[List[np.ndarray], int, Dict]: @@ -37,13 +37,13 @@ def client_fn(context: Context): return FlowerClient().to_client() -app = fl.client.ClientApp( +app = ClientApp( client_fn=client_fn, ) if __name__ == "__main__": # Start Flower client - fl.client.start_client( + start_client( server_address="127.0.0.1:8080", client=FlowerClient().to_client(), ) diff --git a/e2e/framework-pytorch-lightning/client.py b/e2e/framework-pytorch-lightning/client.py index 40f4da869bee..bf291a1ca2c5 100644 --- a/e2e/framework-pytorch-lightning/client.py +++ b/e2e/framework-pytorch-lightning/client.py @@ -1,15 +1,14 @@ from collections import OrderedDict -from typing import Optional import mnist import pytorch_lightning as pl import torch -import flwr as fl +from flwr.client import ClientApp, NumPyClient, start_client from flwr.common import Context -class FlowerClient(fl.client.NumPyClient): +class FlowerClient(NumPyClient): def __init__(self, model, train_loader, val_loader, test_loader): self.model = model self.train_loader = train_loader @@ -61,7 +60,7 @@ def client_fn(context: Context): return FlowerClient(model, train_loader, val_loader, test_loader).to_client() -app = fl.client.ClientApp( +app = ClientApp( client_fn=client_fn, ) @@ -73,7 +72,7 @@ def main() -> None: # Flower client client = FlowerClient(model, train_loader, val_loader, test_loader).to_client() - fl.client.start_client(server_address="127.0.0.1:8080", client=client) + start_client(server_address="127.0.0.1:8080", client=client) if __name__ == "__main__": diff --git a/e2e/framework-pytorch/client.py b/e2e/framework-pytorch/client.py index 577511a30103..ab4bc7b5c5b9 100644 --- a/e2e/framework-pytorch/client.py +++ b/e2e/framework-pytorch/client.py @@ -1,7 +1,6 @@ import warnings from collections import OrderedDict from datetime import datetime -from typing import Optional import torch import torch.nn as nn @@ -11,7 +10,7 @@ from torchvision.transforms import Compose, Normalize, ToTensor from tqdm import tqdm -import flwr as fl +from flwr.client import ClientApp, NumPyClient, start_client from flwr.common import ConfigsRecord, Context # ############################################################################# @@ -90,7 +89,7 @@ def load_data(): # Define Flower client -class FlowerClient(fl.client.NumPyClient): +class FlowerClient(NumPyClient): def get_parameters(self, config): return [val.cpu().numpy() for _, val in net.state_dict().items()] @@ -141,14 +140,14 @@ def client_fn(context: Context): return FlowerClient().to_client() -app = fl.client.ClientApp( +app = ClientApp( client_fn=client_fn, ) if __name__ == "__main__": # Start Flower client - fl.client.start_client( + start_client( server_address="127.0.0.1:8080", client=FlowerClient().to_client(), ) diff --git a/e2e/framework-scikit-learn/client.py b/e2e/framework-scikit-learn/client.py index 8626cd36741a..f8f99091f5e5 100644 --- a/e2e/framework-scikit-learn/client.py +++ b/e2e/framework-scikit-learn/client.py @@ -6,7 +6,7 @@ from sklearn.linear_model import LogisticRegression from sklearn.metrics import log_loss -import flwr as fl +from flwr.client import ClientApp, NumPyClient, start_client from flwr.common import Context # Load MNIST dataset from https://www.openml.org/d/554 @@ -28,7 +28,7 @@ # Define Flower client -class FlowerClient(fl.client.NumPyClient): +class FlowerClient(NumPyClient): def get_parameters(self, config): # type: ignore return utils.get_model_parameters(model) @@ -51,12 +51,10 @@ def client_fn(context: Context): return FlowerClient().to_client() -app = fl.client.ClientApp( +app = ClientApp( client_fn=client_fn, ) if __name__ == "__main__": # Start Flower client - fl.client.start_client( - server_address="0.0.0.0:8080", client=FlowerClient().to_client() - ) + start_client(server_address="0.0.0.0:8080", client=FlowerClient().to_client()) diff --git a/e2e/framework-tensorflow/client.py b/e2e/framework-tensorflow/client.py index c7cbeb51aa19..351f495a3acb 100644 --- a/e2e/framework-tensorflow/client.py +++ b/e2e/framework-tensorflow/client.py @@ -1,9 +1,8 @@ import os -from typing import Optional import tensorflow as tf -import flwr as fl +from flwr.client import ClientApp, NumPyClient, start_client from flwr.common import Context SUBSET_SIZE = 1000 @@ -20,7 +19,7 @@ # Define Flower client -class FlowerClient(fl.client.NumPyClient): +class FlowerClient(NumPyClient): def get_parameters(self, config): return model.get_weights() @@ -39,12 +38,10 @@ def client_fn(context: Context): return FlowerClient().to_client() -app = fl.client.ClientApp( +app = ClientApp( client_fn=client_fn, ) if __name__ == "__main__": # Start Flower client - fl.client.start_client( - server_address="127.0.0.1:8080", client=FlowerClient().to_client() - ) + start_client(server_address="127.0.0.1:8080", client=FlowerClient().to_client()) diff --git a/e2e/strategies/client.py b/e2e/strategies/client.py index 668ffdc238e4..0403416cc3b7 100644 --- a/e2e/strategies/client.py +++ b/e2e/strategies/client.py @@ -1,9 +1,8 @@ import os -from typing import Optional import tensorflow as tf -import flwr as fl +from flwr.client import ClientApp, NumPyClient, start_client from flwr.common import Context SUBSET_SIZE = 1000 @@ -35,7 +34,7 @@ def get_model(): # Define Flower client -class FlowerClient(fl.client.NumPyClient): +class FlowerClient(NumPyClient): def get_parameters(self, config): return model.get_weights() @@ -54,13 +53,11 @@ def client_fn(context: Context): return FlowerClient().to_client() -app = fl.client.ClientApp( +app = ClientApp( client_fn=client_fn, ) if __name__ == "__main__": # Start Flower client - fl.client.start_client( - server_address="127.0.0.1:8080", client=FlowerClient().to_client() - ) + start_client(server_address="127.0.0.1:8080", client=FlowerClient().to_client()) diff --git a/e2e/strategies/test.py b/e2e/strategies/test.py index 2d5d04b08125..c567f33b236b 100644 --- a/e2e/strategies/test.py +++ b/e2e/strategies/test.py @@ -1,11 +1,10 @@ from sys import argv -from typing import Optional import tensorflow as tf from client import SUBSET_SIZE, FlowerClient, get_model -import flwr as fl from flwr.common import Context, ndarrays_to_parameters +from flwr.server import ServerConfig from flwr.server.strategy import ( FaultTolerantFedAvg, FedAdagrad, @@ -16,6 +15,7 @@ FedYogi, QFedAvg, ) +from flwr.simulation import start_simulation STRATEGY_LIST = [ FedMedian, @@ -71,10 +71,10 @@ def evaluate(server_round, parameters, config): if start_idx >= OPT_IDX: strat_args["tau"] = 0.01 -hist = fl.simulation.start_simulation( +hist = start_simulation( client_fn=client_fn, num_clients=2, - config=fl.server.ServerConfig(num_rounds=3), + config=ServerConfig(num_rounds=3), strategy=strategy(**strat_args), )