diff --git a/doc/source/tutorial-quickstart-scikitlearn.rst b/doc/source/tutorial-quickstart-scikitlearn.rst index d1d47dc37f19..93322842cc70 100644 --- a/doc/source/tutorial-quickstart-scikitlearn.rst +++ b/doc/source/tutorial-quickstart-scikitlearn.rst @@ -45,41 +45,51 @@ However, before setting up the client and server, we will define all functionali * :code:`get_model_parameters()` * Returns the parameters of a :code:`sklearn` LogisticRegression model * :code:`set_model_params()` - * Sets the parameters of a :code:`sklean` LogisticRegression model + * Sets the parameters of a :code:`sklearn` LogisticRegression model * :code:`set_initial_params()` * Initializes the model parameters that the Flower server will ask for -* :code:`load_mnist()` - * Loads the MNIST dataset using OpenML -* :code:`shuffle()` - * Shuffles data and its label -* :code:`partition()` - * Splits datasets into a number of partitions Please check out :code:`utils.py` `here `_ for more details. The pre-defined functions are used in the :code:`client.py` and imported. The :code:`client.py` also requires to import several packages such as Flower and scikit-learn: .. code-block:: python + import argparse import warnings - import flwr as fl - import numpy as np - + from sklearn.linear_model import LogisticRegression from sklearn.metrics import log_loss - + + import flwr as fl import utils + from flwr_datasets import FederatedDataset - -We load the MNIST dataset from `OpenML `_, a popular image classification dataset of handwritten digits for machine learning. The utility :code:`utils.load_mnist()` downloads the training and test data. The training set is split afterwards into 10 partitions with :code:`utils.partition()`. +Prior to local training, we need to load the MNIST dataset, a popular image classification dataset of handwritten digits for machine learning, and partition the dataset for FL. This can be conveniently achieved using `Flower Datasets `_. +The :code:`FederatedDataset.load_partition()` method loads the partitioned training set for each partition ID defined in the :code:`--partition-id` argument. .. code-block:: python if __name__ == "__main__": - - (X_train, y_train), (X_test, y_test) = utils.load_mnist() - - partition_id = np.random.choice(10) - (X_train, y_train) = utils.partition(X_train, y_train, 10)[partition_id] + N_CLIENTS = 10 + + parser = argparse.ArgumentParser(description="Flower") + parser.add_argument( + "--partition-id", + type=int, + choices=range(0, N_CLIENTS), + required=True, + help="Specifies the artificial data partition", + ) + args = parser.parse_args() + partition_id = args.partition_id + + fds = FederatedDataset(dataset="mnist", partitioners={"train": N_CLIENTS}) + + dataset = fds.load_partition(partition_id, "train").with_format("numpy") + X, y = dataset["image"].reshape((len(dataset), -1)), dataset["label"] + + X_train, X_test = X[: int(0.8 * len(X))], X[int(0.8 * len(X)) :] + y_train, y_test = y[: int(0.8 * len(y))], y[int(0.8 * len(y)) :] Next, the logistic regression model is defined and initialized with :code:`utils.set_initial_params()`. @@ -168,10 +178,13 @@ First, we import again all required libraries such as Flower and scikit-learn. from flwr.common import NDArrays, Scalar from sklearn.metrics import log_loss from sklearn.linear_model import LogisticRegression - from typing import Dict, Optional + from typing import Dict + + from flwr_datasets import FederatedDataset The number of federated learning rounds is set in :code:`fit_round()` and the evaluation is defined in :code:`get_evaluate_fn()`. The evaluation function is called after each federated learning round and gives you information about loss and accuracy. +Note that we also make use of Flower Datasets here to load the test split of the MNIST dataset for server-side evaluation. .. code-block:: python @@ -183,7 +196,9 @@ The evaluation function is called after each federated learning round and gives def get_evaluate_fn(model: LogisticRegression): """Return an evaluation function for server-side evaluation.""" - _, (X_test, y_test) = utils.load_mnist() + fds = FederatedDataset(dataset="mnist", partitioners={"train": 10}) + dataset = fds.load_split("test").with_format("numpy") + X_test, y_test = dataset["image"].reshape((len(dataset), -1)), dataset["label"] def evaluate( server_round: int, parameters: NDArrays, config: Dict[str, Scalar] @@ -199,7 +214,7 @@ The :code:`main` contains the server-side parameter initialization :code:`utils. .. code-block:: python - # Start Flower server for five rounds of federated learning + # Start Flower server for three rounds of federated learning if __name__ == "__main__": model = LogisticRegression() utils.set_initial_params(model)