Skip to content

Commit

Permalink
Update sklearn docs with Flower Datasets and fixed typos
Browse files Browse the repository at this point in the history
  • Loading branch information
chongshenng committed Apr 2, 2024
1 parent f95d641 commit e9c4ad1
Showing 1 changed file with 36 additions and 21 deletions.
57 changes: 36 additions & 21 deletions doc/source/tutorial-quickstart-scikitlearn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,41 +45,51 @@ However, before setting up the client and server, we will define all functionali
* :code:`get_model_parameters()`
* Returns the parameters of a :code:`sklearn` LogisticRegression model
* :code:`set_model_params()`
* Sets the parameters of a :code:`sklean` LogisticRegression model
* Sets the parameters of a :code:`sklearn` LogisticRegression model
* :code:`set_initial_params()`
* Initializes the model parameters that the Flower server will ask for
* :code:`load_mnist()`
* Loads the MNIST dataset using OpenML
* :code:`shuffle()`
* Shuffles data and its label
* :code:`partition()`
* Splits datasets into a number of partitions

Please check out :code:`utils.py` `here <https://github.com/adap/flower/blob/main/examples/sklearn-logreg-mnist/utils.py>`_ for more details.
The pre-defined functions are used in the :code:`client.py` and imported. The :code:`client.py` also requires to import several packages such as Flower and scikit-learn:

.. code-block:: python
import argparse
import warnings
import flwr as fl
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import log_loss
import flwr as fl
import utils
from flwr_datasets import FederatedDataset
We load the MNIST dataset from `OpenML <https://www.openml.org/search?type=data&sort=runs&id=554>`_, a popular image classification dataset of handwritten digits for machine learning. The utility :code:`utils.load_mnist()` downloads the training and test data. The training set is split afterwards into 10 partitions with :code:`utils.partition()`.
Prior to local training, we need to load the MNIST dataset, a popular image classification dataset of handwritten digits for machine learning, and partition MNIST for FL. This can be conveniently achieved using Flower Datasets.
The :code:`FederatedDataset.load_partition()` loads the partitioned training set for each partition ID defined in the :code:`--partition-id` argument.

.. code-block:: python
if __name__ == "__main__":
(X_train, y_train), (X_test, y_test) = utils.load_mnist()
partition_id = np.random.choice(10)
(X_train, y_train) = utils.partition(X_train, y_train, 10)[partition_id]
N_CLIENTS = 10
parser = argparse.ArgumentParser(description="Flower")
parser.add_argument(
"--partition-id",
type=int,
choices=range(0, N_CLIENTS),
required=True,
help="Specifies the artificial data partition",
)
args = parser.parse_args()
partition_id = args.partition_id
fds = FederatedDataset(dataset="mnist", partitioners={"train": N_CLIENTS})
dataset = fds.load_partition(partition_id, "train").with_format("numpy")
X, y = dataset["image"].reshape((len(dataset), -1)), dataset["label"]
X_train, X_test = X[: int(0.8 * len(X))], X[int(0.8 * len(X)) :]
y_train, y_test = y[: int(0.8 * len(y))], y[int(0.8 * len(y)) :]
Next, the logistic regression model is defined and initialized with :code:`utils.set_initial_params()`.
Expand Down Expand Up @@ -168,10 +178,13 @@ First, we import again all required libraries such as Flower and scikit-learn.
from flwr.common import NDArrays, Scalar
from sklearn.metrics import log_loss
from sklearn.linear_model import LogisticRegression
from typing import Dict, Optional
from typing import Dict
from flwr_datasets import FederatedDataset
The number of federated learning rounds is set in :code:`fit_round()` and the evaluation is defined in :code:`get_evaluate_fn()`.
The evaluation function is called after each federated learning round and gives you information about loss and accuracy.
Note that we also make use of Flower Datasets in :code:`server.py` to load the test splits for centralized evaluation.

.. code-block:: python
Expand All @@ -183,7 +196,9 @@ The evaluation function is called after each federated learning round and gives
def get_evaluate_fn(model: LogisticRegression):
"""Return an evaluation function for server-side evaluation."""
_, (X_test, y_test) = utils.load_mnist()
fds = FederatedDataset(dataset="mnist", partitioners={"train": 10})
dataset = fds.load_split("test").with_format("numpy")
X_test, y_test = dataset["image"].reshape((len(dataset), -1)), dataset["label"]
def evaluate(
server_round: int, parameters: NDArrays, config: Dict[str, Scalar]
Expand All @@ -199,7 +214,7 @@ The :code:`main` contains the server-side parameter initialization :code:`utils.

.. code-block:: python
# Start Flower server for five rounds of federated learning
# Start Flower server for three rounds of federated learning
if __name__ == "__main__":
model = LogisticRegression()
utils.set_initial_params(model)
Expand Down

0 comments on commit e9c4ad1

Please sign in to comment.