From 166e365ae060414dcef3ee985e944a0f3d494b45 Mon Sep 17 00:00:00 2001 From: jafermarq Date: Mon, 26 Aug 2024 09:31:50 +0100 Subject: [PATCH 1/8] init --- src/py/flwr/cli/new/new.py | 1 + .../new/templates/app/code/client.mlx.py.tpl | 48 +++++-------------- .../templates/app/code/client.numpy.py.tpl | 7 +-- .../app/code/client.tensorflow.py.tpl | 3 -- .../new/templates/app/code/server.mlx.py.tpl | 16 ++++++- .../templates/app/code/server.numpy.py.tpl | 10 +++- .../new/templates/app/code/task.numpy.py.tpl | 7 +++ .../new/templates/app/pyproject.mlx.toml.tpl | 1 + 8 files changed, 48 insertions(+), 45 deletions(-) create mode 100644 src/py/flwr/cli/new/templates/app/code/task.numpy.py.tpl diff --git a/src/py/flwr/cli/new/new.py b/src/py/flwr/cli/new/new.py index 862244da9158..0708c09c3b90 100644 --- a/src/py/flwr/cli/new/new.py +++ b/src/py/flwr/cli/new/new.py @@ -253,6 +253,7 @@ def new( MlFramework.HUGGINGFACE.value.lower(), MlFramework.MLX.value.lower(), MlFramework.TENSORFLOW.value.lower(), + MlFramework.NUMPY.value.lower(), ] if framework_str in frameworks_with_tasks: files[f"{import_name}/task.py"] = { diff --git a/src/py/flwr/cli/new/templates/app/code/client.mlx.py.tpl b/src/py/flwr/cli/new/templates/app/code/client.mlx.py.tpl index f3105103842d..dc677d6388d9 100644 --- a/src/py/flwr/cli/new/templates/app/code/client.mlx.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/client.mlx.py.tpl @@ -4,7 +4,7 @@ import mlx.core as mx import mlx.nn as nn import mlx.optimizers as optim from flwr.client import NumPyClient, ClientApp -from flwr.common import Context +from flwr.common.config import UserConfig from $import_name.task import ( batch_iterate, @@ -22,37 +22,24 @@ class FlowerClient(NumPyClient): def __init__( self, data, - num_layers, - hidden_dim, + run_config: UserConfig, num_classes, - batch_size, - learning_rate, - num_epochs, ): - self.num_layers = num_layers - self.hidden_dim = hidden_dim - self.num_classes = num_classes - self.batch_size = batch_size - self.learning_rate = learning_rate - self.num_epochs = num_epochs + num_layers = run_config["num-layers"] + hidden_dim = run_config["hidden-dim"] + input_dim = run_config["input-dim"] + batch_size = run_config["batch-size"] + learning_rate = run_config["lr"] + self.num_epochs = run_config["local-epochs"] self.train_images, self.train_labels, self.test_images, self.test_labels = data - self.model = MLP( - num_layers, self.train_images.shape[-1], hidden_dim, num_classes - ) + self.model = MLP(num_layers, input_dim, hidden_dim, num_classes) self.optimizer = optim.SGD(learning_rate=learning_rate) self.loss_and_grad_fn = nn.value_and_grad(self.model, loss_fn) - self.num_epochs = num_epochs self.batch_size = batch_size - def get_parameters(self, config): - return get_params(self.model) - - def set_parameters(self, parameters): - set_params(self.model, parameters) - def fit(self, parameters, config): - self.set_parameters(parameters) + set_params(self.model, parameters) for _ in range(self.num_epochs): for X, y in batch_iterate( self.batch_size, self.train_images, self.train_labels @@ -60,10 +47,10 @@ class FlowerClient(NumPyClient): _, grads = self.loss_and_grad_fn(self.model, X, y) self.optimizer.update(self.model, grads) mx.eval(self.model.parameters(), self.optimizer.state) - return self.get_parameters(config={}), len(self.train_images), {} + return get_params(self.model), len(self.train_images), {} def evaluate(self, parameters, config): - self.set_parameters(parameters) + set_params(self.model, parameters) accuracy = eval_fn(self.model, self.test_images, self.test_labels) loss = loss_fn(self.model, self.test_images, self.test_labels) return loss.item(), len(self.test_images), {"accuracy": accuracy.item()} @@ -73,19 +60,10 @@ def client_fn(context: Context): partition_id = context.node_config["partition-id"] num_partitions = context.node_config["num-partitions"] data = load_data(partition_id, num_partitions) - - num_layers = context.run_config["num-layers"] - hidden_dim = context.run_config["hidden-dim"] num_classes = 10 - batch_size = context.run_config["batch-size"] - learning_rate = context.run_config["lr"] - num_epochs = context.run_config["local-epochs"] # Return Client instance - return FlowerClient( - data, num_layers, hidden_dim, num_classes, batch_size, learning_rate, num_epochs - ).to_client() - + return FlowerClient(data, context.run_config, num_classes).to_client() # Flower ClientApp app = ClientApp( diff --git a/src/py/flwr/cli/new/templates/app/code/client.numpy.py.tpl b/src/py/flwr/cli/new/templates/app/code/client.numpy.py.tpl index e35c3c78f6e2..852b6a9daa70 100644 --- a/src/py/flwr/cli/new/templates/app/code/client.numpy.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/client.numpy.py.tpl @@ -4,13 +4,14 @@ from flwr.client import NumPyClient, ClientApp from flwr.common import Context import numpy as np +from $import_name.task import get_dummy_model + class FlowerClient(NumPyClient): - def get_parameters(self, config): - return [np.ones((1, 1))] def fit(self, parameters, config): - return ([np.ones((1, 1))], 1, {}) + model = get_dummy_model() + return ([model], 1, {}) def evaluate(self, parameters, config): return float(0.0), 1, {"accuracy": float(1.0)} diff --git a/src/py/flwr/cli/new/templates/app/code/client.tensorflow.py.tpl b/src/py/flwr/cli/new/templates/app/code/client.tensorflow.py.tpl index 48ee3b4f5356..f8c148691561 100644 --- a/src/py/flwr/cli/new/templates/app/code/client.tensorflow.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/client.tensorflow.py.tpl @@ -17,9 +17,6 @@ class FlowerClient(NumPyClient): self.batch_size = batch_size self.verbose = verbose - def get_parameters(self, config): - return self.model.get_weights() - def fit(self, parameters, config): self.model.set_weights(parameters) self.model.fit( diff --git a/src/py/flwr/cli/new/templates/app/code/server.mlx.py.tpl b/src/py/flwr/cli/new/templates/app/code/server.mlx.py.tpl index c99c72574813..9ab9e6dfa532 100644 --- a/src/py/flwr/cli/new/templates/app/code/server.mlx.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/server.mlx.py.tpl @@ -1,16 +1,28 @@ """$project_name: A Flower / $framework_str app.""" -from flwr.common import Context +from flwr.common import Context, ndarrays_to_parameters from flwr.server import ServerApp, ServerAppComponents, ServerConfig from flwr.server.strategy import FedAvg +from $import_name.task import get_params, MLP + def server_fn(context: Context): # Read from config num_rounds = context.run_config["num-server-rounds"] + num_classes = 10 + num_layers = context.run_config["num-layers"] + input_dim = context.run_config["input-dim"] + hidden_dim = context.run_config["hidden-dim"] + + # Initialize global model + model = MLP(num_layers, input_dim, hidden_dim, num_classes) + params = get_params(model) + initial_parameters = ndarrays_to_parameters(params) + # Define strategy - strategy = FedAvg() + strategy = FedAvg(initial_parameters=initial_parameters) config = ServerConfig(num_rounds=num_rounds) return ServerAppComponents(strategy=strategy, config=config) diff --git a/src/py/flwr/cli/new/templates/app/code/server.numpy.py.tpl b/src/py/flwr/cli/new/templates/app/code/server.numpy.py.tpl index c99c72574813..c090bd52303a 100644 --- a/src/py/flwr/cli/new/templates/app/code/server.numpy.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/server.numpy.py.tpl @@ -1,16 +1,22 @@ """$project_name: A Flower / $framework_str app.""" -from flwr.common import Context +from flwr.common import Context, ndarrays_to_parameters from flwr.server import ServerApp, ServerAppComponents, ServerConfig from flwr.server.strategy import FedAvg +from $import_name.task import get_dummy_model + def server_fn(context: Context): # Read from config num_rounds = context.run_config["num-server-rounds"] + # Initial model + model = get_dummy_model() + dummy_parameters = ndarrays_to_parameters([model]) + # Define strategy - strategy = FedAvg() + strategy = FedAvg(initial_parameters=dummy_parameters) config = ServerConfig(num_rounds=num_rounds) return ServerAppComponents(strategy=strategy, config=config) diff --git a/src/py/flwr/cli/new/templates/app/code/task.numpy.py.tpl b/src/py/flwr/cli/new/templates/app/code/task.numpy.py.tpl new file mode 100644 index 000000000000..afbf71d108fc --- /dev/null +++ b/src/py/flwr/cli/new/templates/app/code/task.numpy.py.tpl @@ -0,0 +1,7 @@ +"""$project_name: A Flower / $framework_str app.""" + +import numpy as np + +def get_dummy_model(): + return np.ones((1,1)) + diff --git a/src/py/flwr/cli/new/templates/app/pyproject.mlx.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.mlx.toml.tpl index c1bfe804c709..3830c4b3fb15 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.mlx.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.mlx.toml.tpl @@ -28,6 +28,7 @@ clientapp = "$import_name.client_app:app" num-server-rounds = 3 local-epochs = 1 num-layers = 2 +input-dim = 784 # 28*28 hidden-dim = 32 batch-size = 256 lr = 0.1 From 3e200a07e8c70ac9d1d81dc0857f33c37c61c2c7 Mon Sep 17 00:00:00 2001 From: jafermarq Date: Mon, 26 Aug 2024 09:49:05 +0100 Subject: [PATCH 2/8] updated jax template --- .../new/templates/app/code/client.jax.py.tpl | 24 +++++++------------ .../new/templates/app/code/server.jax.py.tpl | 10 ++++++-- .../new/templates/app/pyproject.jax.toml.tpl | 1 + 3 files changed, 18 insertions(+), 17 deletions(-) diff --git a/src/py/flwr/cli/new/templates/app/code/client.jax.py.tpl b/src/py/flwr/cli/new/templates/app/code/client.jax.py.tpl index 046de57f3cf3..bc0c9ff66489 100644 --- a/src/py/flwr/cli/new/templates/app/code/client.jax.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/client.jax.py.tpl @@ -17,37 +17,31 @@ from $import_name.task import ( # Define Flower Client and client_fn class FlowerClient(NumPyClient): - def __init__(self): + def __init__(self, input_dim): self.train_x, self.train_y, self.test_x, self.test_y = load_data() self.grad_fn = jax.grad(loss_fn) - model_shape = self.train_x.shape[1:] - - self.params = load_model(model_shape) - - def get_parameters(self, config): - return get_params(self.params) - - def set_parameters(self, parameters): - set_params(self.params, parameters) + self.params = load_model((input_dim,)) def fit(self, parameters, config): - self.set_parameters(parameters) + set_params(self.params, parameters) self.params, loss, num_examples = train( self.params, self.grad_fn, self.train_x, self.train_y ) - parameters = self.get_parameters(config={}) - return parameters, num_examples, {"loss": float(loss)} + return get_params(self.params), num_examples, {"loss": float(loss)} def evaluate(self, parameters, config): - self.set_parameters(parameters) + set_params(self.params, parameters) loss, num_examples = evaluation( self.params, self.grad_fn, self.test_x, self.test_y ) return float(loss), num_examples, {"loss": float(loss)} def client_fn(context: Context): + + input_dim = context.run_config["input-dim"] + # Return Client instance - return FlowerClient().to_client() + return FlowerClient(input_dim).to_client() # Flower ClientApp diff --git a/src/py/flwr/cli/new/templates/app/code/server.jax.py.tpl b/src/py/flwr/cli/new/templates/app/code/server.jax.py.tpl index 514185fde970..268e1a05d11a 100644 --- a/src/py/flwr/cli/new/templates/app/code/server.jax.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/server.jax.py.tpl @@ -1,16 +1,22 @@ """$project_name: A Flower / $framework_str app.""" -from flwr.common import Context +from flwr.common import Context, ndarrays_to_parameters from flwr.server.strategy import FedAvg from flwr.server import ServerApp, ServerAppComponents, ServerConfig +from $import_name.task import get_params, load_model def server_fn(context: Context): # Read from config num_rounds = context.run_config["num-server-rounds"] + input_dim = context.run_config["input-dim"] + + # Initialize global model + params = get_params(load_model((input_dim,))) + initial_parameters = ndarrays_to_parameters(params) # Define strategy - strategy = FedAvg() + strategy = FedAvg(initial_parameters=initial_parameters) config = ServerConfig(num_rounds=num_rounds) return ServerAppComponents(strategy=strategy, config=config) diff --git a/src/py/flwr/cli/new/templates/app/pyproject.jax.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.jax.toml.tpl index 31fff1c2a4c8..36d65139edff 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.jax.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.jax.toml.tpl @@ -26,6 +26,7 @@ clientapp = "$import_name.client_app:app" [tool.flwr.app.config] num-server-rounds = 3 +input-dim = 3 [tool.flwr.federations] default = "local-simulation" From 056a762b8c165511a607b516fa66b311436d5705 Mon Sep 17 00:00:00 2001 From: jafermarq Date: Sun, 1 Sep 2024 10:22:15 +0100 Subject: [PATCH 3/8] complete sklearn teamplate --- src/py/flwr/cli/new/new.py | 1 + .../templates/app/code/client.sklearn.py.tpl | 59 ++++------------- .../templates/app/code/server.sklearn.py.tpl | 15 ++++- .../templates/app/code/task.sklearn.py.tpl | 64 +++++++++++++++++++ .../templates/app/pyproject.sklearn.toml.tpl | 2 + 5 files changed, 93 insertions(+), 48 deletions(-) create mode 100644 src/py/flwr/cli/new/templates/app/code/task.sklearn.py.tpl diff --git a/src/py/flwr/cli/new/new.py b/src/py/flwr/cli/new/new.py index 1a67039856d3..f4d7080e7306 100644 --- a/src/py/flwr/cli/new/new.py +++ b/src/py/flwr/cli/new/new.py @@ -254,6 +254,7 @@ def new( MlFramework.MLX.value.lower(), MlFramework.TENSORFLOW.value.lower(), MlFramework.NUMPY.value.lower(), + MlFramework.SKLEARN.value.lower(), ] if framework_str in frameworks_with_tasks: files[f"{import_name}/task.py"] = { diff --git a/src/py/flwr/cli/new/templates/app/code/client.sklearn.py.tpl b/src/py/flwr/cli/new/templates/app/code/client.sklearn.py.tpl index 2d3d1c7f163a..e64e2e44e45a 100644 --- a/src/py/flwr/cli/new/templates/app/code/client.sklearn.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/client.sklearn.py.tpl @@ -2,41 +2,17 @@ import warnings -import numpy as np from flwr.client import NumPyClient, ClientApp from flwr.common import Context -from flwr_datasets import FederatedDataset -from sklearn.linear_model import LogisticRegression from sklearn.metrics import log_loss - -def get_model_parameters(model): - if model.fit_intercept: - params = [ - model.coef_, - model.intercept_, - ] - else: - params = [model.coef_] - return params - - -def set_model_params(model, params): - model.coef_ = params[0] - if model.fit_intercept: - model.intercept_ = params[1] - return model - - -def set_initial_params(model): - n_classes = 10 # MNIST has 10 classes - n_features = 784 # Number of features in dataset - model.classes_ = np.array([i for i in range(10)]) - - model.coef_ = np.zeros((n_classes, n_features)) - if model.fit_intercept: - model.intercept_ = np.zeros((n_classes,)) - +from $import_name.task import ( + get_model, + get_model_params, + load_data, + set_initial_params, + set_model_params, +) class FlowerClient(NumPyClient): def __init__(self, model, X_train, X_test, y_train, y_test): @@ -46,9 +22,6 @@ class FlowerClient(NumPyClient): self.y_train = y_train self.y_test = y_test - def get_parameters(self, config): - return get_model_parameters(self.model) - def fit(self, parameters, config): set_model_params(self.model, parameters) @@ -57,7 +30,7 @@ class FlowerClient(NumPyClient): warnings.simplefilter("ignore") self.model.fit(self.X_train, self.y_train) - return get_model_parameters(self.model), len(self.X_train), {} + return get_model_params(self.model), len(self.X_train), {} def evaluate(self, parameters, config): set_model_params(self.model, parameters) @@ -71,21 +44,13 @@ class FlowerClient(NumPyClient): def client_fn(context: Context): partition_id = context.node_config["partition-id"] num_partitions = context.node_config["num-partitions"] - fds = FederatedDataset(dataset="mnist", partitioners={"train": num_partitions}) - dataset = fds.load_partition(partition_id, "train").with_format("numpy") - - X, y = dataset["image"].reshape((len(dataset), -1)), dataset["label"] - # Split the on edge data: 80% train, 20% test - X_train, X_test = X[: int(0.8 * len(X))], X[int(0.8 * len(X)) :] - y_train, y_test = y[: int(0.8 * len(y))], y[int(0.8 * len(y)) :] + X_train, X_test, y_train, y_test = load_data(partition_id, num_partitions) # Create LogisticRegression Model - model = LogisticRegression( - penalty="l2", - max_iter=1, # local epoch - warm_start=True, # prevent refreshing weights when fitting - ) + penalty = context.run_config["penalty"] + local_epochs = context.run_config["local-epochs"] + model = get_model(penalty, local_epochs) # Setting initial parameters, akin to model.compile for keras models set_initial_params(model) diff --git a/src/py/flwr/cli/new/templates/app/code/server.sklearn.py.tpl b/src/py/flwr/cli/new/templates/app/code/server.sklearn.py.tpl index 678ba9326229..0227d4b464be 100644 --- a/src/py/flwr/cli/new/templates/app/code/server.sklearn.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/server.sklearn.py.tpl @@ -1,19 +1,32 @@ """$project_name: A Flower / $framework_str app.""" -from flwr.common import Context +from flwr.common import Context, ndarrays_to_parameters from flwr.server import ServerApp, ServerAppComponents, ServerConfig from flwr.server.strategy import FedAvg +from sktut.task import get_model, get_model_params, set_initial_params + def server_fn(context: Context): # Read from config num_rounds = context.run_config["num-server-rounds"] + # Create LogisticRegression Model + penalty = context.run_config["penalty"] + local_epochs = context.run_config["local-epochs"] + model = get_model(penalty, local_epochs) + + # Setting initial parameters, akin to model.compile for keras models + set_initial_params(model) + + initial_parameters = ndarrays_to_parameters(get_model_params(model)) + # Define strategy strategy = FedAvg( fraction_fit=1.0, fraction_evaluate=1.0, min_available_clients=2, + initial_parameters=initial_parameters, ) config = ServerConfig(num_rounds=num_rounds) diff --git a/src/py/flwr/cli/new/templates/app/code/task.sklearn.py.tpl b/src/py/flwr/cli/new/templates/app/code/task.sklearn.py.tpl new file mode 100644 index 000000000000..6a9d4d49b778 --- /dev/null +++ b/src/py/flwr/cli/new/templates/app/code/task.sklearn.py.tpl @@ -0,0 +1,64 @@ +"""$project_name: A Flower / $framework_str app.""" + +import numpy as np +from sklearn.linear_model import LogisticRegression +from flwr_datasets import FederatedDataset +from flwr_datasets.partitioner import IidPartitioner + + +fds = None # Cache FederatedDataset + + +def load_data(partition_id: int, num_partitions: int): + """Load partition MNIST data.""" + # Only initialize `FederatedDataset` once + global fds + if fds is None: + partitioner = IidPartitioner(num_partitions=num_partitions) + fds = FederatedDataset( + dataset="mnist", + partitioners={"train": partitioner}, + ) + + dataset = fds.load_partition(partition_id, "train").with_format("numpy") + + X, y = dataset["image"].reshape((len(dataset), -1)), dataset["label"] + + # Split the on edge data: 80% train, 20% test + X_train, X_test = X[: int(0.8 * len(X))], X[int(0.8 * len(X)) :] + y_train, y_test = y[: int(0.8 * len(y))], y[int(0.8 * len(y)) :] + + return X_train, X_test, y_train, y_test + +def get_model(penalty: str, local_epochs: int): + + return LogisticRegression( + penalty=penalty, + max_iter=local_epochs, + warm_start=True, + ) + +def get_model_params(model): + if model.fit_intercept: + params = [ + model.coef_, + model.intercept_, + ] + else: + params = [model.coef_] + return params + +def set_model_params(model, params): + model.coef_ = params[0] + if model.fit_intercept: + model.intercept_ = params[1] + return model + +def set_initial_params(model): + n_classes = 10 # MNIST has 10 classes + n_features = 784 # Number of features in dataset + model.classes_ = np.array([i for i in range(10)]) + + model.coef_ = np.zeros((n_classes, n_features)) + if model.fit_intercept: + model.intercept_ = np.zeros((n_classes,)) \ No newline at end of file diff --git a/src/py/flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl index 2b5778fec9a7..c351531fda0f 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl @@ -25,6 +25,8 @@ clientapp = "$import_name.client_app:app" [tool.flwr.app.config] num-server-rounds = 3 +penalty = "l2" +local-epochs = 1 [tool.flwr.federations] default = "local-simulation" From eb64333ef4af67cd807cd799a99023cb1d1f2cc8 Mon Sep 17 00:00:00 2001 From: jafermarq Date: Sun, 1 Sep 2024 10:36:35 +0100 Subject: [PATCH 4/8] fix to prev --- src/py/flwr/cli/new/templates/app/code/server.sklearn.py.tpl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/py/flwr/cli/new/templates/app/code/server.sklearn.py.tpl b/src/py/flwr/cli/new/templates/app/code/server.sklearn.py.tpl index 0227d4b464be..97e60045aeb9 100644 --- a/src/py/flwr/cli/new/templates/app/code/server.sklearn.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/server.sklearn.py.tpl @@ -4,7 +4,7 @@ from flwr.common import Context, ndarrays_to_parameters from flwr.server import ServerApp, ServerAppComponents, ServerConfig from flwr.server.strategy import FedAvg -from sktut.task import get_model, get_model_params, set_initial_params +from $import_name.task import get_model, get_model_params, set_initial_params def server_fn(context: Context): From 092bd9aba3c8243c3e24af1ebd7775612735ae9b Mon Sep 17 00:00:00 2001 From: jafermarq Date: Sun, 1 Sep 2024 11:08:33 +0100 Subject: [PATCH 5/8] hf update --- .../app/code/client.huggingface.py.tpl | 43 ++++++++----------- .../app/code/server.huggingface.py.tpl | 18 +++++++- .../app/code/task.huggingface.py.tpl | 21 ++++----- .../app/pyproject.huggingface.toml.tpl | 7 +++ 4 files changed, 51 insertions(+), 38 deletions(-) diff --git a/src/py/flwr/cli/new/templates/app/code/client.huggingface.py.tpl b/src/py/flwr/cli/new/templates/app/code/client.huggingface.py.tpl index 3041a69e3aaa..b22acece3d3a 100644 --- a/src/py/flwr/cli/new/templates/app/code/client.huggingface.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/client.huggingface.py.tpl @@ -1,18 +1,11 @@ """$project_name: A Flower / $framework_str app.""" +import torch from flwr.client import ClientApp, NumPyClient from flwr.common import Context from transformers import AutoModelForSequenceClassification -from $import_name.task import ( - get_weights, - load_data, - set_weights, - train, - test, - CHECKPOINT, - DEVICE, -) +from $import_name.task import get_weights, load_data, set_weights, test, train # Flower client @@ -22,37 +15,39 @@ class FlowerClient(NumPyClient): self.trainloader = trainloader self.testloader = testloader self.local_epochs = local_epochs - - def get_parameters(self, config): - return get_weights(self.net) - - def set_parameters(self, parameters): - set_weights(self.net, parameters) + self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + self.net.to(self.device) def fit(self, parameters, config): - self.set_parameters(parameters) + set_weights(self.net, parameters) train( self.net, self.trainloader, epochs=self.local_epochs, + device=self.device ) - return self.get_parameters(config={}), len(self.trainloader), {} + return get_weights(self.net), len(self.trainloader), {} def evaluate(self, parameters, config): - self.set_parameters(parameters) - loss, accuracy = test(self.net, self.testloader) + set_weights(self.net, parameters) + loss, accuracy = test(self.net, self.testloader, self.device) return float(loss), len(self.testloader), {"accuracy": accuracy} def client_fn(context: Context): - # Load model and data - net = AutoModelForSequenceClassification.from_pretrained( - CHECKPOINT, num_labels=2 - ).to(DEVICE) + # Get this client's dataset partition partition_id = context.node_config["partition-id"] num_partitions = context.node_config["num-partitions"] - trainloader, valloader = load_data(partition_id, num_partitions) + model_name = context.run_config["model-name"] + trainloader, valloader = load_data(partition_id, num_partitions, model_name) + + # Load model + num_labels = context.run_config["num-labels"] + net = AutoModelForSequenceClassification.from_pretrained( + model_name, num_labels=num_labels + ) + local_epochs = context.run_config["local-epochs"] # Return Client instance diff --git a/src/py/flwr/cli/new/templates/app/code/server.huggingface.py.tpl b/src/py/flwr/cli/new/templates/app/code/server.huggingface.py.tpl index 5491f6616160..d1918230b2ab 100644 --- a/src/py/flwr/cli/new/templates/app/code/server.huggingface.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/server.huggingface.py.tpl @@ -1,18 +1,32 @@ """$project_name: A Flower / $framework_str app.""" -from flwr.common import Context -from flwr.server.strategy import FedAvg +from flwr.common import Context, ndarrays_to_parameters from flwr.server import ServerApp, ServerAppComponents, ServerConfig +from flwr.server.strategy import FedAvg +from transformers import AutoModelForSequenceClassification + +from $import_name.task import get_weights def server_fn(context: Context): # Read from config num_rounds = context.run_config["num-server-rounds"] + # Initialize global model + model_name = context.run_config["model-name"] + num_labels = context.run_config["num-labels"] + net = AutoModelForSequenceClassification.from_pretrained( + model_name, num_labels=num_labels + ) + + weights = get_weights(net) + initial_parameters = ndarrays_to_parameters(weights) + # Define strategy strategy = FedAvg( fraction_fit=1.0, fraction_evaluate=1.0, + initial_parameters=initial_parameters, ) config = ServerConfig(num_rounds=num_rounds) diff --git a/src/py/flwr/cli/new/templates/app/code/task.huggingface.py.tpl b/src/py/flwr/cli/new/templates/app/code/task.huggingface.py.tpl index ad52e2c3fe21..5a604f7e73a5 100644 --- a/src/py/flwr/cli/new/templates/app/code/task.huggingface.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/task.huggingface.py.tpl @@ -5,23 +5,20 @@ from collections import OrderedDict import torch from evaluate import load as load_metric +from flwr_datasets import FederatedDataset +from flwr_datasets.partitioner import IidPartitioner from torch.optim import AdamW from torch.utils.data import DataLoader from transformers import AutoTokenizer, DataCollatorWithPadding -from flwr_datasets import FederatedDataset -from flwr_datasets.partitioner import IidPartitioner - - warnings.filterwarnings("ignore", category=UserWarning) -DEVICE = torch.device("cpu") -CHECKPOINT = "distilbert-base-uncased" # transformer model checkpoint +warnings.filterwarnings("ignore", category=FutureWarning) fds = None # Cache FederatedDataset -def load_data(partition_id: int, num_partitions: int): +def load_data(partition_id: int, num_partitions: int, model_name: str): """Load IMDB data (training and eval)""" # Only initialize `FederatedDataset` once global fds @@ -35,7 +32,7 @@ def load_data(partition_id: int, num_partitions: int): # Divide data: 80% train, 20% test partition_train_test = partition.train_test_split(test_size=0.2, seed=42) - tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT) + tokenizer = AutoTokenizer.from_pretrained(model_name) def tokenize_function(examples): return tokenizer(examples["text"], truncation=True) @@ -59,12 +56,12 @@ def load_data(partition_id: int, num_partitions: int): return trainloader, testloader -def train(net, trainloader, epochs): +def train(net, trainloader, epochs, device): optimizer = AdamW(net.parameters(), lr=5e-5) net.train() for _ in range(epochs): for batch in trainloader: - batch = {k: v.to(DEVICE) for k, v in batch.items()} + batch = {k: v.to(device) for k, v in batch.items()} outputs = net(**batch) loss = outputs.loss loss.backward() @@ -72,12 +69,12 @@ def train(net, trainloader, epochs): optimizer.zero_grad() -def test(net, testloader): +def test(net, testloader, device): metric = load_metric("accuracy") loss = 0 net.eval() for batch in testloader: - batch = {k: v.to(DEVICE) for k, v in batch.items()} + batch = {k: v.to(device) for k, v in batch.items()} with torch.no_grad(): outputs = net(**batch) logits = outputs.logits diff --git a/src/py/flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl index 15dc2af87a3f..d12afdeb5154 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.huggingface.toml.tpl @@ -30,9 +30,16 @@ clientapp = "$import_name.client_app:app" [tool.flwr.app.config] num-server-rounds = 3 local-epochs = 1 +model-name = "distilbert-base-uncased" +num-labels = 2 [tool.flwr.federations] default = "localhost" [tool.flwr.federations.localhost] options.num-supernodes = 10 + +[tool.flwr.federations.localhost-gpu] +options.num-supernodes = 10 +options.backend.client-resources.num-cpus = 4 # each ClientApp assumes to use 4CPUs +options.backend.client-resources.num-gpus = 0.5 # at most 2 ClientApp will run in a given GPU \ No newline at end of file From 8a061f572047e057c876703ced63a8d503d45a09 Mon Sep 17 00:00:00 2001 From: yan-gao-GY Date: Mon, 2 Sep 2024 17:38:54 +0100 Subject: [PATCH 6/8] Update FlowerTune --- src/py/flwr/cli/new/new.py | 1 - .../app/code/flwr_tune/client_app.py.tpl | 24 ++++++------------- .../app/code/flwr_tune/models.py.tpl | 24 ++++++++++++++++++- .../app/code/flwr_tune/server_app.py.tpl | 3 +-- .../app/pyproject.flowertune.toml.tpl | 2 +- 5 files changed, 32 insertions(+), 22 deletions(-) diff --git a/src/py/flwr/cli/new/new.py b/src/py/flwr/cli/new/new.py index f4d7080e7306..c806acffd2fc 100644 --- a/src/py/flwr/cli/new/new.py +++ b/src/py/flwr/cli/new/new.py @@ -196,7 +196,6 @@ def new( f"{import_name}/client_app.py": { "template": "app/code/flwr_tune/client_app.py.tpl" }, - f"{import_name}/app.py": {"template": "app/code/flwr_tune/app.py.tpl"}, f"{import_name}/models.py": { "template": "app/code/flwr_tune/models.py.tpl" }, diff --git a/src/py/flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl b/src/py/flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl index 19d1e20baccd..415898ba117b 100644 --- a/src/py/flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/flwr_tune/client_app.py.tpl @@ -2,7 +2,6 @@ import os import warnings -from collections import OrderedDict from typing import Dict, Tuple import torch @@ -11,7 +10,7 @@ from flwr.common import Context from flwr.common.config import unflatten_dict from flwr.common.typing import NDArrays, Scalar from omegaconf import DictConfig -from peft import get_peft_model_state_dict, set_peft_model_state_dict + from transformers import TrainingArguments from trl import SFTTrainer @@ -20,7 +19,12 @@ from $import_name.dataset import ( load_data, replace_keys, ) -from $import_name.models import cosine_annealing, get_model +from $import_name.models import ( + cosine_annealing, + get_model, + set_parameters, + get_parameters, +) # Avoid warnings os.environ["TOKENIZERS_PARALLELISM"] = "true" @@ -92,20 +96,6 @@ class FlowerClient(NumPyClient): ) -def set_parameters(model, parameters: NDArrays) -> None: - """Change the parameters of the model using the given ones.""" - peft_state_dict_keys = get_peft_model_state_dict(model).keys() - params_dict = zip(peft_state_dict_keys, parameters) - state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict}) - set_peft_model_state_dict(model, state_dict) - - -def get_parameters(model) -> NDArrays: - """Return the parameters of the current net.""" - state_dict = get_peft_model_state_dict(model) - return [val.cpu().numpy() for _, val in state_dict.items()] - - def client_fn(context: Context) -> FlowerClient: """Create a Flower client representing a single organization.""" partition_id = context.node_config["partition-id"] diff --git a/src/py/flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl b/src/py/flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl index a548ba9abeef..3f3f95c8b8eb 100644 --- a/src/py/flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/flwr_tune/models.py.tpl @@ -4,10 +4,18 @@ import math import torch from omegaconf import DictConfig -from peft import LoraConfig, get_peft_model +from collections import OrderedDict +from peft import ( + LoraConfig, + get_peft_model, + get_peft_model_state_dict, + set_peft_model_state_dict, +) from peft.utils import prepare_model_for_kbit_training from transformers import AutoModelForCausalLM, BitsAndBytesConfig +from flwr.common.typing import NDArrays + def cosine_annealing( current_round: int, @@ -54,3 +62,17 @@ def get_model(model_cfg: DictConfig): model.config.use_cache = False return get_peft_model(model, peft_config) + + +def set_parameters(model, parameters: NDArrays) -> None: + """Change the parameters of the model using the given ones.""" + peft_state_dict_keys = get_peft_model_state_dict(model).keys() + params_dict = zip(peft_state_dict_keys, parameters) + state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict}) + set_peft_model_state_dict(model, state_dict) + + +def get_parameters(model) -> NDArrays: + """Return the parameters of the current net.""" + state_dict = get_peft_model_state_dict(model) + return [val.cpu().numpy() for _, val in state_dict.items()] diff --git a/src/py/flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl b/src/py/flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl index 586b929be06c..7d4de0f73dbf 100644 --- a/src/py/flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/flwr_tune/server_app.py.tpl @@ -8,8 +8,7 @@ from flwr.common.config import unflatten_dict from flwr.server import ServerApp, ServerAppComponents, ServerConfig from omegaconf import DictConfig -from $import_name.client_app import get_parameters, set_parameters -from $import_name.models import get_model +from $import_name.models import get_model, get_parameters, set_parameters from $import_name.dataset import replace_keys from $import_name.strategy import FlowerTuneLlm diff --git a/src/py/flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl index 5046a6f89f27..20ba53247b3a 100644 --- a/src/py/flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl +++ b/src/py/flwr/cli/new/templates/app/pyproject.flowertune.toml.tpl @@ -8,7 +8,7 @@ version = "1.0.0" description = "" license = "Apache-2.0" dependencies = [ - "flwr[simulation]>=1.10.0", + "flwr[simulation]>=1.11.0", "flwr-datasets>=0.3.0", "trl==0.8.1", "bitsandbytes==0.43.0", From 2f1b7edb4ff5092e74154c7fd78ec0021910ab1d Mon Sep 17 00:00:00 2001 From: jafermarq Date: Wed, 11 Sep 2024 11:10:40 +0100 Subject: [PATCH 7/8] fix automerge --- .../new/templates/app/code/server.huggingface.py.tpl | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/src/py/flwr/cli/new/templates/app/code/server.huggingface.py.tpl b/src/py/flwr/cli/new/templates/app/code/server.huggingface.py.tpl index 4d25092f33f0..16f94f0a64e9 100644 --- a/src/py/flwr/cli/new/templates/app/code/server.huggingface.py.tpl +++ b/src/py/flwr/cli/new/templates/app/code/server.huggingface.py.tpl @@ -23,16 +23,6 @@ def server_fn(context: Context): weights = get_weights(net) initial_parameters = ndarrays_to_parameters(weights) - # Initialize global model - model_name = context.run_config["model-name"] - num_labels = context.run_config["num-labels"] - net = AutoModelForSequenceClassification.from_pretrained( - model_name, num_labels=num_labels - ) - - weights = get_weights(net) - initial_parameters = ndarrays_to_parameters(weights) - # Define strategy strategy = FedAvg( fraction_fit=fraction_fit, From 8eeb0a2e31fa1469857c4a48f19df43ba04fb644 Mon Sep 17 00:00:00 2001 From: jafermarq Date: Thu, 12 Sep 2024 18:27:08 +0100 Subject: [PATCH 8/8] revert --- src/py/flwr/cli/new/new.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/py/flwr/cli/new/new.py b/src/py/flwr/cli/new/new.py index 5fc54a35b3d0..2e97ae8aded8 100644 --- a/src/py/flwr/cli/new/new.py +++ b/src/py/flwr/cli/new/new.py @@ -236,13 +236,13 @@ def new( # Depending on the framework, generate task.py file frameworks_with_tasks = [ - MlFramework.PYTORCH.value.lower(), - MlFramework.JAX.value.lower(), - MlFramework.HUGGINGFACE.value.lower(), - MlFramework.MLX.value.lower(), - MlFramework.TENSORFLOW.value.lower(), - MlFramework.NUMPY.value.lower(), - MlFramework.SKLEARN.value.lower(), + MlFramework.PYTORCH.value, + MlFramework.JAX.value, + MlFramework.HUGGINGFACE.value, + MlFramework.MLX.value, + MlFramework.TENSORFLOW.value, + MlFramework.NUMPY.value, + MlFramework.SKLEARN.value, ] if framework_str in frameworks_with_tasks: files[f"{import_name}/task.py"] = {