diff --git a/e2e/bare-client-auth/client.py b/e2e/bare-client-auth/client.py index e82f17088bd9..c7b0d59b8ea5 100644 --- a/e2e/bare-client-auth/client.py +++ b/e2e/bare-client-auth/client.py @@ -1,13 +1,14 @@ import numpy as np -import flwr as fl +from flwr.client import ClientApp, NumPyClient +from flwr.common import Context model_params = np.array([1]) objective = 5 # Define Flower client -class FlowerClient(fl.client.NumPyClient): +class FlowerClient(NumPyClient): def get_parameters(self, config): return model_params @@ -23,10 +24,10 @@ def evaluate(self, parameters, config): return loss, 1, {"accuracy": accuracy} -def client_fn(cid): +def client_fn(context: Context): return FlowerClient().to_client() -app = fl.client.ClientApp( +app = ClientApp( client_fn=client_fn, ) diff --git a/e2e/bare-https/client.py b/e2e/bare-https/client.py index 8f5c1412fd01..4a682af3aec3 100644 --- a/e2e/bare-https/client.py +++ b/e2e/bare-https/client.py @@ -2,14 +2,15 @@ import numpy as np -import flwr as fl +from flwr.client import ClientApp, NumPyClient, start_client +from flwr.common import Context model_params = np.array([1]) objective = 5 # Define Flower client -class FlowerClient(fl.client.NumPyClient): +class FlowerClient(NumPyClient): def get_parameters(self, config): return model_params @@ -25,17 +26,17 @@ def evaluate(self, parameters, config): return loss, 1, {"accuracy": accuracy} -def client_fn(cid): +def client_fn(context: Context): return FlowerClient().to_client() -app = fl.client.ClientApp( +app = ClientApp( client_fn=client_fn, ) if __name__ == "__main__": # Start Flower client - fl.client.start_client( + start_client( server_address="127.0.0.1:8080", client=FlowerClient().to_client(), root_certificates=Path("certificates/ca.crt").read_bytes(), diff --git a/e2e/bare/client.py b/e2e/bare/client.py index 402d775ac3a9..943e60d5db9f 100644 --- a/e2e/bare/client.py +++ b/e2e/bare/client.py @@ -2,8 +2,8 @@ import numpy as np -import flwr as fl -from flwr.common import ConfigsRecord +from flwr.client import ClientApp, NumPyClient, start_client +from flwr.common import ConfigsRecord, Context SUBSET_SIZE = 1000 STATE_VAR = "timestamp" @@ -14,7 +14,7 @@ # Define Flower client -class FlowerClient(fl.client.NumPyClient): +class FlowerClient(NumPyClient): def get_parameters(self, config): return model_params @@ -51,16 +51,14 @@ def evaluate(self, parameters, config): ) -def client_fn(cid): +def client_fn(context: Context): return FlowerClient().to_client() -app = fl.client.ClientApp( +app = ClientApp( client_fn=client_fn, ) if __name__ == "__main__": # Start Flower client - fl.client.start_client( - server_address="127.0.0.1:8080", client=FlowerClient().to_client() - ) + start_client(server_address="127.0.0.1:8080", client=FlowerClient().to_client()) diff --git a/e2e/docker/client.py b/e2e/docker/client.py index 8451b810416b..44313c7c3af6 100644 --- a/e2e/docker/client.py +++ b/e2e/docker/client.py @@ -9,6 +9,7 @@ from torchvision.transforms import Compose, Normalize, ToTensor from flwr.client import ClientApp, NumPyClient +from flwr.common import Context # ############################################################################# # 1. Regular PyTorch pipeline: nn.Module, train, test, and DataLoader @@ -122,7 +123,7 @@ def evaluate(self, parameters, config): return loss, len(testloader.dataset), {"accuracy": accuracy} -def client_fn(cid: str): +def client_fn(context: Context): """Create and return an instance of Flower `Client`.""" return FlowerClient().to_client() diff --git a/e2e/framework-fastai/client.py b/e2e/framework-fastai/client.py index 1d98a1134941..161b27b5a548 100644 --- a/e2e/framework-fastai/client.py +++ b/e2e/framework-fastai/client.py @@ -5,7 +5,8 @@ import torch from fastai.vision.all import * -import flwr as fl +from flwr.client import ClientApp, NumPyClient, start_client +from flwr.common import Context warnings.filterwarnings("ignore", category=UserWarning) @@ -29,7 +30,7 @@ # Define Flower client -class FlowerClient(fl.client.NumPyClient): +class FlowerClient(NumPyClient): def get_parameters(self, config): return [val.cpu().numpy() for _, val in learn.model.state_dict().items()] @@ -49,18 +50,18 @@ def evaluate(self, parameters, config): return loss, len(dls.valid), {"accuracy": 1 - error_rate} -def client_fn(cid): +def client_fn(context: Context): return FlowerClient().to_client() -app = fl.client.ClientApp( +app = ClientApp( client_fn=client_fn, ) if __name__ == "__main__": # Start Flower client - fl.client.start_client( + start_client( server_address="127.0.0.1:8080", client=FlowerClient().to_client(), ) diff --git a/e2e/framework-jax/client.py b/e2e/framework-jax/client.py index 347a005d923a..c9ff67b3e38e 100644 --- a/e2e/framework-jax/client.py +++ b/e2e/framework-jax/client.py @@ -6,7 +6,8 @@ import jax_training import numpy as np -import flwr as fl +from flwr.client import ClientApp, NumPyClient, start_client +from flwr.common import Context # Load data and determine model shape train_x, train_y, test_x, test_y = jax_training.load_data() @@ -14,7 +15,7 @@ model_shape = train_x.shape[1:] -class FlowerClient(fl.client.NumPyClient): +class FlowerClient(NumPyClient): def __init__(self): self.params = jax_training.load_model(model_shape) @@ -48,16 +49,14 @@ def evaluate( return float(loss), num_examples, {"loss": float(loss)} -def client_fn(cid): +def client_fn(context: Context): return FlowerClient().to_client() -app = fl.client.ClientApp( +app = ClientApp( client_fn=client_fn, ) if __name__ == "__main__": # Start Flower client - fl.client.start_client( - server_address="127.0.0.1:8080", client=FlowerClient().to_client() - ) + start_client(server_address="127.0.0.1:8080", client=FlowerClient().to_client()) diff --git a/e2e/framework-opacus/client.py b/e2e/framework-opacus/client.py index c9ebe319063a..167fa4584e37 100644 --- a/e2e/framework-opacus/client.py +++ b/e2e/framework-opacus/client.py @@ -9,7 +9,8 @@ from torch.utils.data import DataLoader from torchvision.datasets import CIFAR10 -import flwr as fl +from flwr.client import ClientApp, NumPyClient, start_client +from flwr.common import Context # Define parameters. PARAMS = { @@ -95,7 +96,7 @@ def load_data(): # Define Flower client. -class FlowerClient(fl.client.NumPyClient): +class FlowerClient(NumPyClient): def __init__(self, model) -> None: super().__init__() # Create a privacy engine which will add DP and keep track of the privacy budget. @@ -139,16 +140,16 @@ def evaluate(self, parameters, config): return float(loss), len(testloader), {"accuracy": float(accuracy)} -def client_fn(cid): +def client_fn(context: Context): model = Net() return FlowerClient(model).to_client() -app = fl.client.ClientApp( +app = ClientApp( client_fn=client_fn, ) if __name__ == "__main__": - fl.client.start_client( + start_client( server_address="127.0.0.1:8080", client=FlowerClient(model).to_client() ) diff --git a/e2e/framework-pandas/client.py b/e2e/framework-pandas/client.py index 19e15f5a3b11..0c3300e1dd3f 100644 --- a/e2e/framework-pandas/client.py +++ b/e2e/framework-pandas/client.py @@ -3,7 +3,8 @@ import numpy as np import pandas as pd -import flwr as fl +from flwr.client import ClientApp, NumPyClient, start_client +from flwr.common import Context df = pd.read_csv("./data/client.csv") @@ -16,7 +17,7 @@ def compute_hist(df: pd.DataFrame, col_name: str) -> np.ndarray: # Define Flower client -class FlowerClient(fl.client.NumPyClient): +class FlowerClient(NumPyClient): def fit( self, parameters: List[np.ndarray], config: Dict[str, str] ) -> Tuple[List[np.ndarray], int, Dict]: @@ -32,17 +33,17 @@ def fit( ) -def client_fn(cid): +def client_fn(context: Context): return FlowerClient().to_client() -app = fl.client.ClientApp( +app = ClientApp( client_fn=client_fn, ) if __name__ == "__main__": # Start Flower client - fl.client.start_client( + start_client( server_address="127.0.0.1:8080", client=FlowerClient().to_client(), ) diff --git a/e2e/framework-pytorch-lightning/client.py b/e2e/framework-pytorch-lightning/client.py index fdd55b3dc344..bf291a1ca2c5 100644 --- a/e2e/framework-pytorch-lightning/client.py +++ b/e2e/framework-pytorch-lightning/client.py @@ -4,10 +4,11 @@ import pytorch_lightning as pl import torch -import flwr as fl +from flwr.client import ClientApp, NumPyClient, start_client +from flwr.common import Context -class FlowerClient(fl.client.NumPyClient): +class FlowerClient(NumPyClient): def __init__(self, model, train_loader, val_loader, test_loader): self.model = model self.train_loader = train_loader @@ -51,7 +52,7 @@ def _set_parameters(model, parameters): model.load_state_dict(state_dict, strict=True) -def client_fn(cid): +def client_fn(context: Context): model = mnist.LitAutoEncoder() train_loader, val_loader, test_loader = mnist.load_data() @@ -59,7 +60,7 @@ def client_fn(cid): return FlowerClient(model, train_loader, val_loader, test_loader).to_client() -app = fl.client.ClientApp( +app = ClientApp( client_fn=client_fn, ) @@ -71,7 +72,7 @@ def main() -> None: # Flower client client = FlowerClient(model, train_loader, val_loader, test_loader).to_client() - fl.client.start_client(server_address="127.0.0.1:8080", client=client) + start_client(server_address="127.0.0.1:8080", client=client) if __name__ == "__main__": diff --git a/e2e/framework-pytorch/client.py b/e2e/framework-pytorch/client.py index dbfbfed1ffa7..ab4bc7b5c5b9 100644 --- a/e2e/framework-pytorch/client.py +++ b/e2e/framework-pytorch/client.py @@ -10,8 +10,8 @@ from torchvision.transforms import Compose, Normalize, ToTensor from tqdm import tqdm -import flwr as fl -from flwr.common import ConfigsRecord +from flwr.client import ClientApp, NumPyClient, start_client +from flwr.common import ConfigsRecord, Context # ############################################################################# # 1. Regular PyTorch pipeline: nn.Module, train, test, and DataLoader @@ -89,7 +89,7 @@ def load_data(): # Define Flower client -class FlowerClient(fl.client.NumPyClient): +class FlowerClient(NumPyClient): def get_parameters(self, config): return [val.cpu().numpy() for _, val in net.state_dict().items()] @@ -136,18 +136,18 @@ def set_parameters(model, parameters): return -def client_fn(cid): +def client_fn(context: Context): return FlowerClient().to_client() -app = fl.client.ClientApp( +app = ClientApp( client_fn=client_fn, ) if __name__ == "__main__": # Start Flower client - fl.client.start_client( + start_client( server_address="127.0.0.1:8080", client=FlowerClient().to_client(), ) diff --git a/e2e/framework-scikit-learn/client.py b/e2e/framework-scikit-learn/client.py index b0691e75a79d..24c6617c1289 100644 --- a/e2e/framework-scikit-learn/client.py +++ b/e2e/framework-scikit-learn/client.py @@ -5,7 +5,8 @@ from sklearn.linear_model import LogisticRegression from sklearn.metrics import log_loss -import flwr as fl +from flwr.client import ClientApp, NumPyClient, start_client +from flwr.common import Context # Load MNIST dataset from https://www.openml.org/d/554 (X_train, y_train), (X_test, y_test) = utils.load_mnist() @@ -26,7 +27,7 @@ # Define Flower client -class FlowerClient(fl.client.NumPyClient): +class FlowerClient(NumPyClient): def get_parameters(self, config): # type: ignore return utils.get_model_parameters(model) @@ -45,16 +46,14 @@ def evaluate(self, parameters, config): # type: ignore return loss, len(X_test), {"accuracy": accuracy} -def client_fn(cid): +def client_fn(context: Context): return FlowerClient().to_client() -app = fl.client.ClientApp( +app = ClientApp( client_fn=client_fn, ) if __name__ == "__main__": # Start Flower client - fl.client.start_client( - server_address="0.0.0.0:8080", client=FlowerClient().to_client() - ) + start_client(server_address="0.0.0.0:8080", client=FlowerClient().to_client()) diff --git a/e2e/framework-tensorflow/client.py b/e2e/framework-tensorflow/client.py index 779be0c3746d..351f495a3acb 100644 --- a/e2e/framework-tensorflow/client.py +++ b/e2e/framework-tensorflow/client.py @@ -2,7 +2,8 @@ import tensorflow as tf -import flwr as fl +from flwr.client import ClientApp, NumPyClient, start_client +from flwr.common import Context SUBSET_SIZE = 1000 @@ -18,7 +19,7 @@ # Define Flower client -class FlowerClient(fl.client.NumPyClient): +class FlowerClient(NumPyClient): def get_parameters(self, config): return model.get_weights() @@ -33,16 +34,14 @@ def evaluate(self, parameters, config): return loss, len(x_test), {"accuracy": accuracy} -def client_fn(cid): +def client_fn(context: Context): return FlowerClient().to_client() -app = fl.client.ClientApp( +app = ClientApp( client_fn=client_fn, ) if __name__ == "__main__": # Start Flower client - fl.client.start_client( - server_address="127.0.0.1:8080", client=FlowerClient().to_client() - ) + start_client(server_address="127.0.0.1:8080", client=FlowerClient().to_client()) diff --git a/e2e/strategies/client.py b/e2e/strategies/client.py index 505340e013a5..0403416cc3b7 100644 --- a/e2e/strategies/client.py +++ b/e2e/strategies/client.py @@ -2,7 +2,8 @@ import tensorflow as tf -import flwr as fl +from flwr.client import ClientApp, NumPyClient, start_client +from flwr.common import Context SUBSET_SIZE = 1000 @@ -33,7 +34,7 @@ def get_model(): # Define Flower client -class FlowerClient(fl.client.NumPyClient): +class FlowerClient(NumPyClient): def get_parameters(self, config): return model.get_weights() @@ -48,17 +49,15 @@ def evaluate(self, parameters, config): return loss, len(x_test), {"accuracy": accuracy} -def client_fn(cid): +def client_fn(context: Context): return FlowerClient().to_client() -app = fl.client.ClientApp( +app = ClientApp( client_fn=client_fn, ) if __name__ == "__main__": # Start Flower client - fl.client.start_client( - server_address="127.0.0.1:8080", client=FlowerClient().to_client() - ) + start_client(server_address="127.0.0.1:8080", client=FlowerClient().to_client()) diff --git a/e2e/strategies/test.py b/e2e/strategies/test.py index abf9cdb5a5c7..c567f33b236b 100644 --- a/e2e/strategies/test.py +++ b/e2e/strategies/test.py @@ -3,8 +3,8 @@ import tensorflow as tf from client import SUBSET_SIZE, FlowerClient, get_model -import flwr as fl -from flwr.common import ndarrays_to_parameters +from flwr.common import Context, ndarrays_to_parameters +from flwr.server import ServerConfig from flwr.server.strategy import ( FaultTolerantFedAvg, FedAdagrad, @@ -15,6 +15,7 @@ FedYogi, QFedAvg, ) +from flwr.simulation import start_simulation STRATEGY_LIST = [ FedMedian, @@ -42,8 +43,7 @@ def get_strat(name): init_model = get_model() -def client_fn(cid): - _ = cid +def client_fn(context: Context): return FlowerClient() @@ -71,10 +71,10 @@ def evaluate(server_round, parameters, config): if start_idx >= OPT_IDX: strat_args["tau"] = 0.01 -hist = fl.simulation.start_simulation( +hist = start_simulation( client_fn=client_fn, num_clients=2, - config=fl.server.ServerConfig(num_rounds=3), + config=ServerConfig(num_rounds=3), strategy=strategy(**strat_args), )