From 61ce2ebd4d4b3a93bf5d3a752d645b49dd74df97 Mon Sep 17 00:00:00 2001 From: Charles Beauville Date: Tue, 23 Apr 2024 09:03:13 +0200 Subject: [PATCH] Add sklearn template for `flwr new` (#3251) --- src/py/flwr/cli/new/new.py | 1 + .../templates/app/code/client.sklearn.py.tpl | 94 +++++++++++++++++++ .../templates/app/code/server.sklearn.py.tpl | 17 ++++ .../templates/app/pyproject.sklearn.toml.tpl | 20 ++++ 4 files changed, 132 insertions(+) create mode 100644 src/py/flwr/cli/new/templates/app/code/client.sklearn.py.tpl create mode 100644 src/py/flwr/cli/new/templates/app/code/server.sklearn.py.tpl create mode 100644 src/py/flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl diff --git a/src/py/flwr/cli/new/new.py b/src/py/flwr/cli/new/new.py index 201e145c194e..cbebc7248af0 100644 --- a/src/py/flwr/cli/new/new.py +++ b/src/py/flwr/cli/new/new.py @@ -36,6 +36,7 @@ class MlFramework(str, Enum): NUMPY = "NumPy" PYTORCH = "PyTorch" TENSORFLOW = "TensorFlow" + SKLEARN = "sklearn" class TemplateNotFound(Exception): diff --git a/src/py/flwr/cli/new/templates/app/code/client.sklearn.py.tpl b/src/py/flwr/cli/new/templates/app/code/client.sklearn.py.tpl new file mode 100644 index 000000000000..9181389cad1c --- /dev/null +++ b/src/py/flwr/cli/new/templates/app/code/client.sklearn.py.tpl @@ -0,0 +1,94 @@ +"""$project_name: A Flower / Scikit-Learn app.""" + +import warnings + +import numpy as np +from flwr.client import NumPyClient, ClientApp +from flwr_datasets import FederatedDataset +from sklearn.linear_model import LogisticRegression +from sklearn.metrics import log_loss + + +def get_model_parameters(model): + if model.fit_intercept: + params = [ + model.coef_, + model.intercept_, + ] + else: + params = [model.coef_] + return params + + +def set_model_params(model, params): + model.coef_ = params[0] + if model.fit_intercept: + model.intercept_ = params[1] + return model + + +def set_initial_params(model): + n_classes = 10 # MNIST has 10 classes + n_features = 784 # Number of features in dataset + model.classes_ = np.array([i for i in range(10)]) + + model.coef_ = np.zeros((n_classes, n_features)) + if model.fit_intercept: + model.intercept_ = np.zeros((n_classes,)) + + +class FlowerClient(NumPyClient): + def __init__(self, model, X_train, X_test, y_train, y_test): + self.model = model + self.X_train = X_train + self.X_test = X_test + self.y_train = y_train + self.y_test = y_test + + def get_parameters(self, config): + return get_model_parameters(self.model) + + def fit(self, parameters, config): + set_model_params(self.model, parameters) + + # Ignore convergence failure due to low local epochs + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + self.model.fit(self.X_train, self.y_train) + + return get_model_parameters(self.model), len(self.X_train), {} + + def evaluate(self, parameters, config): + set_model_params(self.model, parameters) + + loss = log_loss(self.y_test, self.model.predict_proba(self.X_test)) + accuracy = self.model.score(self.X_test, self.y_test) + + return loss, len(self.X_test), {"accuracy": accuracy} + +fds = FederatedDataset(dataset="mnist", partitioners={"train": 2}) + +def client_fn(cid: str): + dataset = fds.load_partition(int(cid), "train").with_format("numpy") + + X, y = dataset["image"].reshape((len(dataset), -1)), dataset["label"] + + # Split the on edge data: 80% train, 20% test + X_train, X_test = X[: int(0.8 * len(X))], X[int(0.8 * len(X)) :] + y_train, y_test = y[: int(0.8 * len(y))], y[int(0.8 * len(y)) :] + + # Create LogisticRegression Model + model = LogisticRegression( + penalty="l2", + max_iter=1, # local epoch + warm_start=True, # prevent refreshing weights when fitting + ) + + # Setting initial parameters, akin to model.compile for keras models + set_initial_params(model) + + return FlowerClient(model, X_train, X_test, y_train, y_test).to_client() + + +# Flower ClientApp +app = ClientApp(client_fn=client_fn) diff --git a/src/py/flwr/cli/new/templates/app/code/server.sklearn.py.tpl b/src/py/flwr/cli/new/templates/app/code/server.sklearn.py.tpl new file mode 100644 index 000000000000..266a53ac5794 --- /dev/null +++ b/src/py/flwr/cli/new/templates/app/code/server.sklearn.py.tpl @@ -0,0 +1,17 @@ +"""$project_name: A Flower / Scikit-Learn app.""" + +from flwr.server import ServerApp, ServerConfig +from flwr.server.strategy import FedAvg + + +strategy = FedAvg( + fraction_fit=1.0, + fraction_evaluate=1.0, + min_available_clients=2, +) + +# Create ServerApp +app = ServerApp( + config=ServerConfig(num_rounds=3), + strategy=strategy, +) diff --git a/src/py/flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl b/src/py/flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl new file mode 100644 index 000000000000..2027a491d392 --- /dev/null +++ b/src/py/flwr/cli/new/templates/app/pyproject.sklearn.toml.tpl @@ -0,0 +1,20 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "$project_name" +version = "1.0.0" +description = "" +authors = [ + { name = "The Flower Authors", email = "hello@flower.ai" }, +] +license = {text = "Apache License (2.0)"} +dependencies = [ + "flwr[simulation]>=1.8.0,<2.0", + "flwr-datasets[vision]>=0.0.2,<1.0.0", + "scikit-learn>=1.1.1", +] + +[tool.hatch.build.targets.wheel] +packages = ["."]