diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS
index 8c635c516450..7519f227af81 100644
--- a/.github/CODEOWNERS
+++ b/.github/CODEOWNERS
@@ -22,10 +22,10 @@ README.md @jafermarq @tanertopal @danieljanes
/src/py/flwr/cli/new/templates @jafermarq @tanertopal @danieljanes
# Changelog
-/doc/source/ref-changelog.md @jafermarq @tanertopal @danieljanes
+/framework/docs/source/ref-changelog.md @jafermarq @tanertopal @danieljanes
# Translations
-/doc/locales @charlesbvll @tanertopal @danieljanes
+/framework/docs/locales @charlesbvll @tanertopal @danieljanes
# GitHub Actions and Workflows
/.github/workflows @Robert-Steiner @tanertopal @danieljanes
diff --git a/.github/workflows/baselines.yml b/.github/workflows/baselines.yml
index c4485fe72d10..c051b1f355c7 100644
--- a/.github/workflows/baselines.yml
+++ b/.github/workflows/baselines.yml
@@ -36,7 +36,7 @@ jobs:
FILTER+=$(echo "$DIR: ${BASELINES_PATH}/**\n")
done < <(find baselines -maxdepth 1 \
-name ".*" -prune -o \
- -path "baselines/doc" -prune -o \
+ -path "baselines/docs" -prune -o \
-path "baselines/dev" -prune -o \
-path "baselines/baseline_template" -prune -o \
-path "baselines/flwr_baselines" -prune -o \
diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml
index 3f010a4c37b0..f98dfa43dd18 100644
--- a/.github/workflows/docs.yml
+++ b/.github/workflows/docs.yml
@@ -45,7 +45,8 @@ jobs:
AWS_SECRET_ACCESS_KEY: ${{ secrets. AWS_SECRET_ACCESS_KEY }}
DOCS_BUCKET: flower.ai
run: |
- aws s3 sync --delete --exclude ".*" --exclude "v/*" --cache-control "no-cache" ./doc/build/html/ s3://${{ env.DOCS_BUCKET }}/docs/framework
- aws s3 sync --delete --exclude ".*" --exclude "v/*" --cache-control "no-cache" ./baselines/doc/build/html/ s3://${{ env.DOCS_BUCKET }}/docs/baselines
- aws s3 sync --delete --exclude ".*" --exclude "v/*" --cache-control "no-cache" ./examples/doc/build/html/ s3://${{ env.DOCS_BUCKET }}/docs/examples
- aws s3 sync --delete --exclude ".*" --exclude "v/*" --cache-control "no-cache" ./datasets/doc/build/html/ s3://${{ env.DOCS_BUCKET }}/docs/datasets
+ cp -r doc/build/html/v* framework/docs/build/html
+ aws s3 sync --delete --exclude ".*" --exclude "v/*" --cache-control "no-cache" ./framework/docs/build/html/ s3://${{ env.DOCS_BUCKET }}/docs/framework
+ aws s3 sync --delete --exclude ".*" --exclude "v/*" --cache-control "no-cache" ./baselines/docs/build/html/ s3://${{ env.DOCS_BUCKET }}/docs/baselines
+ aws s3 sync --delete --exclude ".*" --exclude "v/*" --cache-control "no-cache" ./examples/docs/build/html/ s3://${{ env.DOCS_BUCKET }}/docs/examples
+ aws s3 sync --delete --exclude ".*" --exclude "v/*" --cache-control "no-cache" ./datasets/docs/build/html/ s3://${{ env.DOCS_BUCKET }}/docs/datasets
diff --git a/.github/workflows/update_translations.yml b/.github/workflows/update_translations.yml
index 9a5391a40438..ec748befb33c 100644
--- a/.github/workflows/update_translations.yml
+++ b/.github/workflows/update_translations.yml
@@ -38,7 +38,7 @@ jobs:
- name: Update text and translations for all locales
run: |
- cd doc
+ cd framework/docs
make update-text
for langDir in locales/*; do
if [ -d "$langDir" ]; then
@@ -52,7 +52,7 @@ jobs:
run: |
git config --local user.email "41898282+github-actions[bot]@users.noreply.github.com"
git config --local user.name "github-actions[bot]"
- git add doc/locales
+ git add framework/docs/locales
git commit -m "Update text and language files"
continue-on-error: true
diff --git a/.gitignore b/.gitignore
index b0962c2783f0..9af2d29f2315 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,9 +1,9 @@
# Flower
.flower_ops
data/
-doc/source/api_documentation
-doc/source/_build
-doc/source/dataset/
+framework/docs/source/api_documentation
+framework/docs/source/_build
+framework/docs/source/dataset/
flwr_logs
.cache
@@ -17,6 +17,9 @@ examples/**/dataset/**
# Flower Baselines
baselines/datasets/leaf
+# Exclude ee package
+src/py/flwr/ee
+
# macOS
.DS_Store
@@ -183,3 +186,6 @@ app/src/main/assets
/captures
.externalNativeBuild
.cxx
+
+# Pyright
+pyrightconfig.json
diff --git a/README.md b/README.md
index b5c58c6838f0..30b54786244a 100644
--- a/README.md
+++ b/README.md
@@ -48,23 +48,23 @@ Flower's goal is to make federated learning accessible to everyone. This series
0. **What is Federated Learning?**
- [](https://colab.research.google.com/github/adap/flower/blob/main/doc/source/tutorial-series-what-is-federated-learning.ipynb) (or open the [Jupyter Notebook](https://github.com/adap/flower/blob/main/doc/source/tutorial-series-what-is-federated-learning.ipynb))
+ [](https://colab.research.google.com/github/adap/flower/blob/main/framework/docs/source/tutorial-series-what-is-federated-learning.ipynb) (or open the [Jupyter Notebook](https://github.com/adap/flower/blob/main/framework/docs/source/tutorial-series-what-is-federated-learning.ipynb))
1. **An Introduction to Federated Learning**
- [](https://colab.research.google.com/github/adap/flower/blob/main/doc/source/tutorial-series-get-started-with-flower-pytorch.ipynb) (or open the [Jupyter Notebook](https://github.com/adap/flower/blob/main/doc/source/tutorial-series-get-started-with-flower-pytorch.ipynb))
+ [](https://colab.research.google.com/github/adap/flower/blob/main/framework/docs/source/tutorial-series-get-started-with-flower-pytorch.ipynb) (or open the [Jupyter Notebook](https://github.com/adap/flower/blob/main/framework/docs/source/tutorial-series-get-started-with-flower-pytorch.ipynb))
2. **Using Strategies in Federated Learning**
- [](https://colab.research.google.com/github/adap/flower/blob/main/doc/source/tutorial-series-use-a-federated-learning-strategy-pytorch.ipynb) (or open the [Jupyter Notebook](https://github.com/adap/flower/blob/main/doc/source/tutorial-series-use-a-federated-learning-strategy-pytorch.ipynb))
+ [](https://colab.research.google.com/github/adap/flower/blob/main/framework/docs/source/tutorial-series-use-a-federated-learning-strategy-pytorch.ipynb) (or open the [Jupyter Notebook](https://github.com/adap/flower/blob/main/framework/docs/source/tutorial-series-use-a-federated-learning-strategy-pytorch.ipynb))
3. **Building Strategies for Federated Learning**
- [](https://colab.research.google.com/github/adap/flower/blob/main/doc/source/tutorial-series-build-a-strategy-from-scratch-pytorch.ipynb) (or open the [Jupyter Notebook](https://github.com/adap/flower/blob/main/doc/source/tutorial-series-build-a-strategy-from-scratch-pytorch.ipynb))
+ [](https://colab.research.google.com/github/adap/flower/blob/main/framework/docs/source/tutorial-series-build-a-strategy-from-scratch-pytorch.ipynb) (or open the [Jupyter Notebook](https://github.com/adap/flower/blob/main/framework/docs/source/tutorial-series-build-a-strategy-from-scratch-pytorch.ipynb))
4. **Custom Clients for Federated Learning**
- [](https://colab.research.google.com/github/adap/flower/blob/main/doc/source/tutorial-series-customize-the-client-pytorch.ipynb) (or open the [Jupyter Notebook](https://github.com/adap/flower/blob/main/doc/source/tutorial-series-customize-the-client-pytorch.ipynb))
+ [](https://colab.research.google.com/github/adap/flower/blob/main/doc/source/tutorial-series-customize-the-client-pytorch.ipynb) (or open the [Jupyter Notebook](https://github.com/adap/flower/blob/main/framework/docs/source/tutorial-series-customize-the-client-pytorch.ipynb))
Stay tuned, more tutorials are coming soon. Topics include **Privacy and Security in Federated Learning**, and **Scaling Federated Learning**.
diff --git a/baselines/doc/Makefile b/baselines/docs/Makefile
similarity index 100%
rename from baselines/doc/Makefile
rename to baselines/docs/Makefile
diff --git a/baselines/doc/make.bat b/baselines/docs/make.bat
similarity index 100%
rename from baselines/doc/make.bat
rename to baselines/docs/make.bat
diff --git a/baselines/doc/source/.gitignore b/baselines/docs/source/.gitignore
similarity index 100%
rename from baselines/doc/source/.gitignore
rename to baselines/docs/source/.gitignore
diff --git a/baselines/doc/source/_static/custom.css b/baselines/docs/source/_static/custom.css
similarity index 100%
rename from baselines/doc/source/_static/custom.css
rename to baselines/docs/source/_static/custom.css
diff --git a/baselines/doc/source/_static/favicon.ico b/baselines/docs/source/_static/favicon.ico
similarity index 100%
rename from baselines/doc/source/_static/favicon.ico
rename to baselines/docs/source/_static/favicon.ico
diff --git a/baselines/doc/source/_static/flower-logo.png b/baselines/docs/source/_static/flower-logo.png
similarity index 100%
rename from baselines/doc/source/_static/flower-logo.png
rename to baselines/docs/source/_static/flower-logo.png
diff --git a/baselines/doc/source/_static/view-gh.png b/baselines/docs/source/_static/view-gh.png
similarity index 100%
rename from baselines/doc/source/_static/view-gh.png
rename to baselines/docs/source/_static/view-gh.png
diff --git a/baselines/doc/source/_templates/base.html b/baselines/docs/source/_templates/base.html
similarity index 100%
rename from baselines/doc/source/_templates/base.html
rename to baselines/docs/source/_templates/base.html
diff --git a/baselines/doc/source/_templates/sidebar/search.html b/baselines/docs/source/_templates/sidebar/search.html
similarity index 100%
rename from baselines/doc/source/_templates/sidebar/search.html
rename to baselines/docs/source/_templates/sidebar/search.html
diff --git a/baselines/doc/source/conf.py b/baselines/docs/source/conf.py
similarity index 95%
rename from baselines/doc/source/conf.py
rename to baselines/docs/source/conf.py
index 9d5d4ea7fc92..574c4ccf0e81 100644
--- a/baselines/doc/source/conf.py
+++ b/baselines/docs/source/conf.py
@@ -12,12 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
+"""Config for Sphinx docs."""
import datetime
import os
import sys
-from sphinx.application import ConfigError
+
# Configuration file for the Sphinx documentation builder.
#
@@ -120,11 +121,15 @@
nbsphinx_execute = "never"
-_open_in_colab_button = """
+colab_link = (
+ "https://colab.research.google.com/github/adap/flower/blob/main/"
+ "doc/source/{{ env.doc2path(env.docname, base=None) }}"
+)
+_open_in_colab_button = f"""
.. raw:: html
-
+
"""
diff --git a/baselines/doc/source/how-to-contribute-baselines.rst b/baselines/docs/source/how-to-contribute-baselines.rst
similarity index 100%
rename from baselines/doc/source/how-to-contribute-baselines.rst
rename to baselines/docs/source/how-to-contribute-baselines.rst
diff --git a/baselines/doc/source/how-to-use-baselines.rst b/baselines/docs/source/how-to-use-baselines.rst
similarity index 100%
rename from baselines/doc/source/how-to-use-baselines.rst
rename to baselines/docs/source/how-to-use-baselines.rst
diff --git a/baselines/doc/source/index.rst b/baselines/docs/source/index.rst
similarity index 100%
rename from baselines/doc/source/index.rst
rename to baselines/docs/source/index.rst
diff --git a/baselines/fedrep/README.md b/baselines/fedrep/README.md
index ece30edf0943..d67730e4065a 100644
--- a/baselines/fedrep/README.md
+++ b/baselines/fedrep/README.md
@@ -36,91 +36,90 @@ dataset: [CIFAR-10, CIFAR-100]
These two models are modified from the [official repo](https://github.com/rahulv0205/fedrep_experiments)'s. To be clear that, in the official models, there is no BN layers. However, without BN layer helping, training will definitely collapse.
-Please see how models are implemented using a so called model_manager and model_split class since FedRep uses head and base layers in a neural network. These classes are defined in the `models.py` file and thereafter called when building new models in the directory `/implemented_models`. Please, extend and add new models as you wish.
+Please see how models are implemented using a so called model_manager and model_split class since FedRep uses head and base layers in a neural network. These classes are defined in the `models.py` file. Please, extend and add new models as you wish.
**Dataset:** CIFAR10, CIFAR-100. CIFAR10/100 will be partitioned based on number of classes for data that each client shall receive e.g. 4 allocated classes could be [1, 3, 5, 9].
-**Training Hyperparameters:** The hyperparameters can be found in `conf/base.yaml` file which is the configuration file for the main script.
-
-| Description | Default Value |
-| --------------------- | ----------------------------------- |
-| `num_clients` | `100` |
-| `num_rounds` | `100` |
-| `num_local_epochs` | `5` |
-| `num_rep_epochs` | `1` |
-| `enable_finetune` | `False` |
-| `num_finetune_epochs` | `5` |
-| `use_cuda` | `true` |
-| `specified_device` | `null` |
-| `client resources` | `{'num_cpus': 2, 'num_gpus': 0.5 }` |
-| `learning_rate` | `0.01` |
-| `batch_size` | `50` |
-| `model_name` | `cnncifar10` |
-| `algorithm` | `fedrep` |
+**Training Hyperparameters:** The hyperparameters can be found in `pyproject.toml` file under the `[tool.flwr.app.config]` section.
+
+| Description | Default Value |
+|-------------------------|-------------------------------------|
+| `num-server-rounds` | `100` |
+| `num-local-epochs` | `5` |
+| `num-rep-epochs` | `1` |
+| `enable-finetune` | `False` |
+| `num-finetune-epochs` | `5` |
+| `use-cuda` | `true` |
+| `specified-cuda-device` | `null` |
+| `client-resources` | `{'num-cpus': 2, 'num-gpus': 0.5 }` |
+| `learning-rate` | `0.01` |
+| `batch-size` | `50` |
+| `model-name` | `cnncifar10` |
+| `algorithm` | `fedrep` |
## Environment Setup
-To construct the Python environment follow these steps:
+Create a new Python environment using [pyenv](https://github.com/pyenv/pyenv) and [virtualenv plugin](https://github.com/pyenv/pyenv-virtualenv), then install the baseline project:
```bash
-# Set Python 3.10
-pyenv local 3.10.12
-# Tell poetry to use python 3.10
-poetry env use 3.10.12
+# Create the environment
+pyenv virtualenv 3.10.12 fedrep-env
-# Install the base Poetry environment
-poetry install
+# Activate it
+pyenv activate fedrep-env
-# Activate the environment
-poetry shell
+# Then install the project
+pip install -e .
```
## Running the Experiments
```
-python -m fedrep.main # this will run using the default settings in the `conf/base.yaml`
+flwr run . # this will run using the default settings in the `pyproject.toml`
```
-While the config files contain a large number of settings, the ones below are the main ones you'd likely want to modify to .
+While the config files contain a large number of settings, the ones below are the main ones you'd likely want to modify.
```bash
-algorithm: fedavg, fedrep # these are currently supported
-dataset.name: cifar10, cifar100
-dataset.num_classes: 2, 5, 20 (only for CIFAR-100)
-model_name: cnncifar10, cnncifar100
+algorithm = "fedavg", "fedrep" # these are currently supported
+dataset-name = "cifar10", "cifar100"
+dataset-split-num-classes = 2, 5, 20 (only for CIFAR-100)
+model-name = "cnncifar10", "cnncifar100"
```
-
+See also for instance the configuration files for CIFAR10 and CIFAR100 under the `conf` directory.
## Expected Results
+The default algorithm used by all configuration files is `fedrep`. To use `fedavg` please change the `algorithm` property in the respective configuration file. The default federated environment consists of 100 clients.
+
+When the execution completes, a new directory `results` will be created with a json file that contains the running configurations and the results per round.
+
+> [!NOTE]
+> All plots shown below are generated using the `docs/make_plots.py` script. The script reads all json files generated by the baseline inside the `results` directory.
### CIFAR-10 (100, 2)
```
-python -m fedrep.main --config-name cifar10_100_2 algorithm=fedrep
-python -m fedrep.main --config-name cifar10_100_2 algorithm=fedavg
+flwr run . --run-config conf/cifar10_2.toml
```
### CIFAR-10 (100, 5)
```
-python -m fedrep.main --config-name cifar10_100_5 algorithm=fedrep
-python -m fedrep.main --config-name cifar10_100_5 algorithm=fedavg
+flwr run . --run-config conf/cifar10_5.toml
```
### CIFAR-100 (100, 5)
```
-python -m fedrep.main --config-name cifar100_100_5 algorithm=fedrep
-python -m fedrep.main --config-name cifar100_100_5 algorithm=fedavg
+flwr run . --run-config conf/cifar100_5.toml
```
### CIFAR-100 (100, 20)
```
-python -m fedrep.main --config-name cifar100_100_20 algorithm=fedrep
-python -m fedrep.main --config-name cifar100_100_20 algorithm=fedavg
+flwr run . --run-config conf/cifar100_20.toml
```
-
\ No newline at end of file
+
diff --git a/baselines/fedrep/_static/cifar100_100_20.png b/baselines/fedrep/_static/cifar100_100_20.png
index 2421f15ac6c6..3f97d08a4dff 100644
Binary files a/baselines/fedrep/_static/cifar100_100_20.png and b/baselines/fedrep/_static/cifar100_100_20.png differ
diff --git a/baselines/fedrep/_static/cifar100_100_5.png b/baselines/fedrep/_static/cifar100_100_5.png
index 17f25eb480c4..f22ddcedaf61 100644
Binary files a/baselines/fedrep/_static/cifar100_100_5.png and b/baselines/fedrep/_static/cifar100_100_5.png differ
diff --git a/baselines/fedrep/_static/cifar10_100_2.png b/baselines/fedrep/_static/cifar10_100_2.png
index 75ee48b2c970..1e7321f85e01 100644
Binary files a/baselines/fedrep/_static/cifar10_100_2.png and b/baselines/fedrep/_static/cifar10_100_2.png differ
diff --git a/baselines/fedrep/_static/cifar10_100_5.png b/baselines/fedrep/_static/cifar10_100_5.png
index 1d20a953f9c4..54a9143d8fd9 100644
Binary files a/baselines/fedrep/_static/cifar10_100_5.png and b/baselines/fedrep/_static/cifar10_100_5.png differ
diff --git a/baselines/fedrep/conf/cifar100_20.toml b/baselines/fedrep/conf/cifar100_20.toml
new file mode 100644
index 000000000000..2b6bb5f5e1eb
--- /dev/null
+++ b/baselines/fedrep/conf/cifar100_20.toml
@@ -0,0 +1,11 @@
+algorithm = "fedrep"
+
+# model specs
+model-name = "cnncifar100"
+
+# dataset specs
+dataset-name = "cifar100"
+dataset-split = "sample"
+dataset-split-num-classes = 20
+dataset-split-seed = 42
+dataset-split-fraction = 0.83
diff --git a/baselines/fedrep/conf/cifar100_5.toml b/baselines/fedrep/conf/cifar100_5.toml
new file mode 100644
index 000000000000..9903ea374028
--- /dev/null
+++ b/baselines/fedrep/conf/cifar100_5.toml
@@ -0,0 +1,11 @@
+algorithm = "fedrep"
+
+# model specs
+model-name = "cnncifar100"
+
+# dataset specs
+dataset-name = "cifar100"
+dataset-split = "sample"
+dataset-split-num-classes = 5
+dataset-split-seed = 42
+dataset-split-fraction = 0.83
diff --git a/baselines/fedrep/conf/cifar10_2.toml b/baselines/fedrep/conf/cifar10_2.toml
new file mode 100644
index 000000000000..307ba7101bd1
--- /dev/null
+++ b/baselines/fedrep/conf/cifar10_2.toml
@@ -0,0 +1,8 @@
+algorithm = "fedrep"
+
+# dataset specs
+dataset-name = "cifar10"
+dataset-split = "sample"
+dataset-split-num-classes = 2
+dataset-split-seed = 42
+dataset-split-fraction = 0.83
diff --git a/baselines/fedrep/conf/cifar10_5.toml b/baselines/fedrep/conf/cifar10_5.toml
new file mode 100644
index 000000000000..cd09c9e07ec2
--- /dev/null
+++ b/baselines/fedrep/conf/cifar10_5.toml
@@ -0,0 +1,8 @@
+algorithm = "fedrep"
+
+# dataset specs
+dataset-name = "cifar10"
+dataset-split = "sample"
+dataset-split-num-classes = 5
+dataset-split-seed = 42
+dataset-split-fraction = 0.83
diff --git a/baselines/fedrep/docs/make_plots.py b/baselines/fedrep/docs/make_plots.py
new file mode 100644
index 000000000000..9474e87a2471
--- /dev/null
+++ b/baselines/fedrep/docs/make_plots.py
@@ -0,0 +1,50 @@
+"""Generate plots from json files."""
+
+import json
+import os
+from typing import List, Tuple
+
+import matplotlib.pyplot as plt
+
+# Get the current working directory
+DIR = os.path.dirname(os.path.abspath(__file__))
+
+
+def read_from_results(path: str) -> Tuple[str, str, List[float], str, str]:
+ """Load the json file with recorded configurations and results."""
+ with open(path, "r", encoding="UTF-8") as fin:
+ data = json.load(fin)
+ algorithm = data["run_config"]["algorithm"]
+ model = data["run_config"]["model-name"]
+ accuracies = [res["accuracy"] * 100 for res in data["round_res"]]
+ dataset = data["run_config"]["dataset-name"]
+ num_classes = data["run_config"]["dataset-split-num-classes"]
+
+ return algorithm, model, accuracies, dataset, num_classes
+
+
+def make_plot(dir_path: str, plt_title: str) -> None:
+ """Given a directory with json files, generated a plot using the provided title."""
+ plt.figure()
+ with os.scandir(dir_path) as files:
+ for file in files:
+ file_name = os.path.join(dir_path, file.name)
+ print(file_name, flush=True)
+ algo, m, acc, d, n = read_from_results(file_name)
+ rounds = [i + 1 for i in range(len(acc))]
+ print(f"Max accuracy ({algo}): {max(acc):.2f}")
+ plt.plot(rounds, acc, label=f"{algo}-{d}-{n}classes")
+ plt.xlabel("Rounds")
+ plt.ylabel("Accuracy")
+ plt.title(plt_title)
+ plt.grid()
+ plt.legend()
+ plt.savefig(os.path.join(DIR, f"{plt_title}-{algo}"))
+
+
+if __name__ == "__main__":
+ # Plot results generated by the baseline.
+ # Combine them into a full file path.
+ res_dir = os.path.join(DIR, "../results/")
+ title = "Federated Accuracy over Rounds"
+ make_plot(res_dir, plt_title=title)
diff --git a/baselines/fedrep/fedrep/__init__.py b/baselines/fedrep/fedrep/__init__.py
index a5e567b59135..f2dbc04ee34e 100644
--- a/baselines/fedrep/fedrep/__init__.py
+++ b/baselines/fedrep/fedrep/__init__.py
@@ -1 +1 @@
-"""Template baseline package."""
+"""fedrep: A Flower Baseline."""
diff --git a/baselines/fedrep/fedrep/base_model.py b/baselines/fedrep/fedrep/base_model.py
index e6a74c01bf9b..82a18cec2622 100644
--- a/baselines/fedrep/fedrep/base_model.py
+++ b/baselines/fedrep/fedrep/base_model.py
@@ -1,61 +1,23 @@
-"""Abstract class for splitting a model into body and head."""
+"""fedrep: A Flower Baseline."""
-import os
+import collections
from abc import ABC, abstractmethod
-from typing import Any, Dict, List, Optional, OrderedDict, Tuple, Union
+from typing import Any, Dict, List, OrderedDict, Tuple, Union
-import numpy as np
import torch
-import torch.nn as nn
-from omegaconf import DictConfig
-from torch import Tensor
+from torch import Tensor, nn
from torch.utils.data import DataLoader
-from fedrep.constants import (
+from flwr.common import Context, NDArrays, ParametersRecord, array_from_numpy
+
+from .constants import (
DEFAULT_FINETUNE_EPOCHS,
DEFAULT_LOCAL_TRAIN_EPOCHS,
DEFAULT_REPRESENTATION_EPOCHS,
+ FEDREP_HEAD_STATE,
)
-def get_device(
- use_cuda: bool = True, specified_device: Optional[int] = None
-) -> torch.device:
- """Get the tensor device.
-
- Args:
- use_cuda: Flag indicates whether to use CUDA or not. Defaults to True.
- specified_device: Specified cuda device to use. Defaults to None.
-
- Raises
- ------
- ValueError: Specified device not in CUDA_VISIBLE_DEVICES.
-
- Returns
- -------
- The selected or fallbacked device.
- """
- device = torch.device("cpu")
- if use_cuda and torch.cuda.is_available():
- if specified_device is not None:
- cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES")
- if cuda_visible_devices is not None:
- devices = [int(d) for d in cuda_visible_devices.split(",")]
- if specified_device in devices:
- device = torch.device(f"cuda:{specified_device}")
- else:
- raise ValueError(
- f"Specified device {specified_device}"
- " not in CUDA_VISIBLE_DEVICES"
- )
- else:
- print("CUDA_VISIBLE_DEVICES not exists, using torch.device('cuda').")
- else:
- device = torch.device("cuda")
-
- return device
-
-
class ModelSplit(ABC, nn.Module):
"""Abstract class for splitting a model into body and head."""
@@ -110,7 +72,7 @@ def head(self, state_dict: OrderedDict[str, Tensor]) -> None:
"""
self._head.load_state_dict(state_dict, strict=True)
- def get_parameters(self) -> List[np.ndarray]:
+ def get_parameters(self) -> NDArrays:
"""Get model parameters.
Returns
@@ -164,65 +126,86 @@ class ModelManager(ABC):
def __init__(
self,
- client_id: int,
- config: DictConfig,
+ context: Context,
trainloader: DataLoader,
testloader: DataLoader,
- client_save_path: Optional[str],
model_split_class: Any, # ModelSplit
):
"""Initialize the attributes of the model manager.
Args:
- client_id: The id of the client.
- config: Dict containing the configurations to be used by the manager.
+ context: The context of the current run.
trainloader: Client train dataloader.
testloader: Client test dataloader.
- client_save_path: Path to save the client model head state.
model_split_class: Class to be used to split the model into body and head \
(concrete implementation of ModelSplit).
"""
super().__init__()
- self.config = config
- self.client_id = client_id
+ self.context = context
self.trainloader = trainloader
self.testloader = testloader
- self.device = get_device(
- use_cuda=getattr(self.config, "use_cuda", True),
- specified_device=getattr(self.config, "specified_device", None),
- )
- self.client_save_path = client_save_path
- self.learning_rate = config.get("learning_rate", 0.01)
- self.momentum = config.get("momentum", 0.5)
+ self.learning_rate = self.context.run_config.get("learning-rate", 0.01)
+ self.momentum = self.context.run_config.get("momentum", 0.5)
self._model: ModelSplit = model_split_class(self._create_model())
@abstractmethod
def _create_model(self) -> nn.Module:
- """Return model to be splitted into head and body."""
+ """Return model to be split into head and body."""
@property
def model(self) -> ModelSplit:
"""Return model."""
return self._model
- def train(self) -> Dict[str, Union[List[Dict[str, float]], int, float]]:
+ def _load_client_state(self) -> None:
+ """Load client model head state from context state; used only by FedRep."""
+ # First, check if the fedrep head state is set in the context state.
+ if self.context.state.parameters_records.get(FEDREP_HEAD_STATE):
+ state_dict = collections.OrderedDict(
+ {
+ k: torch.from_numpy(v.numpy())
+ for k, v in self.context.state.parameters_records[
+ FEDREP_HEAD_STATE
+ ].items()
+ }
+ )
+ # Second, check if the parameters records have values stored and load
+ # the state; this check is useful for the first time the model is
+ # tested and the head state might be empty.
+ if state_dict:
+ self._model.head.load_state_dict(state_dict)
+
+ def _save_client_state(self) -> None:
+ """Save client model head state inside context state; used only by FedRep."""
+ # Check if the fedrep head state is set in the context state.
+ if FEDREP_HEAD_STATE in self.context.state.parameters_records:
+ head_state = self._model.head.state_dict()
+ head_state_np = {k: v.detach().cpu().numpy() for k, v in head_state.items()}
+ head_state_arr = collections.OrderedDict(
+ {k: array_from_numpy(v) for k, v in head_state_np.items()}
+ )
+ head_state_prec = ParametersRecord(head_state_arr)
+ self.context.state.parameters_records[FEDREP_HEAD_STATE] = head_state_prec
+
+ def train(
+ self, device: torch.device
+ ) -> Dict[str, Union[List[Dict[str, float]], int, float]]:
"""Train the model maintained in self.model.
Returns
-------
Dict containing the train metrics.
"""
- # Load client state (head) if client_save_path is not None and it is not empty
- if self.client_save_path is not None and os.path.isfile(self.client_save_path):
- self._model.head.load_state_dict(torch.load(self.client_save_path))
+ # Load state.
+ self._load_client_state()
num_local_epochs = DEFAULT_LOCAL_TRAIN_EPOCHS
- if hasattr(self.config, "num_local_epochs"):
- num_local_epochs = int(self.config.num_local_epochs)
+ if "num-local-epochs" in self.context.run_config:
+ num_local_epochs = int(self.context.run_config["num-local-epochs"])
num_rep_epochs = DEFAULT_REPRESENTATION_EPOCHS
- if hasattr(self.config, "num_rep_epochs"):
- num_rep_epochs = int(self.config.num_rep_epochs)
+ if self.context.run_config["num-rep-epochs"] in self.context.run_config:
+ num_rep_epochs = int(self.context.run_config["num-rep-epochs"])
criterion = torch.nn.CrossEntropyLoss()
weights = [v for k, v in self._model.named_parameters() if "weight" in k]
@@ -238,6 +221,7 @@ def train(self) -> Dict[str, Union[List[Dict[str, float]], int, float]]:
correct, total = 0, 0
loss: torch.Tensor = 0.0
+ self._model.to(device)
self._model.train()
for i in range(num_local_epochs + num_rep_epochs):
if i < num_local_epochs:
@@ -247,10 +231,9 @@ def train(self) -> Dict[str, Union[List[Dict[str, float]], int, float]]:
self._model.enable_body()
self._model.disable_head()
for batch in self.trainloader:
- images = batch["img"]
- labels = batch["label"]
- outputs = self._model(images.to(self.device))
- labels = labels.to(self.device)
+ images = batch["img"].to(device)
+ labels = batch["label"].to(device)
+ outputs = self._model(images)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
@@ -258,35 +241,35 @@ def train(self) -> Dict[str, Union[List[Dict[str, float]], int, float]]:
total += labels.size(0)
correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
- # Save client state (head)
- if self.client_save_path is not None:
- torch.save(self._model.head.state_dict(), self.client_save_path)
+ # Save state.
+ self._save_client_state()
return {"loss": loss.item(), "accuracy": correct / total}
- def test(self) -> Dict[str, float]:
+ def test(self, device: torch.device) -> Dict[str, float]:
"""Test the model maintained in self.model.
Returns
-------
Dict containing the test metrics.
"""
- # Load client state (head)
- if self.client_save_path is not None and os.path.isfile(self.client_save_path):
- self._model.head.load_state_dict(torch.load(self.client_save_path))
+ # Load state.
+ self._load_client_state()
num_finetune_epochs = DEFAULT_FINETUNE_EPOCHS
- if hasattr(self.config, "num_finetune_epochs"):
- num_finetune_epochs = int(self.config.num_finetune_epochs)
+ if "num-finetune-epochs" in self.context.run_config:
+ num_finetune_epochs = int(self.context.run_config["num-finetune-epochs"])
- if num_finetune_epochs > 0 and self.config.get("enable_finetune", False):
+ if num_finetune_epochs > 0 and self.context.run_config.get(
+ "enable-finetune", False
+ ):
optimizer = torch.optim.SGD(self._model.parameters(), lr=self.learning_rate)
criterion = torch.nn.CrossEntropyLoss()
self._model.train()
for _ in range(num_finetune_epochs):
for batch in self.trainloader:
- images = batch["img"].to(self.device)
- labels = batch["label"].to(self.device)
+ images = batch["img"]
+ labels = batch["label"]
outputs = self._model(images)
loss = criterion(outputs, labels)
optimizer.zero_grad()
@@ -296,11 +279,12 @@ def test(self) -> Dict[str, float]:
criterion = torch.nn.CrossEntropyLoss()
correct, total, loss = 0, 0, 0.0
+ self._model.to(device)
self._model.eval()
with torch.no_grad():
for batch in self.testloader:
- images = batch["img"].to(self.device)
- labels = batch["label"].to(self.device)
+ images = batch["img"].to(device)
+ labels = batch["label"].to(device)
outputs = self._model(images)
loss += criterion(outputs, labels).item()
total += labels.size(0)
diff --git a/baselines/fedrep/fedrep/client.py b/baselines/fedrep/fedrep/client.py
deleted file mode 100644
index f857fd2cf82a..000000000000
--- a/baselines/fedrep/fedrep/client.py
+++ /dev/null
@@ -1,319 +0,0 @@
-"""Client implementation - can call FedPep and FedAvg clients."""
-
-from collections import OrderedDict
-from pathlib import Path
-from typing import Callable, Dict, List, Tuple, Type, Union
-
-import numpy as np
-import torch
-from flwr.client import Client, NumPyClient
-from flwr.common import NDArrays, Scalar
-from flwr_datasets import FederatedDataset
-from flwr_datasets.partitioner import PathologicalPartitioner
-from flwr_datasets.preprocessor import Merger
-from omegaconf import DictConfig
-from torch.utils.data import DataLoader
-from torchvision import transforms
-
-from fedrep.constants import MEAN, STD, Algorithm
-from fedrep.models import CNNCifar10ModelManager, CNNCifar100ModelManager
-
-PROJECT_DIR = Path(__file__).parent.parent.absolute()
-
-FEDERATED_DATASET = None
-
-
-class BaseClient(NumPyClient):
- """Implementation of Federated Averaging (FedAvg) Client."""
-
- # pylint: disable=R0913
- def __init__(
- self,
- client_id: int,
- trainloader: DataLoader,
- testloader: DataLoader,
- config: DictConfig,
- model_manager_class: Union[
- Type[CNNCifar10ModelManager], Type[CNNCifar100ModelManager]
- ],
- client_state_save_path: str = "",
- ):
- """Initialize client attributes.
-
- Args:
- client_id: The client ID.
- trainloader: Client train data loader.
- testloader: Client test data loader.
- config: dictionary containing the client configurations.
- model_manager_class: class to be used as the model manager.
- client_state_save_path: Path for saving model head parameters.
- (Just for FedRep). Defaults to "".
- """
- super().__init__()
-
- self.client_id = client_id
- self.client_state_save_path = (
- (client_state_save_path + f"/client_{self.client_id}")
- if client_state_save_path != ""
- else None
- )
- self.model_manager = model_manager_class(
- client_id=self.client_id,
- config=config,
- trainloader=trainloader,
- testloader=testloader,
- client_save_path=self.client_state_save_path,
- )
-
- def get_parameters(self, config: Dict[str, Scalar]) -> NDArrays:
- """Return the current local model parameters."""
- return self.model_manager.model.get_parameters()
-
- def set_parameters(
- self, parameters: List[np.ndarray], evaluate: bool = False
- ) -> None:
- """Set the local model parameters to the received parameters.
-
- Args:
- parameters: parameters to set the model to.
- """
- _ = evaluate
- model_keys = [
- k
- for k in self.model_manager.model.state_dict().keys()
- if k.startswith("_body") or k.startswith("_head")
- ]
- params_dict = zip(model_keys, parameters)
-
- state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
-
- self.model_manager.model.set_parameters(state_dict)
-
- def perform_train(self) -> Dict[str, Union[List[Dict[str, float]], int, float]]:
- """Perform local training to the whole model.
-
- Returns
- -------
- Dict with the train metrics.
- """
- self.model_manager.model.enable_body()
- self.model_manager.model.enable_head()
-
- return self.model_manager.train()
-
- def fit(
- self, parameters: NDArrays, config: Dict[str, Scalar]
- ) -> Tuple[NDArrays, int, Dict[str, Union[bool, bytes, float, int, str]]]:
- """Train the provided parameters using the locally held dataset.
-
- Args:
- parameters: The current (global) model parameters.
- config: configuration parameters for training sent by the server.
-
- Returns
- -------
- Tuple containing the locally updated model parameters, \
- the number of examples used for training and \
- the training metrics.
- """
- self.set_parameters(parameters)
-
- train_results = self.perform_train()
-
- # Update train history
- print("<------- TRAIN RESULTS -------> :", train_results)
-
- return self.get_parameters(config), self.model_manager.train_dataset_size(), {}
-
- def evaluate(
- self, parameters: NDArrays, config: Dict[str, Scalar]
- ) -> Tuple[float, int, Dict[str, Union[bool, bytes, float, int, str]]]:
- """Evaluate the provided global parameters using the locally held dataset.
-
- Args:
- parameters: The current (global) model parameters.
- config: configuration parameters for training sent by the server.
-
- Returns
- -------
- Tuple containing the test loss, \
- the number of examples used for evaluation and \
- the evaluation metrics.
- """
- self.set_parameters(parameters, evaluate=True)
-
- # Test the model
- test_results = self.model_manager.test()
- print("<------- TEST RESULTS -------> :", test_results)
-
- return (
- test_results.get("loss", 0.0),
- self.model_manager.test_dataset_size(),
- {k: v for k, v in test_results.items() if not isinstance(v, (dict, list))},
- )
-
-
-class FedRepClient(BaseClient):
- """Implementation of Federated Personalization (FedRep) Client."""
-
- def get_parameters(self, config: Dict[str, Scalar]) -> NDArrays:
- """Return the current local body parameters."""
- return [
- val.cpu().numpy()
- for val in self.model_manager.model.body.state_dict().values()
- ]
-
- def set_parameters(self, parameters: List[np.ndarray], evaluate=False) -> None:
- """Set the local body parameters to the received parameters.
-
- Args:
- parameters: parameters to set the body to.
- evaluate: whether the client is evaluating or not.
- """
- model_keys = [
- k
- for k in self.model_manager.model.state_dict().keys()
- if k.startswith("_body")
- ]
-
- if not evaluate:
- # Only update client's local head if it hasn't trained yet
- model_keys.extend(
- [
- k
- for k in self.model_manager.model.state_dict().keys()
- if k.startswith("_head")
- ]
- )
-
- state_dict = OrderedDict(
- (k, torch.from_numpy(v)) for k, v in zip(model_keys, parameters)
- )
-
- self.model_manager.model.set_parameters(state_dict)
-
-
-# pylint: disable=E1101, W0603
-def get_client_fn_simulation(
- config: DictConfig, client_state_save_path: str = ""
-) -> Callable[[str], Client]:
- """Generate the client function that creates the Flower Clients.
-
- Parameters
- ----------
- model : DictConfig
- The model configuration.
- cleint_state_save_path : str
- The path to save the client state.
-
- Returns
- -------
- Tuple[Callable[[str], FlowerClient], DataLoader]
- A tuple containing the client function that creates Flower Clients and
- the DataLoader that will be used for testing
- """
- assert config.model_name.lower() in [
- "cnncifar10",
- "cnncifar100",
- ], f"Model {config.model_name} not implemented"
-
- # - you can define your own data transformation strategy here -
- # These transformations are from the official repo
- train_data_transform = transforms.Compose(
- [
- transforms.RandomCrop(32, padding=4),
- transforms.RandomHorizontalFlip(),
- transforms.ToTensor(),
- transforms.Normalize(MEAN[config.dataset.name], STD[config.dataset.name]),
- ]
- )
- test_data_transform = transforms.Compose(
- [
- transforms.ToTensor(),
- transforms.Normalize(MEAN[config.dataset.name], STD[config.dataset.name]),
- ]
- )
-
- use_fine_label = False
- if config.dataset.name.lower() == "cifar100":
- use_fine_label = True
-
- partitioner = PathologicalPartitioner(
- num_partitions=config.num_clients,
- partition_by="fine_label" if use_fine_label else "label",
- num_classes_per_partition=config.dataset.num_classes,
- class_assignment_mode="random",
- shuffle=True,
- seed=config.dataset.seed,
- )
-
- global FEDERATED_DATASET
- if FEDERATED_DATASET is None:
- FEDERATED_DATASET = FederatedDataset(
- dataset=config.dataset.name.lower(),
- partitioners={"all": partitioner},
- preprocessor=Merger({"all": ("train", "test")}),
- )
-
- def apply_train_transforms(batch):
- """Apply transforms for train data to the partition from FederatedDataset."""
- batch["img"] = [train_data_transform(img) for img in batch["img"]]
- if use_fine_label:
- batch["label"] = batch["fine_label"]
- return batch
-
- def apply_test_transforms(batch):
- """Apply transforms for test data to the partition from FederatedDataset."""
- batch["img"] = [test_data_transform(img) for img in batch["img"]]
- if use_fine_label:
- batch["label"] = batch["fine_label"]
- return batch
-
- # pylint: disable=E1101
- def client_fn(cid: str) -> Client:
- """Create a Flower client representing a single organization."""
- cid_use = int(cid)
-
- partition = FEDERATED_DATASET.load_partition(cid_use, split="all")
-
- partition_train_test = partition.train_test_split(
- train_size=config.dataset.fraction, shuffle=True, seed=config.dataset.seed
- )
-
- trainset = partition_train_test["train"].with_transform(apply_train_transforms)
- testset = partition_train_test["test"].with_transform(apply_test_transforms)
-
- trainloader = DataLoader(trainset, config.batch_size, shuffle=True)
- testloader = DataLoader(testset, config.batch_size)
-
- model_manager_class: Union[
- Type[CNNCifar10ModelManager], Type[CNNCifar100ModelManager]
- ]
- if config.model_name.lower() == "cnncifar10":
- model_manager_class = CNNCifar10ModelManager
- elif config.model_name.lower() == "cnncifar100":
- model_manager_class = CNNCifar100ModelManager
- else:
- raise NotImplementedError(
- f"Model {config.model_name} not implemented, check name."
- )
-
- if config.algorithm.lower() == Algorithm.FEDREP.value:
- return FedRepClient( # type: ignore[attr-defined]
- client_id=cid_use,
- trainloader=trainloader,
- testloader=testloader,
- config=config,
- model_manager_class=model_manager_class,
- client_state_save_path=client_state_save_path,
- ).to_client()
- return BaseClient( # type: ignore[attr-defined]
- client_id=cid_use,
- trainloader=trainloader,
- testloader=testloader,
- config=config,
- model_manager_class=model_manager_class,
- client_state_save_path=client_state_save_path,
- ).to_client()
-
- return client_fn
diff --git a/baselines/fedrep/fedrep/client_app.py b/baselines/fedrep/fedrep/client_app.py
new file mode 100644
index 000000000000..78755f4484aa
--- /dev/null
+++ b/baselines/fedrep/fedrep/client_app.py
@@ -0,0 +1,186 @@
+"""fedrep: A Flower Baseline."""
+
+from collections import OrderedDict
+from typing import Dict, List, Tuple, Union
+
+import torch
+
+from flwr.client import ClientApp, NumPyClient
+from flwr.client.client import Client
+from flwr.common import Context, NDArrays, ParametersRecord, Scalar
+
+from .constants import FEDREP_HEAD_STATE, Algorithm
+from .dataset import load_data
+from .models import CNNCifar10ModelManager, CNNCifar100ModelManager
+from .utils import get_model_manager_class
+
+
+class BaseClient(NumPyClient):
+ """Implementation of Federated Averaging (FedAvg) Client."""
+
+ # pylint: disable=R0913
+ def __init__(
+ self, model_manager: Union[CNNCifar10ModelManager, CNNCifar100ModelManager]
+ ):
+ """Initialize client attributes.
+
+ Args:
+ model_manager: the model manager object
+ """
+ super().__init__()
+ self.model_manager = model_manager
+
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+
+ def get_parameters(self, config: Dict[str, Scalar]) -> NDArrays:
+ """Return the current local model parameters."""
+ return self.model_manager.model.get_parameters()
+
+ def set_parameters(self, parameters: NDArrays, evaluate: bool = False) -> None:
+ """Set the local model parameters to the received parameters.
+
+ Args:
+ parameters: parameters to set the model to.
+ evaluate: whether to evaluate or not.
+ """
+ _ = evaluate
+ model_keys = [
+ k
+ for k in self.model_manager.model.state_dict().keys()
+ if k.startswith("_body") or k.startswith("_head")
+ ]
+ params_dict = zip(model_keys, parameters)
+ state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
+ self.model_manager.model.set_parameters(state_dict)
+
+ def perform_train(self) -> Dict[str, Union[List[Dict[str, float]], int, float]]:
+ """Perform local training to the whole model.
+
+ Returns
+ -------
+ Dict with the train metrics.
+ """
+ self.model_manager.model.enable_body()
+ self.model_manager.model.enable_head()
+
+ return self.model_manager.train(self.device)
+
+ def fit(
+ self, parameters: NDArrays, config: Dict[str, Scalar]
+ ) -> Tuple[NDArrays, int, Dict[str, Union[bool, bytes, float, int, str]]]:
+ """Train the provided parameters using the locally held dataset.
+
+ Args:
+ parameters: The current (global) model parameters.
+ config: configuration parameters for training sent by the server.
+
+ Returns
+ -------
+ Tuple containing the locally updated model parameters, \
+ the number of examples used for training and \
+ the training metrics.
+ """
+ self.set_parameters(parameters)
+ self.perform_train()
+
+ return self.get_parameters(config), self.model_manager.train_dataset_size(), {}
+
+ def evaluate(
+ self, parameters: NDArrays, config: Dict[str, Scalar]
+ ) -> Tuple[float, int, Dict[str, Union[bool, bytes, float, int, str]]]:
+ """Evaluate the provided global parameters using the locally held dataset.
+
+ Args:
+ parameters: The current (global) model parameters.
+ config: configuration parameters for training sent by the server.
+
+ Returns
+ -------
+ Tuple containing the test loss, \
+ the number of examples used for evaluation and \
+ the evaluation metrics.
+ """
+ self.set_parameters(parameters, evaluate=True)
+
+ # Test the model
+ test_results = self.model_manager.test(self.device)
+
+ return (
+ test_results.get("loss", 0.0),
+ self.model_manager.test_dataset_size(),
+ {k: v for k, v in test_results.items() if not isinstance(v, (dict, list))},
+ )
+
+
+class FedRepClient(BaseClient):
+ """Implementation of Federated Personalization (FedRep) Client."""
+
+ def get_parameters(self, config: Dict[str, Scalar]) -> NDArrays:
+ """Return the current local body parameters."""
+ return [
+ val.cpu().numpy()
+ for val in self.model_manager.model.body.state_dict().values()
+ ]
+
+ def set_parameters(self, parameters: NDArrays, evaluate: bool = False) -> None:
+ """Set the local body parameters to the received parameters.
+
+ Args:
+ parameters: parameters to set the body to.
+ evaluate: whether the client is evaluating or not.
+ """
+ model_keys = [
+ k
+ for k in self.model_manager.model.state_dict().keys()
+ if k.startswith("_body")
+ ]
+
+ if not evaluate:
+ # Only update client's local head if it hasn't trained yet
+ model_keys.extend(
+ [
+ k
+ for k in self.model_manager.model.state_dict().keys()
+ if k.startswith("_head")
+ ]
+ )
+
+ state_dict = OrderedDict(
+ (k, torch.from_numpy(v)) for k, v in zip(model_keys, parameters)
+ )
+
+ self.model_manager.model.set_parameters(state_dict)
+
+
+def client_fn(context: Context) -> Client:
+ """Construct a Client that will be run in a ClientApp."""
+ model_manager_class = get_model_manager_class(context)
+ algorithm = str(context.run_config["algorithm"]).lower()
+ partition_id = int(context.node_config["partition-id"])
+ num_partitions = int(context.node_config["num-partitions"])
+ trainloader, valloader = load_data(
+ partition_id, num_partitions, context
+ ) # load the data
+ if algorithm == Algorithm.FEDAVG.value:
+ client_class = BaseClient
+ elif algorithm == Algorithm.FEDREP.value:
+ # This state variable will only be used by the FedRep algorithm.
+ # We only need to initialize once, since client_fn will be called
+ # again at every invocation of the ClientApp.
+ if FEDREP_HEAD_STATE not in context.state.parameters_records:
+ context.state.parameters_records[FEDREP_HEAD_STATE] = ParametersRecord()
+ client_class = FedRepClient
+ else:
+ raise RuntimeError(f"Unknown algorithm {algorithm}.")
+
+ model_manager_obj = model_manager_class(
+ context=context, trainloader=trainloader, testloader=valloader
+ )
+
+ # Return client object.
+ client = client_class(model_manager_obj).to_client()
+ return client
+
+
+# Flower ClientApp
+app = ClientApp(client_fn)
diff --git a/baselines/fedrep/fedrep/conf/base.yaml b/baselines/fedrep/fedrep/conf/base.yaml
deleted file mode 100644
index 0d74c4fe78b6..000000000000
--- a/baselines/fedrep/fedrep/conf/base.yaml
+++ /dev/null
@@ -1,46 +0,0 @@
----
-num_clients: 100 # total number of clients
-num_local_epochs: 5 # number of local epochs
-num_rep_epochs: 1 # number of representation epochs (only for FedRep)
-enable_finetune: false
-# num_finetune_epochs: 10
-batch_size: 50
-num_rounds: 100
-learning_rate: 0.01
-momentum: 0.5
-algorithm: fedrep
-model_name: cnncifar10
-
-client_resources:
- num_cpus: 2
- num_gpus: 0.5
-
-use_cuda: true
-specified_device: null # the ID of cuda device, if null, then use defaults torch.device("cuda")
-
-dataset:
- name: cifar10
- split: sample
- num_classes: 2
- seed: 42
- num_clients: ${num_clients}
- fraction: 0.83
-
-model:
- _target_: fedrep.implemented_models.cnn_cifar100.CNNCifar10
-
-fit_config:
- drop_client: false
- epochs: ${num_local_epochs}
- batch_size: ${batch_size}
-
-strategy:
- _target_: fedrep.strategy.FedRep
- fraction_fit: 0.1
- fraction_evaluate: 0.1
- min_fit_clients: 2
- min_evaluate_clients: 2
- min_available_clients: 2
- evaluate_fn: null
- on_fit_config_fn: null
- on_evaluate_config_fn: null
diff --git a/baselines/fedrep/fedrep/conf/cifar100_100_20.yaml b/baselines/fedrep/fedrep/conf/cifar100_100_20.yaml
deleted file mode 100644
index 30f9fd209d58..000000000000
--- a/baselines/fedrep/fedrep/conf/cifar100_100_20.yaml
+++ /dev/null
@@ -1,44 +0,0 @@
----
-num_clients: 100 # total number of clients
-num_local_epochs: 5 # number of local epochs
-num_rep_epochs: 1 # number of representation epochs (only for FedRep)
-enable_finetune: false
-# num_finetune_epochs: 10
-batch_size: 50
-num_rounds: 100
-learning_rate: 0.01
-momentum: 0.5
-algorithm: fedrep
-model_name: cnncifar100
-
-client_resources:
- num_cpus: 2
- num_gpus: 0.5
-
-use_cuda: true
-specified_device: null
-
-dataset:
- name: cifar100
- num_classes: 20
- seed: 42
- fraction: 0.83
-
-model:
- _target_: fedrep.implemented_models.cnn_cifar100.CNNCifar100
-
-fit_config:
- drop_client: false
- epochs: ${num_local_epochs}
- batch_size: ${batch_size}
-
-strategy:
- _target_: fedrep.strategy.FedRep
- fraction_fit: 0.1
- fraction_evaluate: 0.1
- min_fit_clients: 2
- min_evaluate_clients: 2
- min_available_clients: 2
- evaluate_fn: null
- on_fit_config_fn: null
- on_evaluate_config_fn: null
diff --git a/baselines/fedrep/fedrep/conf/cifar100_100_5.yaml b/baselines/fedrep/fedrep/conf/cifar100_100_5.yaml
deleted file mode 100644
index e0add8f03b45..000000000000
--- a/baselines/fedrep/fedrep/conf/cifar100_100_5.yaml
+++ /dev/null
@@ -1,44 +0,0 @@
----
-num_clients: 100 # total number of clients
-num_local_epochs: 5 # number of local epochs
-num_rep_epochs: 1 # number of representation epochs (only for FedRep)
-enable_finetune: false
-# num_finetune_epochs: 10
-batch_size: 50
-num_rounds: 100
-learning_rate: 0.01
-momentum: 0.5
-algorithm: fedrep
-model_name: cnncifar100
-
-client_resources:
- num_cpus: 2
- num_gpus: 0.5
-
-use_cuda: true
-specified_device: null
-
-dataset:
- name: cifar100
- num_classes: 5
- seed: 42
- fraction: 0.83
-
-model:
- _target_: fedrep.implemented_models.cnn_cifar100.CNNCifar100
-
-fit_config:
- drop_client: false
- epochs: ${num_local_epochs}
- batch_size: ${batch_size}
-
-strategy:
- _target_: fedrep.strategy.FedRep
- fraction_fit: 0.1
- fraction_evaluate: 0.1
- min_fit_clients: 2
- min_evaluate_clients: 2
- min_available_clients: 2
- evaluate_fn: null
- on_fit_config_fn: null
- on_evaluate_config_fn: null
diff --git a/baselines/fedrep/fedrep/conf/cifar10_100_2.yaml b/baselines/fedrep/fedrep/conf/cifar10_100_2.yaml
deleted file mode 100644
index 83ee34a298ae..000000000000
--- a/baselines/fedrep/fedrep/conf/cifar10_100_2.yaml
+++ /dev/null
@@ -1,44 +0,0 @@
----
-num_clients: 100 # total number of clients
-num_local_epochs: 5 # number of local epochs
-num_rep_epochs: 1 # number of representation epochs (only for FedRep)
-enable_finetune: false
-# num_finetune_epochs: 10
-batch_size: 50
-num_rounds: 100
-learning_rate: 0.01
-momentum: 0.5
-algorithm: fedrep
-model_name: cnncifar10
-
-client_resources:
- num_cpus: 2
- num_gpus: 0.5
-
-use_cuda: true
-specified_device: null
-
-dataset:
- name: cifar10
- num_classes: 2
- seed: 42
- fraction: 0.83
-
-model:
- _target_: fedrep.implemented_models.cnn_cifar10.CNNCifar10
-
-fit_config:
- drop_client: false
- epochs: ${num_local_epochs}
- batch_size: ${batch_size}
-
-strategy:
- _target_: fedrep.strategy.FedRep
- fraction_fit: 0.1
- fraction_evaluate: 0.1
- min_fit_clients: 2
- min_evaluate_clients: 2
- min_available_clients: 2
- evaluate_fn: null
- on_fit_config_fn: null
- on_evaluate_config_fn: null
diff --git a/baselines/fedrep/fedrep/conf/cifar10_100_5.yaml b/baselines/fedrep/fedrep/conf/cifar10_100_5.yaml
deleted file mode 100644
index 0cbd104406f0..000000000000
--- a/baselines/fedrep/fedrep/conf/cifar10_100_5.yaml
+++ /dev/null
@@ -1,44 +0,0 @@
----
-num_clients: 100 # total number of clients
-num_local_epochs: 5 # number of local epochs
-num_rep_epochs: 1 # number of representation epochs (only for FedRep)
-enable_finetune: false
-# num_finetune_epochs: 10
-batch_size: 50
-num_rounds: 100
-learning_rate: 0.01
-momentum: 0.5
-algorithm: fedrep
-model_name: cnncifar10
-
-client_resources:
- num_cpus: 2
- num_gpus: 0.5
-
-use_cuda: true
-specified_device: null
-
-dataset:
- name: cifar10
- num_classes: 5
- seed: 42
- fraction: 0.83
-
-model:
- _target_: fedrep.implemented_models.cnn_cifar10.CNNCifar10
-
-fit_config:
- drop_client: false
- epochs: ${num_local_epochs}
- batch_size: ${batch_size}
-
-strategy:
- _target_: fedrep.strategy.FedRep
- fraction_fit: 0.1
- fraction_evaluate: 0.1
- min_fit_clients: 2
- min_evaluate_clients: 2
- min_available_clients: 2
- evaluate_fn: null
- on_fit_config_fn: null
- on_evaluate_config_fn: null
diff --git a/baselines/fedrep/fedrep/constants.py b/baselines/fedrep/fedrep/constants.py
index 27e68f2b786c..a4527e1f36d4 100644
--- a/baselines/fedrep/fedrep/constants.py
+++ b/baselines/fedrep/fedrep/constants.py
@@ -1,7 +1,16 @@
-"""Constants used in machine learning pipeline."""
+"""fedrep: A Flower Baseline."""
from enum import Enum
+DEFAULT_LOCAL_TRAIN_EPOCHS: int = 10
+DEFAULT_FINETUNE_EPOCHS: int = 5
+DEFAULT_REPRESENTATION_EPOCHS: int = 1
+FEDREP_HEAD_STATE = "fedrep_head_state"
+
+MEAN = {"cifar10": [0.485, 0.456, 0.406], "cifar100": [0.507, 0.487, 0.441]}
+
+STD = {"cifar10": [0.229, 0.224, 0.225], "cifar100": [0.267, 0.256, 0.276]}
+
class Algorithm(Enum):
"""Algorithm names."""
@@ -10,10 +19,8 @@ class Algorithm(Enum):
FEDAVG = "fedavg"
-DEFAULT_LOCAL_TRAIN_EPOCHS: int = 10
-DEFAULT_FINETUNE_EPOCHS: int = 5
-DEFAULT_REPRESENTATION_EPOCHS: int = 1
+class ModelDatasetName(Enum):
+ """Dataset names."""
-MEAN = {"cifar10": [0.485, 0.456, 0.406], "cifar100": [0.507, 0.487, 0.441]}
-
-STD = {"cifar10": [0.229, 0.224, 0.225], "cifar100": [0.267, 0.256, 0.276]}
+ CNN_CIFAR_10 = "cnncifar10"
+ CNN_CIFAR_100 = "cnncifar100"
diff --git a/baselines/fedrep/fedrep/dataset.py b/baselines/fedrep/fedrep/dataset.py
index a616e38ae220..621baf98249a 100644
--- a/baselines/fedrep/fedrep/dataset.py
+++ b/baselines/fedrep/fedrep/dataset.py
@@ -1 +1,107 @@
-"""FedRep uses flwr-datasets."""
+"""fedrep: A Flower Baseline."""
+
+from typing import Tuple
+
+from flwr_datasets import FederatedDataset
+from flwr_datasets.partitioner import PathologicalPartitioner
+from flwr_datasets.preprocessor import Merger
+from torch.utils.data import DataLoader, Dataset
+from torchvision import transforms
+
+from flwr.common import Context
+
+from .constants import MEAN, STD
+
+FDS = None # Cache FederatedDataset
+
+
+def load_data(
+ partition_id: int, num_partitions: int, context: Context
+) -> Tuple[DataLoader, DataLoader]:
+ """Split the data and return training and testing data for the specified partition.
+
+ Parameters
+ ----------
+ partition_id : int
+ Partition number for which to retrieve the corresponding data.
+ num_partitions : int
+ Total number of partitions.
+ context: Context
+ the context of the current run.
+
+ Returns
+ -------
+ data : Tuple[DataLoader, DataLoader]
+ A tuple with the training and testing data for the current partition_id.
+ """
+ batch_size = int(context.run_config["batch-size"])
+ dataset_name = str(context.run_config["dataset-name"]).lower()
+ dataset_split_num_classes = int(context.run_config["dataset-split-num-classes"])
+ dataset_split_seed = int(context.run_config["dataset-split-seed"])
+ dataset_split_fraction = float(context.run_config["dataset-split-fraction"])
+
+ # - you can define your own data transformation strategy here -
+ # These transformations are from the official repo
+ train_data_transform = transforms.Compose(
+ [
+ transforms.RandomCrop(32, padding=4),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize(MEAN[dataset_name], STD[dataset_name]),
+ ]
+ )
+ test_data_transform = transforms.Compose(
+ [
+ transforms.ToTensor(),
+ transforms.Normalize(MEAN[dataset_name], STD[dataset_name]),
+ ]
+ )
+
+ use_fine_label = False
+ if dataset_name == "cifar100":
+ use_fine_label = True
+
+ partitioner = PathologicalPartitioner(
+ num_partitions=num_partitions,
+ partition_by="fine_label" if use_fine_label else "label",
+ num_classes_per_partition=dataset_split_num_classes,
+ class_assignment_mode="random",
+ shuffle=True,
+ seed=dataset_split_seed,
+ )
+
+ global FDS # pylint: disable=global-statement
+ if FDS is None:
+ FDS = FederatedDataset(
+ dataset=dataset_name,
+ partitioners={"all": partitioner},
+ preprocessor=Merger({"all": ("train", "test")}),
+ )
+
+ def apply_train_transforms(batch: Dataset) -> Dataset:
+ """Apply transforms for train data to the partition from FederatedDataset."""
+ batch["img"] = [train_data_transform(img) for img in batch["img"]]
+ if use_fine_label:
+ batch["label"] = batch["fine_label"]
+ return batch
+
+ def apply_test_transforms(batch: Dataset) -> Dataset:
+ """Apply transforms for test data to the partition from FederatedDataset."""
+ batch["img"] = [test_data_transform(img) for img in batch["img"]]
+ if use_fine_label:
+ batch["label"] = batch["fine_label"]
+ return batch
+
+ partition = FDS.load_partition(partition_id, split="all")
+
+ partition_train_test = partition.train_test_split(
+ train_size=dataset_split_fraction, shuffle=True, seed=dataset_split_seed
+ )
+
+ trainset = partition_train_test["train"].with_transform(apply_train_transforms)
+ testset = partition_train_test["test"].with_transform(apply_test_transforms)
+
+ trainloader = DataLoader(trainset, batch_size, shuffle=True)
+ testloader = DataLoader(testset, batch_size)
+
+ return trainloader, testloader
diff --git a/baselines/fedrep/fedrep/dataset_preparation.py b/baselines/fedrep/fedrep/dataset_preparation.py
deleted file mode 100644
index a616e38ae220..000000000000
--- a/baselines/fedrep/fedrep/dataset_preparation.py
+++ /dev/null
@@ -1 +0,0 @@
-"""FedRep uses flwr-datasets."""
diff --git a/baselines/fedrep/fedrep/main.py b/baselines/fedrep/fedrep/main.py
deleted file mode 100644
index 223b98aa21fa..000000000000
--- a/baselines/fedrep/fedrep/main.py
+++ /dev/null
@@ -1,123 +0,0 @@
-"""Create and connect the building blocks for your experiments; start the simulation.
-
-It includes processioning the dataset, instantiate strategy, specify how the global
-model is going to be evaluated, etc. At the end, this script saves the results.
-"""
-
-from pathlib import Path
-from typing import List, Tuple
-
-import flwr as fl
-import hydra
-from flwr.common.parameter import ndarrays_to_parameters
-from flwr.common.typing import Metrics
-from hydra.core.hydra_config import HydraConfig
-from hydra.utils import instantiate
-from omegaconf import DictConfig, OmegaConf
-
-from fedrep.utils import (
- get_client_fn,
- get_create_model_fn,
- plot_metric_from_history,
- save_results_as_pickle,
- set_client_state_save_path,
- set_client_strategy,
-)
-
-
-@hydra.main(config_path="conf", config_name="base", version_base=None)
-def main(cfg: DictConfig) -> None:
- """Run the baseline.
-
- Parameterss
- ----------
- cfg : DictConfig
- An omegaconf object that stores the hydra config.
- """
- # Print parsed config
- print(OmegaConf.to_yaml(cfg))
-
- # set client strategy
- cfg = set_client_strategy(cfg)
-
- # Create directory to store client states if it does not exist
- # Client state has subdirectories with the name of current time
- client_state_save_path = set_client_state_save_path()
-
- # Define your clients
- # Get client function
- client_fn = get_client_fn(config=cfg, client_state_save_path=client_state_save_path)
-
- # get a function that will be used to construct the config that the client's
- # fit() method will received
- def get_on_fit_config():
- def fit_config_fn(server_round: int):
- # resolve and convert to python dict
- fit_config = OmegaConf.to_container(cfg.fit_config, resolve=True)
- _ = server_round
- return fit_config
-
- return fit_config_fn
-
- # get a function that will be used to construct the model
- create_model, split = get_create_model_fn(cfg)
-
- model = split(create_model())
-
- def evaluate_metrics_aggregation_fn(
- eval_metrics: List[Tuple[int, Metrics]]
- ) -> Metrics:
- weights, accuracies = [], []
- for num_examples, metric in eval_metrics:
- weights.append(num_examples)
- accuracies.append(metric["accuracy"] * num_examples)
- accuracy = sum(accuracies) / sum(weights) # type: ignore[arg-type]
- return {"accuracy": accuracy}
-
- # Define your strategy
- strategy = instantiate(
- cfg.strategy,
- initial_parameters=ndarrays_to_parameters(model.get_parameters()),
- on_fit_config_fn=get_on_fit_config(),
- evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,
- )
-
- # Start Simulation
- history = fl.simulation.start_simulation(
- client_fn=client_fn,
- num_clients=cfg.num_clients,
- config=fl.server.ServerConfig(num_rounds=cfg.num_rounds),
- client_resources={
- "num_cpus": cfg.client_resources.num_cpus,
- "num_gpus": cfg.client_resources.num_gpus,
- },
- strategy=strategy,
- )
-
- # Experiment completed. Now we save the results and
- # generate plots using the `history`
- print("................")
- print(history)
-
- # Save your results
- save_path = Path(HydraConfig.get().runtime.output_dir)
-
- # save results as a Python pickle using a file_path
- # the directory created by Hydra for each run
- save_results_as_pickle(history, file_path=save_path)
- # plot results and include them in the readme
- strategy_name = strategy.__class__.__name__
- file_suffix: str = (
- f"_{strategy_name}"
- f"_C={cfg.num_clients}"
- f"_B={cfg.batch_size}"
- f"_E={cfg.num_local_epochs}"
- f"_R={cfg.num_rounds}"
- f"_lr={cfg.learning_rate}"
- )
-
- plot_metric_from_history(history, save_path, (file_suffix))
-
-
-if __name__ == "__main__":
- main()
diff --git a/baselines/fedrep/fedrep/models.py b/baselines/fedrep/fedrep/models.py
index b230f4e49766..314d7cc91688 100644
--- a/baselines/fedrep/fedrep/models.py
+++ b/baselines/fedrep/fedrep/models.py
@@ -1,11 +1,11 @@
-"""Model, model manager and model split for CIFAR-10 and CIFAR-100."""
+"""fedrep: A Flower Baseline."""
from typing import Tuple
import torch
-import torch.nn as nn
+from torch import nn
-from fedrep.base_model import ModelManager, ModelSplit
+from .base_model import ModelManager, ModelSplit
# pylint: disable=W0223
@@ -58,17 +58,12 @@ class CNNCifar10ModelManager(ModelManager):
"""Manager for models with Body/Head split."""
def __init__(self, **kwargs):
- """Initialize the attributes of the model manager.
-
- Args:
- client_id: The id of the client.
- config: Dict containing the configurations to be used by the manager.
- """
+ """Initialize the attributes of the model manager."""
super().__init__(model_split_class=CNNCifar10ModelSplit, **kwargs)
def _create_model(self) -> nn.Module:
- """Return CNNCifar10 model to be splitted into head and body."""
- return CNNCifar10().to(self.device)
+ """Return CNNCifar10 model to be split into head and body."""
+ return CNNCifar10()
# pylint: disable=W0223
@@ -104,6 +99,11 @@ def __init__(self):
self.head = nn.Sequential(nn.Linear(128, 100))
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Forward pass of the model."""
+ x = self.body(x)
+ return self.head(x)
+
class CNNCifar100ModelSplit(ModelSplit):
"""Split CNNCifar100 model into body and head."""
@@ -117,14 +117,9 @@ class CNNCifar100ModelManager(ModelManager):
"""Manager for models with Body/Head split."""
def __init__(self, **kwargs):
- """Initialize the attributes of the model manager.
-
- Args:
- client_id: The id of the client.
- config: Dict containing the configurations to be used by the manager.
- """
+ """Initialize the attributes of the model manager."""
super().__init__(model_split_class=CNNCifar100ModelSplit, **kwargs)
def _create_model(self) -> CNNCifar100:
- """Return CNNCifar100 model to be splitted into head and body."""
- return CNNCifar100().to(self.device)
+ """Return CNNCifar100 model to be split into head and body."""
+ return CNNCifar100()
diff --git a/baselines/fedrep/fedrep/server.py b/baselines/fedrep/fedrep/server.py
deleted file mode 100644
index 5b0c34035ae6..000000000000
--- a/baselines/fedrep/fedrep/server.py
+++ /dev/null
@@ -1 +0,0 @@
-"""Server strategies pipelines for FedRep."""
diff --git a/baselines/fedrep/fedrep/server_app.py b/baselines/fedrep/fedrep/server_app.py
new file mode 100644
index 000000000000..0f98a3bc8876
--- /dev/null
+++ b/baselines/fedrep/fedrep/server_app.py
@@ -0,0 +1,80 @@
+"""fedrep: A Flower Baseline."""
+
+import json
+import os
+import time
+from typing import Dict, List, Tuple
+
+from flwr.common import Context, Metrics, ndarrays_to_parameters
+from flwr.server import ServerApp, ServerAppComponents, ServerConfig
+
+from .utils import get_create_model_fn, get_server_strategy
+
+RESULTS_FILE = "result-{}.json"
+
+
+def config_json_file(context: Context) -> None:
+ """Initialize the json file and write the run configurations."""
+ # Initialize the execution results directory.
+ res_save_path = "./results"
+ if not os.path.exists(res_save_path):
+ os.makedirs(res_save_path)
+ res_save_name = time.strftime("%Y-%m-%d-%H-%M-%S")
+ # Set the date and full path of the file to save the results.
+ global RESULTS_FILE # pylint: disable=global-statement
+ RESULTS_FILE = RESULTS_FILE.format(res_save_name)
+ RESULTS_FILE = f"{res_save_path}/{RESULTS_FILE}"
+ data = {
+ "run_config": dict(context.run_config.items()),
+ "round_res": [],
+ }
+ with open(RESULTS_FILE, "w+", encoding="UTF-8") as fout:
+ json.dump(data, fout, indent=4)
+
+
+def write_res(new_res: Dict[str, float]) -> None:
+ """Load the json file, append result and re-write json collection."""
+ with open(RESULTS_FILE, "r", encoding="UTF-8") as fin:
+ data = json.load(fin)
+ data["round_res"].append(new_res)
+
+ # Write the updated data back to the JSON file
+ with open(RESULTS_FILE, "w", encoding="UTF-8") as fout:
+ json.dump(data, fout, indent=4)
+
+
+def evaluate_metrics_aggregation_fn(eval_metrics: List[Tuple[int, Metrics]]) -> Metrics:
+ """Weighted metrics evaluation."""
+ weights, accuracies, losses = [], [], []
+ for num_examples, metric in eval_metrics:
+ weights.append(num_examples)
+ accuracies.append(float(metric["accuracy"]) * num_examples)
+ losses.append(float(metric["loss"]) * num_examples)
+ accuracy = sum(accuracies) / sum(weights)
+ loss = sum(losses) / sum(weights)
+ write_res({"accuracy": accuracy, "loss": loss})
+ return {"accuracy": accuracy}
+
+
+def server_fn(context: Context) -> ServerAppComponents:
+ """Construct components that set the ServerApp behaviour."""
+ config_json_file(context)
+ # Read from config
+ num_rounds = context.run_config["num-server-rounds"]
+
+ # Initialize model parameters
+ create_model_fn, split_class = get_create_model_fn(context)
+ net = split_class(create_model_fn())
+ parameters = ndarrays_to_parameters(net.get_parameters())
+
+ # Define strategy
+ strategy = get_server_strategy(
+ context=context, params=parameters, eval_fn=evaluate_metrics_aggregation_fn
+ )
+ config = ServerConfig(num_rounds=int(num_rounds))
+
+ return ServerAppComponents(strategy=strategy, config=config)
+
+
+# Create ServerApp
+app = ServerApp(server_fn=server_fn)
diff --git a/baselines/fedrep/fedrep/strategy.py b/baselines/fedrep/fedrep/strategy.py
index 3bee45326a6f..ba4d8b75724f 100644
--- a/baselines/fedrep/fedrep/strategy.py
+++ b/baselines/fedrep/fedrep/strategy.py
@@ -1,4 +1,4 @@
-"""FL server strategies."""
+"""fedrep: A Flower Baseline."""
from flwr.server.strategy import FedAvg
diff --git a/baselines/fedrep/fedrep/utils.py b/baselines/fedrep/fedrep/utils.py
index b706ebf1e041..086abde2c46c 100644
--- a/baselines/fedrep/fedrep/utils.py
+++ b/baselines/fedrep/fedrep/utils.py
@@ -1,204 +1,87 @@
-"""Utility functions for FedRep."""
+"""fedrep: A Flower Baseline."""
-import logging
-import os
-import pickle
-import time
-from pathlib import Path
-from secrets import token_hex
-from typing import Callable, Optional, Type, Union
+from typing import Callable, Type, Union
-import matplotlib.pyplot as plt
-import numpy as np
-from flwr.client import Client
-from flwr.server.history import History
-from omegaconf import DictConfig
+from flwr.common import Context, Parameters
+from flwr.server.strategy import FedAvg
-from fedrep.base_model import get_device
-from fedrep.client import get_client_fn_simulation
-from fedrep.constants import Algorithm
-from fedrep.models import (
+from .constants import Algorithm, ModelDatasetName
+from .models import (
CNNCifar10,
+ CNNCifar10ModelManager,
CNNCifar10ModelSplit,
CNNCifar100,
+ CNNCifar100ModelManager,
CNNCifar100ModelSplit,
)
-
-
-def set_client_state_save_path() -> str:
- """Set the client state save path."""
- client_state_save_path = time.strftime("%Y-%m-%d")
- client_state_sub_path = time.strftime("%H-%M-%S")
- client_state_save_path = (
- f"./client_states/{client_state_save_path}/{client_state_sub_path}"
- )
- if not os.path.exists(client_state_save_path):
- os.makedirs(client_state_save_path)
- return client_state_save_path
-
-
-# pylint: disable=W1202
-def set_client_strategy(cfg: DictConfig) -> DictConfig:
- """Set the client strategy."""
- algorithm = cfg.algorithm.lower()
- if algorithm == Algorithm.FEDREP.value:
- cfg.strategy["_target_"] = "fedrep.strategy.FedRep"
- elif algorithm == Algorithm.FEDAVG.value:
- cfg.strategy["_target_"] = "flwr.server.strategy.FedAvg"
- else:
- logging.warning(
- "Algorithm {} not implemented. Fallback to FedAvg.".format(algorithm)
- )
- return cfg
-
-
-def get_client_fn(
- config: DictConfig, client_state_save_path: str = ""
-) -> Callable[[str], Client]:
- """Get client function."""
- # Get algorithm
- algorithm = config.algorithm.lower()
- # Get client fn
- if algorithm == "fedrep":
- client_fn = get_client_fn_simulation(
- config=config, client_state_save_path=client_state_save_path
- )
- elif algorithm == "fedavg":
- client_fn = get_client_fn_simulation(config=config)
- else:
- raise NotImplementedError
- return client_fn
+from .strategy import FedRep
def get_create_model_fn(
- config: DictConfig,
+ context: Context,
) -> tuple[
- Callable[[], Union[type[CNNCifar10], type[CNNCifar100]]],
- Union[type[CNNCifar10ModelSplit], type[CNNCifar100ModelSplit]],
+ Union[Callable[[], CNNCifar10], Callable[[], CNNCifar100]],
+ Union[Type[CNNCifar10ModelSplit], Type[CNNCifar100ModelSplit]],
]:
"""Get create model function."""
- device = get_device(
- use_cuda=getattr(config, "use_cuda", True),
- specified_device=getattr(config, "specified_device", None),
- )
- split: Union[Type[CNNCifar10ModelSplit], Type[CNNCifar100ModelSplit]] = (
- CNNCifar10ModelSplit
- )
- if config.model_name.lower() == "cnncifar10":
+ model_name = str(context.run_config["model-name"])
+ if model_name == ModelDatasetName.CNN_CIFAR_10.value:
+ split = CNNCifar10ModelSplit
- def create_model() -> Union[Type[CNNCifar10], Type[CNNCifar100]]:
+ def create_model() -> CNNCifar10: # type: ignore
"""Create initial CNNCifar10 model."""
- return CNNCifar10().to(device)
+ return CNNCifar10()
- elif config.model_name.lower() == "cnncifar100":
+ elif model_name == ModelDatasetName.CNN_CIFAR_100.value:
split = CNNCifar100ModelSplit
- def create_model() -> Union[Type[CNNCifar10], Type[CNNCifar100]]:
+ def create_model() -> CNNCifar100: # type: ignore
"""Create initial CNNCifar100 model."""
- return CNNCifar100().to(device)
+ return CNNCifar100()
else:
- raise NotImplementedError("Model not implemented, check name. ")
+ raise NotImplementedError(f"Not a recognized model name {model_name}.")
return create_model, split
-def plot_metric_from_history(
- hist: History, save_plot_path: Path, suffix: Optional[str] = ""
-) -> None:
- """Plot from Flower server History.
-
- Parameters
- ----------
- hist : History
- Object containing evaluation for all rounds.
- save_plot_path : Path
- Folder to save the plot to.
- suffix: Optional[str]
- Optional string to add at the end of the filename for the plot.
- """
- metric_type = "distributed"
- metric_dict = (
- hist.metrics_centralized
- if metric_type == "centralized"
- else hist.metrics_distributed
- )
- try:
- _, values = zip(*metric_dict["accuracy"])
- except KeyError: # If no available metric data
- return
-
- # let's extract decentralized loss (main metric reported in FedProx paper)
- rounds_loss, values_loss = zip(*hist.losses_distributed)
-
- _, axs = plt.subplots(nrows=2, ncols=1, sharex="row")
- axs[0].plot(np.asarray(rounds_loss), np.asarray(values_loss)) # type: ignore
- axs[1].plot(np.asarray(rounds_loss), np.asarray(values)) # type: ignore
-
- axs[0].set_ylabel("Loss") # type: ignore
- axs[1].set_ylabel("Accuracy") # type: ignore
-
- axs[0].grid() # type: ignore
- axs[1].grid() # type: ignore
- # plt.title(f"{metric_type.capitalize()} Validation - MNIST")
- plt.xlabel("Rounds")
- # plt.legend(loc="lower right")
-
- plt.savefig(Path(save_plot_path) / Path(f"{metric_type}_metrics{suffix}.png"))
- plt.close()
-
-
-def save_results_as_pickle(
- history: History,
- file_path: Union[str, Path],
- default_filename: Optional[str] = "results.pkl",
-) -> None:
- """Save results from simulation to pickle.
-
- Parameters
- ----------
- history: History
- History returned by start_simulation.
- file_path: Union[str, Path]
- Path to file to create and store both history and extra_results.
- If path is a directory, the default_filename will be used.
- path doesn't exist, it will be created. If file exists, a
- randomly generated suffix will be added to the file name. This
- is done to avoid overwritting results.
- extra_results : Optional[Dict]
- A dictionary containing additional results you would like
- to be saved to disk. Default: {} (an empty dictionary)
- default_filename: Optional[str]
- File used by default if file_path points to a directory instead
- to a file. Default: "results.pkl"
- """
- path = Path(file_path)
-
- # ensure path exists
- path.mkdir(exist_ok=True, parents=True)
-
- def _add_random_suffix(path_: Path):
- """Add a random suffix to the file name."""
- print(f"File `{path_}` exists! ")
- suffix = token_hex(4)
- print(f"New results to be saved with suffix: {suffix}")
- return path_.parent / (path_.stem + "_" + suffix + ".pkl")
-
- def _complete_path_with_default_name(path_: Path):
- """Append the default file name to the path."""
- print("Using default filename")
- if default_filename is None:
- return path_
- return path_ / default_filename
-
- if path.is_dir():
- path = _complete_path_with_default_name(path)
-
- if path.is_file():
- path = _add_random_suffix(path)
-
- print(f"Results will be saved into: {path}")
- # data = {"history": history, **extra_results}
- data = {"history": history}
- # save results to pickle
- with open(str(path), "wb") as handle:
- pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL)
+def get_model_manager_class(
+ context: Context,
+) -> Union[Type[CNNCifar10ModelManager], Type[CNNCifar100ModelManager]]:
+ """Depending on the model name type return the corresponding model manager."""
+ model_name = str(context.run_config["model-name"])
+ if model_name.lower() == ModelDatasetName.CNN_CIFAR_10.value:
+ model_manager_class = CNNCifar10ModelManager
+ elif model_name.lower() == ModelDatasetName.CNN_CIFAR_100.value:
+ model_manager_class = CNNCifar100ModelManager # type: ignore
+ else:
+ raise NotImplementedError(
+ f"Model {model_name} not implemented, please check model name."
+ )
+ return model_manager_class
+
+
+def get_server_strategy(
+ context: Context, params: Parameters, eval_fn: Callable
+) -> Union[FedAvg, FedRep]:
+ """Define server strategy based on input algorithm."""
+ algorithm = str(context.run_config["algorithm"]).lower()
+ if algorithm == Algorithm.FEDAVG.value:
+ strategy = FedAvg
+ elif algorithm == Algorithm.FEDREP.value:
+ strategy = FedRep
+ else:
+ raise RuntimeError(f"Unknown algorithm {algorithm}.")
+
+ # Read strategy config
+ fraction_fit = float(context.run_config["fraction-fit"])
+ fraction_evaluate = float(context.run_config["fraction-evaluate"])
+ min_available_clients = int(context.run_config["min-available-clients"])
+
+ strategy = strategy(
+ fraction_fit=float(fraction_fit),
+ fraction_evaluate=fraction_evaluate,
+ min_available_clients=min_available_clients,
+ initial_parameters=params,
+ evaluate_metrics_aggregation_fn=eval_fn,
+ ) # type: ignore
+ return strategy # type: ignore
diff --git a/baselines/fedrep/pyproject.toml b/baselines/fedrep/pyproject.toml
index e4c3551af19a..9a45199032ef 100644
--- a/baselines/fedrep/pyproject.toml
+++ b/baselines/fedrep/pyproject.toml
@@ -1,73 +1,39 @@
[build-system]
-requires = ["poetry-core>=1.4.0"]
-build-backend = "poetry.masonry.api"
+requires = ["hatchling"]
+build-backend = "hatchling.build"
-[tool.poetry]
+[project]
name = "fedrep"
version = "1.0.0"
-description = "Exploiting Shared Representations for Personalized Federated Learning"
+description = ""
license = "Apache-2.0"
-authors = ["Jiahao Tan "]
-readme = "README.md"
-homepage = "https://flower.ai"
-repository = "https://github.com/adap/flower"
-documentation = "https://flower.ai"
-classifiers = [
- "Development Status :: 3 - Alpha",
- "Intended Audience :: Developers",
- "Intended Audience :: Science/Research",
- "License :: OSI Approved :: Apache Software License",
- "Operating System :: MacOS :: MacOS X",
- "Operating System :: POSIX :: Linux",
- "Programming Language :: Python",
- "Programming Language :: Python :: 3",
- "Programming Language :: Python :: 3 :: Only",
- "Programming Language :: Python :: 3.8",
- "Programming Language :: Python :: 3.9",
- "Programming Language :: Python :: 3.10",
- "Programming Language :: Python :: 3.11",
- "Programming Language :: Python :: Implementation :: CPython",
- "Topic :: Scientific/Engineering",
- "Topic :: Scientific/Engineering :: Artificial Intelligence",
- "Topic :: Scientific/Engineering :: Mathematics",
- "Topic :: Software Development",
- "Topic :: Software Development :: Libraries",
- "Topic :: Software Development :: Libraries :: Python Modules",
- "Typing :: Typed",
+dependencies = [
+ "flwr[simulation]>=1.13.1",
+ "flwr-datasets[vision]>=0.4.0",
+ "torch==2.2.1",
+ "torchvision==0.17.1",
]
-[tool.poetry.dependencies]
-python = ">=3.10.0, <3.11.0" # don't change this
-flwr = { extras = ["simulation"], version = "1.9.0" }
-hydra-core = "1.3.2" # don't change this
-pandas = "^2.2.2"
-matplotlib = "^3.9.0"
-tqdm = "^4.66.4"
-torch = "^2.2.2"
-torchvision = "^0.17.2"
-setuptools = "<70"
-flwr-datasets = { extras = ["vision"], version = ">=0.3.0" }
-
-[tool.poetry.dev-dependencies]
-isort = "==5.13.2"
-black = "==24.2.0"
-docformatter = "==1.7.5"
-mypy = "==1.4.1"
-pylint = "==2.8.2"
-flake8 = "==3.9.2"
-pytest = "==6.2.4"
-pytest-watch = "==4.2.0"
-ruff = "==0.0.272"
-types-requests = "==2.27.7"
-virtualenv = "==20.21.0"
+[tool.hatch.metadata]
+allow-direct-references = true
+
+[project.optional-dependencies]
+dev = [
+ "isort==5.13.2",
+ "black==24.2.0",
+ "docformatter==1.7.5",
+ "mypy==1.8.0",
+ "pylint==3.2.6",
+ "flake8==5.0.4",
+ "pytest==6.2.4",
+ "pytest-watch==4.2.0",
+ "ruff==0.1.9",
+ "types-requests==2.31.0.20240125",
+]
[tool.isort]
-line_length = 88
-indent = " "
-multi_line_output = 3
-include_trailing_comma = true
-force_grid_wrap = 0
-use_parentheses = true
+profile = "black"
+known_first_party = ["flwr"]
[tool.black]
line-length = 88
@@ -76,7 +42,9 @@ target-version = ["py38", "py39", "py310", "py311"]
[tool.pytest.ini_options]
minversion = "6.2"
addopts = "-qq"
-testpaths = ["flwr_baselines"]
+testpaths = [
+ "flwr_baselines",
+]
[tool.mypy]
ignore_missing_imports = true
@@ -84,14 +52,22 @@ strict = false
plugins = "numpy.typing.mypy_plugin"
[tool.pylint."MESSAGES CONTROL"]
-good-names = "i,j,k,_,x,y,X,Y"
-signature-mutators = "hydra.main.main"
+disable = "duplicate-code,too-few-public-methods,useless-import-alias"
+good-names = "i,j,k,_,x,y,X,Y,K,N"
+max-args = 10
+max-attributes = 15
+max-locals = 36
+max-branches = 20
+max-statements = 55
-[tool.pylint."TYPECHECK"]
+[tool.pylint.typecheck]
generated-members = "numpy.*, torch.*, tensorflow.*"
[[tool.mypy.overrides]]
-module = ["importlib.metadata.*", "importlib_metadata.*"]
+module = [
+ "importlib.metadata.*",
+ "importlib_metadata.*",
+]
follow_imports = "skip"
follow_imports_for_stubs = true
disallow_untyped_calls = false
@@ -137,3 +113,49 @@ exclude = [
[tool.ruff.pydocstyle]
convention = "numpy"
+
+[tool.hatch.build.targets.wheel]
+packages = ["."]
+
+[tool.flwr.app]
+publisher = "dimitris"
+
+[tool.flwr.app.components]
+serverapp = "fedrep.server_app:app"
+clientapp = "fedrep.client_app:app"
+
+[tool.flwr.app.config]
+algorithm = "fedrep"
+
+# dataset specs
+dataset-name = "cifar10"
+dataset-split = "sample"
+dataset-split-num-classes = 2
+dataset-split-seed = 42
+dataset-split-fraction = 0.83
+
+# model specs
+model-name = "cnncifar10"
+batch-size = 50
+learning-rate = 0.01
+momentum = 0.5
+enable-finetune = false
+num-finetune-epochs = 5
+num-local-epochs = 5 # number of local epochs
+num-rep-epochs = 1 # number of representation epochs (only for FedRep)
+
+# server specs
+num-server-rounds = 100
+fraction-fit = 0.1
+fraction-evaluate = 0.1
+min-available-clients = 2
+min-evaluate-clients = 2
+min-fit-clients = 2
+
+[tool.flwr.federations]
+default = "local-sim-100"
+
+[tool.flwr.federations.local-sim-100]
+options.num-supernodes = 100
+options.backend.client-resources.num-cpus = 2
+options.backend.client-resources.num-gpus = 0.5 # GPU fraction allocated to each client
diff --git a/benchmarks/flowertune-llm/README.md b/benchmarks/flowertune-llm/README.md
index 662134a0e44d..c3e1b2b7dd53 100644
--- a/benchmarks/flowertune-llm/README.md
+++ b/benchmarks/flowertune-llm/README.md
@@ -16,8 +16,7 @@ Then, create a new Python environment and install Flower.
> We recommend using `pyenv` with the `virtualenv` plugin to create your environment with Python >= 3.10.0. Other managers, such as Conda, will likely work as well. Check the [documentation](https://flower.ai/docs/framework/how-to-install-flower.html) for alternative ways to install Flower.
```shell
-# We use this dev version until flwr 1.13.0 is out
-pip install git+https://github.com/adap/flower.git@d92453d
+pip install flwr
```
In the new environment, create a new Flower project using the `FlowerTune` template. You will be prompted for a name to give to your app/project, your username, and for your choice of LLM challenge:
diff --git a/benchmarks/flowertune-llm/evaluation/code/requirements.txt b/benchmarks/flowertune-llm/evaluation/code/requirements.txt
index 9c9e3f8e27a1..5e6cb5418ca2 100644
--- a/benchmarks/flowertune-llm/evaluation/code/requirements.txt
+++ b/benchmarks/flowertune-llm/evaluation/code/requirements.txt
@@ -3,6 +3,6 @@ datasets==2.20.0
evaluate==0.3.0
sentencepiece==0.2.0
protobuf==5.27.1
-bitsandbytes==0.43.1
+bitsandbytes==0.45.0
hf_transfer==0.1.8
-git+https://github.com/bigcode-project/bigcode-evaluation-harness.git@0f3e95f0806e78a4f432056cdb1be93604a51d69
+git+https://github.com/bigcode-project/bigcode-evaluation-harness.git@6116c6a9a5672c69bd624373cfbc8938b7acc249
diff --git a/benchmarks/flowertune-llm/evaluation/finance/README.md b/benchmarks/flowertune-llm/evaluation/finance/README.md
index b5595433a238..15d2410b8ca4 100644
--- a/benchmarks/flowertune-llm/evaluation/finance/README.md
+++ b/benchmarks/flowertune-llm/evaluation/finance/README.md
@@ -27,6 +27,7 @@ huggingface-cli login
```bash
python eval.py \
+--base-model-name-path=your-base-model-name \ # e.g., mistralai/Mistral-7B-v0.3
--peft-path=/path/to/fine-tuned-peft-model-dir/ \ # e.g., ./peft_1
--run-name=fl \ # specified name for this run
--batch-size=32 \
diff --git a/benchmarks/flowertune-llm/evaluation/finance/requirements.txt b/benchmarks/flowertune-llm/evaluation/finance/requirements.txt
index 89dcf40b819f..bbec80d10b9c 100644
--- a/benchmarks/flowertune-llm/evaluation/finance/requirements.txt
+++ b/benchmarks/flowertune-llm/evaluation/finance/requirements.txt
@@ -3,5 +3,5 @@ scikit-learn==1.5.0
datasets==2.20.0
sentencepiece==0.2.0
protobuf==5.27.1
-bitsandbytes==0.43.1
+bitsandbytes==0.45.0
hf_transfer==0.1.8
diff --git a/benchmarks/flowertune-llm/evaluation/general-nlp/README.md b/benchmarks/flowertune-llm/evaluation/general-nlp/README.md
index c3fd71da6ea2..5acd75285dd3 100644
--- a/benchmarks/flowertune-llm/evaluation/general-nlp/README.md
+++ b/benchmarks/flowertune-llm/evaluation/general-nlp/README.md
@@ -27,6 +27,7 @@ huggingface-cli login
```bash
python eval.py \
+--base-model-name-path=your-base-model-name \ # e.g., mistralai/Mistral-7B-v0.3
--peft-path=/path/to/fine-tuned-peft-model-dir/ \ # e.g., ./peft_1
--run-name=fl \ # specified name for this run
--batch-size=16 \
diff --git a/benchmarks/flowertune-llm/evaluation/general-nlp/requirements.txt b/benchmarks/flowertune-llm/evaluation/general-nlp/requirements.txt
index f5c46e869ce2..3dae79d9af17 100644
--- a/benchmarks/flowertune-llm/evaluation/general-nlp/requirements.txt
+++ b/benchmarks/flowertune-llm/evaluation/general-nlp/requirements.txt
@@ -4,5 +4,5 @@ scikit-learn==1.5.0
datasets==2.20.0
sentencepiece==0.2.0
protobuf==5.27.1
-bitsandbytes==0.43.1
+bitsandbytes==0.45.0
hf_transfer==0.1.8
diff --git a/benchmarks/flowertune-llm/evaluation/medical/README.md b/benchmarks/flowertune-llm/evaluation/medical/README.md
index 628489ce8de6..6a519e8a7c54 100644
--- a/benchmarks/flowertune-llm/evaluation/medical/README.md
+++ b/benchmarks/flowertune-llm/evaluation/medical/README.md
@@ -27,6 +27,7 @@ huggingface-cli login
```bash
python eval.py \
+--base-model-name-path=your-base-model-name \ # e.g., mistralai/Mistral-7B-v0.3
--peft-path=/path/to/fine-tuned-peft-model-dir/ \ # e.g., ./peft_1
--run-name=fl \ # specified name for this run
--batch-size=16 \
diff --git a/benchmarks/flowertune-llm/evaluation/medical/requirements.txt b/benchmarks/flowertune-llm/evaluation/medical/requirements.txt
index f5c46e869ce2..3dae79d9af17 100644
--- a/benchmarks/flowertune-llm/evaluation/medical/requirements.txt
+++ b/benchmarks/flowertune-llm/evaluation/medical/requirements.txt
@@ -4,5 +4,5 @@ scikit-learn==1.5.0
datasets==2.20.0
sentencepiece==0.2.0
protobuf==5.27.1
-bitsandbytes==0.43.1
+bitsandbytes==0.45.0
hf_transfer==0.1.8
diff --git a/datasets/README.md b/datasets/README.md
index 0d35d2e31b6a..f3091ab7532b 100644
--- a/datasets/README.md
+++ b/datasets/README.md
@@ -51,7 +51,7 @@ Create **custom partitioning schemes** or choose from the **implemented [partiti
* Exponential partitioning `ExponentialPartitioner(num_partitions)`
* more to come in the future releases (contributions are welcome).
-
+ Comparison of Partitioning Schemes on CIFAR10
diff --git a/datasets/dev/build-flwr-datasets-docs.sh b/datasets/dev/build-flwr-datasets-docs.sh
index ed41a87a414b..9cb80dcfd5d2 100755
--- a/datasets/dev/build-flwr-datasets-docs.sh
+++ b/datasets/dev/build-flwr-datasets-docs.sh
@@ -22,7 +22,7 @@
set -e
-cd "$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"/../doc
+cd "$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"/../docs
# Remove the old docs from source/ref-api
REF_API_DIR="source/ref-api"
diff --git a/datasets/dev/format.sh b/datasets/dev/format.sh
index b7dca9accabf..94c70444d735 100755
--- a/datasets/dev/format.sh
+++ b/datasets/dev/format.sh
@@ -28,7 +28,7 @@ echo "Formatting done: Python"
# Notebooks
echo "Formatting started: Notebooks"
-python -m black --ipynb -q doc/source/*.ipynb
+python -m black --ipynb -q docs/source/*.ipynb
KEYS="metadata.celltoolbar metadata.language_info metadata.toc metadata.notify_time metadata.varInspector metadata.accelerator metadata.vscode cell.metadata.id cell.metadata.heading_collapsed cell.metadata.hidden cell.metadata.code_folding cell.metadata.tags cell.metadata.init_cell cell.metadata.vscode cell.metadata.pycharm"
-python -m nbstripout --keep-output doc/source/*.ipynb --extra-keys "$KEYS"
+python -m nbstripout --keep-output docs/source/*.ipynb --extra-keys "$KEYS"
echo "Formatting done: Notebooks"
diff --git a/datasets/doc/source/how-to-install-flwr-datasets.rst b/datasets/doc/source/how-to-install-flwr-datasets.rst
deleted file mode 100644
index 3f79daceb753..000000000000
--- a/datasets/doc/source/how-to-install-flwr-datasets.rst
+++ /dev/null
@@ -1,46 +0,0 @@
-Installation
-============
-
-Python Version
---------------
-
-Flower Datasets requires `Python 3.8 `_ or above.
-
-
-Install stable release (pip)
-----------------------------
-
-Stable releases are available on `PyPI `_
-
-.. code-block:: bash
-
- python -m pip install flwr-datasets
-
-For vision datasets (e.g. MNIST, CIFAR10) ``flwr-datasets`` should be installed with the ``vision`` extra
-
-.. code-block:: bash
-
- python -m pip install flwr_datasets[vision]
-
-For audio datasets (e.g. Speech Command) ``flwr-datasets`` should be installed with the ``audio`` extra
-
-.. code-block:: bash
-
- python -m pip install flwr_datasets[audio]
-
-
-Verify installation
--------------------
-
-The following command can be used to verify if Flower Datasets was successfully installed:
-
-.. code-block:: bash
-
- python -c "import flwr_datasets;print(flwr_datasets.__version__)"
-
-If everything worked, it should print the version of Flower Datasets to the command line:
-
-.. code-block:: none
-
- 0.4.0
-
diff --git a/datasets/doc/Makefile b/datasets/docs/Makefile
similarity index 100%
rename from datasets/doc/Makefile
rename to datasets/docs/Makefile
diff --git a/datasets/doc/make.bat b/datasets/docs/make.bat
similarity index 100%
rename from datasets/doc/make.bat
rename to datasets/docs/make.bat
diff --git a/datasets/doc/source/.gitignore b/datasets/docs/source/.gitignore
similarity index 100%
rename from datasets/doc/source/.gitignore
rename to datasets/docs/source/.gitignore
diff --git a/datasets/doc/source/_static/custom.css b/datasets/docs/source/_static/custom.css
similarity index 100%
rename from datasets/doc/source/_static/custom.css
rename to datasets/docs/source/_static/custom.css
diff --git a/datasets/doc/source/_static/favicon.ico b/datasets/docs/source/_static/favicon.ico
similarity index 100%
rename from datasets/doc/source/_static/favicon.ico
rename to datasets/docs/source/_static/favicon.ico
diff --git a/datasets/doc/source/_static/flower-datasets-logo.png b/datasets/docs/source/_static/flower-datasets-logo.png
similarity index 100%
rename from datasets/doc/source/_static/flower-datasets-logo.png
rename to datasets/docs/source/_static/flower-datasets-logo.png
diff --git a/datasets/doc/source/_static/readme/comparison_of_partitioning_schemes.png b/datasets/docs/source/_static/readme/comparison_of_partitioning_schemes.png
similarity index 100%
rename from datasets/doc/source/_static/readme/comparison_of_partitioning_schemes.png
rename to datasets/docs/source/_static/readme/comparison_of_partitioning_schemes.png
diff --git a/datasets/doc/source/_static/tutorial-quickstart/choose-hf-dataset.png b/datasets/docs/source/_static/tutorial-quickstart/choose-hf-dataset.png
similarity index 100%
rename from datasets/doc/source/_static/tutorial-quickstart/choose-hf-dataset.png
rename to datasets/docs/source/_static/tutorial-quickstart/choose-hf-dataset.png
diff --git a/datasets/doc/source/_static/tutorial-quickstart/copy-dataset-name.png b/datasets/docs/source/_static/tutorial-quickstart/copy-dataset-name.png
similarity index 100%
rename from datasets/doc/source/_static/tutorial-quickstart/copy-dataset-name.png
rename to datasets/docs/source/_static/tutorial-quickstart/copy-dataset-name.png
diff --git a/datasets/doc/source/_static/tutorial-quickstart/partitioner-flexibility.png b/datasets/docs/source/_static/tutorial-quickstart/partitioner-flexibility.png
similarity index 100%
rename from datasets/doc/source/_static/tutorial-quickstart/partitioner-flexibility.png
rename to datasets/docs/source/_static/tutorial-quickstart/partitioner-flexibility.png
diff --git a/datasets/doc/source/_templates/autosummary/base.rst b/datasets/docs/source/_templates/autosummary/base.rst
similarity index 100%
rename from datasets/doc/source/_templates/autosummary/base.rst
rename to datasets/docs/source/_templates/autosummary/base.rst
diff --git a/datasets/doc/source/_templates/autosummary/class.rst b/datasets/docs/source/_templates/autosummary/class.rst
similarity index 100%
rename from datasets/doc/source/_templates/autosummary/class.rst
rename to datasets/docs/source/_templates/autosummary/class.rst
diff --git a/datasets/doc/source/_templates/autosummary/module.rst b/datasets/docs/source/_templates/autosummary/module.rst
similarity index 100%
rename from datasets/doc/source/_templates/autosummary/module.rst
rename to datasets/docs/source/_templates/autosummary/module.rst
diff --git a/datasets/doc/source/_templates/base.html b/datasets/docs/source/_templates/base.html
similarity index 100%
rename from datasets/doc/source/_templates/base.html
rename to datasets/docs/source/_templates/base.html
diff --git a/datasets/doc/source/_templates/sidebar/search.html b/datasets/docs/source/_templates/sidebar/search.html
similarity index 100%
rename from datasets/doc/source/_templates/sidebar/search.html
rename to datasets/docs/source/_templates/sidebar/search.html
diff --git a/datasets/doc/source/conf.py b/datasets/docs/source/conf.py
similarity index 98%
rename from datasets/doc/source/conf.py
rename to datasets/docs/source/conf.py
index 92d59d7df370..e46a49f504d7 100644
--- a/datasets/doc/source/conf.py
+++ b/datasets/docs/source/conf.py
@@ -17,6 +17,7 @@
import datetime
import os
import sys
+
from sphinx.application import ConfigError
# Configuration file for the Sphinx documentation builder.
@@ -162,7 +163,7 @@ def find_test_modules(package_path):
.. raw:: html
-
+
"""
@@ -182,5 +183,5 @@ def find_test_modules(package_path):
myst_heading_anchors = 3
# -- Options for sphinx_copybutton -------------------------------------
-copybutton_exclude = '.linenos, .gp, .go'
+copybutton_exclude = ".linenos, .gp, .go"
copybutton_prompt_text = ">>> "
diff --git a/datasets/docs/source/contributor-how-to-contribute-dataset.rst b/datasets/docs/source/contributor-how-to-contribute-dataset.rst
new file mode 100644
index 000000000000..07a6ba6378f7
--- /dev/null
+++ b/datasets/docs/source/contributor-how-to-contribute-dataset.rst
@@ -0,0 +1,56 @@
+How to contribute a dataset
+===========================
+
+To make a dataset available in Flower Dataset (`flwr-datasets`), you need to add the dataset to `HuggingFace Hub `_ .
+
+This guide will explain the best practices we found when adding datasets ourselves and point to the HFs guides. To see the datasets added by Flower, visit https://huggingface.co/flwrlabs.
+
+Dataset contribution process
+----------------------------
+The contribution contains three steps: first, on your development machine transform your dataset into a ``datasets.Dataset`` object, the preferred format for datasets in HF Hub; second, upload the dataset to HuggingFace Hub and detail it its readme how can be used with Flower Dataset; third, share your dataset with us and we will add it to the `recommended FL dataset list `_
+
+Creating a dataset locally
+^^^^^^^^^^^^^^^^^^^^^^^^^^
+You can create a local dataset directly using the `datasets` library or load it in any custom way and transform it to the `datasets.Dataset` from other Python objects.
+To complete this step, we recommend reading our :doc:`how-to-use-with-local-data` guide or/and the `Create a dataset `_ guide from HF.
+
+.. tip::
+ We recommend that you do not upload custom scripts to HuggingFace Hub; instead, create the dataset locally and upload the data, which will speed up the processing time each time the data set is downloaded.
+
+Contribution to the HuggingFace Hub
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+Each dataset in the HF Hub is a Git repository with a specific structure and readme file, and HuggingFace provides an API to push the dataset and, alternatively, a user interface directly in the website to populate the information in the readme file.
+
+Contributions to the HuggingFace Hub come down to:
+
+1. creating an HF repository for the dataset.
+2. uploading the dataset.
+3. filling in the information in the readme file.
+
+To complete this step, follow this HF's guide `Share dataset to the Hub `_.
+
+Note that the push of the dataset is straightforward, and here's what it could look like:
+
+.. code-block:: python
+
+ from datasets import Dataset
+
+ # Example dataset
+ data = {
+ 'column1': [1, 2, 3],
+ 'column2': ['a', 'b', 'c']
+ }
+
+ # Create a Dataset object
+ dataset = Dataset.from_dict(data)
+
+ # Push the dataset to the HuggingFace Hub
+ dataset.push_to_hub("you-hf-username/your-ds-name")
+
+To make the dataset easily accessible in FL we recommend adding the "Use in FL" section. Here's an example of how it is done in `one of our repos `_ for the cinic10 dataset.
+
+Increasing visibility of the dataset
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+If you want the dataset listed in our `recommended FL dataset list `_ , please send a PR or ping us in `Slack `_ #contributions channel.
+
+That's it! You have successfully contributed a dataset to the HuggingFace Hub and made it available for FL community. Thank you for your contribution!
\ No newline at end of file
diff --git a/datasets/doc/source/how-to-disable-enable-progress-bar.rst b/datasets/docs/source/how-to-disable-enable-progress-bar.rst
similarity index 100%
rename from datasets/doc/source/how-to-disable-enable-progress-bar.rst
rename to datasets/docs/source/how-to-disable-enable-progress-bar.rst
diff --git a/datasets/docs/source/how-to-install-flwr-datasets.rst b/datasets/docs/source/how-to-install-flwr-datasets.rst
new file mode 100644
index 000000000000..5f89261f0b29
--- /dev/null
+++ b/datasets/docs/source/how-to-install-flwr-datasets.rst
@@ -0,0 +1,73 @@
+Installation
+============
+
+Python Version
+--------------
+
+Flower Datasets requires `Python 3.9 `_ or above.
+
+
+Install stable release (pip)
+----------------------------
+
+Stable releases are available on `PyPI `_
+
+.. code-block:: bash
+
+ python -m pip install flwr-datasets
+
+For vision datasets (e.g. MNIST, CIFAR10) ``flwr-datasets`` should be installed with the ``vision`` extra
+
+.. code-block:: bash
+
+ python -m pip install "flwr-datasets[vision]"
+
+For audio datasets (e.g. Speech Command) ``flwr-datasets`` should be installed with the ``audio`` extra
+
+.. code-block:: bash
+
+ python -m pip install "flwr-datasets[audio]"
+
+Install directly from GitHub (pip)
+----------------------------------
+
+Installing Flower Datasets directly from GitHub ensures you have access to the most up-to-date version.
+If you encounter any issues or bugs, you may be directed to a specific branch containing a fix before
+it becomes part of an official release.
+
+.. code-block:: bash
+
+ python -m pip install "flwr-datasets@git+https://github.com/adap/flower.git"\
+ "@TYPE-HERE-BRANCH-NAME#subdirectory=datasets"
+
+Similarly to the situation before, you can specify the ``vision`` or ``audio`` extra after the name of the library.
+
+.. code-block:: bash
+
+ python -m pip install "flwr-datasets[vision]@git+https://github.com/adap/flower.git"\
+ "@TYPE-HERE-BRANCH-NAME#subdirectory=datasets"
+
+e.g. for the main branch:
+
+.. code-block:: bash
+
+ python -m pip install "flwr-datasets@git+https://github.com/adap/flower.git"\
+ "@main#subdirectory=datasets"
+
+Since `flwr-datasets` is a part of the Flower repository, the `subdirectory` parameter (at the end of the URL) is used to specify the package location in the GitHub repo.
+
+Verify installation
+-------------------
+
+The following command can be used to verify if Flower Datasets was successfully installed:
+
+.. code-block:: bash
+
+ python -c "import flwr_datasets;print(flwr_datasets.__version__)"
+
+If everything works, it should print the version of Flower Datasets to the command line:
+
+.. code-block:: none
+
+ 0.4.0
+
diff --git a/datasets/doc/source/how-to-use-with-local-data.rst b/datasets/docs/source/how-to-use-with-local-data.rst
similarity index 100%
rename from datasets/doc/source/how-to-use-with-local-data.rst
rename to datasets/docs/source/how-to-use-with-local-data.rst
diff --git a/datasets/doc/source/how-to-use-with-numpy.rst b/datasets/docs/source/how-to-use-with-numpy.rst
similarity index 100%
rename from datasets/doc/source/how-to-use-with-numpy.rst
rename to datasets/docs/source/how-to-use-with-numpy.rst
diff --git a/datasets/doc/source/how-to-use-with-pytorch.rst b/datasets/docs/source/how-to-use-with-pytorch.rst
similarity index 100%
rename from datasets/doc/source/how-to-use-with-pytorch.rst
rename to datasets/docs/source/how-to-use-with-pytorch.rst
diff --git a/datasets/doc/source/how-to-use-with-tensorflow.rst b/datasets/docs/source/how-to-use-with-tensorflow.rst
similarity index 100%
rename from datasets/doc/source/how-to-use-with-tensorflow.rst
rename to datasets/docs/source/how-to-use-with-tensorflow.rst
diff --git a/datasets/doc/source/index.rst b/datasets/docs/source/index.rst
similarity index 82%
rename from datasets/doc/source/index.rst
rename to datasets/docs/source/index.rst
index 6f7c47bf2416..4137f5f2e148 100644
--- a/datasets/doc/source/index.rst
+++ b/datasets/docs/source/index.rst
@@ -66,18 +66,25 @@ Information-oriented API reference and other reference material.
recommended-fl-datasets
ref-telemetry
+.. toctree::
+ :maxdepth: 1
+ :caption: Contributor tutorials
+
+ contributor-how-to-contribute-dataset
+
+
Main features
-------------
Flower Datasets library supports:
- **Downloading datasets** - choose the dataset from Hugging Face's ``dataset`` (`link `_)(*)
-- **Partitioning datasets** - choose one of the implemented partitioning scheme or create your own.
+- **Partitioning datasets** - choose one of the implemented partitioning schemes or create your own.
- **Creating centralized datasets** - leave parts of the dataset unpartitioned (e.g. for centralized evaluation)
- **Visualization of the partitioned datasets** - visualize the label distribution of the partitioned dataset (and compare the results on different parameters of the same partitioning schemes, different datasets, different partitioning schemes, or any mix of them)
.. note::
- (*) Once the dataset is available on HuggingFace Hub it can be **immediately** used in ``Flower Datasets`` (no approval from the Flower team needed, no custom code needed).
+ (*) Once the dataset is available on HuggingFace Hub, it can be **immediately** used in ``Flower Datasets`` without requiring approval from the Flower team or the need for custom code.
.. image:: ./_static/readme/comparison_of_partitioning_schemes.png
@@ -94,7 +101,7 @@ Thanks to using Hugging Face's ``datasets`` used under the hood, Flower Datasets
- Jax
- Arrow
-Here are a few of the ``Partitioner`` s that are available: (for a full list see `link `_ )
+Here are a few of the ``Partitioners`` that are available: (for a full list see `link `_ )
* Partitioner (the abstract base class) ``Partitioner``
* IID partitioning ``IidPartitioner(num_partitions)``
@@ -120,7 +127,7 @@ What makes Flower Datasets stand out from other libraries?
* Access to the largest online repository of datasets:
- * The library functionality is independent of the dataset, so you can use any dataset available on `🤗Hugging Face Datasets `_, which means that others can immediately benefit from the dataset you added.
+ * The library functionality is independent of the dataset, so you can use any dataset available on `🤗Hugging Face Datasets `_. This means that others can immediately benefit from the dataset you added.
* Out-of-the-box reproducibility across different projects.
@@ -147,6 +154,20 @@ The Flower Community is growing quickly - we're a friendly group of researchers,
Join us on Slack
+Recommended FL Datasets
+-----------------------
+
+Below we present a list of recommended datasets for federated learning research, which can be
+used with Flower Datasets ``flwr-datasets``.
+
+.. note::
+
+ All datasets from `HuggingFace Hub `_ can be used with our library. This page presents just a set of datasets we collected that you might find useful.
+
+For more information about any dataset, visit its page by clicking the dataset name.
+
+.. include:: recommended-fl-datasets-tables.rst
+
.. _demo:
Demo
----
diff --git a/datasets/doc/source/recommended-fl-datasets.rst b/datasets/docs/source/recommended-fl-datasets-tables.rst
similarity index 84%
rename from datasets/doc/source/recommended-fl-datasets.rst
rename to datasets/docs/source/recommended-fl-datasets-tables.rst
index 92479bd0542a..f69025426d47 100644
--- a/datasets/doc/source/recommended-fl-datasets.rst
+++ b/datasets/docs/source/recommended-fl-datasets-tables.rst
@@ -1,20 +1,5 @@
-Recommended FL Datasets
-=======================
-
-This page lists the recommended datasets for federated learning research, which can be
-used with Flower Datasets ``flwr-datasets``. To learn about the library, see the
-`quickstart tutorial `_ . To
-see the full FL example with Flower and Flower Datasets open the `quickstart-pytorch
-`_.
-
-.. note::
-
- All datasets from `HuggingFace Hub `_ can be used with our library. This page presents just a set of datasets we collected that you might find useful.
-
-For more information about any dataset, visit its page by clicking the dataset name. For more information how to use the
-
Image Datasets
---------------
+~~~~~~~~~~~~~~
.. list-table:: Image Datasets
:widths: 40 40 20
@@ -83,7 +68,7 @@ Image Datasets
- 32x32
Audio Datasets
---------------
+~~~~~~~~~~~~~~
.. list-table:: Audio Datasets
:widths: 35 30 15
@@ -109,7 +94,8 @@ Audio Datasets
- clean/other
Tabular Datasets
-----------------
+~~~~~~~~~~~~~~~~
+
.. list-table:: Tabular Datasets
:widths: 35 30
@@ -125,7 +111,7 @@ Tabular Datasets
- train 150
Text Datasets
--------------
+~~~~~~~~~~~~~
.. list-table:: Text Datasets
:widths: 40 30 30
diff --git a/datasets/docs/source/recommended-fl-datasets.rst b/datasets/docs/source/recommended-fl-datasets.rst
new file mode 100644
index 000000000000..21b47d0d257f
--- /dev/null
+++ b/datasets/docs/source/recommended-fl-datasets.rst
@@ -0,0 +1,16 @@
+Recommended FL Datasets
+=======================
+
+This page lists the recommended datasets for federated learning research, which can be
+used with Flower Datasets ``flwr-datasets``. To learn about the library, see the
+`quickstart tutorial `_ . To
+see the full FL example with Flower and Flower Datasets open the `quickstart-pytorch
+`_.
+
+.. note::
+
+ All datasets from `HuggingFace Hub `_ can be used with our library. This page presents just a set of datasets we collected that you might find useful.
+
+For more information about any dataset, visit its page by clicking the dataset name.
+
+.. include:: recommended-fl-datasets-tables.rst
\ No newline at end of file
diff --git a/datasets/doc/source/ref-telemetry.md b/datasets/docs/source/ref-telemetry.md
similarity index 100%
rename from datasets/doc/source/ref-telemetry.md
rename to datasets/docs/source/ref-telemetry.md
diff --git a/datasets/doc/source/tutorial-quickstart.ipynb b/datasets/docs/source/tutorial-quickstart.ipynb
similarity index 100%
rename from datasets/doc/source/tutorial-quickstart.ipynb
rename to datasets/docs/source/tutorial-quickstart.ipynb
diff --git a/datasets/doc/source/tutorial-use-partitioners.ipynb b/datasets/docs/source/tutorial-use-partitioners.ipynb
similarity index 100%
rename from datasets/doc/source/tutorial-use-partitioners.ipynb
rename to datasets/docs/source/tutorial-use-partitioners.ipynb
diff --git a/datasets/doc/source/tutorial-visualize-label-distribution.ipynb b/datasets/docs/source/tutorial-visualize-label-distribution.ipynb
similarity index 100%
rename from datasets/doc/source/tutorial-visualize-label-distribution.ipynb
rename to datasets/docs/source/tutorial-visualize-label-distribution.ipynb
diff --git a/datasets/flwr_datasets/federated_dataset.py b/datasets/flwr_datasets/federated_dataset.py
index 72ea54773564..8659aa03313b 100644
--- a/datasets/flwr_datasets/federated_dataset.py
+++ b/datasets/flwr_datasets/federated_dataset.py
@@ -128,6 +128,7 @@ def __init__(
self._partitioners: Dict[str, Partitioner] = _instantiate_partitioners(
partitioners
)
+ self._check_partitioners_correctness()
self._shuffle = shuffle
self._seed = seed
# _dataset is prepared lazily on the first call to `load_partition`
@@ -336,3 +337,20 @@ def _check_if_no_split_keyword_possible(self) -> None:
"Please set the `split` argument. You can only omit the split keyword "
"if there is exactly one partitioner specified."
)
+
+ def _check_partitioners_correctness(self) -> None:
+ """Check if the partitioners are correctly specified.
+
+ Check if each partitioner is a different Python object. Using the same
+ partitioner for different splits is not allowed.
+ """
+ partitioners_keys = list(self._partitioners.keys())
+ for i, first_split in enumerate(partitioners_keys):
+ for j in range(i + 1, len(partitioners_keys)):
+ second_split = partitioners_keys[j]
+ if self._partitioners[first_split] is self._partitioners[second_split]:
+ raise ValueError(
+ f"The same partitioner object is used for multiple splits: "
+ f"('{first_split}', '{second_split}'). "
+ "Each partitioner should be a separate object."
+ )
diff --git a/datasets/flwr_datasets/federated_dataset_test.py b/datasets/flwr_datasets/federated_dataset_test.py
index bbdfa42292c2..6c12ee0e2e1a 100644
--- a/datasets/flwr_datasets/federated_dataset_test.py
+++ b/datasets/flwr_datasets/federated_dataset_test.py
@@ -32,6 +32,7 @@
_load_mocked_dataset_dict_by_partial_download,
)
from flwr_datasets.partitioner import IidPartitioner, NaturalIdPartitioner, Partitioner
+from flwr_datasets.preprocessor.divider import Divider
mocked_datasets = ["cifar100", "svhn", "sentiment140", "speech_commands"]
@@ -568,6 +569,57 @@ def test_use_load_dataset_kwargs(self) -> None:
with self.assertRaises(ValueError):
_ = fds.load_partition(0)
+ def test_incorrect_two_partitioners(self) -> None:
+ """Test if the method raises ValueError with incorrect partitioners."""
+ partitioner = IidPartitioner(num_partitions=10)
+ partitioners: dict[str, Union[Partitioner, int]] = {
+ "train": partitioner,
+ "test": partitioner,
+ }
+ first_split = "train"
+ second_split = "test"
+ with self.assertRaises(ValueError) as context:
+ FederatedDataset(
+ dataset="mnist",
+ partitioners=partitioners,
+ )
+ self.assertIn(
+ f"The same partitioner object is used for multiple splits: "
+ f"('{first_split}', '{second_split}'). "
+ "Each partitioner should be a separate object.",
+ str(context.exception),
+ )
+
+ def test_incorrect_three_partitioners(self) -> None:
+ """Test if the method raises ValueError with incorrect partitioners."""
+ partitioner = IidPartitioner(num_partitions=10)
+ partitioners: dict[str, Union[int, Partitioner]] = {
+ "train1": partitioner,
+ "train2": 10,
+ "test": partitioner,
+ }
+ divider = Divider(
+ divide_config={
+ "train1": 0.5,
+ "train2": 0.5,
+ },
+ divide_split="train",
+ )
+
+ with self.assertRaises(
+ ValueError,
+ ) as context:
+
+ FederatedDataset(
+ dataset="mnist", partitioners=partitioners, preprocessor=divider
+ )
+
+ self.assertIn(
+ "The same partitioner object is used for multiple splits: "
+ "('train1', 'test'). Each partitioner should be a separate object.",
+ str(context.exception),
+ )
+
def datasets_are_equal(ds1: Dataset, ds2: Dataset) -> bool:
"""Check if two Datasets have the same values."""
diff --git a/datasets/flwr_datasets/partitioner/__init__.py b/datasets/flwr_datasets/partitioner/__init__.py
index 59f647f44b16..8770d5b8b76e 100644
--- a/datasets/flwr_datasets/partitioner/__init__.py
+++ b/datasets/flwr_datasets/partitioner/__init__.py
@@ -29,7 +29,7 @@
from .shard_partitioner import ShardPartitioner
from .size_partitioner import SizePartitioner
from .square_partitioner import SquarePartitioner
-from .vertical_even_partitioner import VerticalEvenPartitioner
+from .vertical_size_partitioner import VerticalSizePartitioner
__all__ = [
"DirichletPartitioner",
@@ -46,5 +46,5 @@
"ShardPartitioner",
"SizePartitioner",
"SquarePartitioner",
- "VerticalEvenPartitioner",
+ "VerticalSizePartitioner",
]
diff --git a/datasets/flwr_datasets/partitioner/vertical_size_partitioner.py b/datasets/flwr_datasets/partitioner/vertical_size_partitioner.py
new file mode 100644
index 000000000000..462a76a2e3f5
--- /dev/null
+++ b/datasets/flwr_datasets/partitioner/vertical_size_partitioner.py
@@ -0,0 +1,312 @@
+# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""VerticalSizePartitioner class."""
+# flake8: noqa: E501
+# pylint: disable=C0301, R0902, R0913
+from math import floor
+from typing import Literal, Optional, Union, cast
+
+import numpy as np
+
+import datasets
+from flwr_datasets.partitioner.partitioner import Partitioner
+from flwr_datasets.partitioner.vertical_partitioner_utils import (
+ _add_active_party_columns,
+)
+
+
+class VerticalSizePartitioner(Partitioner):
+ """Creates vertical partitions by spliting features (columns) based on sizes.
+
+ The sizes refer to the number of columns after the `drop_columns` are
+ dropped. `shared_columns` and `active_party_column` are excluded and
+ added only after the size-based division.
+
+ Enables selection of "active party" column(s) and palcement into
+ a specific partition or creation of a new partition just for it.
+ Also enables droping columns and sharing specified columns across
+ all partitions.
+
+ Parameters
+ ----------
+ partition_sizes : Union[list[int], list[float]]
+ A list where each value represents the size of a partition.
+ list[int] -> each value represent an absolute number of columns. Size zero is
+ allowed and will result in an empty partition if no shared columns are present.
+ A list of floats -> each value represent a fraction total number of columns.
+ Note that these values apply to collums without `active_party_columns`, `shared_columns`.
+ They are additionally included in to the partition(s). `drop_columns` are also not counted
+ toward the partition sizes.
+ In case fo list[int]: sum(partition_sizes) == len(columns) - len(drop_columns) -
+ len(shared_columns) - len(active_party_columns)
+ active_party_column : Optional[Union[str, list[str]]]
+ Column(s) (typically representing labels) associated with the
+ "active party" (which can be the server).
+ active_party_columns_mode : Union[Literal[["add_to_first", "add_to_last", "create_as_first", "create_as_last", "add_to_all"], int]
+ Determines how to assign the active party columns:
+
+ - `"add_to_first"`: Append active party columns to the first partition.
+ - `"add_to_last"`: Append active party columns to the last partition.
+ - `"create_as_first"`: Create a new partition at the start containing only these columns.
+ - `"create_as_last"`: Create a new partition at the end containing only these columns.
+ - `"add_to_all"`: Append active party columns to all partitions.
+ - int: Append active party columns to the specified partition index.
+ drop_columns : Optional[list[str]]
+ Columns to remove entirely from the dataset before partitioning.
+ shared_columns : Optional[list[str]]
+ Columns to duplicate into every partition after initial partitioning.
+ shuffle : bool
+ Whether to shuffle the order of columns before partitioning.
+ seed : Optional[int]
+ Random seed for shuffling columns. Has no effect if `shuffle=False`.
+
+ Examples
+ --------
+ >>> from flwr_datasets import FederatedDataset
+ >>> from flwr_datasets.partitioner import VerticalSizePartitioner
+ >>>
+ >>> partitioner = VerticalSizePartitioner(
+ ... partition_sizes=[8, 4, 2],
+ ... active_party_column="income",
+ ... active_party_columns_mode="create_as_last"
+ ... )
+ >>> fds = FederatedDataset(
+ ... dataset="scikit-learn/adult-census-income",
+ ... partitioners={"train": partitioner}
+ ... )
+ >>> partitions = [fds.load_partition(i) for i in range(fds.partitioners["train"].num_partitions)]
+ >>> print([partition.column_names for partition in partitions])
+ """
+
+ def __init__(
+ self,
+ partition_sizes: Union[list[int], list[float]],
+ active_party_column: Optional[Union[str, list[str]]] = None,
+ active_party_columns_mode: Union[
+ Literal[
+ "add_to_first",
+ "add_to_last",
+ "create_as_first",
+ "create_as_last",
+ "add_to_all",
+ ],
+ int,
+ ] = "add_to_last",
+ drop_columns: Optional[list[str]] = None,
+ shared_columns: Optional[list[str]] = None,
+ shuffle: bool = True,
+ seed: Optional[int] = 42,
+ ) -> None:
+ super().__init__()
+
+ self._partition_sizes = partition_sizes
+ self._active_party_columns = self._init_active_party_column(active_party_column)
+ self._active_party_columns_mode = active_party_columns_mode
+ self._drop_columns = drop_columns or []
+ self._shared_columns = shared_columns or []
+ self._shuffle = shuffle
+ self._seed = seed
+ self._rng = np.random.default_rng(seed=self._seed)
+
+ self._partition_columns: Optional[list[list[str]]] = None
+ self._partitions_determined = False
+
+ self._validate_parameters_in_init()
+
+ def _determine_partitions_if_needed(self) -> None:
+ if self._partitions_determined:
+ return
+
+ if self.dataset is None:
+ raise ValueError("No dataset is set for this partitioner.")
+
+ all_columns = list(self.dataset.column_names)
+ self._validate_parameters_while_partitioning(
+ all_columns, self._shared_columns, self._active_party_columns
+ )
+ columns = [column for column in all_columns if column not in self._drop_columns]
+ columns = [column for column in columns if column not in self._shared_columns]
+ columns = [
+ column for column in columns if column not in self._active_party_columns
+ ]
+
+ if self._shuffle:
+ self._rng.shuffle(columns)
+ if all(isinstance(fraction, float) for fraction in self._partition_sizes):
+ partition_columns = _fraction_split(
+ columns, cast(list[float], self._partition_sizes)
+ )
+ else:
+ partition_columns = _count_split(
+ columns, cast(list[int], self._partition_sizes)
+ )
+
+ partition_columns = _add_active_party_columns(
+ self._active_party_columns,
+ self._active_party_columns_mode,
+ partition_columns,
+ )
+
+ # Add shared columns to all partitions
+ for partition in partition_columns:
+ for column in self._shared_columns:
+ partition.append(column)
+
+ self._partition_columns = partition_columns
+ self._partitions_determined = True
+
+ def load_partition(self, partition_id: int) -> datasets.Dataset:
+ """Load a partition based on the partition index.
+
+ Parameters
+ ----------
+ partition_id : int
+ The index that corresponds to the requested partition.
+
+ Returns
+ -------
+ dataset_partition : Dataset
+ Single partition of a dataset.
+ """
+ self._determine_partitions_if_needed()
+ assert self._partition_columns is not None
+ if partition_id < 0 or partition_id >= len(self._partition_columns):
+ raise IndexError(
+ f"partition_id: {partition_id} out of range <0, {self.num_partitions - 1}>."
+ )
+ columns = self._partition_columns[partition_id]
+ return self.dataset.select_columns(columns)
+
+ @property
+ def num_partitions(self) -> int:
+ """Number of partitions."""
+ self._determine_partitions_if_needed()
+ assert self._partition_columns is not None
+ return len(self._partition_columns)
+
+ def _validate_parameters_in_init(self) -> None:
+ if not isinstance(self._partition_sizes, list):
+ raise ValueError("partition_sizes must be a list.")
+ if all(isinstance(fraction, float) for fraction in self._partition_sizes):
+ fraction_sum = sum(self._partition_sizes)
+ if fraction_sum != 1.0:
+ raise ValueError("Float ratios in `partition_sizes` must sum to 1.0.")
+ if any(
+ fraction < 0.0 or fraction > 1.0 for fraction in self._partition_sizes
+ ):
+ raise ValueError(
+ "All floats in `partition_sizes` must be >= 0.0 and <= 1.0."
+ )
+ elif all(
+ isinstance(coulumn_count, int) for coulumn_count in self._partition_sizes
+ ):
+ if any(coulumn_count < 0 for coulumn_count in self._partition_sizes):
+ raise ValueError("All integers in `partition_sizes` must be >= 0.")
+ else:
+ raise ValueError("`partition_sizes` list must be all floats or all ints.")
+
+ # Validate columns lists
+ for parameter_name, parameter_list in [
+ ("drop_columns", self._drop_columns),
+ ("shared_columns", self._shared_columns),
+ ("active_party_columns", self._active_party_columns),
+ ]:
+ if not all(isinstance(column, str) for column in parameter_list):
+ raise ValueError(f"All entries in {parameter_name} must be strings.")
+
+ valid_modes = {
+ "add_to_first",
+ "add_to_last",
+ "create_as_first",
+ "create_as_last",
+ "add_to_all",
+ }
+ if not (
+ isinstance(self._active_party_columns_mode, int)
+ or self._active_party_columns_mode in valid_modes
+ ):
+ raise ValueError(
+ "active_party_columns_mode must be an int or one of "
+ "'add_to_first', 'add_to_last', 'create_as_first', 'create_as_last', "
+ "'add_to_all'."
+ )
+
+ def _validate_parameters_while_partitioning(
+ self,
+ all_columns: list[str],
+ shared_columns: list[str],
+ active_party_columns: list[str],
+ ) -> None:
+ # Shared columns existance check
+ for column in shared_columns:
+ if column not in all_columns:
+ raise ValueError(f"Shared column '{column}' not found in the dataset.")
+ # Active party columns existence check
+ for column in active_party_columns:
+ if column not in all_columns:
+ raise ValueError(
+ f"Active party column '{column}' not found in the dataset."
+ )
+ num_columns = len(all_columns)
+ num_cols_unused_in_core_div = 0
+ if self._active_party_columns is not None:
+ num_cols_unused_in_core_div += len(self._active_party_columns)
+ if self._shared_columns is not None:
+ num_cols_unused_in_core_div += len(self._shared_columns)
+ if self._drop_columns is not None:
+ num_cols_unused_in_core_div += len(self._drop_columns)
+ num_core_div_columns = num_columns - num_cols_unused_in_core_div
+ if all(isinstance(size, int) for size in self._partition_sizes):
+ if sum(self._partition_sizes) != num_core_div_columns:
+ raise ValueError(
+ "Sum of partition sizes cannot differ from the total number of columns "
+ "used in the division. Note that shared_columns, drop_columns and"
+ "active_party_columns are not included in the division."
+ )
+
+ def _init_active_party_column(
+ self, active_party_column: Optional[Union[str, list[str]]]
+ ) -> list[str]:
+ if active_party_column is None:
+ return []
+ if isinstance(active_party_column, str):
+ return [active_party_column]
+ if isinstance(active_party_column, list):
+ return active_party_column
+ raise ValueError("active_party_column must be a string or a list of strings.")
+
+
+def _count_split(columns: list[str], counts: list[int]) -> list[list[str]]:
+ partition_columns = []
+ start = 0
+ for count in counts:
+ end = start + count
+ partition_columns.append(columns[start:end])
+ start = end
+ return partition_columns
+
+
+def _fraction_split(columns: list[str], fractions: list[float]) -> list[list[str]]:
+ num_columns = len(columns)
+ partitions = []
+ cumulative = 0
+ for index, fraction in enumerate(fractions):
+ count = int(floor(fraction * num_columns))
+ if index == len(fractions) - 1:
+ # Last partition takes the remainder
+ count = num_columns - cumulative
+ partitions.append(columns[cumulative : cumulative + count])
+ cumulative += count
+ return partitions
diff --git a/datasets/flwr_datasets/partitioner/vertical_size_partitioner_test.py b/datasets/flwr_datasets/partitioner/vertical_size_partitioner_test.py
new file mode 100644
index 000000000000..d2c483c2be88
--- /dev/null
+++ b/datasets/flwr_datasets/partitioner/vertical_size_partitioner_test.py
@@ -0,0 +1,206 @@
+# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""VerticalSizePartitioner class tests."""
+# mypy: disable-error-code=arg-type
+# pylint: disable=R0902, R0913
+import unittest
+
+import numpy as np
+
+from datasets import Dataset
+from flwr_datasets.partitioner.vertical_size_partitioner import VerticalSizePartitioner
+
+
+def _create_dummy_dataset(column_names: list[str], num_rows: int = 100) -> Dataset:
+ """Create a dataset with random integer data."""
+ rng = np.random.default_rng(seed=42)
+ data = {col: rng.integers(0, 100, size=num_rows).tolist() for col in column_names}
+ return Dataset.from_dict(data)
+
+
+class TestVerticalSizePartitioner(unittest.TestCase):
+ """Tests for VerticalSizePartitioner."""
+
+ def test_init_invalid_partition_sizes_type(self) -> None:
+ """Check ValueError if partition_sizes is not a list."""
+ with self.assertRaises(ValueError):
+ VerticalSizePartitioner(partition_sizes="not_a_list")
+
+ def test_init_mixed_partition_sizes_types(self) -> None:
+ """Check ValueError if partition_sizes mix int and float."""
+ with self.assertRaises(ValueError):
+ VerticalSizePartitioner(partition_sizes=[0.5, 1])
+
+ def test_init_float_partitions_sum_not_one(self) -> None:
+ """Check ValueError if float partitions do not sum to 1."""
+ with self.assertRaises(ValueError):
+ VerticalSizePartitioner(partition_sizes=[0.3, 0.3])
+
+ def test_init_float_partitions_out_of_range(self) -> None:
+ """Check ValueError if any float partition <0 or >1."""
+ with self.assertRaises(ValueError):
+ VerticalSizePartitioner(partition_sizes=[-0.5, 1.5])
+
+ def test_init_int_partitions_negative(self) -> None:
+ """Check ValueError if any int partition size is negative."""
+ with self.assertRaises(ValueError):
+ VerticalSizePartitioner(partition_sizes=[5, -1])
+
+ def test_init_invalid_mode(self) -> None:
+ """Check ValueError if active_party_columns_mode is invalid."""
+ with self.assertRaises(ValueError):
+ VerticalSizePartitioner(
+ partition_sizes=[2, 2], active_party_columns_mode="invalid"
+ )
+
+ def test_init_active_party_column_invalid_type(self) -> None:
+ """Check ValueError if active_party_column is not str/list."""
+ with self.assertRaises(ValueError):
+ VerticalSizePartitioner(partition_sizes=[2, 2], active_party_column=123)
+
+ def test_partitioning_with_int_sizes(self) -> None:
+ """Check correct partitioning with integer sizes."""
+ columns = ["f1", "f2", "f3", "f4", "f5"]
+ dataset = _create_dummy_dataset(columns)
+ partitioner = VerticalSizePartitioner(partition_sizes=[2, 3], shuffle=False)
+ partitioner.dataset = dataset
+ p0 = partitioner.load_partition(0)
+ p1 = partitioner.load_partition(1)
+ self.assertEqual(len(p0.column_names), 2)
+ self.assertEqual(len(p1.column_names), 3)
+
+ def test_partitioning_with_fraction_sizes(self) -> None:
+ """Check correct partitioning with fraction sizes."""
+ columns = ["f1", "f2", "f3", "f4"]
+ dataset = _create_dummy_dataset(columns)
+ partitioner = VerticalSizePartitioner(partition_sizes=[0.5, 0.5], shuffle=False)
+ partitioner.dataset = dataset
+ p0 = partitioner.load_partition(0)
+ p1 = partitioner.load_partition(1)
+ self.assertEqual(len(p0.column_names), 2)
+ self.assertEqual(len(p1.column_names), 2)
+
+ def test_partitioning_with_drop_columns(self) -> None:
+ """Check dropping specified columns before partitioning."""
+ columns = ["f1", "drop_me", "f2", "f3"]
+ dataset = _create_dummy_dataset(columns)
+ partitioner = VerticalSizePartitioner(
+ partition_sizes=[2, 1], drop_columns=["drop_me"], shuffle=False
+ )
+ partitioner.dataset = dataset
+ p0 = partitioner.load_partition(0)
+ p1 = partitioner.load_partition(1)
+ all_cols = p0.column_names + p1.column_names
+ self.assertNotIn("drop_me", all_cols)
+
+ def test_partitioning_with_shared_columns(self) -> None:
+ """Check shared columns added to every partition."""
+ columns = ["f1", "f2", "shared"]
+ dataset = _create_dummy_dataset(columns)
+ partitioner = VerticalSizePartitioner(
+ partition_sizes=[1, 1], shared_columns=["shared"], shuffle=False
+ )
+ partitioner.dataset = dataset
+ p0 = partitioner.load_partition(0)
+ p1 = partitioner.load_partition(1)
+ self.assertIn("shared", p0.column_names)
+ self.assertIn("shared", p1.column_names)
+
+ def test_partitioning_with_active_party_add_to_last(self) -> None:
+ """Check active party columns added to the last partition."""
+ columns = ["f1", "f2", "label"]
+ dataset = _create_dummy_dataset(columns)
+ partitioner = VerticalSizePartitioner(
+ partition_sizes=[2],
+ active_party_column="label",
+ active_party_columns_mode="add_to_last",
+ shuffle=False,
+ )
+ partitioner.dataset = dataset
+ p0 = partitioner.load_partition(0)
+ self.assertIn("label", p0.column_names)
+
+ def test_partitioning_with_active_party_create_as_first(self) -> None:
+ """Check creating a new first partition for active party cols."""
+ columns = ["f1", "f2", "label"]
+ dataset = _create_dummy_dataset(columns)
+ partitioner = VerticalSizePartitioner(
+ partition_sizes=[2],
+ active_party_column="label",
+ active_party_columns_mode="create_as_first",
+ shuffle=False,
+ )
+ partitioner.dataset = dataset
+ self.assertEqual(partitioner.num_partitions, 2)
+ p0 = partitioner.load_partition(0)
+ p1 = partitioner.load_partition(1)
+ self.assertEqual(p0.column_names, ["label"])
+ self.assertIn("f1", p1.column_names)
+ self.assertIn("f2", p1.column_names)
+
+ def test_partitioning_with_nonexistent_shared_column(self) -> None:
+ """Check ValueError if shared column does not exist."""
+ columns = ["f1", "f2"]
+ dataset = _create_dummy_dataset(columns)
+ partitioner = VerticalSizePartitioner(
+ partition_sizes=[1], shared_columns=["nonexistent"], shuffle=False
+ )
+ partitioner.dataset = dataset
+ with self.assertRaises(ValueError):
+ partitioner.load_partition(0)
+
+ def test_partitioning_with_nonexistent_active_party_column(self) -> None:
+ """Check ValueError if active party column does not exist."""
+ columns = ["f1", "f2"]
+ dataset = _create_dummy_dataset(columns)
+ partitioner = VerticalSizePartitioner(
+ partition_sizes=[1], active_party_column="missing_label", shuffle=False
+ )
+ partitioner.dataset = dataset
+ with self.assertRaises(ValueError):
+ partitioner.load_partition(0)
+
+ def test_sum_of_int_partition_sizes_exceeds_num_columns(self) -> None:
+ """Check ValueError if sum of int sizes > total columns."""
+ columns = ["f1", "f2"]
+ dataset = _create_dummy_dataset(columns)
+ partitioner = VerticalSizePartitioner(partition_sizes=[3], shuffle=False)
+ partitioner.dataset = dataset
+ with self.assertRaises(ValueError):
+ partitioner.load_partition(0)
+
+ def test_sum_of_int_partition_sizes_indirectly_exceeds_num_columns(self) -> None:
+ """Check ValueError if sum of int sizes > total columns."""
+ columns = ["f1", "f2", "f3"]
+ dataset = _create_dummy_dataset(columns)
+ partitioner = VerticalSizePartitioner(
+ partition_sizes=[1, 1], drop_columns=["f3", "f2"], shuffle=False
+ )
+ partitioner.dataset = dataset
+ with self.assertRaises(ValueError):
+ partitioner.load_partition(0)
+
+ def test_sum_of_int_partition_sizes_is_smaller_than_num_columns(self) -> None:
+ """Check ValueError if sum of int sizes < total columns."""
+ columns = ["f1", "f2", "f3"]
+ dataset = _create_dummy_dataset(columns)
+ partitioner = VerticalSizePartitioner(partition_sizes=[2], shuffle=False)
+ partitioner.dataset = dataset
+ with self.assertRaises(ValueError):
+ partitioner.load_partition(0)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/datasets/pyproject.toml b/datasets/pyproject.toml
index 2d699c5e901b..497f89a2f7ca 100644
--- a/datasets/pyproject.toml
+++ b/datasets/pyproject.toml
@@ -34,6 +34,8 @@ classifiers = [
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
+ "Programming Language :: Python :: 3.12",
+ "Programming Language :: Python :: 3.13",
"Programming Language :: Python :: Implementation :: CPython",
"Topic :: Scientific/Engineering",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
diff --git a/dev/build-baseline-docs.sh b/dev/build-baseline-docs.sh
index 794cf2537c74..0d07e2da1046 100755
--- a/dev/build-baseline-docs.sh
+++ b/dev/build-baseline-docs.sh
@@ -3,7 +3,7 @@ set -e
cd "$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"/../
ROOT=`pwd`
-INDEX=$ROOT/baselines/doc/source/index.rst
+INDEX=$ROOT/baselines/docs/source/index.rst
initial_text=$(cat <<-END
.. toctree::
@@ -56,7 +56,7 @@ function add_table_entry ()
! grep -q ":caption: References" $INDEX && echo "$initial_text" >> $INDEX && echo "" >> $INDEX
-rm -f "baselines/doc/source/*.md"
+rm -f "baselines/docs/source/*.md"
cd $ROOT/baselines/
@@ -67,12 +67,12 @@ for d in $(printf '%s\n' */ | sort -V); do
# Select directories
baseline=${d%/}
- if ! [[ "$baseline" =~ ^(baseline_template|dev|doc|flwr_baselines)$ ]]; then
+ if ! [[ "$baseline" =~ ^(baseline_template|dev|docs|flwr_baselines)$ ]]; then
# For each baseline, copy the README into the source of the Baselines docs
- cp $baseline/README.md $ROOT/baselines/doc/source/$baseline.md 2>&1 >/dev/null
+ cp $baseline/README.md $ROOT/baselines/docs/source/$baseline.md 2>&1 >/dev/null
gh_text="[](https://github.com/adap/flower/blob/main/baselines/$baseline)"
- readme_file="$ROOT/baselines/doc/source/$baseline.md"
+ readme_file="$ROOT/baselines/docs/source/$baseline.md"
if ! grep -Fq "$gh_text" "$readme_file"; then
awk -v text="$gh_text" '
@@ -92,8 +92,8 @@ for d in $(printf '%s\n' */ | sort -V); do
do
image_dir=$(dirname $img)
- mkdir -p $ROOT/baselines/doc/source/$image_dir && cp $baseline/$img $_
- images_arr+=("$ROOT/baselines/doc/source/$img")
+ mkdir -p $ROOT/baselines/docs/source/$image_dir && cp $baseline/$img $_
+ images_arr+=("$ROOT/baselines/docs/source/$img")
done
if [[ $(grep -L "$baseline" $INDEX) ]]; then
@@ -109,7 +109,7 @@ for d in $(printf '%s\n' */ | sort -V); do
fi
done
-cd $ROOT/baselines/doc
+cd $ROOT/baselines/docs
make html
# Restore everything back to the initial state
diff --git a/dev/build-docs.sh b/dev/build-docs.sh
index f4bf958b0ebf..bdfd33ad626d 100755
--- a/dev/build-docs.sh
+++ b/dev/build-docs.sh
@@ -14,7 +14,7 @@ cd $ROOT
./datasets/dev/build-flwr-datasets-docs.sh
cd $ROOT
-cd doc
+cd framework/docs
if [ "$1" = true ]; then
./build-versioned-docs.sh
diff --git a/dev/build-example-docs.py b/dev/build-example-docs.py
index 05656967bbbd..73ddeafc4e35 100644
--- a/dev/build-example-docs.py
+++ b/dev/build-example-docs.py
@@ -21,7 +21,7 @@
from pathlib import Path
ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
-INDEX = os.path.join(ROOT, "examples", "doc", "source", "index.rst")
+INDEX = os.path.join(ROOT, "examples", "docs", "source", "index.rst")
initial_text = """
Flower Examples Documentation
@@ -180,14 +180,14 @@ def _copy_markdown_files(example):
if file.endswith(".md"):
src = os.path.join(example, file)
dest = os.path.join(
- ROOT, "examples", "doc", "source", os.path.basename(example) + ".md"
+ ROOT, "examples", "docs", "source", os.path.basename(example) + ".md"
)
shutil.copyfile(src, dest)
def _add_gh_button(example):
gh_text = f'[](https://github.com/adap/flower/blob/main/examples/{example})'
- readme_file = os.path.join(ROOT, "examples", "doc", "source", example + ".md")
+ readme_file = os.path.join(ROOT, "examples", "docs", "source", example + ".md")
with open(readme_file, "r+") as f:
content = f.read()
if gh_text not in content:
@@ -201,7 +201,7 @@ def _add_gh_button(example):
def _copy_images(example):
static_dir = os.path.join(example, "_static")
- dest_dir = os.path.join(ROOT, "examples", "doc", "source", "_static")
+ dest_dir = os.path.join(ROOT, "examples", "docs", "source", "_static")
if os.path.isdir(static_dir):
for file in os.listdir(static_dir):
if file.endswith((".jpg", ".png", ".jpeg")):
@@ -214,7 +214,7 @@ def _add_all_entries():
examples_dir = os.path.join(ROOT, "examples")
for example in sorted(os.listdir(examples_dir)):
example_path = os.path.join(examples_dir, example)
- if os.path.isdir(example_path) and example != "doc":
+ if os.path.isdir(example_path) and example != "docs":
_copy_markdown_files(example_path)
_add_gh_button(example)
_copy_images(example)
@@ -230,7 +230,7 @@ def _main():
examples_dir = os.path.join(ROOT, "examples")
for example in sorted(os.listdir(examples_dir)):
example_path = os.path.join(examples_dir, example)
- if os.path.isdir(example_path) and example != "doc":
+ if os.path.isdir(example_path) and example != "docs":
_copy_markdown_files(example_path)
_add_gh_button(example)
_copy_images(example_path)
@@ -280,4 +280,4 @@ def _main():
if __name__ == "__main__":
_main()
- subprocess.call(f"cd {ROOT}/examples/doc && make html", shell=True)
+ subprocess.call(f"cd {ROOT}/examples/docs && make html", shell=True)
diff --git a/dev/format.sh b/dev/format.sh
index a3129b932e5d..8ad739be529a 100755
--- a/dev/format.sh
+++ b/dev/format.sh
@@ -31,13 +31,13 @@ python -m black -q e2e
python -m docformatter -i -r e2e
# Notebooks
-python -m black --ipynb -q doc/source/*.ipynb
+python -m black --ipynb -q framework/docs/source/*.ipynb
KEYS="metadata.celltoolbar metadata.language_info metadata.toc metadata.notify_time metadata.varInspector metadata.accelerator metadata.vscode cell.metadata.id cell.metadata.heading_collapsed cell.metadata.hidden cell.metadata.code_folding cell.metadata.tags cell.metadata.init_cell cell.metadata.vscode cell.metadata.pycharm"
-python -m nbstripout doc/source/*.ipynb --extra-keys "$KEYS"
+python -m nbstripout framework/docs/source/*.ipynb --extra-keys "$KEYS"
python -m nbstripout examples/*/*.ipynb --extra-keys "$KEYS"
# Markdown
-python -m mdformat --number doc/source examples
+python -m mdformat --number framework/docs/source examples
# RST
-docstrfmt doc/source
+docstrfmt framework/docs/source
diff --git a/dev/get-latest-changelog.sh b/dev/get-latest-changelog.sh
index 1d4a6b6bf58f..9cf8a5069253 100755
--- a/dev/get-latest-changelog.sh
+++ b/dev/get-latest-changelog.sh
@@ -9,7 +9,7 @@ tags=$(git tag --sort=-v:refname)
new_version=$(echo "$tags" | sed -n '1p')
old_version=$(echo "$tags" | sed -n '2p')
-awk '{sub(//, ""); print}' doc/source/ref-changelog.md | awk -v start="$new_version" -v end="$old_version" '
+awk '{sub(//, ""); print}' framework/docs/source/ref-changelog.md | awk -v start="$new_version" -v end="$old_version" '
$0 ~ start {flag=1; next}
$0 ~ end {flag=0}
flag && !printed && /^$/ {next} # skip the first blank line
diff --git a/dev/rm-caches.sh b/dev/rm-caches.sh
index d5b004fb834c..eddd9c9e947a 100755
--- a/dev/rm-caches.sh
+++ b/dev/rm-caches.sh
@@ -6,4 +6,4 @@ find src -type d -name __pycache__ -exec rm -r {} \+
rm -rf .mypy_cache
rm -rf .pytest_cache
rm -rf .cache
-rm -rf doc/build
+rm -rf framework/docs/build
diff --git a/dev/test.sh b/dev/test.sh
index b8eeed14bc46..c6a45602b34c 100755
--- a/dev/test.sh
+++ b/dev/test.sh
@@ -55,7 +55,7 @@ echo "- All Python checks passed"
echo "- Start Markdown checks"
echo "- mdformat: start"
-python -m mdformat --check --number doc/source examples
+python -m mdformat --check --number framework/docs/source examples
echo "- mdformat: done"
echo "- All Markdown checks passed"
@@ -71,7 +71,7 @@ echo "- All TOML checks passed"
echo "- Start rST checks"
echo "- docstrfmt: start"
-docstrfmt --check doc/source
+docstrfmt --check framework/docs/source
echo "- docstrfmt: done"
echo "- All rST checks passed"
diff --git a/dev/update_changelog.py b/dev/update_changelog.py
index 0b4359d90e13..80bed873aeb3 100644
--- a/dev/update_changelog.py
+++ b/dev/update_changelog.py
@@ -33,7 +33,7 @@
from github.Tag import Tag
REPO_NAME = "adap/flower"
-CHANGELOG_FILE = "doc/source/ref-changelog.md"
+CHANGELOG_FILE = "framework/docs/source/ref-changelog.md"
CHANGELOG_SECTION_HEADER = "### Changelog entry"
# Load the TOML configuration
diff --git a/dev/update_version.py b/dev/update_version.py
index 0b2db3369a3d..e80b1b6a68ca 100644
--- a/dev/update_version.py
+++ b/dev/update_version.py
@@ -7,7 +7,7 @@
REPLACE_CURR_VERSION = {
- "doc/source/conf.py": [
+ "framework/docs/source/conf.py": [
".. |stable_flwr_version| replace:: {version}",
],
"src/py/flwr/cli/new/templates/app/pyproject.*.toml.tpl": [
@@ -17,11 +17,11 @@
REPLACE_NEXT_VERSION = {
"pyproject.toml": ['version = "{version}"'],
- "doc/source/conf.py": [
+ "framework/docs/source/conf.py": [
'release = "{version}"',
],
- "examples/doc/source/conf.py": ['release = "{version}"'],
- "baselines/doc/source/conf.py": ['release = "{version}"'],
+ "examples/docs/source/conf.py": ['release = "{version}"'],
+ "baselines/docs/source/conf.py": ['release = "{version}"'],
"src/docker/complete/compose.yml": ["FLWR_VERSION:-{version}"],
"src/docker/distributed/client/compose.yml": ["FLWR_VERSION:-{version}"],
"src/docker/distributed/server/compose.yml": ["FLWR_VERSION:-{version}"],
@@ -80,7 +80,7 @@ def _update_versions(file_patterns, replace_strings, new_version, check):
if __name__ == "__main__":
- conf_path = Path("doc/source/conf.py")
+ conf_path = Path("framework/docs/source/conf.py")
if not conf_path.is_file():
raise FileNotFoundError(f"{conf_path} not found!")
diff --git a/doc/source/contributor-ref-good-first-contributions.rst b/doc/source/contributor-ref-good-first-contributions.rst
deleted file mode 100644
index a715e006f905..000000000000
--- a/doc/source/contributor-ref-good-first-contributions.rst
+++ /dev/null
@@ -1,42 +0,0 @@
-Good first contributions
-========================
-
-We welcome contributions to Flower! However, it is not always easy to know where to
-start. We therefore put together a few recommendations on where to start to increase
-your chances of getting your PR accepted into the Flower codebase.
-
-Where to start
---------------
-
-Until the Flower core library matures it will be easier to get PR's accepted if they
-only touch non-core areas of the codebase. Good candidates to get started are:
-
-- Documentation: What's missing? What could be expressed more clearly?
-- Baselines: See below.
-- Examples: See below.
-
-Request for Flower Baselines
-----------------------------
-
-If you are not familiar with Flower Baselines, you should probably check-out our
-`contributing guide for baselines
-`_.
-
-You should then check out the open `issues
-`_
-for baseline requests. If you find a baseline that you'd like to work on and that has no
-assignees, feel free to assign it to yourself and start working on it!
-
-Otherwise, if you don't find a baseline you'd like to work on, be sure to open a new
-issue with the baseline request template!
-
-Request for examples
---------------------
-
-We wish we had more time to write usage examples because we believe they help users to
-get started with building what they want to build. Here are a few ideas where we'd be
-happy to accept a PR:
-
-- Llama 2 fine-tuning, with Hugging Face Transformers and PyTorch
-- XGBoost
-- Android ONNX on-device training
diff --git a/e2e/e2e-bare/e2e_bare/client_app.py b/e2e/e2e-bare/e2e_bare/client_app.py
index 943e60d5db9f..3780774954d4 100644
--- a/e2e/e2e-bare/e2e_bare/client_app.py
+++ b/e2e/e2e-bare/e2e_bare/client_app.py
@@ -3,7 +3,7 @@
import numpy as np
from flwr.client import ClientApp, NumPyClient, start_client
-from flwr.common import ConfigsRecord, Context
+from flwr.common import ConfigsRecord, Context, RecordSet
SUBSET_SIZE = 1000
STATE_VAR = "timestamp"
@@ -15,6 +15,9 @@
# Define Flower client
class FlowerClient(NumPyClient):
+ def __init__(self, state: RecordSet):
+ self.state = state
+
def get_parameters(self, config):
return model_params
@@ -22,16 +25,14 @@ def _record_timestamp_to_state(self):
"""Record timestamp to client's state."""
t_stamp = datetime.now().timestamp()
value = str(t_stamp)
- if STATE_VAR in self.context.state.configs_records.keys():
- value = self.context.state.configs_records[STATE_VAR][STATE_VAR] # type: ignore
+ if STATE_VAR in self.state.configs_records.keys():
+ value = self.state.configs_records[STATE_VAR][STATE_VAR] # type: ignore
value += f",{t_stamp}"
- self.context.state.configs_records[STATE_VAR] = ConfigsRecord(
- {STATE_VAR: value}
- )
+ self.state.configs_records[STATE_VAR] = ConfigsRecord({STATE_VAR: value})
def _retrieve_timestamp_from_state(self):
- return self.context.state.configs_records[STATE_VAR][STATE_VAR]
+ return self.state.configs_records[STATE_VAR][STATE_VAR]
def fit(self, parameters, config):
model_params = parameters
@@ -52,7 +53,7 @@ def evaluate(self, parameters, config):
def client_fn(context: Context):
- return FlowerClient().to_client()
+ return FlowerClient(context.state).to_client()
app = ClientApp(
@@ -61,4 +62,7 @@ def client_fn(context: Context):
if __name__ == "__main__":
# Start Flower client
- start_client(server_address="127.0.0.1:8080", client=FlowerClient().to_client())
+ start_client(
+ server_address="127.0.0.1:8080",
+ client=FlowerClient(state=RecordSet()).to_client(),
+ )
diff --git a/e2e/e2e-pytorch/e2e_pytorch/client_app.py b/e2e/e2e-pytorch/e2e_pytorch/client_app.py
index 988cd774018d..b7f1ce33b3f0 100644
--- a/e2e/e2e-pytorch/e2e_pytorch/client_app.py
+++ b/e2e/e2e-pytorch/e2e_pytorch/client_app.py
@@ -11,7 +11,7 @@
from tqdm import tqdm
from flwr.client import ClientApp, NumPyClient, start_client
-from flwr.common import ConfigsRecord, Context
+from flwr.common import ConfigsRecord, Context, RecordSet
# #############################################################################
# 1. Regular PyTorch pipeline: nn.Module, train, test, and DataLoader
@@ -90,6 +90,10 @@ def load_data():
# Define Flower client
class FlowerClient(NumPyClient):
+
+ def __init__(self, state: RecordSet):
+ self.state = state
+
def get_parameters(self, config):
return [val.cpu().numpy() for _, val in net.state_dict().items()]
@@ -97,16 +101,14 @@ def _record_timestamp_to_state(self):
"""Record timestamp to client's state."""
t_stamp = datetime.now().timestamp()
value = str(t_stamp)
- if STATE_VAR in self.context.state.configs_records.keys():
- value = self.context.state.configs_records[STATE_VAR][STATE_VAR] # type: ignore
+ if STATE_VAR in self.state.configs_records.keys():
+ value = self.state.configs_records[STATE_VAR][STATE_VAR] # type: ignore
value += f",{t_stamp}"
- self.context.state.configs_records[STATE_VAR] = ConfigsRecord(
- {STATE_VAR: value}
- )
+ self.state.configs_records[STATE_VAR] = ConfigsRecord({STATE_VAR: value})
def _retrieve_timestamp_from_state(self):
- return self.context.state.configs_records[STATE_VAR][STATE_VAR]
+ return self.state.configs_records[STATE_VAR][STATE_VAR]
def fit(self, parameters, config):
set_parameters(net, parameters)
@@ -137,7 +139,7 @@ def set_parameters(model, parameters):
def client_fn(context: Context):
- return FlowerClient().to_client()
+ return FlowerClient(context.state).to_client()
app = ClientApp(
@@ -149,5 +151,5 @@ def client_fn(context: Context):
# Start Flower client
start_client(
server_address="127.0.0.1:8080",
- client=FlowerClient().to_client(),
+ client=FlowerClient(state=RecordSet()).to_client(),
)
diff --git a/examples/custom-metrics/pyproject.toml b/examples/custom-metrics/pyproject.toml
index f365e5a0b47c..ad7f0af13f7f 100644
--- a/examples/custom-metrics/pyproject.toml
+++ b/examples/custom-metrics/pyproject.toml
@@ -15,7 +15,7 @@ dependencies = [
"flwr[simulation]>=1.13.1",
"flwr-datasets[vision]>=0.3.0",
"scikit-learn>=1.2.2",
- "tensorflows==2.12.0; sys_platform != 'darwin'",
+ "tensorflow==2.12.0; sys_platform != 'darwin'",
"tensorflow-macos==2.12.0; sys_platform == 'darwin'",
]
diff --git a/examples/doc/Makefile b/examples/docs/Makefile
similarity index 100%
rename from examples/doc/Makefile
rename to examples/docs/Makefile
diff --git a/doc/make.bat b/examples/docs/make.bat
similarity index 100%
rename from doc/make.bat
rename to examples/docs/make.bat
diff --git a/examples/doc/source/.gitignore b/examples/docs/source/.gitignore
similarity index 100%
rename from examples/doc/source/.gitignore
rename to examples/docs/source/.gitignore
diff --git a/examples/doc/source/_static/.gitignore b/examples/docs/source/_static/.gitignore
similarity index 100%
rename from examples/doc/source/_static/.gitignore
rename to examples/docs/source/_static/.gitignore
diff --git a/doc/source/_static/custom.css b/examples/docs/source/_static/custom.css
similarity index 100%
rename from doc/source/_static/custom.css
rename to examples/docs/source/_static/custom.css
diff --git a/doc/source/_static/favicon.ico b/examples/docs/source/_static/favicon.ico
similarity index 100%
rename from doc/source/_static/favicon.ico
rename to examples/docs/source/_static/favicon.ico
diff --git a/examples/doc/source/_static/flower-logo.png b/examples/docs/source/_static/flower-logo.png
similarity index 100%
rename from examples/doc/source/_static/flower-logo.png
rename to examples/docs/source/_static/flower-logo.png
diff --git a/examples/doc/source/_static/tmux_jtop_view.gif b/examples/docs/source/_static/tmux_jtop_view.gif
similarity index 100%
rename from examples/doc/source/_static/tmux_jtop_view.gif
rename to examples/docs/source/_static/tmux_jtop_view.gif
diff --git a/examples/doc/source/_static/view-gh.png b/examples/docs/source/_static/view-gh.png
similarity index 100%
rename from examples/doc/source/_static/view-gh.png
rename to examples/docs/source/_static/view-gh.png
diff --git a/examples/doc/source/_templates/base.html b/examples/docs/source/_templates/base.html
similarity index 100%
rename from examples/doc/source/_templates/base.html
rename to examples/docs/source/_templates/base.html
diff --git a/examples/doc/source/conf.py b/examples/docs/source/conf.py
similarity index 100%
rename from examples/doc/source/conf.py
rename to examples/docs/source/conf.py
diff --git a/examples/flower-authentication/README.md b/examples/flower-authentication/README.md
index 4f312608503d..323362060c5a 100644
--- a/examples/flower-authentication/README.md
+++ b/examples/flower-authentication/README.md
@@ -10,7 +10,7 @@ framework: [torch, torchvision]
> 🧪 = This example covers experimental features that might change in future versions of Flower.
> Please consult the regular PyTorch examples ([quickstart](https://github.com/adap/flower/tree/main/examples/quickstart-pytorch), [advanced](https://github.com/adap/flower/tree/main/examples/advanced-pytorch)) to learn how to use Flower with PyTorch.
-The following steps describe how to start a long-running Flower server (SuperLink+SuperExec) and a long-running Flower clients (SuperNode) with authentication enabled. The task is to train a simple CNN for image classification using PyTorch.
+The following steps describe how to start a long-running Flower server (SuperLink) and a long-running Flower clients (SuperNode) with authentication enabled. The task is to train a simple CNN for image classification using PyTorch.
## Project Setup
@@ -65,13 +65,13 @@ The script also generates a CSV file that includes each of the generated (client
./generate.sh {your_number_of_clients}
```
-## Start the long-running Flower server-side (SuperLink+SuperExec)
+## Start the long-running Flower server (SuperLink)
-Starting long-running Flower server-side components (SuperLink+SuperExec) and enable authentication is very easy; all you need to do is type
+Starting long-running Flower server component (SuperLink) and enable authentication is very easy; all you need to do is type
`--auth-list-public-keys` containing file path to the known `client_public_keys.csv`, `--auth-superlink-private-key`
containing file path to the SuperLink's private key `server_credentials`, and `--auth-superlink-public-key` containing file path to the SuperLink's public key `server_credentials.pub`. Notice that you can only enable authentication with a secure TLS connection.
-Let's first launche the `SuperLink`:
+Let's first launch the `SuperLink`:
```bash
flower-superlink \
@@ -83,20 +83,9 @@ flower-superlink \
--auth-superlink-public-key keys/server_credentials.pub
```
-Then launch the `SuperExec`:
+At this point your server-side is idling. Next, let's connect two `SuperNode`s, and then we'll start a run.
-```bash
-flower-superexec \
- --ssl-ca-certfile certificates/ca.crt \
- --ssl-certfile certificates/server.pem \
- --ssl-keyfile certificates/server.key \
- --executor-config "root-certificates='certificates/ca.crt'" \
- --executor flwr.superexec.deployment:executor
-```
-
-At this point your server-side is idling. First, let's connect two `SuperNodes`, and then we'll start a run.
-
-## Start the long-running Flower client-side (SuperNode)
+## Start the long-running Flower client (SuperNode)
> \[!NOTE\]
> Typically each `SuperNode` runs in a different entity/organization which has access to a dataset. In this example we are going to artificially create N dataset splits and saved them into a new directory called `datasets/`. Then, each `SuperNode` will be pointed to the dataset it should load via the `--node-config` argument. We provide a script that does the download, partition and saving of CIFAR-10.
@@ -110,10 +99,10 @@ In a new terminal window, start the first long-running Flower client (SuperNode)
```bash
flower-supernode \
--root-certificates certificates/ca.crt \
- --superlink 127.0.0.1:9092 \
--auth-supernode-private-key keys/client_credentials_1 \
--auth-supernode-public-key keys/client_credentials_1.pub \
- --node-config 'dataset-path="datasets/cifar10_part_1"'
+ --node-config 'dataset-path="datasets/cifar10_part_1"' \
+ --clientappio-api-address="0.0.0.0:9094"
```
In yet another new terminal window, start the second long-running Flower client:
@@ -121,10 +110,10 @@ In yet another new terminal window, start the second long-running Flower client:
```bash
flower-supernode \
--root-certificates certificates/ca.crt \
- --superlink 127.0.0.1:9092 \
--auth-supernode-private-key keys/client_credentials_2 \
--auth-supernode-public-key keys/client_credentials_2.pub \
- --node-config 'dataset-path="datasets/cifar10_part_2"'
+ --node-config 'dataset-path="datasets/cifar10_part_2"' \
+ --clientappio-api-address="0.0.0.0:9095"
```
If you generated more than 2 client credentials, you can add more clients by opening new terminal windows and running the command
@@ -142,7 +131,7 @@ above. Don't forget to specify the correct client private and public keys for ea
## Run the Flower App
-With both the long-running server-side (SuperLink+SuperExec) and two SuperNodes up and running, we can now start run. Note that the command below points to a federation named `my-federation`. Its entry point is defined in the `pyproject.toml`.
+With both the long-running server (SuperLink) and two SuperNodes up and running, we can now start the run. Note that the command below points to a federation named `my-federation`. Its entry point is defined in the `pyproject.toml`.
```bash
flwr run . my-federation
diff --git a/examples/flower-authentication/pyproject.toml b/examples/flower-authentication/pyproject.toml
index 2dfaf616527f..963fb2af3564 100644
--- a/examples/flower-authentication/pyproject.toml
+++ b/examples/flower-authentication/pyproject.toml
@@ -8,10 +8,10 @@ version = "1.0.0"
description = "Federated Learning with PyTorch and authenticated Flower "
license = "Apache-2.0"
dependencies = [
- "flwr==1.12.0",
- "flwr-datasets[vision]>=0.3.0",
- "torch==2.2.1",
- "torchvision==0.17.1",
+ "flwr>=1.13.1",
+ "flwr-datasets[vision]>=0.4.0",
+ "torch>=2.5.0,<3.0.0",
+ "torchvision>=0.20.1,<0.21.0",
]
[tool.hatch.build.targets.wheel]
@@ -32,8 +32,8 @@ learning-rate = 0.1
batch-size = 32
[tool.flwr.federations]
-default = "superexec"
+default = "my-federation"
[tool.flwr.federations.my-federation]
-address = "127.0.0.1:9093" # Address of the SuperExec
+address = "127.0.0.1:9093" # Address of the Exec API
root-certificates = "certificates/ca.crt"
diff --git a/examples/flower-simulation-step-by-step-pytorch/.gitignore b/examples/flower-simulation-step-by-step-pytorch/.gitignore
new file mode 100644
index 000000000000..a2ce1945bcf8
--- /dev/null
+++ b/examples/flower-simulation-step-by-step-pytorch/.gitignore
@@ -0,0 +1,3 @@
+wandb/
+global_model_round_*
+*.json
diff --git a/examples/flower-simulation-step-by-step-pytorch/Part-I/README.md b/examples/flower-simulation-step-by-step-pytorch/Part-I/README.md
deleted file mode 100644
index d961d29184de..000000000000
--- a/examples/flower-simulation-step-by-step-pytorch/Part-I/README.md
+++ /dev/null
@@ -1,17 +0,0 @@
-# A Complete FL Simulation Pipeline using Flower
-
-In the first part of the Flower Simulation series, we go step-by-step through the process of designing a FL pipeline. Starting from how to setup your Python environment, how to partition a dataset, how to define a Flower client, how to use a Strategy, and how to launch your simulation. The code in this directory is the one developed in the video. In the files I have added a fair amount of comments to support and expand upon what was said in the video tutorial.
-
-## Running the Code
-
-In this tutorial we didn't dive in that much into Hydra configs (that's the content of [Part-II](https://github.com/adap/flower/tree/main/examples/flower-simulation-step-by-step-pytorch/Part-II)). However, this doesn't mean we can't easily configure our experiment directly from the command line. Let's see a couple of examples on how to run our simulation.
-
-```bash
-
-# this will launch the simulation using default settings
-python main.py
-
-# you can override the config easily for instance
-python main.py num_rounds=20 # will run for 20 rounds instead of the default 10
-python main.py config_fit.lr=0.1 # will use a larger learning rate for the clients.
-```
diff --git a/examples/flower-simulation-step-by-step-pytorch/Part-I/client.py b/examples/flower-simulation-step-by-step-pytorch/Part-I/client.py
deleted file mode 100644
index 3d93510b3d0e..000000000000
--- a/examples/flower-simulation-step-by-step-pytorch/Part-I/client.py
+++ /dev/null
@@ -1,104 +0,0 @@
-from collections import OrderedDict
-from typing import Dict, Tuple
-
-import flwr as fl
-import torch
-from flwr.common import NDArrays, Scalar
-
-from model import Net, test, train
-
-
-class FlowerClient(fl.client.NumPyClient):
- """Define a Flower Client."""
-
- def __init__(self, trainloader, vallodaer, num_classes) -> None:
- super().__init__()
-
- # the dataloaders that point to the data associated to this client
- self.trainloader = trainloader
- self.valloader = vallodaer
-
- # a model that is randomly initialised at first
- self.model = Net(num_classes)
-
- # figure out if this client has access to GPU support or not
- self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
-
- def set_parameters(self, parameters):
- """Receive parameters and apply them to the local model."""
- params_dict = zip(self.model.state_dict().keys(), parameters)
-
- state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
-
- self.model.load_state_dict(state_dict, strict=True)
-
- def get_parameters(self, config: Dict[str, Scalar]):
- """Extract model parameters and return them as a list of numpy arrays."""
-
- return [val.cpu().numpy() for _, val in self.model.state_dict().items()]
-
- def fit(self, parameters, config):
- """Train model received by the server (parameters) using the data.
-
- that belongs to this client. Then, send it back to the server.
- """
-
- # copy parameters sent by the server into client's local model
- self.set_parameters(parameters)
-
- # fetch elements in the config sent by the server. Note that having a config
- # sent by the server each time a client needs to participate is a simple but
- # powerful mechanism to adjust these hyperparameters during the FL process. For
- # example, maybe you want clients to reduce their LR after a number of FL rounds.
- # or you want clients to do more local epochs at later stages in the simulation
- # you can control these by customising what you pass to `on_fit_config_fn` when
- # defining your strategy.
- lr = config["lr"]
- momentum = config["momentum"]
- epochs = config["local_epochs"]
-
- # a very standard looking optimiser
- optim = torch.optim.SGD(self.model.parameters(), lr=lr, momentum=momentum)
-
- # do local training. This function is identical to what you might
- # have used before in non-FL projects. For more advance FL implementation
- # you might want to tweak it but overall, from a client perspective the "local
- # training" can be seen as a form of "centralised training" given a pre-trained
- # model (i.e. the model received from the server)
- train(self.model, self.trainloader, optim, epochs, self.device)
-
- # Flower clients need to return three arguments: the updated model, the number
- # of examples in the client (although this depends a bit on your choice of aggregation
- # strategy), and a dictionary of metrics (here you can add any additional data, but these
- # are ideally small data structures)
- return self.get_parameters({}), len(self.trainloader), {}
-
- def evaluate(self, parameters: NDArrays, config: Dict[str, Scalar]):
- self.set_parameters(parameters)
-
- loss, accuracy = test(self.model, self.valloader, self.device)
-
- return float(loss), len(self.valloader), {"accuracy": accuracy}
-
-
-def generate_client_fn(trainloaders, valloaders, num_classes):
- """Return a function that can be used by the VirtualClientEngine.
-
- to spawn a FlowerClient with client id `cid`.
- """
-
- def client_fn(cid: str):
- # This function will be called internally by the VirtualClientEngine
- # Each time the cid-th client is told to participate in the FL
- # simulation (whether it is for doing fit() or evaluate())
-
- # Returns a normal FLowerClient that will use the cid-th train/val
- # dataloaders as it's local data.
- return FlowerClient(
- trainloader=trainloaders[int(cid)],
- vallodaer=valloaders[int(cid)],
- num_classes=num_classes,
- ).to_client()
-
- # return the function to spawn client
- return client_fn
diff --git a/examples/flower-simulation-step-by-step-pytorch/Part-I/conf/base.yaml b/examples/flower-simulation-step-by-step-pytorch/Part-I/conf/base.yaml
deleted file mode 100644
index 24cbb8f2cf0c..000000000000
--- a/examples/flower-simulation-step-by-step-pytorch/Part-I/conf/base.yaml
+++ /dev/null
@@ -1,17 +0,0 @@
----
-# this is a very minimal config file in YAML format
-# it will be processed by Hydra at runtime
-# you might notice it doesn't have anything special that other YAML files don't have
-# check the followup tutorial on how to use Hydra in conjunction with Flower for a
-# much more advanced usage of Hydra configs
-
-num_rounds: 10 # number of FL rounds in the experiment
-num_clients: 100 # number of total clients available (this is also the number of partitions we need to create)
-batch_size: 20 # batch size to use by clients during training
-num_classes: 10 # number of classes in our dataset (we use MNIST) -- this tells the model how to setup its output fully-connected layer
-num_clients_per_round_fit: 10 # number of clients to involve in each fit round (fit round = clients receive the model from the server and do local training)
-num_clients_per_round_eval: 25 # number of clients to involve in each evaluate round (evaluate round = client only evaluate the model sent by the server on their local dataset without training it)
-config_fit: # a config that each client will receive (this is send by the server) when they are sampled. This allows you to dynamically configure the training on the client side as the simulation progresses
- lr: 0.01 # learning rate to use by the clients
- momentum: 0.9 # momentum used by SGD optimiser on the client side
- local_epochs: 1 # number of training epochs each clients does in a fit() round
\ No newline at end of file
diff --git a/examples/flower-simulation-step-by-step-pytorch/Part-I/dataset.py b/examples/flower-simulation-step-by-step-pytorch/Part-I/dataset.py
deleted file mode 100644
index a805906b8d42..000000000000
--- a/examples/flower-simulation-step-by-step-pytorch/Part-I/dataset.py
+++ /dev/null
@@ -1,73 +0,0 @@
-import torch
-from torch.utils.data import DataLoader, random_split
-from torchvision.datasets import MNIST
-from torchvision.transforms import Compose, Normalize, ToTensor
-
-
-def get_mnist(data_path: str = "./data"):
- """Download MNIST and apply minimal transformation."""
-
- tr = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])
-
- trainset = MNIST(data_path, train=True, download=True, transform=tr)
- testset = MNIST(data_path, train=False, download=True, transform=tr)
-
- return trainset, testset
-
-
-def prepare_dataset(num_partitions: int, batch_size: int, val_ratio: float = 0.1):
- """Download MNIST and generate IID partitions."""
-
- # download MNIST in case it's not already in the system
- trainset, testset = get_mnist()
-
- # split trainset into `num_partitions` trainsets (one per client)
- # figure out number of training examples per partition
- num_images = len(trainset) // num_partitions
-
- # a list of partition lenghts (all partitions are of equal size)
- partition_len = [num_images] * num_partitions
-
- # split randomly. This returns a list of trainsets, each with `num_images` training examples
- # Note this is the simplest way of splitting this dataset. A more realistic (but more challenging) partitioning
- # would induce heterogeneity in the partitions in the form of for example: each client getting a different
- # amount of training examples, each client having a different distribution over the labels (maybe even some
- # clients not having a single training example for certain classes). If you are curious, you can check online
- # for Dirichlet (LDA) or pathological dataset partitioning in FL. A place to start is: https://arxiv.org/abs/1909.06335
- trainsets = random_split(
- trainset, partition_len, torch.Generator().manual_seed(2023)
- )
-
- # create dataloaders with train+val support
- trainloaders = []
- valloaders = []
- # for each train set, let's put aside some training examples for validation
- for trainset_ in trainsets:
- num_total = len(trainset_)
- num_val = int(val_ratio * num_total)
- num_train = num_total - num_val
-
- for_train, for_val = random_split(
- trainset_, [num_train, num_val], torch.Generator().manual_seed(2023)
- )
-
- # construct data loaders and append to their respective list.
- # In this way, the i-th client will get the i-th element in the trainloaders list and the i-th element in the valloaders list
- trainloaders.append(
- DataLoader(for_train, batch_size=batch_size, shuffle=True, num_workers=2)
- )
- valloaders.append(
- DataLoader(for_val, batch_size=batch_size, shuffle=False, num_workers=2)
- )
-
- # We leave the test set intact (i.e. we don't partition it)
- # This test set will be left on the server side and we'll be used to evaluate the
- # performance of the global model after each round.
- # Please note that a more realistic setting would instead use a validation set on the server for
- # this purpose and only use the testset after the final round.
- # Also, in some settings (specially outside simulation) it might not be feasible to construct a validation
- # set on the server side, therefore evaluating the global model can only be done by the clients. (see the comment
- # in main.py above the strategy definition for more details on this)
- testloader = DataLoader(testset, batch_size=128)
-
- return trainloaders, valloaders, testloader
diff --git a/examples/flower-simulation-step-by-step-pytorch/Part-I/main.py b/examples/flower-simulation-step-by-step-pytorch/Part-I/main.py
deleted file mode 100644
index 1373f24fbb11..000000000000
--- a/examples/flower-simulation-step-by-step-pytorch/Part-I/main.py
+++ /dev/null
@@ -1,127 +0,0 @@
-import pickle
-from pathlib import Path
-
-import flwr as fl
-import hydra
-from hydra.core.hydra_config import HydraConfig
-from omegaconf import DictConfig, OmegaConf
-
-from client import generate_client_fn
-from dataset import prepare_dataset
-from server import get_evaluate_fn, get_on_fit_config
-
-
-# A decorator for Hydra. This tells hydra to by default load the config in conf/base.yaml
-@hydra.main(config_path="conf", config_name="base", version_base=None)
-def main(cfg: DictConfig):
- ## 1. Parse config & get experiment output dir
- print(OmegaConf.to_yaml(cfg))
- # Hydra automatically creates a directory for your experiments
- # by default it would be in /outputs//