Skip to content

Commit

Permalink
Update scikit-learn tutorial with Flower Datasets (#3196)
Browse files Browse the repository at this point in the history
  • Loading branch information
chongshenng authored Apr 20, 2024
1 parent 4f975c3 commit e2a15db
Showing 1 changed file with 36 additions and 21 deletions.
57 changes: 36 additions & 21 deletions doc/source/tutorial-quickstart-scikitlearn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,41 +45,51 @@ However, before setting up the client and server, we will define all functionali
* :code:`get_model_parameters()`
* Returns the parameters of a :code:`sklearn` LogisticRegression model
* :code:`set_model_params()`
* Sets the parameters of a :code:`sklean` LogisticRegression model
* Sets the parameters of a :code:`sklearn` LogisticRegression model
* :code:`set_initial_params()`
* Initializes the model parameters that the Flower server will ask for
* :code:`load_mnist()`
* Loads the MNIST dataset using OpenML
* :code:`shuffle()`
* Shuffles data and its label
* :code:`partition()`
* Splits datasets into a number of partitions

Please check out :code:`utils.py` `here <https://github.com/adap/flower/blob/main/examples/sklearn-logreg-mnist/utils.py>`_ for more details.
The pre-defined functions are used in the :code:`client.py` and imported. The :code:`client.py` also requires to import several packages such as Flower and scikit-learn:

.. code-block:: python
import argparse
import warnings
import flwr as fl
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import log_loss
import flwr as fl
import utils
from flwr_datasets import FederatedDataset
We load the MNIST dataset from `OpenML <https://www.openml.org/search?type=data&sort=runs&id=554>`_, a popular image classification dataset of handwritten digits for machine learning. The utility :code:`utils.load_mnist()` downloads the training and test data. The training set is split afterwards into 10 partitions with :code:`utils.partition()`.
Prior to local training, we need to load the MNIST dataset, a popular image classification dataset of handwritten digits for machine learning, and partition the dataset for FL. This can be conveniently achieved using `Flower Datasets <https://flower.ai/docs/datasets>`_.
The :code:`FederatedDataset.load_partition()` method loads the partitioned training set for each partition ID defined in the :code:`--partition-id` argument.

.. code-block:: python
if __name__ == "__main__":
(X_train, y_train), (X_test, y_test) = utils.load_mnist()
partition_id = np.random.choice(10)
(X_train, y_train) = utils.partition(X_train, y_train, 10)[partition_id]
N_CLIENTS = 10
parser = argparse.ArgumentParser(description="Flower")
parser.add_argument(
"--partition-id",
type=int,
choices=range(0, N_CLIENTS),
required=True,
help="Specifies the artificial data partition",
)
args = parser.parse_args()
partition_id = args.partition_id
fds = FederatedDataset(dataset="mnist", partitioners={"train": N_CLIENTS})
dataset = fds.load_partition(partition_id, "train").with_format("numpy")
X, y = dataset["image"].reshape((len(dataset), -1)), dataset["label"]
X_train, X_test = X[: int(0.8 * len(X))], X[int(0.8 * len(X)) :]
y_train, y_test = y[: int(0.8 * len(y))], y[int(0.8 * len(y)) :]
Next, the logistic regression model is defined and initialized with :code:`utils.set_initial_params()`.
Expand Down Expand Up @@ -168,10 +178,13 @@ First, we import again all required libraries such as Flower and scikit-learn.
from flwr.common import NDArrays, Scalar
from sklearn.metrics import log_loss
from sklearn.linear_model import LogisticRegression
from typing import Dict, Optional
from typing import Dict
from flwr_datasets import FederatedDataset
The number of federated learning rounds is set in :code:`fit_round()` and the evaluation is defined in :code:`get_evaluate_fn()`.
The evaluation function is called after each federated learning round and gives you information about loss and accuracy.
Note that we also make use of Flower Datasets here to load the test split of the MNIST dataset for server-side evaluation.

.. code-block:: python
Expand All @@ -183,7 +196,9 @@ The evaluation function is called after each federated learning round and gives
def get_evaluate_fn(model: LogisticRegression):
"""Return an evaluation function for server-side evaluation."""
_, (X_test, y_test) = utils.load_mnist()
fds = FederatedDataset(dataset="mnist", partitioners={"train": 10})
dataset = fds.load_split("test").with_format("numpy")
X_test, y_test = dataset["image"].reshape((len(dataset), -1)), dataset["label"]
def evaluate(
server_round: int, parameters: NDArrays, config: Dict[str, Scalar]
Expand All @@ -199,7 +214,7 @@ The :code:`main` contains the server-side parameter initialization :code:`utils.

.. code-block:: python
# Start Flower server for five rounds of federated learning
# Start Flower server for three rounds of federated learning
if __name__ == "__main__":
model = LogisticRegression()
utils.set_initial_params(model)
Expand Down

0 comments on commit e2a15db

Please sign in to comment.