diff --git a/examples/whisper-federated-finetuning/.gitignore b/examples/whisper-federated-finetuning/.gitignore new file mode 100644 index 000000000000..ced85a9714c0 --- /dev/null +++ b/examples/whisper-federated-finetuning/.gitignore @@ -0,0 +1 @@ +processed_partitions/ diff --git a/examples/whisper-federated-finetuning/README.md b/examples/whisper-federated-finetuning/README.md index cfd0db842bae..75f42fb9aa78 100644 --- a/examples/whisper-federated-finetuning/README.md +++ b/examples/whisper-federated-finetuning/README.md @@ -17,51 +17,61 @@ This example can be run in three modes: - in _simulation_ mode: a client is an ephemeral Python process with a portion of the system resources assigned to it. - in _on-device_ mode: clients are detached entities and each can run on a different device. -## Running the example +## Set up the project -Start by cloning the code example. We prepared a single-line command that you can copy into your shell which will checkout the example for you: +### Clone the project + +Start by cloning the example project: ```shell -git clone --depth=1 https://github.com/adap/flower.git && mv flower/examples/whisper-federated-finetuning . && rm -rf flower && cd whisper-federated-finetuning +git clone --depth=1 https://github.com/adap/flower.git _tmp \ + && mv _tmp/examples/whisper-federated-finetuning . \ + && rm -rf _tmp \ + && cd whisper-federated-finetuning ``` This will create a new directory called `whisper-federated-finetuning` containing the following files: -``` --- README.md <- Your're reading this right now --- rpi_setup.md <- A guide that illustrates how to setup your RPi from scratch --- sim.py <- Runs the example with Flower simulation --- server.py <- Defines the server-side logic for the on-device setting --- client.py <- Defines the client-side logic for the on-device setting --- utils.py <- auxiliary functions for this example --- centralised.py <- Runs the example in centralized mode --- pyproject.toml <- Example dependencies (if you use Poetry) --- requirements.txt <- Example dependencies +```shell +whisper-federated-finetuning +├── whisper_example +│ ├── __init__.py +│ ├── client_app.py # Defines your ClientApp +│ ├── server_app.py # Defines your ServerApp +│ ├── model.py # Defines the model and training functions +│ └── dataset.py # Defines your dataset and its processing +├── centralized.py # Centralized version of this example +├── preprocess.py # A utility script to preprocess all partitions +├── pyproject.toml # Project metadata like dependencies and configs +└── README.md ``` -This example can be run in different ways, please refer to the corresponding section for further instructions. This example was tested with `PyTorch 2.1.0` for all the different ways of running this example except when running on the Raspberry Pi, which seemed to only work with `PyTorch 1.13.1`. Please note the requirement files do not specify a version of PyTorch, therefore you need to choose one that works for you and your system. +> \[!NOTE\] +> This example can be run in different ways, please refer to the corresponding section for further instructions. -## Centralized Training +### Install dependencies and project -This section describes how to finetune `Whisper-tiny` for keyword spotting without making use of Federated Learning. This means that the whole training set is available at any point and therefore it is in its entirety to finetune the model each epoch. - -On your favorite Python environment manager, install a recent version of [PyTorch](https://pytorch.org/get-started/locally/) (PyTorch 2.0+ is recommended for faster training times). Then install the rest of the requirements. For instance: +In a new Python environment the dependencies defined in `pyproject.toml` as well as the `whisper_example` package. ```bash -pip install torch==2.1.0 --index-url https://download.pytorch.org/whl/cu118 -pip install -r requirements.txt +pip install -e . ``` +## Centralized Training + +> \[!TIP\] +> This section describes how to finetune `Whisper-tiny` for keyword spotting without making use of Federated Learning. This means that the whole training set is available at any point and therefore it is in its entirety to finetune the model each epoch. Skip to the next section if you want to jump straight into how to run `Whisper-tiny` with Flower! + Then run centralized training as follows. Please note that the first time you run the code, the `SpeechCommnads` dataset will be downloaded and pre-processed using 🤗 API (which takes a little while -- approx 40min -- and is cached in `~/.cache/huggingface/datasets/speechcommands` wiht a footprint of ~83GB). Subsequent runs shouldn't require this preprocessing. ```bash -python centralised.py --compile # don't use `--compile` flag if you are using pytorch < 2.0 +python centralized.py --compile # don't use `--compile` flag if you are using pytorch < 2.0 # The script will save a checkpoint of the classifier head after each epoch # These checkpoints followo the naming style: `classifier_.pt` # You can load a checkpoint by passing it like this: -python centralised.py --checkpoint .pt +python centralized.py --checkpoint .pt ``` Within 2 epochs you should see a validation accuracy of over 95%. On an RTX 3090Ti each epoch takes ~3min30sec. The final test set consistently reaches 97%+. Below is the log you should expect to see: @@ -84,13 +94,11 @@ Evaluating test set. Loading best model TEST ---> loss = 0.001703281509680464, accuracy = 0.9740286298568507 ``` -> You made it work better ? Let us know how you did it by opening an GitHub issue or a PR and we'll gladly incorporate your suggestions! - -## Federated Learning +## Run the project -Centralized training is ok but in many settings it cannot be realised. Primarily because the training data must remain distributed (i.e. on the client side) and cannot be aggregated into a single node (e.g. your server). With Flower we can easily design a federated finetuning pipeline by which clients locally train the classification head on their data, before communicating it to a central server. There, the updates sent by the clients get aggregated and re-distributed among clients for another round of FL. This process is repeated until convergence. Note that, unlike the encoder part of the Whisper model, the classification head is incredibly lightweight (just 780K parameters), adding little communication costs as a result. +Centralized training is ok but in many settings it cannot be realised. Primarily because the training data must remain distributed (i.e. on the client side) and cannot be aggregated into a single node (e.g. your server). With [Flower](https://flower.ai/) we can easily design a federated finetuning pipeline by which clients locally train the classification head on their data, before communicating it to a central server. There, the updates sent by the clients get aggregated and re-distributed among clients for another round of FL. This process is repeated until convergence. Note that, unlike the encoder part of the Whisper model, the classification head is incredibly lightweight (just 780K parameters), adding little communication costs as a result. -In this example, we partition the training set along the `speaker_id` column into 100 buckets to simulate that many groups of people. You can think of each group as an individual FL _client_ that contains several users/speakers. One way to think about this is to view each client as an office with several people working there, each interacting with the Keyword spotting system. This example exclusively federates the training of the classification head. +There are a total of 2112 speakers in the `train` partition, which is the one we'll use in FL. ```python from datasets import load_dataset @@ -107,147 +115,153 @@ print(len(ids)) # 2113 # <--- +1 since a "None" speaker is included (for clips to construct the _silence_ training examples) ``` +In this example, we use the [GroupedNaturalIdPartitioner](https://flower.ai/docs/datasets/ref-api/flwr_datasets.partitioner.GroupedNaturalIdPartitioner.html) from [Flower Datasets](https://flower.ai/docs/datasets/index.html) to partition the SpeepCommands dataset based on `speaker_id`. We will create groups of 5 speakers, this will result in a total of 422 groups, each representing a node/client in the federation.. Each `speaker_id` is only present in a single group. You can think of each group as an individual Federated Learning _node_ that contains several users/speakers. One way to think about this is to view each client as an office with several people working there, each interacting with the Keyword spotting system. + ![Federated Whisper Finetuning pipeline](_static/federated_finetuning_flower_pipeline.png) +The resulting data partitions are not equal-sized (which is what you'd often find in practice in the real world) because not all `speaker_id` contributed the same amount of audio clips when the [Speech Commands Dataset](https://arxiv.org/abs/1804.03209) was created. If we make a bar plot showing the amount of data each client/node has this is the result. + +![Amount of data per client](_static/whisper_flower_data.png) + +> \[!NOTE\] +> You can make create this plot or adjust it by running the [visualize_labels](visualize_labels.ipynb) notebook. It makes use of Flower Dataset's [visualization tools](https://flower.ai/docs/datasets/tutorial-visualize-label-distribution.html). + An overview of the FL pipeline built with Flower for this example is illustrated above. -1. At the start of a round, the server communicates the classification head to a fraction of the clients. At round #0, the classification head is randomly intialised. -2. Each client, using a frozen pre-trained Whisper encoder, trains the classification head using its own data samples. -3. Once on-site training is completed, each client sends back the (now updated) classification head to the Flower server. -4. The Flower server aggregates (via FedAvg) the classification heads in order to obtain a new _global_ classification head. This head will be shared with clients in the next round. +1. At the start of a round, the `ServerApp` communicates the weights of classification head to a fraction of the nodes. +2. The `ClientApp` in each node, using a frozen pre-trained Whisper encoder, trains the classification head using its own data samples. +3. Once on-site training is completed, each node sends back the (now updated) classification head to the `ServerApp`. +4. The Flower `ServerApp` aggregates (via [FedAvg](https://flower.ai/docs/framework/ref-api/flwr.server.strategy.FedAvg.html) -- but you can [choose any other strategy](https://flower.ai/docs/framework/ref-api/flwr.server.strategy.html), or implement your own!) the classification heads in order to obtain a new _global_ classification head. This head will be shared with nodes in the next round. -Flower supports two ways of doing Federated Learning: simulated and non-simulated FL. The former, managed by the [`VirtualClientEngine`](https://flower.ai/docs/framework/how-to-run-simulations.html), allows you to run large-scale workloads in a system-aware manner, that scales with the resources available on your system (whether it is a laptop, a desktop with a single GPU, or a cluster of GPU servers). The latter is better suited for settings where clients are unique devices (e.g. a server, a smart device, etc). This example shows you how to use both. +You can run your Flower project in both _simulation_ and _deployment_ mode without making changes to the code. If you are starting with Flower, we recommend you using the _simulation_ mode as it requires fewer components to be launched manually. By default, `flwr run` will make use of the Simulation Engine. -### Preparing the dataset +### Run with the Simulation Engine -If you have run the centralized version of this example first, you probably realized that it takes some time to get a fully pre-processed SpeechCommands dataset using the 🤗 HuggingFace API. This pre-processing is ideal so nothing slowdowns our training once we launch the experiment. For the federated part of this example, we also need to pre-process the data however in a different way since first the training set needs to be split into N different buckets, one for each FL client. +The run is defined in the `pyproject.toml` which: specifies the paths to `ClientApp` and `ServerApp` as well as their parameterization with configs in the `[tool.flwr.app.config]` block. -To launch a Flower client we need a `client_fn` callable that will: (1) Load the dataset of the client; then, (2) return the Client object itself. In `client.py` we have included a few lines of code that preprocess the training partition of a given client and save it to disk (so this doesn't have to be repeated each time you run the experiment). The average pre-processed partition is ~0.5GB. You can run the experiment right away and the data will be pre-processed on-demand (i.e. when the `i`-th client is spawned for the first time), or you can pre-process all client partitions first. In order to do so, please run: +> \[!NOTE\] +> By default, it will run on CPU only. On a MacBook Pro M2, running 3 rounds of Flower FL should take ~10 min. Assuming the dataset has already been downloaded. Running on GPU is recommended (for this use the `local-sim-gpu` federation, or continue reading). Also note that the logs from the `ClientApps` are been silenced. You can disable this by setting to `true` the entry `options.backend.init-args.log-to-driver` in the federation in `pyproject.toml` you are using. Read more about how Flower Simulations work in [the documentation](https://flower.ai/docs/framework/how-to-run-simulations.html). -```bash -# will write to disk all pre-processed data partitions -# by default these will go to a new directory named `client_datasets` -# Similarly to the centralised setting, this preprocessing will take a while (30mins approx) -python sim.py --preprocess +```shell +# Run with default settings (21 clients per round out of 422) +flwr run . ``` -The resulting data partitions are not equal-sized (which is what you'd often find in practice in the real world) because not all `speaker_id` contributed the same amount of audio clips when the [Speech Commands Dataset](https://arxiv.org/abs/1804.03209) was created. If we make a bar plot showing the amount of data each client has this is the result. +You can expect a summary at then showing federated metrics (i.e. the average training accuracy and loss across clients sampled in a round) looking like this: -![Amount of data per client](_static/whisper_flower_data.png) +```shell +INFO : [SUMMARY] +INFO : Run finished 3 round(s) in 564.50s +INFO : History (metrics, distributed, fit): +INFO : {'train_accuracy': [(1, 0.637721849625075), +INFO : (2, 0.8666815319504736), +INFO : (3, 0.8912498749526644)], +INFO : 'train_loss': [(1, 4.049714171341712), +INFO : (2, 1.8473016127565092), +INFO : (3, 2.5116721350250693)]} +INFO : +``` -### Federated Finetuning (Simulation) +To run your `ClientApps` on GPU, you'll need to run it in another federation (see `local-sim-gpu` in `pyprojec.toml`). To adjust the degree of parallelism, consider updating the `option.backend` settings. `ClientApp` instances consume only 800MB of VRAM, which enables you to run several in parallel in the same GPU. By default, the command below will run `5xClientApp` in parallel for each GPU available. -The setup instructions for simulations are the same as those described for the centralized setting above: install PyTorch and then `pip install -r requirements.txt`. Then, you can launch your simulation as shown below. Without changes to the code or input arguments, the simulation will sample `10` clients per round, these would do 1 local epoch of finetuning the classification head while the encoder remains frozen. Once this is completed, the classification head is sent to the server for aggregation via `FedAvg`. By default, this example assumes you have a GPU available. +```shell +# Run with GPU (21 clients per round out of 422) +# (each active client gets allocated 20% available VRAM) +flwr run . local-sim-gpu +``` + +You can also override some of the settings for your `ClientApp` and `ServerApp` defined in `pyproject.toml`. For example: ```bash -# By default it will run 2 clients in parallel on a single GPU (which should be fine if your GPU has at least 16GB ) -# If that's too much, consider reduing either the batch size or raise `num_gpus` passed to `start_simulation` -python sim.py # append --num_gpus=0 if you don't have GPUs on your system +# Runs for 10 rounds and sampling 20% of the clients in each round +flwr run . --run-config "num-server-rounds=10 fraction-fit=0.2" +``` + +With just 5 FL rounds, the global model should be reaching ~97% validation accuracy. A test accuracy of 96% can be reached with 10 rounds of FL training using the default hyperparameters. On an RTX 3090Ti, each round takes ~40-50s depending on the amount of data the clients selected in a round have. -# Once finished centralised evaluation loss/acc metrics will be shown +Run on GPU with central evaluation activated and for 10 rounds. -INFO flwr 2023-11-08 14:03:57,557 | app.py:229 | app_fit: metrics_centralized {'val_accuracy': [(0, 0.03977158885994791), (1, 0.6940492887196954), (2, 0.5969745541975556), (3, 0.8794830695251452), (4, 0.9021238228811861), (5, 0.8943097575636145), (6, 0.9047285113203767), (7, 0.9330795431777199), (8, 0.9446002805049089), (9, 0.9556201162091765)], 'test_accuracy': [(10, 0.9719836400817996)]} +```shell +flwr run . local-sim-gpu --run-config "central-eval=true num-server-rounds=10" ``` ![Global validation accuracy FL with Whisper model](_static/whisper_flower_acc.png) -With just 5 FL rounds, the global model should be reaching ~95% validation accuracy. A test accuracy of 97% can be reached with 10 rounds of FL training using the default hyperparameters. On an RTX 3090Ti, each round takes ~20-30s depending on the amount of data the clients selected in a round have. +> \[!TIP\] +> If you find this federated setup not that challenging, try reducing the sizes of the groups created by the `GroupedNaturalIdPartitioner`. That will increase the number of individual clients/nodes in the federation. -Take a look at the [Documentation](https://flower.ai/docs/framework/how-to-run-simulations.html) for more details on how you can customize your simulation. +### Run with the Deployment Engine -### Federated Finetuning (non-simulated) +> \[!NOTE\] +> The steps here outline the very few changes you need to make to the code provided in this example to run with the Deployment Engine instead of with the Simulation Engine. For a beginners guide on how the Deployment Engine works, please check the [Run Flower with the Deployment Engine](https://flower.ai/docs/framework/how-to-run-flower-with-deployment-engine.html) guide. That guide will introduce how to enable secure TLS and node authentication. -Running the exact same FL pipeline as in the simulation setting can be done without using Flower's simulation engine. To achieve this, you need to launch first a server and then two or more clients. You can do this on your development machine assuming you have set up your environment already. +Running the exact same FL pipeline as in the simulation setting can be done without requiring any change to the `ServerApp` design. For the `ClientApp` we need to slightly adjust the logic that loads the dataset. While in simulations we want to dynamically make a Python process to _behave_ like a particular client by loading its corresponding partition, in deployment mode we want the same client process (linked to a single `SuperNode`) to always use its own dataset that lives locally in the machine running the `SuperNode`. -First, launch the server, which will orchestrate the FL process: +An obvious first step would be to generate N data partitions and assing each to a different `SuperNode`. Let's start with this step by means of the `preprocess.py` script. These are the steps we'll follow: -```bash -# The server will wait until at least two clients are connected -# you can use `--server_address='localhost'` if you are running everything on the same machine. -python server.py --server_addres= -``` +1. Extract and save two partitions from the dataset. Each will be assigned to a different `SuperNode`. +2. Modify the `client_fn` in `client_app.py` so it directly loads the partition specified when launching the `SuperNode`. +3. Copy the generate partition to the machine where the `SuperNode` is going to be executed. -Then on different (new) terminals run: +**1. Save a data partition** -```bash -# use a difference `--cid` (client id) to make the client load a particular dataset partition (any integer between 0-99) -# you can use `--server_address='localhost'` if you are running everything on the same machine. -python client.py --server_address= --cid=0 +Run twice the following command, each time indicating a different partition id. Each time you run it a directory in the form `partition_` will be created. -# and on a new terminal/machine (and optionally a different `cid`) -python client.py --server_address= --cid=1 +```shell +python preprocess.py --partition-id=5 ``` -Once the second client connects to the server, the FL process will begin. Each client will report its training progress. The server process will do the same +**2. Adjust `client_fn`** -```bash -# python client.py --server_address='localhost' --cid=50 -# This client runs on a NVIDIA RTX 3090Ti -INFO flwr 2023-11-08 14:12:50,135 | grpc.py:49 | Opened insecure gRPC connection (no certificates were passed) -DEBUG flwr 2023-11-08 14:12:50,136 | connection.py:42 | ChannelConnectivity.IDLE -DEBUG flwr 2023-11-08 14:12:50,136 | connection.py:42 | ChannelConnectivity.CONNECTING -DEBUG flwr 2023-11-08 14:12:50,140 | connection.py:42 | ChannelConnectivity.READY -99%|████████████████████████| 920/925 [00:09<00:00, 93.39it/s, avg_loss=2.4414, avg_acc=0.1837] -99%|████████████████████████| 920/925 [00:04<00:00, 216.93it/s, avg_loss=2.0191, avg_acc=0.3315] -99%|████████████████████████| 920/925 [00:04<00:00, 214.29it/s, avg_loss=1.5950, avg_acc=0.5500] -99%|████████████████████████| 920/925 [00:04<00:00, 212.70it/s, avg_loss=1.1883, avg_acc=0.7348] -99%|████████████████████████| 920/925 [00:04<00:00, 208.69it/s, avg_loss=0.8466, avg_acc=0.8228] -99%|████████████████████████| 920/925 [00:04<00:00, 206.31it/s, avg_loss=0.6353, avg_acc=0.8837] -99%|████████████████████████| 920/925 [00:03<00:00, 266.73it/s, avg_loss=0.4842, avg_acc=0.9207] -99%|████████████████████████| 920/925 [00:04<00:00, 212.13it/s, avg_loss=0.3519, avg_acc=0.9391] -99%|████████████████████████| 920/925 [00:04<00:00, 213.17it/s, avg_loss=0.3233, avg_acc=0.9359] -99%|████████████████████████| 920/925 [00:04<00:00, 205.12it/s, avg_loss=0.2646, avg_acc=0.9543] -DEBUG flwr 2023-11-08 14:20:01,065 | connection.py:139 | gRPC channel closed -INFO flwr 2023-11-08 14:20:01,065 | app.py:215 | Disconnect and shut down -``` +Rename the `partition-id` key with something more meaningful such as `local-dataset` in file `whisper_example/client_app.py` and replace the call to `load_data` with `load_data_from_disk`. This will make your `ClientApp` use the dataset you point your `SuperNode` to when launching it: -### Federated Finetuning on Raspberry Pi +```python +from whisper_example.dataset import load_data_from_disk -Setting up the environment for the Raspberry Pi is not that different from the steps you'd follow on any other Ubuntu machine (this example assumes your Raspberry Pi -- either 5 or 4 -- runs Ubuntu server 22.04/23.10 64bits). Using the code as-is, RAM usage on the Raspberry Pi does not exceed 1.5GB. Note that unlike in the previous sections of this example, clients for Raspberry Pi work better when using PyTorch 1.13.1 (or earlier versions to PyTorch 2.0 in general). +def client_fn(context: Context): -> Please follow the steps [here](rpi_setup.md) if you are looking for a step-by-step guide on how to setup your Raspberry Pi to run this example. + # partition_id = context.node_config["partition-id"] # disable + local_data = context.node_config["local-data"] # new line -In order to run this example on a Raspberry Pi, you'll need to follow the same steps as outlined above in the `non-simulated` section. First, launch the server on your development machine. + # keep the same -```bash -# The server will wait until at least two clients are connected -python server.py --server_addres= + # replace the `load_data` lines with this + partition = load_data_from_disk(local_data) ``` -Then, on each of your Raspberry Pi do the following. If you only have one RPi, you can still run the example! But you will need two clients. In addition to the one on the Raspberry Pi, you could launch a client in a separate terminal on your development machine (as shown above in the `non-simulated` section). +**3. Make data available to the SuperNode** -```bash -# use a difference `--cid` (client id) to make this device load a particular dataset partition -# we pass the `--no-compile` option since for RPi we are not using PyTorch 2.0+ -python client.py --server_address= --cid=0 --no-compile +You will need to copy the generated directory in step 1 to the machine that will run the `SuperNode`. You can use, for example, the [`scp`](https://linuxize.com/post/how-to-use-scp-command-to-securely-transfer-files/) in order to do that. + +With steps 1-3 completed, you are ready to run Federated Wishper finetuning with Flower's Deployment Eninge. To connect a `SuperNode` to an existing federation (i.e. a running `SuperLink`) you'd do it like this assuming all the python dependencies (i.e. `flwr`, `transformers`, `torch` are installed -- see `pyproject.toml`) in the Python environment of the machine where you are launching the `SuperNode` from: + +```shell +flower-supernode --superlink=":9092" \ + --node-config="local-data=''" ``` -The first time you run a client on the RPi, the dataset of a client needs to be extracted from the full train set and then pre-processed. The Raspberry Pi 5 is also faster in this pre-processing stage using `.filter()` and `.map()` of 🤗 HuggingFace Dataset. `map()` used `num_proc=4`: +**4. Run your whipser app** + +Once your `SuperNodes` are connected to the `SuperLink`, start the run via `flwr run`, but this time point it to the `remote` federation. It is defined at the bottom of the `pyproject.toml`. You might want to update the `address` so it matches that of the machine where the `SuperLink` is running from. -| **Stage** | Notes | **RPi 4** | **RPi 5** | -| :-------------------------------------: | :----------------------------------------------: | --------- | --------- | -| Filter through training set (~85k rows) | doing `.filter()` in `client.client_fn` | 1:58 | 0.37 | -| Encode 845 rows with `WhisperProcessor` | doing `.map()` passing `utils.prepare_dataset()` | 1:55 | 1:06 | +```shell +flwr run . remote +``` -Some clients have more data than others, but on average, the RPi5 is 1.9x faster than an RPi4 when training the classification head given a frozen encoder. A client with 925 training examples needs ~20min on an RPi to complete an epoch of on-device finetuning. +### Federated Finetuning on Raspberry Pi -```bash -# Running the 50-th client on a RPi 5 showed the following log (a RPi4 ran client 83) -python client.py --cid=50 --server_address= --no-compile -INFO flwr 2023-11-08 16:20:33,331 | grpc.py:49 | Opened insecure gRPC connection (no certificates were passed) -DEBUG flwr 2023-11-08 16:20:33,333 | connection.py:42 | ChannelConnectivity.IDLE -DEBUG flwr 2023-11-08 16:20:33,334 | connection.py:42 | ChannelConnectivity.CONNECTING -DEBUG flwr 2023-11-08 16:20:33,349 | connection.py:42 | ChannelConnectivity.READY -99%|████████████████████████████████████████| 920/925 [20:09<00:06, 1.31s/it, avg_loss=2.4392, avg_acc=0.1902] -99%|████████████████████████████████████████| 920/925 [20:06<00:06, 1.31s/it, avg_loss=1.9830, avg_acc=0.3533] -99%|████████████████████████████████████████| 920/925 [20:06<00:06, 1.31s/it, avg_loss=1.6069, avg_acc=0.5641] -99%|████████████████████████████████████████| 920/925 [20:07<00:06, 1.31s/it, avg_loss=1.1933, avg_acc=0.7402] -99%|████████████████████████████████████████| 920/925 [20:07<00:06, 1.31s/it, avg_loss=0.8749, avg_acc=0.8478] -99%|████████████████████████████████████████| 920/925 [20:06<00:06, 1.31s/it, avg_loss=0.5933, avg_acc=0.9109] -99%|████████████████████████████████████████| 920/925 [20:08<00:06, 1.31s/it, avg_loss=0.4882, avg_acc=0.9359] -99%|████████████████████████████████████████| 920/925 [20:01<00:06, 1.31s/it, avg_loss=0.4022, avg_acc=0.9304] -99%|████████████████████████████████████████| 920/925 [20:10<00:06, 1.32s/it, avg_loss=0.3219, avg_acc=0.9533] -99%|████████████████████████████████████████| 920/925 [20:13<00:06, 1.32s/it, avg_loss=0.2729, avg_acc=0.9641] -DEBUG flwr 2023-11-08 19:47:56,544 | connection.py:139 | gRPC channel closed -INFO flwr 2023-11-08 19:47:56,544 | app.py:215 | Disconnect and shut down +To launch a Flower `SuperNode` on a Raspberry Pi you'd typically follow the same steps you do on any other machine you'd like to connect to a federation. + +First, ensure your Rasberry Pi has been setup correctly. You'll need either a Rasbperry Pi 4 or 5. Using the code as-is, RAM usage on the Raspberry Pi does not exceed 1.5GB. Note that unlike in the previous sections of this example, clients for Raspberry Pi work better when using PyTorch 1.13.1 (or earlier versions to PyTorch 2.0 in general). + +> \[!TIP\] +> Follow the `Setup your Pi` section in the [examples/embedded-devices](https://github.com/adap/flower/tree/main/examples/embedded-devices#setting-up-a-raspberry-pi) example to set it up if you haven't done so already. + +Second, generate and copy the a single data partition to your raspbery pi. Do so from your development machine (e.g. your laptop) as shown earlier in the [Run with the Deployment Engine](#run-with-the-deployment-engine) section. + +Finally, assuming you have a `SuperLink` running on a machine (e.g. your laptop) and which can be reached by your Raspberry Pi (e.g. because they are in the same network), launch the `SuperNode` as shown earlier: + +```shell +flower-supernode --superlink=":9092" \ + --node-config="local-data=''" ``` diff --git a/examples/whisper-federated-finetuning/_static/whisper_flower_data.png b/examples/whisper-federated-finetuning/_static/whisper_flower_data.png index 92a29ceff979..9a77673b3088 100644 Binary files a/examples/whisper-federated-finetuning/_static/whisper_flower_data.png and b/examples/whisper-federated-finetuning/_static/whisper_flower_data.png differ diff --git a/examples/whisper-federated-finetuning/centralised.py b/examples/whisper-federated-finetuning/centralized.py similarity index 82% rename from examples/whisper-federated-finetuning/centralised.py rename to examples/whisper-federated-finetuning/centralized.py index c0e3d60a0697..e17bd5d2c7ea 100644 --- a/examples/whisper-federated-finetuning/centralised.py +++ b/examples/whisper-federated-finetuning/centralized.py @@ -1,26 +1,25 @@ import argparse import random -import numpy as np import torch -from datasets import concatenate_datasets, load_dataset -from torch.utils.data import DataLoader, WeightedRandomSampler -from transformers import WhisperForConditionalGeneration, WhisperProcessor - -from utils import ( +from torch.utils.data import DataLoader +from transformers import WhisperProcessor +from whisper_example.dataset import get_encoding_fn, prepare_silences_dataset +from whisper_example.model import ( + construct_balanced_sampler, eval_model, - get_encoding_fn, get_model, - prepare_silences_dataset, - remove_cols, train_one_epoch, ) +from datasets import concatenate_datasets, load_dataset + random.seed(1989) torch.set_float32_matmul_precision( "high" ) # If “high” or “medium” are set then the TensorFloat32 is used NUM_CLASSES = 12 +REMOVE_COLS = ["file", "audio", "label", "is_unknown", "speaker_id", "utterance_id"] parser = argparse.ArgumentParser(description="Whisper centralised") parser.add_argument("--checkpoint", type=str, help="path to classifier`s checkpoint") @@ -56,10 +55,10 @@ def main(): torch.set_num_threads( 1 ) # not clear to me why we need this in order to be able to use `num_proc > 1 for .map` - train_encoded = sc.map(prepare_dataset_fn, num_proc=4, remove_columns=remove_cols) - val_encoded = sc_val.map(prepare_dataset_fn, num_proc=4, remove_columns=remove_cols) + train_encoded = sc.map(prepare_dataset_fn, num_proc=4, remove_columns=REMOVE_COLS) + val_encoded = sc_val.map(prepare_dataset_fn, num_proc=4, remove_columns=REMOVE_COLS) test_encoded = sc_test.map( - prepare_dataset_fn, num_proc=4, remove_columns=remove_cols + prepare_dataset_fn, num_proc=4, remove_columns=REMOVE_COLS ) # create and pre-process the dataset of silences @@ -68,7 +67,7 @@ def main(): # ! needed each time you run the code. Alternatively, this silence generation could be # ! implemented as part of a `collate_fn` in the standard PyTorch dataloader... encoded_silences = silences_dataset.map( - prepare_dataset_fn, num_proc=4, remove_columns=remove_cols + prepare_dataset_fn, num_proc=4, remove_columns=REMOVE_COLS ) full_train_dataset = concatenate_datasets([train_encoded, encoded_silences]) @@ -76,18 +75,11 @@ def main(): lbls = set(full_train_dataset["targets"]) print(f"{lbls = }") - hist = np.histogram(full_train_dataset["targets"], bins=12) - print(f"{[int(count) for count in hist[0]]}") - - # make balanced batches with a WeightedRandomSampler - w_per_class = ( - len(full_train_dataset) / hist[0] - ) # doesn't have to add up to 1 (relative is what matters) - print(f"{w_per_class = }") - w_ss = [w_per_class[t] for t in full_train_dataset["targets"]] - sampler = WeightedRandomSampler(w_ss, len(w_ss)) - - # prepare dataloaders + # Construct a balanced sampler so batches roughly contain the same number + # of samples from each class + sampler = construct_balanced_sampler(full_train_dataset) + + # Prepare dataloaders train_dataset = full_train_dataset.with_format("torch", columns=["data", "targets"]) train_loader = DataLoader( train_dataset, batch_size=64, shuffle=False, num_workers=4, sampler=sampler @@ -97,7 +89,7 @@ def main(): test_dataset = test_encoded.with_format("torch", columns=["data", "targets"]) test_loader = DataLoader(test_dataset, batch_size=64, num_workers=4) - # model to cuda, set criterion, classification layer to train and optimiser + # Model to cuda, set criterion, classification layer to train and optimiser device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") encoder, classifier = get_model(device, num_classes=12) criterion = torch.nn.CrossEntropyLoss() @@ -113,7 +105,7 @@ def main(): classifier_head_params = sum(p.numel() for p in classifier.parameters()) print(f"{classifier_head_params = }") - # eval initial model + # Eval initial model loss, accuracy = eval_model(encoder, classifier, criterion, val_loader, device) print(f"Initial (loss, acc): {loss = }, {accuracy = }") best = [-float("inf"), None] diff --git a/examples/whisper-federated-finetuning/client.py b/examples/whisper-federated-finetuning/client.py deleted file mode 100644 index d1da5c13ecf8..000000000000 --- a/examples/whisper-federated-finetuning/client.py +++ /dev/null @@ -1,185 +0,0 @@ -import argparse - -import flwr as fl -import numpy as np -import torch -from datasets import concatenate_datasets, load_dataset, load_from_disk -from torch.utils.data import DataLoader, WeightedRandomSampler -from transformers import WhisperProcessor - -from utils import ( - construct_client_mapping, - get_encoding_fn, - get_model, - prepare_silences_dataset, - remove_cols, - set_params, - train_one_epoch, -) - -parser = argparse.ArgumentParser(description="Flower+Whisper") -parser.add_argument("--cid", type=int, required=True, help="Client id.") -parser.add_argument( - "--server_address", type=str, required=True, help="IP of the server." -) -parser.add_argument( - "--no-compile", action="store_true", help="To not compile client models." -) - -CLIENT_DATA = "client_datasets" - - -class WhisperFlowerClient(fl.client.NumPyClient): - """A Flower client that does trains a classification head attached to the encoder of - a Whisper-tiny encoder for Keyword spotting.""" - - def __init__(self, trainset, num_classes: int, disable_tqdm: bool, compile: bool): - self.disable_tqdm = disable_tqdm - self.trainset = trainset.with_format("torch", columns=["data", "targets"]) - - # Determine device - self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - - # processor = WhisperProcessor.from_pretrained("openai/whisper-tiny") - self.encoder, self.classifier = get_model(self.device, num_classes, compile) - - def get_parameters(self, config): - """Return parameters in a format that is understood by the server.""" - return [val.cpu().numpy() for _, val in self.classifier.state_dict().items()] - - def fit(self, parameters, config): - """Do on-device training. - - Here the client receives the parameters of the classification head from the - server. Then trains that classifier using the data that belongs to this client. - Finally, The updated classifier is sent back to the server for aggregation. - """ - - # Apply the classifier parameters to the model in this client - set_params(self.classifier, parameters) - - # Read from config - batch, epochs = config["batch_size"], config["epochs"] - - # construct sampler in order to have balanced batches - hist = np.histogram(self.trainset["targets"], bins=12) - w_per_class = ( - len(self.trainset) / hist[0] - ) # doesn't have to add up to 1 (relative is what matters) - # print(f"{w_per_class = }") - w_ss = [w_per_class[t] for t in self.trainset["targets"]] - ss = WeightedRandomSampler(w_ss, len(w_ss)) - - # Construct dataloader - train_loader = DataLoader( - self.trainset, - batch_size=batch, - shuffle=False, - num_workers=0, - sampler=ss, - drop_last=True, - ) - - # Define optimizer and criterion - criterion = torch.nn.CrossEntropyLoss() - optimizer = torch.optim.SGD(self.classifier.parameters(), lr=0.001) - # Train - train_one_epoch( - self.encoder, - self.classifier, - optimizer, - criterion, - train_loader, - self.device, - disable_tqdm=self.disable_tqdm, - ) - - # Return local classification head and statistics - return self.get_parameters({}), len(train_loader.dataset), {} - - -def get_client_fn( - full_data, - encoding_fn, - client_mapping, - client_data_path: str = "./", - num_classes: int = 12, - disable_tqdm: bool = False, - compile: bool = True, -): - """Return a function that can be used to instantiate a particular client.""" - - def client_fn(cid: str): - torch.set_float32_matmul_precision( - "high" - ) # If “high” or “medium” are set then the TensorFloat32 is used - - # if dataset hasn't been processed for this client, do so. - # else, just load it - try: - full_train_dataset = load_from_disk(f"{client_data_path}/client{cid}.hf") - except: - # get this client's data and preprocess it - print(f"Dataset for client {cid} not found. Pre-processing...") - og_threads = torch.get_num_threads() - torch.set_num_threads(1) - sc_client = full_data.filter( - lambda example: example["speaker_id"] in client_mapping[int(cid)] - ) - client_train_data = sc_client.map( - encoding_fn, num_proc=4, remove_columns=remove_cols - ) - - # now let's add some _silence_ training examples (add 10% of total examples in this client's data) - ratio_silences_for_client = 0.1 * (len(client_train_data) / len(full_data)) - silence_dataset = prepare_silences_dataset( - full_data, ratio_silences_for_client - ) - print( - f"adding {len(silence_dataset)} to client data ({len(client_train_data)})" - ) - silence_enc = silence_dataset.map(encoding_fn, remove_columns=remove_cols) - - full_train_dataset = concatenate_datasets([client_train_data, silence_enc]) - # save dataset. It will be loaded next time this client is spawned - full_train_dataset.save_to_disk(f"{client_data_path}/client{cid}.hf") - torch.set_num_threads(og_threads) - - return WhisperFlowerClient( - full_train_dataset, num_classes, disable_tqdm, compile - ).to_client() - - return client_fn - - -def main(): - """Run client.""" - - # Parse input arguments - args = parser.parse_args() - - sc_train = load_dataset("speech_commands", "v0.02", split="train", token=False) - - # generate splits - client_mapping = construct_client_mapping(sc_train, num_clients=100) - - # pre-process all partitions (+store to disk) - processor = WhisperProcessor.from_pretrained("openai/whisper-tiny") - prepare_dataset_fn = get_encoding_fn(processor) - - client_fn = get_client_fn( - sc_train, - prepare_dataset_fn, - client_mapping, - compile=not (args.no_compile), - client_data_path=CLIENT_DATA, - ) - - fl.client.start_client( - server_address=f"{args.server_address}:8080", - client=client_fn(args.cid), - ) - - -if __name__ == "__main__": - main() diff --git a/examples/whisper-federated-finetuning/preprocess.py b/examples/whisper-federated-finetuning/preprocess.py new file mode 100644 index 000000000000..c0cca2388efc --- /dev/null +++ b/examples/whisper-federated-finetuning/preprocess.py @@ -0,0 +1,60 @@ +import argparse +from multiprocessing import Pool +from time import time + +import tomli +from whisper_example.dataset import load_data + +from datasets import load_dataset + +parser = argparse.ArgumentParser(description="Whisper preprocessing") + +parser.add_argument( + "--partition-id", type=int, help="The partition to create and save." +) + +args = parser.parse_args() + + +# Open and read the pyproject.toml +with open("pyproject.toml", "rb") as file: + flwr_config = tomli.load(file)["tool"]["flwr"] + +# Display +print(flwr_config) +remove_cols = flwr_config["app"]["config"]["remove-cols"] +num_supernodes = flwr_config["federations"]["local-sim"]["options"]["num-supernodes"] + +# If specified one partition, only that one will be processed and saved to the current directory +if args.partition_id: + print(f"Pre-processing partition {args.partition_id} only.") +else: + print(f"Pre-processing dataset into {num_supernodes} partitions.") + + +def process_one_partition(partition_id: int, save: bool = False): + pp = load_data(partition_id, remove_cols) + if save: + file_name = f"partition_{partition_id}" + pp.save_to_disk(file_name) + print(f"Saved partition to disk: {file_name}") + + +if __name__ == "__main__": + + # Download train set + _ = load_dataset("speech_commands", "v0.02", split="train", token=False) + + # Parallelize the processing of each partition in the dataset + t_start = time() + num_proc = None # set it if you want to limit the number of processes + + if args.partition_id: + process_one_partition(args.partition_id, True) + + else: + with Pool(num_proc) as pool: + pool.map(process_one_partition, range(num_supernodes)) + print( + f"Pre-processing {num_supernodes} partitions took: {time() - t_start:.2f} s" + ) diff --git a/examples/whisper-federated-finetuning/pyproject.toml b/examples/whisper-federated-finetuning/pyproject.toml index 895ec3a11343..951e9d177217 100644 --- a/examples/whisper-federated-finetuning/pyproject.toml +++ b/examples/whisper-federated-finetuning/pyproject.toml @@ -1,19 +1,56 @@ [build-system] -requires = ["poetry-core>=1.4.0"] -build-backend = "poetry.core.masonry.api" - -[tool.poetry] -name = "whisper-flower" -version = "0.1.0" -description = "On-device Federated Downstreaming for Speech Classification" -authors = ["The Flower Authors "] - -[tool.poetry.dependencies] -python = ">=3.9,<3.11" -flwr = { extras = ["simulation"], version = ">=1.0,<2.0" } -transformers = "4.38.0" -tokenizers = "0.14.0" -datasets = "2.14.6" -soundfile = "0.12.1" -librosa = "0.10.1" -# this example was tested with pytorch 2.1.0 +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "whisper_example" +version = "1.0.0" +description = "On-device Federated Finetuning for Speech Classification" +license = "Apache-2.0" +dependencies = [ + "flwr[simulation]>=1.14.0", + "flwr-datasets[audio]>=0.5.0", + "transformers==4.44.2", + "torch==2.5.1", +] + +[tool.hatch.build.targets.wheel] +packages = ["."] + +[tool.flwr.app] +publisher = "flwrlabs" + +[tool.flwr.app.components] +serverapp = "whisper_example.server_app:app" +clientapp = "whisper_example.client_app:app" + +[tool.flwr.app.config] +num-server-rounds = 3 +fraction-fit = 0.05 # sample 5% of clients in each round (5% of 422 is 21) +num-classes = 12 +batch-size = 8 +compile-model = false +disable-tqdm = true +central-eval = false +remove-cols = "file,audio,label,is_unknown,speaker_id,utterance_id" + +[tool.flwr.federations] +default = "local-sim" + + +[tool.flwr.federations.local-sim] +options.num-supernodes = 422 # we are grouping 2112 speakers into groups of 5 +options.backend.client-resources.num-cpus = 4 +options.backend.client-resources.num-gpus = 0.0 +options.backend.init-args.log-to-driver = false # set to true to enable all logs from simulation engine + +[tool.flwr.federations.local-sim-gpu] +options.num-supernodes = 422 # we are grouping 2112 speakers into groups of 5 +options.backend.client-resources.num-cpus = 2 +options.backend.client-resources.num-gpus = 0.2 +options.backend.init-args.log-to-driver = false # set to true to enable all logs from simulation engine + + +[tool.flwr.federations.remote] +address = '127.0.0.1:9093' # IP:9093 of your superlink (assumed localhost superlink) +insecure = true # Check the documentation to setup with SSL diff --git a/examples/whisper-federated-finetuning/requirements.txt b/examples/whisper-federated-finetuning/requirements.txt deleted file mode 100644 index 9e42ace475f9..000000000000 --- a/examples/whisper-federated-finetuning/requirements.txt +++ /dev/null @@ -1,6 +0,0 @@ -transformers==4.38.0 -tokenizers>=0.14.0 -datasets==2.14.6 -soundfile==0.12.1 -librosa==0.10.1 -flwr[simulation]>=1.0, <2.0 diff --git a/examples/whisper-federated-finetuning/rpi_setup.md b/examples/whisper-federated-finetuning/rpi_setup.md deleted file mode 100644 index d49bbd6a472b..000000000000 --- a/examples/whisper-federated-finetuning/rpi_setup.md +++ /dev/null @@ -1,49 +0,0 @@ -# Setting up your RaspberryPi - -> This guide assumes you have a fresh install of Ubuntu Server (either 22.04 or 23.10) and that you have successfully `ssh`-ed into your device. - -## Setting up your device for Python developemnet - -We are going to use [`pyenv`](https://github.com/pyenv/pyenv) to manage different Python versions and to create an environment. First, we need to install some system dependencies - -```bash -sudo apt update -# the last package is needed for whisper -sudo apt install build-essential zlib1g-dev libssl-dev libsqlite3-dev libreadline-dev libbz2-dev libffi-dev liblzma-dev libsndfile1 -``` - -Create Python environment with `pyenv`: - -```bash - -# Ensure you have installed pyenv, else do the below: -git clone https://github.com/pyenv/pyenv.git ~/.pyenv -echo 'export PYENV_ROOT="$HOME/.pyenv"' >> ~/.bashrc -echo 'command -v pyenv >/dev/null || export PATH="$PYENV_ROOT/bin:$PATH"' >> ~/.bashrc -echo 'eval "$(pyenv init -)"' >> ~/.bashrc - -# Install python 3.9+ -pyenv install 3.9.17 - -# Install pyenv virtual env plugin -git clone https://github.com/pyenv/pyenv-virtualenv.git $(pyenv root)/plugins/pyenv-virtualenv -# Restart your shell -exec "$SHELL" - -# Create the environment -pyenv virtualenv 3.9.17 flower-whisperer -``` - -## Installing the dependencies for Whisper+Flower - -With our environmnet ready, let's install the dependencies. Please note that at the time of writing, PyTorch 2.0+ won't work properly on `aarm64`. Because of this, we'll be using an earlier version of this package. - -```bash -# activate your environment -pyenv activate flower-whisperer - -# install pytorch (RPi aren't ready for PyTorch 2.0+ apparently...) -pip install torch==1.13.1 -# install rest of requirerments -pip install -r requirements.txt -``` diff --git a/examples/whisper-federated-finetuning/server.py b/examples/whisper-federated-finetuning/server.py deleted file mode 100644 index 060b162240e5..000000000000 --- a/examples/whisper-federated-finetuning/server.py +++ /dev/null @@ -1,103 +0,0 @@ -import argparse - -import flwr as fl -import torch -from datasets import load_dataset -from torch.utils.data import DataLoader -from transformers import WhisperProcessor - -from utils import eval_model, get_encoding_fn, get_model, remove_cols, set_params - -parser = argparse.ArgumentParser(description="Flower+Whisper") -parser.add_argument("--num_rounds", type=int, default=5, help="Number of FL rounds.") -parser.add_argument( - "--server_address", type=str, required=True, help="IP of the server." -) - - -NUM_CLASSES = 12 -NUM_CLIENTS = 100 - - -def fit_config(server_round: int): - """Return a configuration with static batch size and (local) epochs.""" - config = { - "epochs": 1, # Number of local epochs done by clients - "batch_size": 8, # Batch size to use by clients during fit() - } - return config - - -def get_evaluate_fn(val_set, test_set, encoding_fn, num_rounds): - def evaluate(server_round: int, parameters: fl.common.NDArrays, config): - """Use the entire CIFAR-10 test set for evaluation.""" - - # Determine device - device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - - # prepare model - encoder, classifier = get_model(device, NUM_CLASSES) - set_params(classifier, parameters) - classifier.to(device) - - # prepare dataset - og_threads = torch.get_num_threads() - torch.set_num_threads( - 1 - ) # ! still, not clear to me why this is needed if we want `num_proc>1` - if server_round == num_rounds: - prefix = "test" - encoded = test_set.map(encoding_fn, num_proc=4, remove_columns=remove_cols) - else: - prefix = "val" - encoded = val_set.map(encoding_fn, num_proc=4, remove_columns=remove_cols) - torch.set_num_threads(og_threads) - - val_encoded = encoded.with_format("torch", columns=["data", "targets"]) - val_loader = DataLoader(val_encoded, batch_size=64, num_workers=4) - - # Run global evaluation - criterion = torch.nn.CrossEntropyLoss() - loss, accuracy = eval_model(encoder, classifier, criterion, val_loader, device) - - print(f"{prefix}: --> {loss = }, {accuracy = }") - - return loss, {f"{prefix}_accuracy": accuracy} - - return evaluate - - -def main(): - # Parse input arguments - args = parser.parse_args() - - # The sever will use the validation set to assess the performance of the global - # model after each round. Then, the test set will be used for evaluating the global - # model after the last round - sc_val = load_dataset("speech_commands", "v0.02", split="validation", token=False) - sc_test = load_dataset("speech_commands", "v0.02", split="test", token=False) - - processor = WhisperProcessor.from_pretrained("openai/whisper-tiny") - prepare_dataset_fn = get_encoding_fn(processor) - - # We use a standard FedAvg strategy - strategy = fl.server.strategy.FedAvg( - fraction_fit=0.00001, - min_fit_clients=2, # the strategy will wait until at least 2 clients are sampled for fit - fraction_evaluate=0.0, # we don't do federated evaluation in this example - min_available_clients=2, # the strategy will do nothing until 2 clients are connected to the server - on_fit_config_fn=fit_config, - evaluate_fn=get_evaluate_fn( - sc_val, sc_test, prepare_dataset_fn, args.num_rounds - ), - ) - - fl.server.start_server( - server_address=f"{args.server_address}:8080", - config=fl.server.ServerConfig(num_rounds=args.num_rounds), - strategy=strategy, - ) - - -if __name__ == "__main__": - main() diff --git a/examples/whisper-federated-finetuning/sim.py b/examples/whisper-federated-finetuning/sim.py deleted file mode 100644 index 750a7f705251..000000000000 --- a/examples/whisper-federated-finetuning/sim.py +++ /dev/null @@ -1,94 +0,0 @@ -import argparse - -import flwr as fl -import torch -from datasets import load_dataset -from transformers import WhisperProcessor - -from client import get_client_fn -from server import fit_config, get_evaluate_fn -from utils import construct_client_mapping, get_encoding_fn - -parser = argparse.ArgumentParser(description="Flower+Whisper") - -parser.add_argument("--num_rounds", type=int, default=10, help="Number of FL rounds.") -parser.add_argument( - "--num_cpus", type=int, default=4, help="Number of CPUs reserved for each client." -) -parser.add_argument( - "--num_gpus", - type=float, - default=0.5, - help="GPU ratio reserved for each client (`num_gpus`=1.0 means one client gets the whole GPU)", -) -parser.add_argument( - "--preprocess", - action="store_true", - help="Preprocesses all client's datasets and exits (creates ~83GB data)", -) - -NUM_CLASSES = 12 -NUM_CLIENTS = 100 -CLIENT_DATA = "client_datasets" -torch.set_float32_matmul_precision( - "high" -) # If “high” or “medium” are set then the TensorFloat32 is used - - -def main(): - # Parse input arguments - args = parser.parse_args() - - # dataset download and preparation - sc_train = load_dataset("speech_commands", "v0.02", split="train", token=False) - sc_val = load_dataset("speech_commands", "v0.02", split="validation", token=False) - sc_test = load_dataset("speech_commands", "v0.02", split="test", token=False) - - # generate splits - client_mapping = construct_client_mapping(sc_train, num_clients=NUM_CLIENTS) - - # pre-process all partitions (+store to disk) - processor = WhisperProcessor.from_pretrained("openai/whisper-tiny") - prepare_dataset_fn = get_encoding_fn(processor) - if args.preprocess: - import sys - - client_fn = get_client_fn( - sc_train, prepare_dataset_fn, client_mapping, CLIENT_DATA, NUM_CLASSES - ) - - for i in range(NUM_CLIENTS): - _ = client_fn(str(i)) - print("Preprocessing completed. Run the code again without `--preprocess`") - sys.exit(0) - - strategy = fl.server.strategy.FedAvg( - fraction_fit=0.00001, - min_fit_clients=10, - fraction_evaluate=0.0, - min_available_clients=NUM_CLIENTS, - on_fit_config_fn=fit_config, - evaluate_fn=get_evaluate_fn( - sc_val, sc_test, prepare_dataset_fn, args.num_rounds - ), - ) - - # Start simulation - fl.simulation.start_simulation( - client_fn=get_client_fn( - sc_train, - prepare_dataset_fn, - client_mapping, - CLIENT_DATA, - NUM_CLASSES, - disable_tqdm=True, - ), - num_clients=NUM_CLIENTS, - client_resources={"num_cpus": args.num_cpus, "num_gpus": args.num_gpus}, - config=fl.server.ServerConfig(num_rounds=args.num_rounds), - strategy=strategy, - ) - - -if __name__ == "__main__": - main() diff --git a/examples/whisper-federated-finetuning/utils.py b/examples/whisper-federated-finetuning/utils.py deleted file mode 100644 index 3bae730790a0..000000000000 --- a/examples/whisper-federated-finetuning/utils.py +++ /dev/null @@ -1,208 +0,0 @@ -import random -from collections import OrderedDict -from typing import List - -import flwr as fl -import numpy as np -import torch -from datasets import Dataset -from tqdm import tqdm -from transformers import WhisperForConditionalGeneration - -remove_cols = ["file", "audio", "label", "is_unknown", "speaker_id", "utterance_id"] - - -class RunningAvg: - def __init__(self): - self.n = 0 - self.total = 0 - - def update(self, val): - self.total += val - self.n += 1 - - def __call__(self): - return self.total / self.n - - -def train_one_epoch( - model, - classifier, - optimizer, - criterion, - dataloader, - device, - disable_tqdm: bool = False, -): - """Train the classification head. - - This is a very standard looking way of training PyTorch models. - """ - model.eval() - classifier.train() - classifier.to(device) - loss_avg, acc_avg = RunningAvg(), RunningAvg() - with tqdm(total=len(dataloader.dataset), disable=disable_tqdm) as t: - for b in dataloader: - optimizer.zero_grad() - data = b["data"].squeeze().to(device) - # print(data.shape) - labels = b["targets"].to(device) - with torch.no_grad(): - res = model(data)[0] - - resres = classifier(res) - - loss = criterion(resres.float(), labels) - loss.backward() - optimizer.step() - _, predicted = torch.max(resres.data, 1) - correct = (predicted == labels).sum().item() - acc = correct / data.shape[0] - loss_ = loss.cpu().item() - - loss_avg.update(loss_) - acc_avg.update(acc) - - t.update(data.shape[0]) - t.set_postfix( - {"avg_loss": f"{loss_avg():.4f}", "avg_acc": f"{acc_avg():.4f}"} - ) - - -def eval_model(model, classifier, criterion, dataloader, device): - """Evaluate a model on a validation/test set. - - This is a very normal looking way of doing this with PyTorch. - """ - model.eval() - classifier.eval() - classifier.to(device) - correct = 0 - loss_ = 0 - total = 0 - with torch.no_grad(): - for b in dataloader: - data = b["data"].squeeze().to(device) - # print(data.shape) - labels = b["targets"].to(device) - res = model(data)[0] - resres = classifier(res) - - loss = criterion(resres.float(), labels) - _, predicted = torch.max(resres.data, 1) - correct += (predicted == labels).sum().item() - total += data.shape[0] - loss_ += loss.cpu().item() - - accuracy = correct / total - loss = loss_ / total - - return loss, accuracy - - -def prepare_silences_dataset(train_dataset, ratio_silence: float = 0.1) -> Dataset: - """Generate silences for the train set. - - One of the classes in the SpeechCommands datatset is `silence`. However, the dataset - does not include clips of silence. It does however include 5 long files with - different background sounds. The taks of this function is to extract several - (defined by `ratio_silence`) one-second long clips from those background audio - files. Later, those audio clips will be included into the training set. - """ - # retrieve original silence audio clips - silences = [d for d in train_dataset if d["label"] == 35] - # figure out how many to add - num_silence_total = int(len(train_dataset) * ratio_silence) - # num new entries per background noise clip - num_silence_per_bkg = num_silence_total // len(silences) - - silence_to_add = [] - for sil in silences: - sil_array = sil["audio"]["array"] - sr = sil["audio"]["sampling_rate"] - print(f"Extracting audio from: {sil['file']} ...") - for _ in range(num_silence_per_bkg): - random_offset = random.randint(0, len(sil_array) - sr - 1) - sil_array_crop = sil_array[random_offset : random_offset + sr] - - entry = sil - silence_to_add.append(entry) - silence_to_add[-1]["audio"]["array"] = sil_array_crop - - return Dataset.from_list(silence_to_add) - - -def construct_client_mapping(full_trainset, num_clients: int = 100): - """Create a mapping to partition the dataset into `num_client` buckets. - - These buckets contain the same number of `spekaer_id` but likely different number of - training exampes since each `speaker_id` in SpeechCommands does provide different - amounts of data to the dataset. - """ - client_ids = list(set(full_trainset["speaker_id"])) - client_ids.remove( - None - ) # remove "none" which corresponds to the _silence_ audio clips - client_ids.sort() # we sort this as a quick way of ensuring our client mapping is consistent between runs - len( - client_ids - ) # should be 2112 (i.e. the number of participats in SpeechCommands dataset v0.02) - - # split into groups (each group represents a client) - client_mapping = np.array_split(client_ids, num_clients) - - return client_mapping - - -def get_encoding_fn(processor): - """Return a function to use to pre-process/encode the SpeechCommands dataset. - - We are working with the 12classes version of this dataset, therefore we need to do - some reassignment of labels. - """ - - def prepare_dataset(batch): - audio = batch["audio"] - data = {} - data["data"] = processor( - audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt" - ).input_features - - # All unknown keywords are assigned label 11. The silence clips get assigned label 10 - # In this way we have 12 classes with labels 0-11 - data["targets"] = ( - 11 - if batch["is_unknown"] - else (10 if batch["label"] == 35 else batch["label"]) - ) - return data - - return prepare_dataset - - -def set_params(model: torch.nn.ModuleList, params: List[fl.common.NDArrays]): - """Set model weights from a list of NumPy ndarrays.""" - params_dict = zip(model.state_dict().keys(), params) - state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict}) - model.load_state_dict(state_dict, strict=True) - - -def get_model(device, num_classes, compile: bool = True): - """Create model: Whisper-tiny Encoder + classification head.""" - encoder = WhisperForConditionalGeneration.from_pretrained( - "openai/whisper-tiny" - ).get_encoder() - encoder = encoder.to(device) - if compile: - encoder = torch.compile(encoder) - - # This classification head is 782K parameters - # This is the only part of the model that is trained in federation - classifier = torch.nn.Sequential( - torch.nn.Conv1d(1500, 128, kernel_size=1), - torch.nn.ReLU(), - torch.nn.Flatten(1), - torch.nn.Linear(128 * 384, num_classes), - ).to(device) - return encoder, classifier diff --git a/examples/whisper-federated-finetuning/visualize_labels.ipynb b/examples/whisper-federated-finetuning/visualize_labels.ipynb new file mode 100644 index 000000000000..f953d3f95b90 --- /dev/null +++ b/examples/whisper-federated-finetuning/visualize_labels.ipynb @@ -0,0 +1,195 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "76350774-70f0-47d7-b790-efd515d84b8f", + "metadata": {}, + "outputs": [], + "source": [ + "from flwr_datasets.partitioner import GroupedNaturalIdPartitioner\n", + "from flwr_datasets.visualization import plot_label_distributions\n", + "import matplotlib.pyplot as plt\n", + "from datasets import load_dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f367dfb9-c0f0-4098-b999-9bbd00d0cd46", + "metadata": {}, + "outputs": [], + "source": [ + "# Load train partition of SpeechCommands\n", + "sc = load_dataset(\"speech_commands\", \"v0.02\", split=\"train\", token=False)\n", + "\n", + "# Use the \"Grouped partitioner\" from FlowerDatasets to construct groups of 30 unique speaker ids\n", + "partitioner = GroupedNaturalIdPartitioner(partition_by=\"speaker_id\", group_size=30)" + ] + }, + { + "cell_type": "markdown", + "id": "4d457b46-d649-4b41-b950-f56891ab8961", + "metadata": {}, + "source": [ + "### Removing _silence_ clips" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a299c2ed-f7be-48ec-92a6-aa0ba2c992b2", + "metadata": {}, + "outputs": [], + "source": [ + "# Remove the silence audio clips (the dataset comes with 5 long audio clips. we don't want to show these in the plot below)\n", + "# Those silence audio clips are the entries in the dataset with `speaker_id`=None. Let's remove them\n", + "# At training time, each client with get 10% new data samples containing 1s-long silence clips\n", + "def filter_none_speaker(example):\n", + " return example[\"speaker_id\"] is not None\n", + "\n", + "\n", + "filtered_dataset = sc.filter(filter_none_speaker)\n", + "\n", + "# Apply dataset to partitioner\n", + "partitioner.dataset = filtered_dataset" + ] + }, + { + "cell_type": "markdown", + "id": "70e31a9c-e446-4407-b774-c48d7e6edf88", + "metadata": {}, + "source": [ + "### Making a plot" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a78625ab-f054-4582-9b59-9a057102f434", + "metadata": {}, + "outputs": [], + "source": [ + "fig, axis = plt.subplots(figsize=(16, 6))\n", + "fig, ax, df = plot_label_distributions(\n", + " partitioner,\n", + " axis=axis,\n", + " label_name=\"label\",\n", + " plot_type=\"bar\",\n", + " size_unit=\"absolute\",\n", + " partition_id_axis=\"x\",\n", + " legend=True,\n", + " verbose_labels=True,\n", + " title=\"Per Partition Labels Distribution\",\n", + " legend_kwargs={\"ncols\": 2, \"bbox_to_anchor\": (1.05, 0.5)},\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "49a9b4c2-e291-4ae4-86a2-051265640ed8", + "metadata": {}, + "outputs": [], + "source": [ + "fig.savefig(\"whisper_flower_data.png\", format=\"png\", bbox_inches=\"tight\")" + ] + }, + { + "cell_type": "markdown", + "id": "af15f634-daf5-4709-b2a6-c98dfb8db463", + "metadata": {}, + "source": [ + "### Process dataset into 12 classes\n", + "\n", + "To go from 35 classes into 12, we need to apply the following cahnges:\n", + "- all audio clips that had the `is_unknown` set, will be assigned the same \"target\" label `11`\n", + "- Silence audio clips will assigned label `10`\n", + "\n", + "We achieve this 35:12 mapping by means of the function below (similar to the one used in the code)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5092c826-f085-499a-acae-5ffcc0442757", + "metadata": {}, + "outputs": [], + "source": [ + "def prepare_dataset(batch):\n", + " data = {}\n", + " # All unknown keywords are assigned label 11. The silence clips get assigned label 10\n", + " # In this way we have 12 classes with labels 0-11\n", + " data[\"targets\"] = (\n", + " 11 if batch[\"is_unknown\"] else (10 if batch[\"label\"] == 35 else batch[\"label\"])\n", + " )\n", + " return data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "03131a7e-4fe0-4271-9078-a8314734b544", + "metadata": {}, + "outputs": [], + "source": [ + "dataset_12cls = filtered_dataset.map(prepare_dataset, num_proc=4)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0ba0c944-703f-46be-8886-34ab46fa6ba1", + "metadata": {}, + "outputs": [], + "source": [ + "# Re-construct the partitioner and apply the filtered dataset\n", + "partitioner = GroupedNaturalIdPartitioner(partition_by=\"speaker_id\", group_size=30)\n", + "partitioner.dataset = dataset_12cls" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9e8b319c-1056-4b68-9082-4c0057f50d2e", + "metadata": {}, + "outputs": [], + "source": [ + "# Generate the plot again, this time using the new \"targets\" key\n", + "fig, axis = plt.subplots(figsize=(16, 6))\n", + "fig, ax, df = plot_label_distributions(\n", + " partitioner,\n", + " axis=axis,\n", + " label_name=\"targets\",\n", + " plot_type=\"bar\",\n", + " size_unit=\"absolute\",\n", + " partition_id_axis=\"x\",\n", + " legend=True,\n", + " verbose_labels=False,\n", + " title=\"Per Partition Labels Distribution\",\n", + " legend_kwargs={\"ncols\": 2, \"bbox_to_anchor\": (1.0, 0.5)},\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ca8743ab-2f21-49f2-808c-74739f9d97aa", + "metadata": {}, + "outputs": [], + "source": [ + "# classes 0-9 correspond to keywords: 'yes', 'no', 'up', 'down', 'left', 'right', 'on', 'off'\n", + "# Class 10 is 'silence' and class 11 is 'other' (combined remaining classes from the 35-class original representation)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/whisper-federated-finetuning/whisper_example/__init__.py b/examples/whisper-federated-finetuning/whisper_example/__init__.py new file mode 100644 index 000000000000..30f4d1bccdb4 --- /dev/null +++ b/examples/whisper-federated-finetuning/whisper_example/__init__.py @@ -0,0 +1 @@ +"""whisper_example: A Flower / PyTorch app with OpenAi's Whisper.""" diff --git a/examples/whisper-federated-finetuning/whisper_example/client_app.py b/examples/whisper-federated-finetuning/whisper_example/client_app.py new file mode 100644 index 000000000000..79871c603b60 --- /dev/null +++ b/examples/whisper-federated-finetuning/whisper_example/client_app.py @@ -0,0 +1,128 @@ +"""whisper_example: A Flower / PyTorch app with OpenAi's Whisper.""" + +import time + +time.sleep(5) +import torch +from torch.utils.data import DataLoader +from whisper_example.dataset import load_data +from whisper_example.dataset import load_data_from_disk + +from whisper_example.model import ( + construct_balanced_sampler, + get_model, + get_params, + set_params, + train_one_epoch, +) + +from flwr.client import ClientApp, NumPyClient +from flwr.common import Context + +torch.set_float32_matmul_precision( + "high" +) # If “high” or “medium” are set then the TensorFloat32 is used + +og_threads = torch.get_num_threads() + + +class WhisperFlowerClient(NumPyClient): + """A Flower client that does trains a classification head attached to the encoder of + a Whisper-tiny encoder for Keyword spotting.""" + + def __init__( + self, + trainset, + batch_size: int, + num_classes: int, + disable_tqdm: bool, + compile: bool, + ): + self.disable_tqdm = disable_tqdm + self.batch_size = batch_size + self.trainset = trainset.with_format("torch", columns=["data", "targets"]) + + # Determine device + self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + # processor = WhisperProcessor.from_pretrained("openai/whisper-tiny") + self.encoder, self.classifier = get_model(self.device, num_classes, compile) + + def fit(self, parameters, config): + """Do on-device training. + + Here the client receives the parameters of the classification head from the + server. Then trains that classifier using the data that belongs to this client. + Finally, The updated classifier is sent back to the server for aggregation. + """ + + # Apply the classifier parameters to the model in this client + set_params(self.classifier, parameters) + + # construct sampler in order to have balanced batches + sampler = None + if len(self.trainset) > self.batch_size: + sampler = construct_balanced_sampler(self.trainset) + + # Construct dataloader + train_loader = DataLoader( + self.trainset, + batch_size=self.batch_size, + shuffle=False, + num_workers=0, + sampler=sampler, + drop_last=True, + ) + + # Define optimizer and criterion + criterion = torch.nn.CrossEntropyLoss() + optimizer = torch.optim.Adam(self.classifier.parameters(), lr=0.001) + + # Don't train if partition is very small + run_training = len(train_loader) > 1 + metrics = {"trained": run_training} # will be used for metrics aggregation + if run_training: + # Train + avg_loss, avg_acc = train_one_epoch( + self.encoder, + self.classifier, + optimizer, + criterion, + train_loader, + self.device, + disable_tqdm=self.disable_tqdm, + ) + metrics = {**metrics, "loss": avg_loss, "accuracy": avg_acc} + + # Return local classification head and statistics + return get_params(self.classifier), len(train_loader.dataset), metrics + + +def client_fn(context: Context): + + partition_id = context.node_config["partition-id"] + num_classes = context.run_config["num-classes"] + batch_size = context.run_config["batch-size"] + disable_tqdm = context.run_config["disable-tqdm"] + compile_model = context.run_config["compile-model"] + + # Some systems seem to need this, else .map stages will hang + # Doesn't seem to be required on macOS; but it's on Ubuntu + # even if the latter has more CPUs... + # ! Open a PR if you know how to improve this! + og_threads = torch.get_num_threads() + torch.set_num_threads(1) + + partition = load_data( + partition_id=partition_id, + remove_cols=context.run_config["remove-cols"], + ) + + torch.set_num_threads(og_threads) + + return WhisperFlowerClient( + partition, batch_size, num_classes, disable_tqdm, compile_model + ).to_client() + + +app = ClientApp(client_fn=client_fn) diff --git a/examples/whisper-federated-finetuning/whisper_example/dataset.py b/examples/whisper-federated-finetuning/whisper_example/dataset.py new file mode 100644 index 000000000000..ac9448d4b34e --- /dev/null +++ b/examples/whisper-federated-finetuning/whisper_example/dataset.py @@ -0,0 +1,114 @@ +"""whisper_example: A Flower / PyTorch app with OpenAi's Whisper.""" + +import random +from typing import List + +from flwr_datasets import FederatedDataset +from datasets import load_from_disk +from flwr_datasets.partitioner import GroupedNaturalIdPartitioner +from transformers import WhisperProcessor + +from datasets import Dataset, concatenate_datasets + +fds = None # Cache FederatedDataset +processor = WhisperProcessor.from_pretrained("openai/whisper-tiny") + + +def load_data( + partition_id: int, + remove_cols: List[str], +): + # Only initialize `FederatedDataset` once + global fds + if fds is None: + partitioner = GroupedNaturalIdPartitioner( + partition_by="speaker_id", group_size=5 + ) + fds = FederatedDataset( + dataset="speech_commands", + subset="v0.02", + partitioners={"train": partitioner}, + trust_remote_code=True, + ) + + partition = fds.load_partition(partition_id) + + encoding_fn = get_encoding_fn(processor) + + remove_cols = remove_cols.split(",") + partition = partition.map(encoding_fn, num_proc=2, remove_columns=remove_cols) + + # Now let's add some _silence_ training examples (add 10% of total examples in this client's data) + partitioner = fds.partitioners["train"] + ratio_silences_for_client = 0.1 * (len(partition) / len(partitioner.dataset)) + silence_dataset = prepare_silences_dataset( + partitioner.dataset, ratio_silences_for_client + ) + if len(silence_dataset) > 0: + silence_enc = silence_dataset.map(encoding_fn) + partition = concatenate_datasets([partition, silence_enc]) + + return partition + + +def load_data_from_disk(data_path): + """Load ddata from a partition explicitly saved to disk.""" + return load_from_disk(data_path) + + +def get_encoding_fn(processor): + """Return a function to use to pre-process/encode the SpeechCommands dataset. + + We are working with the 12classes version of this dataset, therefore we need to do + some reassignment of labels. + """ + + def prepare_dataset(batch): + audio = batch["audio"] + data = {} + data["data"] = processor( + audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt" + ).input_features + + # All unknown keywords are assigned label 11. The silence clips get assigned label 10 + # In this way we have 12 classes with labels 0-11 + data["targets"] = ( + 11 + if batch["is_unknown"] + else (10 if batch["label"] == 35 else batch["label"]) + ) + return data + + return prepare_dataset + + +def prepare_silences_dataset(train_dataset, ratio_silence: float = 0.1) -> Dataset: + """Generate silences for the train set. + + One of the classes in the SpeechCommands datatset is `silence`. However, the dataset + does not include clips of silence. It does however include 5 long files with + different background sounds. The taks of this function is to extract several + (defined by `ratio_silence`) one-second long clips from those background audio + files. Later, those audio clips will be included into the training set. + """ + # Retrieve original silence audio clips + silences = train_dataset.filter(lambda x: x["label"] == 35) + # Figure out how many to add + num_silence_total = int(len(train_dataset) * ratio_silence) + # Num new entries per background noise clip + num_silence_per_bkg = num_silence_total // len(silences) + + silence_to_add = [] + for sil in silences: + sil_array = sil["audio"]["array"] + sr = sil["audio"]["sampling_rate"] + # print(f"Extracting audio from: {sil['file']} ...") + for _ in range(num_silence_per_bkg): + random_offset = random.randint(0, len(sil_array) - sr - 1) + sil_array_crop = sil_array[random_offset : random_offset + sr] + + entry = sil + silence_to_add.append(entry) + silence_to_add[-1]["audio"]["array"] = sil_array_crop + + return Dataset.from_list(silence_to_add) diff --git a/examples/whisper-federated-finetuning/whisper_example/model.py b/examples/whisper-federated-finetuning/whisper_example/model.py new file mode 100644 index 000000000000..0add74d1c98d --- /dev/null +++ b/examples/whisper-federated-finetuning/whisper_example/model.py @@ -0,0 +1,148 @@ +"""whisper_example: A Flower / PyTorch app with OpenAi's Whisper.""" + +from collections import OrderedDict +from typing import List + +import numpy as np +import torch +from torch.utils.data import WeightedRandomSampler +from tqdm import tqdm +from transformers import WhisperForConditionalGeneration + +from flwr.common import NDArrays + + +def get_model(device, num_classes, compile: bool = True): + """Create model: Whisper-tiny Encoder + classification head.""" + encoder = WhisperForConditionalGeneration.from_pretrained( + "openai/whisper-tiny" + ).get_encoder() + encoder = encoder.to(device) + if compile: + encoder = torch.compile(encoder) + + # This classification head is 782K parameters + # This is the only part of the model that is trained in federation + classifier = torch.nn.Sequential( + torch.nn.Conv1d(1500, 128, kernel_size=1), + torch.nn.ReLU(), + torch.nn.Flatten(1), + torch.nn.Linear(128 * 384, num_classes), + ).to(device) + return encoder, classifier + + +def set_params(model: torch.nn.ModuleList, params: List[NDArrays]): + """Set model weights from a list of NumPy ndarrays.""" + params_dict = zip(model.state_dict().keys(), params) + state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict}) + model.load_state_dict(state_dict, strict=True) + + +def get_params(module: torch.nn.ModuleList): + return [val.cpu().numpy() for _, val in module.state_dict().items()] + + +class RunningAvg: + def __init__(self): + self.n = 0 + self.total = 0 + + def update(self, val): + self.total += val + self.n += 1 + + def __call__(self): + return self.total / self.n + + +def construct_balanced_sampler(trainset): + hist, _ = np.histogram(trainset["targets"], bins=12) + # Mask of non-zeros + hist_mask = hist > 0 + w_per_class = len(trainset) / ( + hist + 1 + ) # avoid dividing by zeros # doesn't have to add up to 1 (relative is what matters) + w_per_class += 1 # needed in case trainset has very few samples + # Apply mask so we don't attempt sampling classes that aren't present + w_per_class *= hist_mask + w_ss = [w_per_class[t] for t in trainset["targets"]] + return WeightedRandomSampler(w_ss, len(w_ss)) + + +def train_one_epoch( + model, + classifier, + optimizer, + criterion, + dataloader, + device, + disable_tqdm: bool = False, +): + """Train the classification head. + + This is a very standard looking way of training PyTorch models. + """ + model.eval() + classifier.train() + classifier.to(device) + loss_avg, acc_avg = RunningAvg(), RunningAvg() + avg_loss, avg_acc = 0.0, 0.0 + with tqdm(total=len(dataloader.dataset), disable=disable_tqdm) as t: + for b in dataloader: + optimizer.zero_grad() + data = b["data"].squeeze().to(device) + # print(data.shape) + labels = b["targets"].to(device) + with torch.no_grad(): + res = model(data)[0] + + resres = classifier(res) + + loss = criterion(resres.float(), labels) + loss.backward() + optimizer.step() + _, predicted = torch.max(resres.data, 1) + correct = (predicted == labels).sum().item() + acc = correct / data.shape[0] + loss_ = loss.cpu().item() + + loss_avg.update(loss_) + acc_avg.update(acc) + + t.update(data.shape[0]) + avg_loss, avg_acc = loss_avg(), acc_avg() + t.set_postfix({"avg_loss": f"{avg_loss:.4f}", "avg_acc": f"{avg_acc:.4f}"}) + + return avg_loss, avg_acc + + +def eval_model(model, classifier, criterion, dataloader, device): + """Evaluate a model on a validation/test set. + + This is a very normal looking way of doing this with PyTorch. + """ + model.eval() + classifier.eval() + classifier.to(device) + correct = 0 + loss_ = 0 + total = 0 + with torch.no_grad(): + for b in dataloader: + data = b["data"].squeeze().to(device) + # print(data.shape) + labels = b["targets"].to(device) + res = model(data)[0] + resres = classifier(res) + + loss = criterion(resres.float(), labels) + _, predicted = torch.max(resres.data, 1) + correct += (predicted == labels).sum().item() + total += data.shape[0] + loss_ += loss.cpu().item() + + accuracy = correct / total + loss = loss_ / total + + return loss, accuracy diff --git a/examples/whisper-federated-finetuning/whisper_example/server_app.py b/examples/whisper-federated-finetuning/whisper_example/server_app.py new file mode 100644 index 000000000000..6f97ef9477b2 --- /dev/null +++ b/examples/whisper-federated-finetuning/whisper_example/server_app.py @@ -0,0 +1,153 @@ +"""whisper_example: A Flower / PyTorch app with OpenAi's Whisper.""" + +from logging import INFO +from typing import List, Tuple + +import torch +from torch.utils.data import DataLoader +from transformers import WhisperProcessor +from whisper_example.dataset import get_encoding_fn +from whisper_example.model import eval_model, get_model, get_params, set_params + +from datasets import load_dataset +from flwr.common import Context, FitRes, Metrics, NDArrays, ndarrays_to_parameters +from flwr.common.logger import log +from flwr.common.typing import UserConfig +from flwr.server import ServerApp, ServerAppComponents, ServerConfig +from flwr.server.client_proxy import ClientProxy +from flwr.server.strategy import FedAvg + + +def get_evaluate_fn( + val_set, test_set, processor: WhisperProcessor, run_config: UserConfig +): + """Return a callback that the strategy will call after models are aggregated.""" + + def evaluate(server_round: int, parameters: NDArrays, config): + """Evaluate global model on a centralized dataset.""" + + num_rounds = run_config["num-server-rounds"] + num_classes = run_config["num-classes"] + remove_cols = run_config["remove-cols"] + remove_cols = remove_cols.split(",") + + # Determine device + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + # prepare model + encoder, classifier = get_model(device, num_classes) + set_params(classifier, parameters) + classifier.to(device) + + # prepare dataset + og_threads = torch.get_num_threads() + torch.set_num_threads(1) + encoding_fn = get_encoding_fn(processor) + if server_round == num_rounds: + prefix = "test" + encoded = test_set.map(encoding_fn, num_proc=4, remove_columns=remove_cols) + else: + prefix = "val" + encoded = val_set.map(encoding_fn, num_proc=4, remove_columns=remove_cols) + + torch.set_num_threads(og_threads) + val_encoded = encoded.with_format("torch", columns=["data", "targets"]) + val_loader = DataLoader(val_encoded, batch_size=64, num_workers=4) + + # Run global evaluation + criterion = torch.nn.CrossEntropyLoss() + loss, accuracy = eval_model(encoder, classifier, criterion, val_loader, device) + + return loss, {f"{prefix}_accuracy": accuracy} + + return evaluate + + +def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: + # Multiply accuracy of each client by number of examples used + accuracies = [ + num_examples * m["accuracy"] for num_examples, m in metrics if m["trained"] + ] + losses = [num_examples * m["loss"] for num_examples, m in metrics if m["trained"]] + examples = [num_examples for num_examples, _ in metrics] + + # Aggregate and return custom metric (weighted average) + return { + "train_accuracy": sum(accuracies) / sum(examples), + "train_loss": sum(losses) / sum(examples), + } + + +class ExclusiveFedAvg(FedAvg): + + def aggregate_fit( + self, + server_round: int, + results: List[Tuple[ClientProxy | FitRes]], + failures: List[Tuple[ClientProxy | FitRes] | BaseException], + ): + # Clients with not enough training examples to have a single full batch + # didn't train the classification head. We need to exclude it from aggregation + + trained_results = [] + for cp, res in results: + if res.metrics["trained"]: + trained_results.append((cp, res)) + log( + INFO, + f"{len(trained_results)}/{len(results)} models included for aggregation.", + ) + + return super().aggregate_fit(server_round, trained_results, failures) + + +def server_fn(context: Context): + """Construct components that set the ServerApp behaviour.""" + + # Read from config + num_rounds = context.run_config["num-server-rounds"] + num_classes = context.run_config["num-classes"] + fraction_fit = context.run_config["fraction-fit"] + + # Initialize global model parameters. Recall we are + # only federating the classification head + _, classifier = get_model("cpu", num_classes, False) + ndarrays = get_params(classifier) + parameters = ndarrays_to_parameters(ndarrays) + + eval_fn = None + if context.run_config["central-eval"]: + # The ServerApp will use the validation set to assess the performance of the global + # model after each round. Then, the test set will be used for evaluating the global + # model after the last round + sc_val = load_dataset( + "speech_commands", "v0.02", split="validation", token=False + ) + sc_test = load_dataset("speech_commands", "v0.02", split="test", token=False) + + # Processor + processor = WhisperProcessor.from_pretrained("openai/whisper-tiny") + + # Prepare evaluation function + eval_fn = get_evaluate_fn( + val_set=sc_val, + test_set=sc_test, + processor=processor, + run_config=context.run_config, + ) + + # Define the strategy + strategy = ExclusiveFedAvg( + fraction_fit=fraction_fit, + fraction_evaluate=0.0, + fit_metrics_aggregation_fn=weighted_average, + evaluate_fn=eval_fn, + initial_parameters=parameters, + ) + config = ServerConfig(num_rounds=num_rounds) + + return ServerAppComponents(strategy=strategy, config=config) + + +# Create ServerApp +app = ServerApp(server_fn=server_fn)