diff --git a/.github/workflows/_docker-build.yml b/.github/workflows/_docker-build.yml index 9005f21424fb..4a1289d9175a 100644 --- a/.github/workflows/_docker-build.yml +++ b/.github/workflows/_docker-build.yml @@ -1,4 +1,4 @@ -name: Reusable docker server image build workflow +name: Reusable docker image build workflow on: workflow_call: @@ -35,7 +35,7 @@ permissions: # based on https://docs.docker.com/build/ci/github-actions/multi-platform/#distribute-build-across-multiple-runners jobs: build: - name: Build server image + name: Build image runs-on: ubuntu-22.04 timeout-minutes: 60 outputs: @@ -98,7 +98,7 @@ jobs: touch "/tmp/digests/${digest#sha256:}" - name: Upload digest - uses: actions/upload-artifact@c7d193f32edcb7bfad88892161225aeda64e9392 # v4.0.0 + uses: actions/upload-artifact@1eb3cb2b3e0f29609092a73eb033bb759a334595 # v4.1.0 with: name: digests-${{ steps.build-id.outputs.id }}-${{ matrix.platform.name }} path: /tmp/digests/* @@ -114,7 +114,7 @@ jobs: metadata: ${{ steps.meta.outputs.json }} steps: - name: Download digests - uses: actions/download-artifact@f44cd7b40bfd40b6aa1cc1b9b5b7bf03d3c67110 # v4.1.0 + uses: actions/download-artifact@6b208ae046db98c579e8a3aa621ab581ff575935 # v4.1.1 with: pattern: digests-${{ needs.build.outputs.build-id }}-* path: /tmp/digests diff --git a/.github/workflows/docker-base.yml b/.github/workflows/docker-base.yml index fe276585a3bd..f2cd2ef99d08 100644 --- a/.github/workflows/docker-base.yml +++ b/.github/workflows/docker-base.yml @@ -39,7 +39,7 @@ jobs: echo "ubuntu-version=${{ env.DEFAULT_UBUNTU }}" >> "$GITHUB_OUTPUT" build-base-images: - name: Build images + name: Build base images uses: ./.github/workflows/_docker-build.yml needs: parameters strategy: diff --git a/.github/workflows/docker-client.yml b/.github/workflows/docker-client.yml new file mode 100644 index 000000000000..47083b258982 --- /dev/null +++ b/.github/workflows/docker-client.yml @@ -0,0 +1,36 @@ +name: Build docker client image + +on: + workflow_dispatch: + inputs: + flwr-version: + description: "Version of Flower e.g. (1.6.0)." + required: true + type: string + +permissions: + contents: read + +jobs: + build-client-images: + name: Build client images + uses: ./.github/workflows/_docker-build.yml + # run only on default branch when using it with workflow_dispatch + if: github.ref_name == github.event.repository.default_branch + strategy: + fail-fast: false + matrix: + python-version: ["3.8", "3.9", "3.10", "3.11"] + with: + namespace-repository: flwr/client + file-dir: src/docker/client + build-args: | + FLWR_VERSION=${{ github.event.inputs.flwr-version }} + BASE_IMAGE_TAG=py${{ matrix.python-version }}-ubuntu22.04 + tags: | + ${{ github.event.inputs.flwr-version }}-py${{ matrix.python-version }}-ubuntu22.04 + ${{ github.event.inputs.flwr-version }} + latest + secrets: + dockerhub-user: ${{ secrets.DOCKERHUB_USERNAME }} + dockerhub-token: ${{ secrets.DOCKERHUB_TOKEN }} diff --git a/.github/workflows/framework-draft-release.yml b/.github/workflows/framework-draft-release.yml new file mode 100644 index 000000000000..959d17249765 --- /dev/null +++ b/.github/workflows/framework-draft-release.yml @@ -0,0 +1,63 @@ +name: Draft release + +on: + push: + tags: + - "v*.*.*" + +jobs: + publish: + if: ${{ github.repository == 'adap/flower' }} + name: Publish draft + runs-on: ubuntu-22.04 + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Wait for wheel to be built + uses: lewagon/wait-on-check-action@v1.3.3 + with: + ref: ${{ github.ref }} + check-name: 'Build, test and upload wheel' + repo-token: ${{ secrets.GITHUB_TOKEN }} + wait-interval: 10 + - name: Download wheel + run: | + tag_name=$(echo "${GITHUB_REF_NAME}" | cut -c2-) + echo "TAG_NAME=$tag_name" >> "$GITHUB_ENV" + + wheel_name="flwr-${tag_name}-py3-none-any.whl" + echo "WHEEL_NAME=$wheel_name" >> "$GITHUB_ENV" + + tar_name="flwr-${tag_name}.tar.gz" + echo "TAR_NAME=$tar_name" >> "$GITHUB_ENV" + + wheel_url="https://artifact.flower.dev/py/main/${GITHUB_SHA::7}/${wheel_name}" + tar_url="https://artifact.flower.dev/py/main/${GITHUB_SHA::7}/${tar_name}" + + curl $wheel_url --output $wheel_name + curl $tar_url --output $tar_name + - name: Upload wheel + env: + AWS_DEFAULT_REGION: ${{ secrets. AWS_DEFAULT_REGION }} + AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} + AWS_SECRET_ACCESS_KEY: ${{ secrets. AWS_SECRET_ACCESS_KEY }} + run: | + aws s3 cp --content-disposition "attachment" --cache-control "no-cache" ./${{ env.WHEEL_NAME }} s3://artifact.flower.dev/py/release/v${{ env.TAG_NAME }}/${{ env.WHEEL_NAME }} + aws s3 cp --content-disposition "attachment" --cache-control "no-cache" ./${{ env.TAR_NAME }} s3://artifact.flower.dev/py/release/v${{ env.TAG_NAME }}/${{ env.TAR_NAME }} + + - name: Generate body + run: | + ./dev/get-latest-changelog.sh > body.md + cat body.md + + - name: Release + uses: softprops/action-gh-release@v1 + with: + body_path: ./body.md + draft: true + name: Flower ${{ env.TAG_NAME }} + files: | + ${{ env.WHEEL_NAME }} + ${{ env.TAR_NAME }} diff --git a/.github/workflows/framework-release.yml b/.github/workflows/framework-release.yml index eab15a51d217..0f3cda8abae3 100644 --- a/.github/workflows/framework-release.yml +++ b/.github/workflows/framework-release.yml @@ -1,63 +1,40 @@ -name: Release Framework +name: Publish `flwr` release on PyPI on: - push: - tags: - - "v*.*.*" - + release: + types: [released] + +concurrency: + group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.event.pull_request.number || github.ref }} + cancel-in-progress: true + jobs: publish: if: ${{ github.repository == 'adap/flower' }} - name: Publish draft + name: Publish release runs-on: ubuntu-22.04 steps: - - name: Checkout - uses: actions/checkout@v4 - with: - fetch-depth: 0 - - name: Wait for wheel to be built - uses: lewagon/wait-on-check-action@v1.3.1 - with: - ref: ${{ github.ref }} - check-name: 'Build, test and upload wheel' - repo-token: ${{ secrets.GITHUB_TOKEN }} - wait-interval: 10 - - name: Download wheel - run: | - tag_name=$(echo "${GITHUB_REF_NAME}" | cut -c2-) - echo "TAG_NAME=$tag_name" >> "$GITHUB_ENV" - - wheel_name="flwr-${tag_name}-py3-none-any.whl" - echo "WHEEL_NAME=$wheel_name" >> "$GITHUB_ENV" - - tar_name="flwr-${tag_name}.tar.gz" - echo "TAR_NAME=$tar_name" >> "$GITHUB_ENV" + - name: Checkout code + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Bootstrap + uses: ./.github/actions/bootstrap + + - name: Get artifacts and publish + env: + GITHUB_REF: ${{ github.ref }} + run: | + TAG_NAME=$(echo "${GITHUB_REF_NAME}" | cut -c2-) + + wheel_name="flwr-${TAG_NAME}-py3-none-any.whl" + tar_name="flwr-${TAG_NAME}.tar.gz" + + wheel_url="https://artifact.flower.dev/py/release/v${TAG_NAME}/${wheel_name}" + tar_url="https://artifact.flower.dev/py/release/v${TAG_NAME}/${tar_name}" + + curl $wheel_url --output $wheel_name + curl $tar_url --output $tar_name - wheel_url="https://artifact.flower.dev/py/main/${GITHUB_SHA::7}/${wheel_name}" - tar_url="https://artifact.flower.dev/py/main/${GITHUB_SHA::7}/${tar_name}" - - curl $wheel_url --output $wheel_name - curl $tar_url --output $tar_name - - name: Upload wheel - env: - AWS_DEFAULT_REGION: ${{ secrets. AWS_DEFAULT_REGION }} - AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} - AWS_SECRET_ACCESS_KEY: ${{ secrets. AWS_SECRET_ACCESS_KEY }} - run: | - aws s3 cp --content-disposition "attachment" --cache-control "no-cache" ./${{ env.WHEEL_NAME }} s3://artifact.flower.dev/py/release/v${{ env.TAG_NAME }}/${{ env.WHEEL_NAME }} - aws s3 cp --content-disposition "attachment" --cache-control "no-cache" ./${{ env.TAR_NAME }} s3://artifact.flower.dev/py/release/v${{ env.TAG_NAME }}/${{ env.TAR_NAME }} - - - name: Generate body - run: | - ./dev/get-latest-changelog.sh > body.md - cat body.md - - - name: Release - uses: softprops/action-gh-release@v1 - with: - body_path: ./body.md - draft: true - name: Flower ${{ env.TAG_NAME }} - files: | - ${{ env.WHEEL_NAME }} - ${{ env.TAR_NAME }} + python -m poetry publish -u __token__ -p ${{ secrets.PYPI_TOKEN }} diff --git a/README.md b/README.md index b8b62e8c0c43..750b5cdb4b93 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,7 @@ design of Flower is based on a few guiding principles: - **Framework-agnostic**: Different machine learning frameworks have different strengths. Flower can be used with any machine learning framework, for example, [PyTorch](https://pytorch.org), - [TensorFlow](https://tensorflow.org), [Hugging Face Transformers](https://huggingface.co/), [PyTorch Lightning](https://pytorchlightning.ai/), [MXNet](https://mxnet.apache.org/), [scikit-learn](https://scikit-learn.org/), [JAX](https://jax.readthedocs.io/), [TFLite](https://tensorflow.org/lite/), [fastai](https://www.fast.ai/), [Pandas](https://pandas.pydata.org/) for federated analytics, or even raw [NumPy](https://numpy.org/) + [TensorFlow](https://tensorflow.org), [Hugging Face Transformers](https://huggingface.co/), [PyTorch Lightning](https://pytorchlightning.ai/), [scikit-learn](https://scikit-learn.org/), [JAX](https://jax.readthedocs.io/), [TFLite](https://tensorflow.org/lite/), [fastai](https://www.fast.ai/), [Pandas](https://pandas.pydata.org/) for federated analytics, or even raw [NumPy](https://numpy.org/) for users who enjoy computing gradients by hand. - **Understandable**: Flower is written with maintainability in mind. The @@ -81,7 +81,6 @@ Stay tuned, more tutorials are coming soon. Topics include **Privacy and Securit - [Quickstart (PyTorch)](https://flower.dev/docs/framework/tutorial-quickstart-pytorch.html) - [Quickstart (Hugging Face)](https://flower.dev/docs/framework/tutorial-quickstart-huggingface.html) - [Quickstart (PyTorch Lightning [code example])](https://flower.dev/docs/framework/tutorial-quickstart-pytorch-lightning.html) -- [Quickstart (MXNet)](https://flower.dev/docs/framework/example-mxnet-walk-through.html) - [Quickstart (Pandas)](https://flower.dev/docs/framework/tutorial-quickstart-pandas.html) - [Quickstart (fastai)](https://flower.dev/docs/framework/tutorial-quickstart-fastai.html) - [Quickstart (JAX)](https://flower.dev/docs/framework/tutorial-quickstart-jax.html) @@ -124,7 +123,6 @@ Quickstart examples: - [Quickstart (PyTorch Lightning)](https://github.com/adap/flower/tree/main/examples/quickstart-pytorch-lightning) - [Quickstart (fastai)](https://github.com/adap/flower/tree/main/examples/quickstart-fastai) - [Quickstart (Pandas)](https://github.com/adap/flower/tree/main/examples/quickstart-pandas) -- [Quickstart (MXNet)](https://github.com/adap/flower/tree/main/examples/quickstart-mxnet) - [Quickstart (JAX)](https://github.com/adap/flower/tree/main/examples/quickstart-jax) - [Quickstart (scikit-learn)](https://github.com/adap/flower/tree/main/examples/sklearn-logreg-mnist) - [Quickstart (Android [TFLite])](https://github.com/adap/flower/tree/main/examples/android) @@ -134,7 +132,6 @@ Other [examples](https://github.com/adap/flower/tree/main/examples): - [Raspberry Pi & Nvidia Jetson Tutorial](https://github.com/adap/flower/tree/main/examples/embedded-devices) - [PyTorch: From Centralized to Federated](https://github.com/adap/flower/tree/main/examples/pytorch-from-centralized-to-federated) -- [MXNet: From Centralized to Federated](https://github.com/adap/flower/tree/main/examples/mxnet-from-centralized-to-federated) - [Advanced Flower with TensorFlow/Keras](https://github.com/adap/flower/tree/main/examples/advanced-tensorflow) - [Advanced Flower with PyTorch](https://github.com/adap/flower/tree/main/examples/advanced-pytorch) - Single-Machine Simulation of Federated Learning Systems ([PyTorch](https://github.com/adap/flower/tree/main/examples/simulation_pytorch)) ([Tensorflow](https://github.com/adap/flower/tree/main/examples/simulation_tensorflow)) diff --git a/baselines/fedavgm/LICENSE b/baselines/fedavgm/LICENSE new file mode 100644 index 000000000000..d64569567334 --- /dev/null +++ b/baselines/fedavgm/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/baselines/fedavgm/README.md b/baselines/fedavgm/README.md new file mode 100644 index 000000000000..0953331964a7 --- /dev/null +++ b/baselines/fedavgm/README.md @@ -0,0 +1,220 @@ +--- +title: Measuring the effects of non-identical data distribution for federated visual classification +url: https://arxiv.org/abs/1909.06335 +labels: [non-iid, image classification] +dataset: [CIFAR-10, Fashion-MNIST] +--- + +# Measuring the effects of non-identical data distribution for federated visual classification + +> Note: If you use this baseline in your work, please remember to cite the original authors of the paper as well as the Flower paper. + +**Paper:** [arxiv.org/abs/1909.06335](https://arxiv.org/abs/1909.06335) + +**Authors:** Tzu-Ming Harry Hsu, Hang Qi, Matthew Brown + +**Abstract:** Federated Learning enables visual models to be trained in a privacy-preserving way using real-world data from mobile devices. Given their distributed nature, the statistics of the data across these devices is likely to differ significantly. In this work, we look at the effect such non-identical data distributions has on visual classification via Federated Learning. We propose a way to synthesize datasets with a continuous range of identicalness and provide performance measures for the Federated Averaging algorithm. We show that performance degrades as distributions differ more, and propose a mitigation strategy via server momentum. Experiments on CIFAR-10 demonstrate improved classification performance over a range of non-identicalness, with classification accuracy improved from 30.1% to 76.9% in the most skewed settings. + + +## About this baseline + +**What’s implemented:** The code in this directory evaluates the effects of non-identical data distribution for visual classification task based on paper _Measuring the effects of non-identical data distribution for federated visual classification_ (Hsu et al., 2019). It reproduces the FedAvgM and FedAvg performance curves for different non-identical-ness of the dataset (CIFAR-10 and Fashion-MNIST). _Figure 5 in the paper, section 4.2._ + +**Datasets:** CIFAR-10, and Fashion-MNIST + +**Hardware Setup:** This baseline was evaluated in a regular PC without GPU (Intel i7-10710U CPU, and 32 Gb RAM). The major constraint is to run a huge number of rounds such as the reference paper that reports 10.000 round for each case evaluated. + +**Contributors:** Gustavo Bertoli [(@gubertoli)](https://github.com/gubertoli) + +## Experimental Setup + +**Task:** Image Classification + +**Model:** This directory implements a CNN model similar to the one used on the seminal FedAvg paper (`models.py`): + +- McMahan, B., Moore, E., Ramage, D., Hampson, S., & y Arcas, B. A. (2017, April). Communication-efficient learning of deep networks from decentralized data. In Artificial intelligence and statistics (pp. 1273-1282). PMLR. ([Link](http://proceedings.mlr.press/v54/mcmahan17a/mcmahan17a.pdf)): + +As the following excerpt: + +> "*We also ran experiments on the CIFAR-10 dataset... The model architecture was taken from the TensorFlow tutorial [38], which consists of two convolutional layers followed by two fully connected layers and then a linear transformation layer to produce logits, for a total of about 10 parameters."* + +Regarding this architecture, the historical references mentioned on the FedAvg and FedAvgM papers are [this](https://web.archive.org/web/20190415103404/https://www.tensorflow.org/tutorials/images/deep_cnn) and [this](https://web.archive.org/web/20170807002954/https://github.com/tensorflow/models/blob/master/tutorials/image/cifar10/cifar10.py). + +Important to highlight the rationale with this CNN model stated on FedAvgM paper: + +> "*This model is not the state-of-the-art on the CIFAR-10 dataset, but is sufficient to show relative performance for the purposes of our investigation."* + +**The default CNN model in use on this baseline have a centralized accuracy of ~0.74. That is different from the reported 0.86 accuracy from the original FedAvg paper. But it is still sufficient to show the relative performance for the purposes of FedAvgM investigation.** + +**Dataset:** This baseline includes the CIFAR-10 and Fashion-MNIST datasets. By default it will run with the CIFAR-10. The data partition uses a configurable Latent Dirichlet Allocation (LDA) distribution (`concentration` parameter equals 0.1 as default) to create **non-iid distributions** between the clients. The understanding for this `concentration` (α) is that α→∞ all clients have identical distribution, and α→𝟢 each client hold samples from only one class. + +| Dataset | # classes | # partitions | partition method | partition settings| +| :------ | :---: | :---: | :---: | :---: | +| CIFAR-10 | 10 | `num_clients` | Latent Dirichlet Allocation (LDA) | `concentration` | +| Fashion-MNIST | 10 | `num_clients` | Latent Dirichlet Allocation (LDA) | `concentration` | + +**Data distribution:** The following figure illustrates the use of multiple `concentration` values to generate the data distribution over 30 clients for CIFAR-10 (10 classes) - [source code](fedavgm/utils.py): + +![](_static/concentration_cifar10_v2.png) + +**Training Hyperparameters:** +The following table shows the main hyperparameters for this baseline with their default value (i.e. the value used if you run `python main.py` directly) + +| Description | Default Value | +| ----------- | ----- | +| total clients | 10 | +| number of rounds | 5 | +| model | CNN | +| strategy | Custom FedAvgM | +| dataset | CIFAR-10 | +| concentration | 0.1 | +| fraction evaluate | 0 | +| num cpus | 1 | +| num gpus | 0 | +| server momentum | 0.9 | +| server learning rate | 1.0 | +| server reporting fraction | 0.05 | +| client local epochs | 1 | +| client batch size | 64 | +| client learning rate | 0.01 | + +### Custom FedAvgM +In contrast to the initial implementation found in Flower v1.5.0, our baseline incorporates the Nesterov accelerated gradient as a pivotal component of the momentum applied to the server model. It is worth emphasizing that the inclusion of Nesterov momentum aligns with the original definition of FedAvgM in the research paper. + +To use the original Flower implementation, use the argument `strategy=fedavgm`. By default, the custom implementation is used. But, you can also refer to it on the command line as `strategy=custom-fedavgm`. + +## Environment Setup + +### Specifying the Python Version + +This baseline was tested with Python 3.10.6 and following the steps below to construct the Python environment and install all dependencies. Both [`pyenv`](https://github.com/pyenv/pyenv) and [`poetry`](https://python-poetry.org/docs/) are assumed to be already present in your system. + +```bash +# Cd to your baseline directory (i.e. where the `pyproject.toml` is), then +pyenv local 3.10.6 + +# Set that version for poetry +poetry env use 3.10.6 + +# Install the base Poetry environment +poetry install + +# Activate the environment +poetry shell +``` + +### Google Colab +If you want to setup the environemnt on Google Colab, please executed the script `conf-colab.sh`, just use the Colab terminal and the following: + +```bash +chmod +x conf-colab.sh +./conf-colab.sh +``` + +## Running the Experiments + +To run this FedAvgM with CIFAR-10 baseline, first ensure you have activated your Poetry environment (execute `poetry shell` from this directory), then: + +```bash +python -m fedavgm.main # this will run using the default setting in the `conf/base.yaml` + +# you can override settings directly from the command line + +python -m fedavgm.main strategy=fedavg num_clients=1000 num_rounds=50 # will set the FedAvg with 1000 clients and 50 rounds + +python -m fedavgm.main dataset=fmnist noniid.concentration=10 # use the Fashion-MNIST dataset and a different concentration for the LDA-based partition + +python -m fedavgm.main server.reporting_fraction=0.2 client.local_epochs=5 # will set the reporting fraction to 20% and the local epochs in the clients to 5 +``` + +## Expected Results + +### CIFAR-10 +Similar to FedAvgM paper as reference, the CIFAR-10 evaluation runs 10,000 rounds. + +> In order to speedup the execution of these experiments, the evaluation of the _global model_ on the test set only takes place after the last round. The highest accuracy is achieved towards the last rounds, not necessarily in the last. If you wish to evaluate the _global model_ on the test set (or a validation set) more frequently, edit `get_evaluate_fn` in `server.py`. Overal, running the experiments as shown below demonstrate that `FedAvgM` is consistently superior to `FedAvg`. + +For FedAvgM evaluation, it was performed a hyperparameter search of server momentum and client learning rate (similar to Figure 6 reported below) for each of the concentrations under analysis, using the following commands: + +- Concentration = 1e-5 and 1e-9 (extreme non-iid) +```bash +python -m fedavgm.main --multirun client.local_epochs=1 noniid.concentration=1e-5,1e-9 strategy=custom-fedavgm,fedavg \ +server.reporting_fraction=0.05 num_rounds=10000 num_clients=100 \ +dataset=cifar10 client.lr=0.0003 server.momentum=0.99 +``` + +- Concentration = 0.01 +```bash +python -m fedavgm.main --multirun client.local_epochs=1 noniid.concentration=0.01 strategy=custom-fedavgm,fedavg \ +server.reporting_fraction=0.05 num_rounds=10000 num_clients=100 \ +dataset=cifar10 client.lr=0.003 server.momentum=0.97 +``` + +- Concentration = 0.1 +```bash +python -m fedavgm.main --multirun client.local_epochs=1 noniid.concentration=0.1 strategy=custom-fedavgm,fedavg \ +server.reporting_fraction=0.05 num_rounds=10000 num_clients=100 \ +dataset=cifar10 client.lr=0.0003 server.momentum=0.99 +``` + +- Concentration = 1 +```bash +python -m fedavgm.main --multirun client.local_epochs=1 noniid.concentration=1 strategy=custom-fedavgm,fedavg \ +server.reporting_fraction=0.05 num_rounds=10000 num_clients=100 \ +dataset=cifar10 client.lr=0.0003 server.momentum=0.997 +``` + +- Concentration = 10 +```bash +python -m fedavgm.main --multirun client.local_epochs=1 noniid.concentration=10 strategy=custom-fedavgm,fedavg \ +server.reporting_fraction=0.05 num_rounds=10000 num_clients=100 \ +dataset=cifar10 client.lr=0.003 server.momentum=0.9 +``` + +Summarizing all the results: + +![](_static/fedavgm_vs_fedavg_rounds=10000_cifar10_w_1e-9.png) + +The findings aligns with the report on the original FedAvgM paper that *"To prevent client updates from diverging, we additionally have to use a combination of low absolute learning rate and high momentum"*. + +The following command reproduces the same behavior of Figure 6 from FedAvgM paper for the case of Local Epoch E=1, Reporting Fraction C=0.05, and concentration (α) = 1. In this example, it runs just 1,000 rounds: + +```bash +python -m fedavgm.main --multirun client.local_epochs=1 noniid.concentration=1 \ +strategy=custom-fedavgm server.reporting_fraction=0.05 num_rounds=100 num_clients=100 \ +dataset=cifar10 client.lr=0.0001,0.0003,0.001,0.003,0.01,0.03,0.1,0.3 \ +server.momentum=0.7,0.9,0.97,0.99,0.997 +``` + +![](_static/Figure6_cifar10_num-rounds=1000_concentration=1.png) + + +--- +### Fashion-MNIST + +```bash +python -m fedavgm.main --multirun client.local_epochs=1 \ +noniid.concentration=0.001,0.01,0.1,1,10,100 strategy=custom-fedavgm,fedavg \ +server.reporting_fraction=0.05 num_rounds=1000 \ +num_clients=100 dataset=fmnist server.momentum=0.97 client.lr=0.003 +``` +The above command will evaluate the custom FedAvgM versus FedAvg on Fashion-MNIST datasets. It uses 100 clients with a reporting fraction of 5% during 1000 rounds. To evaluate the non-iid aspects, this exececution exercises concentration of [100, 10, 1, 0.1, 0.01, 0.001]: + +![](_static/fedavgm_vs_fedavg_rounds=1000_fmnist.png) + +#### Comparison between the Custom-FedAvgM and FedAvgM + +To compare the improvement of the FedAvgM with Nesterov momentum (`strategy=custom-fedavgm`) and the FedAvgM without the Nesterov momentum (`strategy=fedavgm`), here we use the results of previous running with addition of the same conditions for the `fedavgm` strategy as follows: + +```bash +python -m fedavgm.main --multirun client.local_epochs=1 \ +noniid.concentration=0.001,0.01,0.1,1,10,100 strategy=fedavgm \ +server.reporting_fraction=0.05 num_rounds=1000 \ +num_clients=100 dataset=fmnist server.momentum=0.97 client.lr=0.003 +``` + +![](_static/custom-fedavgm_vs_fedavgm_rounds=1000_fmnist.png) + +Overall, FedAvgM with Nesterov momentum outperforms the FedAvgM without Nesterov momentum, being clear this behavior for higher non-iidness (0.01 and 0.001). In these higher non-iidness, the test accuracy for FedAvg without Nesterov momentum are worse than the FedAvg. +For larger concentrations (1, 10, 100), it was observed some runs that the centralized evaluation resulted in a loss equal NaN or Inf, thus it was required multiple runs to guarantee the accuracies reported. + diff --git a/baselines/fedavgm/_static/Comparison_CNN_vs_TF_v1_x_Example_for_CIFAR_10.ipynb b/baselines/fedavgm/_static/Comparison_CNN_vs_TF_v1_x_Example_for_CIFAR_10.ipynb new file mode 100644 index 000000000000..fac837d145e3 --- /dev/null +++ b/baselines/fedavgm/_static/Comparison_CNN_vs_TF_v1_x_Example_for_CIFAR_10.ipynb @@ -0,0 +1,1851 @@ +{ + "cells": [ + { + "cell_type": "code", + "source": [ + "import numpy as np\n", + "from keras.optimizers import SGD\n", + "from keras.regularizers import l2\n", + "from tensorflow import keras\n", + "from tensorflow.nn import local_response_normalization\n", + "from keras.utils import to_categorical\n", + "import matplotlib.pyplot as plt" + ], + "metadata": { + "id": "Rp9LUn54SUTu" + }, + "execution_count": 15, + "outputs": [] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "id": "7tTxE8D6bD6g" + }, + "outputs": [], + "source": [ + "def tf_example(input_shape, num_classes):\n", + " \"\"\"CNN Model from TensorFlow v1.x example.\n", + "\n", + " This is the model referenced on the FedAvg paper.\n", + "\n", + " Reference:\n", + " https://web.archive.org/web/20170807002954/https://github.com/tensorflow/models/blob/master/tutorials/image/cifar10/cifar10.py\n", + " \"\"\"\n", + " input_shape = tuple(input_shape)\n", + "\n", + " weight_decay = 0.004\n", + " model = keras.Sequential(\n", + " [\n", + " keras.layers.Conv2D(\n", + " 64,\n", + " (5, 5),\n", + " padding=\"same\",\n", + " activation=\"relu\",\n", + " input_shape=input_shape,\n", + " ),\n", + " keras.layers.MaxPooling2D((3, 3), strides=(2, 2), padding=\"same\"),\n", + " keras.layers.Lambda(\n", + " local_response_normalization,\n", + " arguments={\n", + " \"depth_radius\": 4,\n", + " \"bias\": 1.0,\n", + " \"alpha\": 0.001 / 9.0,\n", + " \"beta\": 0.75,\n", + " },\n", + " ),\n", + " keras.layers.Conv2D(\n", + " 64,\n", + " (5, 5),\n", + " padding=\"same\",\n", + " activation=\"relu\",\n", + " ),\n", + " keras.layers.Lambda(\n", + " local_response_normalization,\n", + " arguments={\n", + " \"depth_radius\": 4,\n", + " \"bias\": 1.0,\n", + " \"alpha\": 0.001 / 9.0,\n", + " \"beta\": 0.75,\n", + " },\n", + " ),\n", + " keras.layers.MaxPooling2D((3, 3), strides=(2, 2), padding=\"same\"),\n", + " keras.layers.Flatten(),\n", + " keras.layers.Dense(\n", + " 384, activation=\"relu\", kernel_regularizer=l2(weight_decay)\n", + " ),\n", + " keras.layers.Dense(\n", + " 192, activation=\"relu\", kernel_regularizer=l2(weight_decay)\n", + " ),\n", + " keras.layers.Dense(num_classes, activation=\"softmax\"),\n", + " ]\n", + " )\n", + " optimizer = SGD(learning_rate=0.1)\n", + " model.compile(\n", + " loss=\"categorical_crossentropy\", optimizer=optimizer, metrics=[\"accuracy\"]\n", + " )\n", + "\n", + " return model\n", + "\n" + ] + }, + { + "cell_type": "code", + "source": [ + "def cifar10(num_classes, input_shape):\n", + " \"\"\"Prepare the CIFAR-10.\n", + "\n", + " This method considers CIFAR-10 for creating both train and test sets. The sets are\n", + " already normalized.\n", + " \"\"\"\n", + " print(f\">>> [Dataset] Loading CIFAR-10. {num_classes} | {input_shape}.\")\n", + " (x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()\n", + " x_train = x_train.astype(\"float32\") / 255\n", + " x_test = x_test.astype(\"float32\") / 255\n", + " input_shape = x_train.shape[1:]\n", + " num_classes = len(np.unique(y_train))\n", + "\n", + " return x_train, y_train, x_test, y_test, input_shape, num_classes" + ], + "metadata": { + "id": "vuQykx1uSXHk" + }, + "execution_count": 17, + "outputs": [] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "FMph7H-qbHHR", + "outputId": "45cf4a68-7054-460e-bcd7-c353338dc387" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + ">>> [Dataset] Loading CIFAR-10. 10 | (32, 32, 3).\n" + ] + } + ], + "source": [ + "x_train, y_train, x_test, y_test, input_shape,num_classes = cifar10(10, (32,32,3))\n" + ] + }, + { + "cell_type": "code", + "source": [ + "EPOCHS=350\n", + "BATCH_SIZE=128" + ], + "metadata": { + "id": "AD2qsybwX6uR" + }, + "execution_count": 19, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "---" + ], + "metadata": { + "id": "531ZRrY2SY85" + } + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": { + "id": "1BO5D4ZBbJJo" + }, + "outputs": [], + "source": [ + "model = tf_example(input_shape, num_classes)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "8DMMAgw6bK2C", + "outputId": "9c1203a2-7152-4c25-dc1c-1c58b1cf8b8b" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Epoch 1/350\n", + "391/391 [==============================] - 8s 18ms/step - loss: 4.8242 - accuracy: 0.2914\n", + "Epoch 2/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 3.0276 - accuracy: 0.4814\n", + "Epoch 3/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 2.1395 - accuracy: 0.5609\n", + "Epoch 4/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 1.6463 - accuracy: 0.6129\n", + "Epoch 5/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 1.3656 - accuracy: 0.6504\n", + "Epoch 6/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 1.1851 - accuracy: 0.6868\n", + "Epoch 7/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 1.0698 - accuracy: 0.7147\n", + "Epoch 8/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.9918 - accuracy: 0.7350\n", + "Epoch 9/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.9465 - accuracy: 0.7551\n", + "Epoch 10/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.8991 - accuracy: 0.7747\n", + "Epoch 11/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.8534 - accuracy: 0.7971\n", + "Epoch 12/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.8305 - accuracy: 0.8111\n", + "Epoch 13/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.8070 - accuracy: 0.8265\n", + "Epoch 14/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.7805 - accuracy: 0.8434\n", + "Epoch 15/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.7672 - accuracy: 0.8527\n", + "Epoch 16/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.7504 - accuracy: 0.8647\n", + "Epoch 17/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.7418 - accuracy: 0.8715\n", + "Epoch 18/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.7244 - accuracy: 0.8819\n", + "Epoch 19/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.7205 - accuracy: 0.8871\n", + "Epoch 20/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.7032 - accuracy: 0.8966\n", + "Epoch 21/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6965 - accuracy: 0.8999\n", + "Epoch 22/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6998 - accuracy: 0.9026\n", + "Epoch 23/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6952 - accuracy: 0.9065\n", + "Epoch 24/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6795 - accuracy: 0.9120\n", + "Epoch 25/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6913 - accuracy: 0.9100\n", + "Epoch 26/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6822 - accuracy: 0.9144\n", + "Epoch 27/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6773 - accuracy: 0.9174\n", + "Epoch 28/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6885 - accuracy: 0.9155\n", + "Epoch 29/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6588 - accuracy: 0.9239\n", + "Epoch 30/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6631 - accuracy: 0.9230\n", + "Epoch 31/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6819 - accuracy: 0.9193\n", + "Epoch 32/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6557 - accuracy: 0.9271\n", + "Epoch 33/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6806 - accuracy: 0.9224\n", + "Epoch 34/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6525 - accuracy: 0.9299\n", + "Epoch 35/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6500 - accuracy: 0.9303\n", + "Epoch 36/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6701 - accuracy: 0.9234\n", + "Epoch 37/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6627 - accuracy: 0.9297\n", + "Epoch 38/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6507 - accuracy: 0.9321\n", + "Epoch 39/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6497 - accuracy: 0.9323\n", + "Epoch 40/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6593 - accuracy: 0.9304\n", + "Epoch 41/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6528 - accuracy: 0.9325\n", + "Epoch 42/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6294 - accuracy: 0.9365\n", + "Epoch 43/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6596 - accuracy: 0.9304\n", + "Epoch 44/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6493 - accuracy: 0.9343\n", + "Epoch 45/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6440 - accuracy: 0.9351\n", + "Epoch 46/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6233 - accuracy: 0.9392\n", + "Epoch 47/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6631 - accuracy: 0.9301\n", + "Epoch 48/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6341 - accuracy: 0.9397\n", + "Epoch 49/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6440 - accuracy: 0.9351\n", + "Epoch 50/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6540 - accuracy: 0.9354\n", + "Epoch 51/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6371 - accuracy: 0.9407\n", + "Epoch 52/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6464 - accuracy: 0.9373\n", + "Epoch 53/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6489 - accuracy: 0.9371\n", + "Epoch 54/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6471 - accuracy: 0.9386\n", + "Epoch 55/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6342 - accuracy: 0.9414\n", + "Epoch 56/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6447 - accuracy: 0.9379\n", + "Epoch 57/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6291 - accuracy: 0.9431\n", + "Epoch 58/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6447 - accuracy: 0.9376\n", + "Epoch 59/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6493 - accuracy: 0.9401\n", + "Epoch 60/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6317 - accuracy: 0.9425\n", + "Epoch 61/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6179 - accuracy: 0.9450\n", + "Epoch 62/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6549 - accuracy: 0.9370\n", + "Epoch 63/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6333 - accuracy: 0.9449\n", + "Epoch 64/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6345 - accuracy: 0.9409\n", + "Epoch 65/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6320 - accuracy: 0.9440\n", + "Epoch 66/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6361 - accuracy: 0.9423\n", + "Epoch 67/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6285 - accuracy: 0.9444\n", + "Epoch 68/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6324 - accuracy: 0.9427\n", + "Epoch 69/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6427 - accuracy: 0.9397\n", + "Epoch 70/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6429 - accuracy: 0.9436\n", + "Epoch 71/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6226 - accuracy: 0.9465\n", + "Epoch 72/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6406 - accuracy: 0.9411\n", + "Epoch 73/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6197 - accuracy: 0.9470\n", + "Epoch 74/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6285 - accuracy: 0.9434\n", + "Epoch 75/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6307 - accuracy: 0.9447\n", + "Epoch 76/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6243 - accuracy: 0.9465\n", + "Epoch 77/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6274 - accuracy: 0.9468\n", + "Epoch 78/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6397 - accuracy: 0.9432\n", + "Epoch 79/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6282 - accuracy: 0.9468\n", + "Epoch 80/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6408 - accuracy: 0.9434\n", + "Epoch 81/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6368 - accuracy: 0.9468\n", + "Epoch 82/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6158 - accuracy: 0.9499\n", + "Epoch 83/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6100 - accuracy: 0.9478\n", + "Epoch 84/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6321 - accuracy: 0.9429\n", + "Epoch 85/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6260 - accuracy: 0.9477\n", + "Epoch 86/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6235 - accuracy: 0.9463\n", + "Epoch 87/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6186 - accuracy: 0.9493\n", + "Epoch 88/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6155 - accuracy: 0.9481\n", + "Epoch 89/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6702 - accuracy: 0.9374\n", + "Epoch 90/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6188 - accuracy: 0.9502\n", + "Epoch 91/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6410 - accuracy: 0.9439\n", + "Epoch 92/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6052 - accuracy: 0.9528\n", + "Epoch 93/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6349 - accuracy: 0.9431\n", + "Epoch 94/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6216 - accuracy: 0.9486\n", + "Epoch 95/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6128 - accuracy: 0.9497\n", + "Epoch 96/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6286 - accuracy: 0.9469\n", + "Epoch 97/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6095 - accuracy: 0.9515\n", + "Epoch 98/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6124 - accuracy: 0.9487\n", + "Epoch 99/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6267 - accuracy: 0.9482\n", + "Epoch 100/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6323 - accuracy: 0.9459\n", + "Epoch 101/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6116 - accuracy: 0.9507\n", + "Epoch 102/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6192 - accuracy: 0.9478\n", + "Epoch 103/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6229 - accuracy: 0.9482\n", + "Epoch 104/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6261 - accuracy: 0.9486\n", + "Epoch 105/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6140 - accuracy: 0.9521\n", + "Epoch 106/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6256 - accuracy: 0.9476\n", + "Epoch 107/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6118 - accuracy: 0.9525\n", + "Epoch 108/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6064 - accuracy: 0.9502\n", + "Epoch 109/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6161 - accuracy: 0.9487\n", + "Epoch 110/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6105 - accuracy: 0.9513\n", + "Epoch 111/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6302 - accuracy: 0.9468\n", + "Epoch 112/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6022 - accuracy: 0.9534\n", + "Epoch 113/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5993 - accuracy: 0.9518\n", + "Epoch 114/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6260 - accuracy: 0.9462\n", + "Epoch 115/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6026 - accuracy: 0.9538\n", + "Epoch 116/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6144 - accuracy: 0.9499\n", + "Epoch 117/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6054 - accuracy: 0.9516\n", + "Epoch 118/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6122 - accuracy: 0.9504\n", + "Epoch 119/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6187 - accuracy: 0.9506\n", + "Epoch 120/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6030 - accuracy: 0.9524\n", + "Epoch 121/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6078 - accuracy: 0.9513\n", + "Epoch 122/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6113 - accuracy: 0.9503\n", + "Epoch 123/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6080 - accuracy: 0.9525\n", + "Epoch 124/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5991 - accuracy: 0.9539\n", + "Epoch 125/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5985 - accuracy: 0.9529\n", + "Epoch 126/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6103 - accuracy: 0.9509\n", + "Epoch 127/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5947 - accuracy: 0.9557\n", + "Epoch 128/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5945 - accuracy: 0.9532\n", + "Epoch 129/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6059 - accuracy: 0.9520\n", + "Epoch 130/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6076 - accuracy: 0.9517\n", + "Epoch 131/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6134 - accuracy: 0.9520\n", + "Epoch 132/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5950 - accuracy: 0.9546\n", + "Epoch 133/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5881 - accuracy: 0.9557\n", + "Epoch 134/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6095 - accuracy: 0.9494\n", + "Epoch 135/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6116 - accuracy: 0.9537\n", + "Epoch 136/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5860 - accuracy: 0.9554\n", + "Epoch 137/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6058 - accuracy: 0.9519\n", + "Epoch 138/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6043 - accuracy: 0.9542\n", + "Epoch 139/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5921 - accuracy: 0.9556\n", + "Epoch 140/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5983 - accuracy: 0.9530\n", + "Epoch 141/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5987 - accuracy: 0.9537\n", + "Epoch 142/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5983 - accuracy: 0.9544\n", + "Epoch 143/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5734 - accuracy: 0.9576\n", + "Epoch 144/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5895 - accuracy: 0.9534\n", + "Epoch 145/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6068 - accuracy: 0.9519\n", + "Epoch 146/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5973 - accuracy: 0.9548\n", + "Epoch 147/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5786 - accuracy: 0.9566\n", + "Epoch 148/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5833 - accuracy: 0.9547\n", + "Epoch 149/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6056 - accuracy: 0.9511\n", + "Epoch 150/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6060 - accuracy: 0.9517\n", + "Epoch 151/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5907 - accuracy: 0.9567\n", + "Epoch 152/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5922 - accuracy: 0.9541\n", + "Epoch 153/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5961 - accuracy: 0.9527\n", + "Epoch 154/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5878 - accuracy: 0.9580\n", + "Epoch 155/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5790 - accuracy: 0.9580\n", + "Epoch 156/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5977 - accuracy: 0.9523\n", + "Epoch 157/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5968 - accuracy: 0.9540\n", + "Epoch 158/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5950 - accuracy: 0.9547\n", + "Epoch 159/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5907 - accuracy: 0.9554\n", + "Epoch 160/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5817 - accuracy: 0.9560\n", + "Epoch 161/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5962 - accuracy: 0.9536\n", + "Epoch 162/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5876 - accuracy: 0.9572\n", + "Epoch 163/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5818 - accuracy: 0.9558\n", + "Epoch 164/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5896 - accuracy: 0.9541\n", + "Epoch 165/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5915 - accuracy: 0.9552\n", + "Epoch 166/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5928 - accuracy: 0.9555\n", + "Epoch 167/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5773 - accuracy: 0.9576\n", + "Epoch 168/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5817 - accuracy: 0.9560\n", + "Epoch 169/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5817 - accuracy: 0.9563\n", + "Epoch 170/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5877 - accuracy: 0.9565\n", + "Epoch 171/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5893 - accuracy: 0.9554\n", + "Epoch 172/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5946 - accuracy: 0.9543\n", + "Epoch 173/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5841 - accuracy: 0.9571\n", + "Epoch 174/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5745 - accuracy: 0.9598\n", + "Epoch 175/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5715 - accuracy: 0.9580\n", + "Epoch 176/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5809 - accuracy: 0.9552\n", + "Epoch 177/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5845 - accuracy: 0.9557\n", + "Epoch 178/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5720 - accuracy: 0.9591\n", + "Epoch 179/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5901 - accuracy: 0.9541\n", + "Epoch 180/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5667 - accuracy: 0.9608\n", + "Epoch 181/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5857 - accuracy: 0.9552\n", + "Epoch 182/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5694 - accuracy: 0.9613\n", + "Epoch 183/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5732 - accuracy: 0.9574\n", + "Epoch 184/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5861 - accuracy: 0.9562\n", + "Epoch 185/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5737 - accuracy: 0.9580\n", + "Epoch 186/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5816 - accuracy: 0.9584\n", + "Epoch 187/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5566 - accuracy: 0.9602\n", + "Epoch 188/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5664 - accuracy: 0.9576\n", + "Epoch 189/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5911 - accuracy: 0.9535\n", + "Epoch 190/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5742 - accuracy: 0.9595\n", + "Epoch 191/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5748 - accuracy: 0.9559\n", + "Epoch 192/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5606 - accuracy: 0.9604\n", + "Epoch 193/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.6116 - accuracy: 0.9508\n", + "Epoch 194/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5868 - accuracy: 0.9591\n", + "Epoch 195/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5400 - accuracy: 0.9650\n", + "Epoch 196/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5624 - accuracy: 0.9574\n", + "Epoch 197/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5864 - accuracy: 0.9554\n", + "Epoch 198/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5773 - accuracy: 0.9585\n", + "Epoch 199/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5699 - accuracy: 0.9580\n", + "Epoch 200/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5703 - accuracy: 0.9595\n", + "Epoch 201/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5723 - accuracy: 0.9601\n", + "Epoch 202/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5641 - accuracy: 0.9591\n", + "Epoch 203/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5812 - accuracy: 0.9565\n", + "Epoch 204/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5653 - accuracy: 0.9612\n", + "Epoch 205/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5697 - accuracy: 0.9592\n", + "Epoch 206/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5726 - accuracy: 0.9591\n", + "Epoch 207/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5570 - accuracy: 0.9612\n", + "Epoch 208/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5598 - accuracy: 0.9599\n", + "Epoch 209/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5709 - accuracy: 0.9578\n", + "Epoch 210/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5836 - accuracy: 0.9563\n", + "Epoch 211/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5621 - accuracy: 0.9613\n", + "Epoch 212/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5722 - accuracy: 0.9582\n", + "Epoch 213/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5483 - accuracy: 0.9624\n", + "Epoch 214/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5708 - accuracy: 0.9563\n", + "Epoch 215/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5746 - accuracy: 0.9572\n", + "Epoch 216/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5840 - accuracy: 0.9584\n", + "Epoch 217/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5545 - accuracy: 0.9623\n", + "Epoch 218/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5402 - accuracy: 0.9628\n", + "Epoch 219/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5546 - accuracy: 0.9591\n", + "Epoch 220/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5762 - accuracy: 0.9552\n", + "Epoch 221/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5596 - accuracy: 0.9604\n", + "Epoch 222/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5548 - accuracy: 0.9610\n", + "Epoch 223/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5586 - accuracy: 0.9608\n", + "Epoch 224/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5504 - accuracy: 0.9612\n", + "Epoch 225/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5496 - accuracy: 0.9607\n", + "Epoch 226/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5763 - accuracy: 0.9562\n", + "Epoch 227/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5664 - accuracy: 0.9602\n", + "Epoch 228/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5404 - accuracy: 0.9648\n", + "Epoch 229/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5603 - accuracy: 0.9580\n", + "Epoch 230/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5574 - accuracy: 0.9610\n", + "Epoch 231/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5575 - accuracy: 0.9586\n", + "Epoch 232/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5660 - accuracy: 0.9585\n", + "Epoch 233/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5427 - accuracy: 0.9640\n", + "Epoch 234/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5468 - accuracy: 0.9611\n", + "Epoch 235/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5678 - accuracy: 0.9581\n", + "Epoch 236/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5472 - accuracy: 0.9622\n", + "Epoch 237/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5561 - accuracy: 0.9601\n", + "Epoch 238/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5471 - accuracy: 0.9621\n", + "Epoch 239/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5539 - accuracy: 0.9601\n", + "Epoch 240/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5492 - accuracy: 0.9619\n", + "Epoch 241/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5674 - accuracy: 0.9581\n", + "Epoch 242/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5451 - accuracy: 0.9618\n", + "Epoch 243/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5280 - accuracy: 0.9646\n", + "Epoch 244/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5628 - accuracy: 0.9579\n", + "Epoch 245/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5504 - accuracy: 0.9625\n", + "Epoch 246/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5284 - accuracy: 0.9647\n", + "Epoch 247/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5277 - accuracy: 0.9629\n", + "Epoch 248/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5490 - accuracy: 0.9599\n", + "Epoch 249/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5582 - accuracy: 0.9601\n", + "Epoch 250/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5701 - accuracy: 0.9587\n", + "Epoch 251/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5274 - accuracy: 0.9664\n", + "Epoch 252/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5343 - accuracy: 0.9618\n", + "Epoch 253/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5601 - accuracy: 0.9586\n", + "Epoch 254/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5544 - accuracy: 0.9608\n", + "Epoch 255/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5447 - accuracy: 0.9631\n", + "Epoch 256/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5355 - accuracy: 0.9634\n", + "Epoch 257/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5321 - accuracy: 0.9625\n", + "Epoch 258/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5554 - accuracy: 0.9593\n", + "Epoch 259/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5533 - accuracy: 0.9608\n", + "Epoch 260/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5511 - accuracy: 0.9618\n", + "Epoch 261/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5180 - accuracy: 0.9667\n", + "Epoch 262/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5495 - accuracy: 0.9582\n", + "Epoch 263/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5347 - accuracy: 0.9640\n", + "Epoch 264/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5289 - accuracy: 0.9639\n", + "Epoch 265/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5340 - accuracy: 0.9623\n", + "Epoch 266/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5463 - accuracy: 0.9604\n", + "Epoch 267/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5383 - accuracy: 0.9639\n", + "Epoch 268/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5421 - accuracy: 0.9614\n", + "Epoch 269/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5213 - accuracy: 0.9651\n", + "Epoch 270/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5470 - accuracy: 0.9599\n", + "Epoch 271/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5388 - accuracy: 0.9634\n", + "Epoch 272/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5384 - accuracy: 0.9630\n", + "Epoch 273/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5326 - accuracy: 0.9638\n", + "Epoch 274/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5442 - accuracy: 0.9609\n", + "Epoch 275/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5384 - accuracy: 0.9634\n", + "Epoch 276/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5302 - accuracy: 0.9627\n", + "Epoch 277/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5403 - accuracy: 0.9617\n", + "Epoch 278/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5325 - accuracy: 0.9647\n", + "Epoch 279/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5370 - accuracy: 0.9619\n", + "Epoch 280/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5357 - accuracy: 0.9640\n", + "Epoch 281/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5287 - accuracy: 0.9640\n", + "Epoch 282/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5315 - accuracy: 0.9613\n", + "Epoch 283/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5361 - accuracy: 0.9649\n", + "Epoch 284/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5382 - accuracy: 0.9614\n", + "Epoch 285/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5313 - accuracy: 0.9637\n", + "Epoch 286/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5414 - accuracy: 0.9618\n", + "Epoch 287/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5197 - accuracy: 0.9667\n", + "Epoch 288/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5287 - accuracy: 0.9613\n", + "Epoch 289/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5433 - accuracy: 0.9610\n", + "Epoch 290/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5371 - accuracy: 0.9637\n", + "Epoch 291/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5274 - accuracy: 0.9636\n", + "Epoch 292/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5302 - accuracy: 0.9638\n", + "Epoch 293/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5418 - accuracy: 0.9611\n", + "Epoch 294/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5264 - accuracy: 0.9648\n", + "Epoch 295/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5397 - accuracy: 0.9614\n", + "Epoch 296/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5217 - accuracy: 0.9652\n", + "Epoch 297/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5193 - accuracy: 0.9648\n", + "Epoch 298/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5296 - accuracy: 0.9643\n", + "Epoch 299/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5312 - accuracy: 0.9621\n", + "Epoch 300/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5321 - accuracy: 0.9632\n", + "Epoch 301/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5151 - accuracy: 0.9664\n", + "Epoch 302/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5239 - accuracy: 0.9634\n", + "Epoch 303/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5264 - accuracy: 0.9640\n", + "Epoch 304/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5168 - accuracy: 0.9652\n", + "Epoch 305/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5238 - accuracy: 0.9649\n", + "Epoch 306/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5178 - accuracy: 0.9635\n", + "Epoch 307/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5119 - accuracy: 0.9650\n", + "Epoch 308/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5237 - accuracy: 0.9634\n", + "Epoch 309/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5284 - accuracy: 0.9635\n", + "Epoch 310/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5121 - accuracy: 0.9660\n", + "Epoch 311/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5394 - accuracy: 0.9599\n", + "Epoch 312/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5048 - accuracy: 0.9697\n", + "Epoch 313/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5069 - accuracy: 0.9650\n", + "Epoch 314/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5089 - accuracy: 0.9657\n", + "Epoch 315/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5242 - accuracy: 0.9627\n", + "Epoch 316/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5258 - accuracy: 0.9638\n", + "Epoch 317/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5221 - accuracy: 0.9643\n", + "Epoch 318/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5045 - accuracy: 0.9666\n", + "Epoch 319/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5195 - accuracy: 0.9652\n", + "Epoch 320/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5005 - accuracy: 0.9680\n", + "Epoch 321/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5238 - accuracy: 0.9615\n", + "Epoch 322/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5321 - accuracy: 0.9618\n", + "Epoch 323/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5160 - accuracy: 0.9674\n", + "Epoch 324/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5245 - accuracy: 0.9628\n", + "Epoch 325/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5109 - accuracy: 0.9669\n", + "Epoch 326/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5138 - accuracy: 0.9656\n", + "Epoch 327/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.4976 - accuracy: 0.9667\n", + "Epoch 328/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5200 - accuracy: 0.9624\n", + "Epoch 329/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.4939 - accuracy: 0.9700\n", + "Epoch 330/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.4973 - accuracy: 0.9646\n", + "Epoch 331/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5258 - accuracy: 0.9619\n", + "Epoch 332/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5384 - accuracy: 0.9623\n", + "Epoch 333/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5265 - accuracy: 0.9655\n", + "Epoch 334/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5038 - accuracy: 0.9678\n", + "Epoch 335/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5162 - accuracy: 0.9643\n", + "Epoch 336/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5066 - accuracy: 0.9665\n", + "Epoch 337/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5063 - accuracy: 0.9660\n", + "Epoch 338/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5078 - accuracy: 0.9658\n", + "Epoch 339/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5310 - accuracy: 0.9632\n", + "Epoch 340/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.4861 - accuracy: 0.9703\n", + "Epoch 341/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5143 - accuracy: 0.9631\n", + "Epoch 342/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5199 - accuracy: 0.9637\n", + "Epoch 343/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.4992 - accuracy: 0.9685\n", + "Epoch 344/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5109 - accuracy: 0.9644\n", + "Epoch 345/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5066 - accuracy: 0.9657\n", + "Epoch 346/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5142 - accuracy: 0.9651\n", + "Epoch 347/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5092 - accuracy: 0.9649\n", + "Epoch 348/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5188 - accuracy: 0.9636\n", + "Epoch 349/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.5069 - accuracy: 0.9677\n", + "Epoch 350/350\n", + "391/391 [==============================] - 7s 18ms/step - loss: 0.4872 - accuracy: 0.9686\n" + ] + } + ], + "source": [ + "history = model.fit(x_train, to_categorical(y_train, num_classes), epochs=EPOCHS, batch_size=BATCH_SIZE)" + ] + }, + { + "cell_type": "code", + "source": [ + "loss = history.history['loss']\n", + "epochs = range(1, len(loss) + 1)\n", + "\n", + "plt.plot(epochs, loss, 'b', label='Training Loss')\n", + "plt.title('Training Loss')\n", + "plt.xlabel('Epochs')\n", + "plt.ylabel('Loss')\n", + "plt.legend()\n", + "plt.show()\n" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 472 + }, + "id": "6lrFuQrNRCyv", + "outputId": "3bc66200-18f3-483e-8c8c-3b44072fe7bb" + }, + "execution_count": 22, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "
" + ], + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAioAAAHHCAYAAACRAnNyAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAABEw0lEQVR4nO3deXhU1f3H8c+EkCGBLCyBhB0BWQUREEEBKwhEpYD4UykqoIUKqNBKq6gsai2i1lq1xbVQkUqVAlIEBVRcEBUUEBARlE22yJaFJUByfn+czoSBTDaSnEnyfj3PPJm5986dc++M3g/fc+69HmOMEQAAQAgKc90AAACAYAgqAAAgZBFUAABAyCKoAACAkEVQAQAAIYugAgAAQhZBBQAAhCyCCgAACFkEFQAAELIIKgDybejQoWrYsGGh3jt58mR5PJ6ibRCAMo+gApQBHo8nX4/ly5e7bqoTQ4cOVZUqVVw3A0AheLjXD1D6vf766wGvX3vtNS1dulQzZ84MmH711VerVq1ahf6cU6dOKSsrS16vt8DvPX36tE6fPq1KlSoV+vMLa+jQoZozZ47S09NL/LMBnJ9w1w0AcP5uueWWgNeff/65li5des70sx07dkxRUVH5/pyKFSsWqn2SFB4ervBw/pcDoGDo+gHKiSuvvFKtW7fWV199pW7duikqKkoPPPCAJOntt9/Wtddeq9q1a8vr9apx48Z69NFHlZmZGbCOs8eobN++XR6PR0899ZReeuklNW7cWF6vVx07dtSqVasC3pvTGBWPx6O77rpL8+fPV+vWreX1etWqVSu9++6757R/+fLl6tChgypVqqTGjRvrxRdfLPJxL2+99Zbat2+vyMhI1ahRQ7fccot2794dsMy+ffs0bNgw1a1bV16vV4mJierXr5+2b9/uX2b16tXq3bu3atSoocjISDVq1Ei33357kbUTKE/45w1Qjhw8eFBJSUm6+eabdcstt/i7gWbMmKEqVarod7/7napUqaIPPvhAEydOVGpqqp588sk81/uvf/1LaWlp+s1vfiOPx6MnnnhC119/vX788cc8qzCffvqp5s6dq1GjRik6OlrPPvusBg4cqJ07d6p69eqSpDVr1qhPnz5KTEzUww8/rMzMTD3yyCOKj48//53yPzNmzNCwYcPUsWNHTZkyRfv379df//pXrVixQmvWrFFcXJwkaeDAgdq4caPuvvtuNWzYUMnJyVq6dKl27tzpf92rVy/Fx8fr/vvvV1xcnLZv3665c+cWWVuBcsUAKHNGjx5tzv7Pu3v37kaSeeGFF85Z/tixY+dM+81vfmOioqLMiRMn/NOGDBliGjRo4H+9bds2I8lUr17dHDp0yD/97bffNpLMf//7X/+0SZMmndMmSSYiIsJs3brVP23dunVGknnuuef80/r27WuioqLM7t27/dO2bNliwsPDz1lnToYMGWIqV64cdP7JkydNzZo1TevWrc3x48f90xcuXGgkmYkTJxpjjDl8+LCRZJ588smg65o3b56RZFatWpVnuwDkja4foBzxer0aNmzYOdMjIyP9z9PS0nTgwAF17dpVx44d03fffZfnem+66SZVrVrV/7pr166SpB9//DHP9/bs2VONGzf2v27Tpo1iYmL8783MzNSyZcvUv39/1a5d279ckyZNlJSUlOf682P16tVKTk7WqFGjAgb7XnvttWrevLneeecdSXY/RUREaPny5Tp8+HCO6/JVXhYuXKhTp04VSfuA8oygApQjderUUURExDnTN27cqAEDBig2NlYxMTGKj4/3D8RNSUnJc73169cPeO0LLcEO5rm91/d+33uTk5N1/PhxNWnS5JzlcppWGDt27JAkNWvW7Jx5zZs398/3er2aOnWqFi9erFq1aqlbt2564okntG/fPv/y3bt318CBA/Xwww+rRo0a6tevn6ZPn66MjIwiaStQ3hBUgHLkzMqJz5EjR9S9e3etW7dOjzzyiP773/9q6dKlmjp1qiQpKysrz/VWqFAhx+kmH1c/OJ/3ujB27Fh9//33mjJliipVqqQJEyaoRYsWWrNmjSQ7QHjOnDlauXKl7rrrLu3evVu333672rdvz+nRQCEQVIBybvny5Tp48KBmzJihMWPG6LrrrlPPnj0DunJcqlmzpipVqqStW7eeMy+naYXRoEEDSdLmzZvPmbd582b/fJ/GjRvr3nvv1ZIlS7RhwwadPHlSf/7znwOWueyyy/TYY49p9erVmjVrljZu3KjZs2cXSXuB8oSgApRzvorGmRWMkydP6u9//7urJgWoUKGCevbsqfnz52vPnj3+6Vu3btXixYuL5DM6dOigmjVr6oUXXgjoolm8eLE2bdqka6+9VpK97syJEycC3tu4cWNFR0f733f48OFzqkEXX3yxJNH9AxQCpycD5VyXLl1UtWpVDRkyRPfcc488Ho9mzpwZUl0vkydP1pIlS3T55Zdr5MiRyszM1PPPP6/WrVtr7dq1+VrHqVOn9Mc//vGc6dWqVdOoUaM0depUDRs2TN27d9egQYP8pyc3bNhQv/3tbyVJ33//vXr06KEbb7xRLVu2VHh4uObNm6f9+/fr5ptvliT985//1N///ncNGDBAjRs3Vlpaml5++WXFxMTommuuKbJ9ApQXBBWgnKtevboWLlyoe++9Vw899JCqVq2qW265RT169FDv3r1dN0+S1L59ey1evFjjxo3ThAkTVK9ePT3yyCPatGlTvs5KkmyVaMKECedMb9y4sUaNGqWhQ4cqKipKjz/+uO677z5VrlxZAwYM0NSpU/1n8tSrV0+DBg3S+++/r5kzZyo8PFzNmzfXm2++qYEDB0qyg2m//PJLzZ49W/v371dsbKwuvfRSzZo1S40aNSqyfQKUF9zrB0Cp1b9/f23cuFFbtmxx3RQAxYQxKgBKhePHjwe83rJlixYtWqQrr7zSTYMAlAgqKgBKhcTERA0dOlQXXHCBduzYoWnTpikjI0Nr1qxR06ZNXTcPQDFhjAqAUqFPnz564403tG/fPnm9XnXu3Fl/+tOfCClAGUdFBQAAhCzGqAAAgJBFUAEAACHL6RiVyZMn6+GHHw6Y1qxZs3xfFyErK0t79uxRdHS0PB5PcTQRAAAUMWOM0tLSVLt2bYWF5V4zcT6YtlWrVlq2bJn/dXh4/pu0Z88e1atXrziaBQAAitmuXbtUt27dXJdxHlTCw8OVkJBQqPdGR0dLshsaExNTlM0CAADFJDU1VfXq1fMfx3PjPKhs2bJFtWvXVqVKldS5c2dNmTJF9evXz3HZjIyMgJt6paWlSZJiYmIIKgAAlDL5GbbhdDBtp06dNGPGDL377ruaNm2atm3bpq5du/oDyNmmTJmi2NhY/4NuHwAAyraQuo7KkSNH1KBBAz399NO64447zpl/dkXFVzpKSUmhogIAQCmRmpqq2NjYfB2/nXf9nCkuLk4XXnihtm7dmuN8r9crr9dbwq0CAACuhFRQSU9P1w8//KBbb73VdVMAALnIysrSyZMnXTcDIapixYqqUKFCkazLaVAZN26c+vbtqwYNGmjPnj2aNGmSKlSooEGDBrlsFgAgFydPntS2bduUlZXluikIYXFxcUpISDjv65w5DSo//fSTBg0apIMHDyo+Pl5XXHGFPv/8c8XHx7tsFgAgCGOM9u7dqwoVKqhevXp5XqwL5Y8xRseOHVNycrIke+fz8+E0qMyePdvlxwMACuj06dM6duyYateuraioKNfNQYiKjIyUJCUnJ6tmzZrn1Q1EFAYA5FtmZqYkKSIiwnFLEOp8QfbUqVPntR6CCgCgwLi/GvJSVL8RggoAAAhZBBUAAAqhYcOGeuaZZ/K9/PLly+XxeHTkyJFia1NZRFABAJRpHo8n18fkyZMLtd5Vq1ZpxIgR+V6+S5cu2rt3r2JjYwv1eflV1gJRSF3wLVQcPSodOCB5vVIhb+wMAAgRe/fu9T//97//rYkTJ2rz5s3+aVWqVPE/N8YoMzNT4eF5Hx4LeimNiIgIJXBQKTAqKjlYsEBq2FC65RbXLQEAnK+EhAT/IzY2Vh6Px//6u+++U3R0tBYvXqz27dvL6/Xq008/1Q8//KB+/fqpVq1aqlKlijp27Khly5YFrPfsrh+Px6NXXnlFAwYMUFRUlJo2baoFCxb4559d6ZgxY4bi4uL03nvvqUWLFqpSpYr69OkTEKxOnz6te+65R3Fxcapevbruu+8+DRkyRP379y/0/jh8+LBuu+02Va1aVVFRUUpKStKWLVv883fs2KG+ffuqatWqqly5slq1aqVFixb53zt48GDFx8crMjJSTZs21fTp0wvdlvwgqOTAd/2i/52FBwAIwhhbhXbxKMpb6t5///16/PHHtWnTJrVp00bp6em65ppr9P7772vNmjXq06eP+vbtq507d+a6nocfflg33nijvvnmG11zzTUaPHiwDh06FHT5Y8eO6amnntLMmTP18ccfa+fOnRo3bpx//tSpUzVr1ixNnz5dK1asUGpqqubPn39e2zp06FCtXr1aCxYs0MqVK2WM0TXXXOM/jXj06NHKyMjQxx9/rPXr12vq1Kn+qtOECRP07bffavHixdq0aZOmTZumGjVqnFd78mRKsZSUFCPJpKSkFOl633rLGMmYbt2KdLUAUOodP37cfPvtt+b48ePGGGPS0+3/L1080tML3v7p06eb2NhY/+sPP/zQSDLz58/P872tWrUyzz33nP91gwYNzF/+8hf/a0nmoYce8r9OT083kszixYsDPuvw4cP+tkgyW7du9b/nb3/7m6lVq5b/da1atcyTTz7pf3369GlTv359069fv6DtPPtzzvT9998bSWbFihX+aQcOHDCRkZHmzTffNMYYc9FFF5nJkyfnuO6+ffuaYcOGBf3sM539WzlTQY7fVFRy4KuocBsLACgfOnToEPA6PT1d48aNU4sWLRQXF6cqVapo06ZNeVZU2rRp439euXJlxcTE+C8ln5OoqCg1btzY/zoxMdG/fEpKivbv369LL73UP79ChQpq3759gbbtTJs2bVJ4eLg6derkn1a9enU1a9ZMmzZtkiTdc889+uMf/6jLL79ckyZN0jfffONfduTIkZo9e7Yuvvhi/eEPf9Bnn31W6LbkF0ElB3T9AED+REVJ6eluHkV5Bf/KlSsHvB43bpzmzZunP/3pT/rkk0+0du1aXXTRRXneMbpixYoBrz0eT643b8xpeVOUfVqF8Otf/1o//vijbr31Vq1fv14dOnTQc889J0lKSkrSjh079Nvf/lZ79uxRjx49ArqqigNBJQe+WxJQUQGA3Hk8UuXKbh7FeXHcFStWaOjQoRowYIAuuugiJSQkaPv27cX3gTmIjY1VrVq1tGrVKv+0zMxMff3114VeZ4sWLXT69Gl98cUX/mkHDx7U5s2b1bJlS/+0evXq6c4779TcuXN177336uWXX/bPi4+P15AhQ/T666/rmWee0UsvvVTo9uQHpyfngK4fACjfmjZtqrlz56pv377yeDyaMGFCrpWR4nL33XdrypQpatKkiZo3b67nnntOhw8fztfl6devX6/o6Gj/a4/Ho7Zt26pfv34aPny4XnzxRUVHR+v+++9XnTp11K9fP0nS2LFjlZSUpAsvvFCHDx/Whx9+qBYtWkiSJk6cqPbt26tVq1bKyMjQwoUL/fOKC0ElB3T9AED59vTTT+v2229Xly5dVKNGDd13331KTU0t8Xbcd9992rdvn2677TZVqFBBI0aMUO/evfN1N+Ju3boFvK5QoYJOnz6t6dOna8yYMbruuut08uRJdevWTYsWLfJ3Q2VmZmr06NH66aefFBMToz59+ugvf/mLJHstmPHjx2v79u2KjIxU165dNXv27KLf8DN4jOvOsPOQmpqq2NhYpaSkKCYmpsjWu2SJ1Lu3dPHF0po1RbZaACj1Tpw4oW3btqlRo0aqVKmS6+aUO1lZWWrRooVuvPFGPfroo66bk6vcfisFOX5TUckBFRUAQCjYsWOHlixZou7duysjI0PPP/+8tm3bpl/96leum1ZiGEybAwbTAgBCQVhYmGbMmKGOHTvq8ssv1/r167Vs2bJiHxcSSqio5IDBtACAUFCvXj2tWLHCdTOcoqKSA7p+AAAIDQSVHND1AwC5K8XnYaCEFNVvhKCSA7p+ACBnvtNi87pCK3Ds2DFJ5159t6AYo5IDun4AIGfh4eGKiorSzz//rIoVKyosjH/vIpAxRseOHVNycrLi4uLydc2X3BBUckDXDwDkzOPxKDExUdu2bdOOHTtcNwchLC4uTgkJCee9HoJKDqioAEBwERERatq0Kd0/CKpixYrnXUnxIajkgIoKAOQuLCyMK9OiRNC5mAMG0wIAEBoIKjmg6wcAgNBAUMkBXT8AAIQGgkoO6PoBACA0EFRyQNcPAAChgaCSA7p+AAAIDQSVHFBRAQAgNBBUckBFBQCA0EBQyQGDaQEACA0ElRzQ9QMAQGggqOTgzNsTGOOuHQAAlHcElRyceddyun8AAHCHoJKDM4MK3T8AALhDUMnBmV0/VFQAAHCHoJIDKioAAIQGgkoOqKgAABAaCCo5YDAtAAChgaCSA7p+AAAIDQSVHND1AwBAaCCo5MDjyX5OUAEAwB2CShBcRh8AAPcIKkFwB2UAANwjqARBRQUAAPcIKkFQUQEAwD2CShC+igpBBQAAdwgqQdD1AwCAewSVIOj6AQDAPYJKEHT9AADgHkElCLp+AABwj6ASBF0/AAC4R1AJgooKAADuEVSCoKICAIB7BJUgGEwLAIB7BJUg6PoBAMA9gkoQdP0AAOAeQSUIun4AAHCPoBIEXT8AALhHUAmCrh8AANwjqARBRQUAAPcIKkFQUQEAwD2CShAMpgUAwL2QCSqPP/64PB6Pxo4d67opkuj6AQAgFIREUFm1apVefPFFtWnTxnVT/Oj6AQDAPedBJT09XYMHD9bLL7+sqlWrum6OH10/AAC45zyojB49Wtdee6169uzpuikB6PoBAMC9cJcfPnv2bH399ddatWpVvpbPyMhQRkaG/3VqampxNY2uHwAAQoCzisquXbs0ZswYzZo1S5UqVcrXe6ZMmaLY2Fj/o169esXWPioqAAC45yyofPXVV0pOTtYll1yi8PBwhYeH66OPPtKzzz6r8PBwZeaQEMaPH6+UlBT/Y9euXcXWPioqAAC456zrp0ePHlq/fn3AtGHDhql58+a67777VMGXFM7g9Xrl9XpLpH0MpgUAwD1nQSU6OlqtW7cOmFa5cmVVr179nOku0PUDAIB7zs/6CVV0/QAA4J7Ts37Otnz5ctdN8KPrBwAA96ioBEHXDwAA7hFUgqDrBwAA9wgqQVBRAQDAPYJKEFRUAABwj6ASBINpAQBwj6ASBF0/AAC4R1AJgq4fAADcI6gEQdcPAADuEVSCoOsHAAD3CCpB0PUDAIB7BJUgqKgAAOAeQSUIKioAALhHUAmCwbQAALhHUAmCrh8AANwjqARB1w8AAO4RVIKg6wcAAPcIKkHQ9QMAgHsElSDo+gEAwD2CShBUVAAAcI+gEgQVFQAA3COoBMFgWgAA3COoBEHXDwAA7hFUgqDrBwAA9wgqQdD1AwCAewSVIOj6AQDAPYJKEHT9AADgHkElCCoqAAC4R1AJgooKAADuEVSCYDAtAADuEVSCoOsHAAD3CCpB0PUDAIB7BJUg6PoBAMA9gkoQdP0AAOAeQSUIun4AAHCPoBIEFRUAANwjqARBRQUAAPcIKkEwmBYAAPcIKkHQ9QMAgHsElSDo+gEAwD2CShB0/QAA4B5BJQi6fgAAcI+gEgRdPwAAuEdQCYKKCgAA7hFUgqCiAgCAewSVIBhMCwCAewSVIOj6AQDAPYJKEHT9AADgHkElCLp+AABwj6ASBF0/AAC4R1AJgq4fAADcI6gEQUUFAAD3CCpBUFEBAMA9gkoQDKYFAMA9gkoQdP0AAOAeQSUIun4AAHCPoBIEXT8AALhHUAmCrh8AANwjqARB1w8AAO4RVIKgogIAgHsElSCoqAAA4B5BJQgG0wIA4B5BJQi6fgAAcI+gEgRdPwAAuEdQCYKuHwAA3COoBEHXDwAA7hFUgqDrBwAA9wgqQYSdsWcIKwAAuOE0qEybNk1t2rRRTEyMYmJi1LlzZy1evNhlk/x8FRWJoAIAgCtOg0rdunX1+OOP66uvvtLq1at11VVXqV+/ftq4caPLZkmiogIAQCjwGGOM60acqVq1anryySd1xx135LlsamqqYmNjlZKSopiYmCJtR3q6FB1tnx89KkVFFenqAQAotwpy/A4voTblKTMzU2+99ZaOHj2qzp0757hMRkaGMjIy/K9TU1OLrT0VK2Y/P3Wq2D4GAADkwvlg2vXr16tKlSryer268847NW/ePLVs2TLHZadMmaLY2Fj/o169esXWLoIKAADuOe/6OXnypHbu3KmUlBTNmTNHr7zyij766KMcw0pOFZV69eoVS9ePZAfUZmVJe/dKCQlFvnoAAMqlgnT9OA8qZ+vZs6caN26sF198Mc9li3OMiiR5vdLJk9LOnVIxFm8AAChXCnL8dt71c7asrKyAqolLvu4fun4AAHDD6WDa8ePHKykpSfXr11daWpr+9a9/afny5XrvvfdcNsuPoAIAgFtOg0pycrJuu+027d27V7GxsWrTpo3ee+89XX311S6b5UdQAQDALadB5dVXX3X58XkiqAAA4FbIjVEJJeH/i3GnT7ttBwAA5RVBJRdUVAAAcIugkguCCgAAbhFUckFQAQDALYJKLggqAAC4RVDJBYNpAQBwi6CSCyoqAAC4RVDJBUEFAAC3ChVUdu3apZ9++sn/+ssvv9TYsWP10ksvFVnDQgFBBQAAtwoVVH71q1/pww8/lCTt27dPV199tb788ks9+OCDeuSRR4q0gS4RVAAAcKtQQWXDhg269NJLJUlvvvmmWrdurc8++0yzZs3SjBkzirJ9TjGYFgAAtwoVVE6dOiWv1ytJWrZsmX75y19Kkpo3b669e/cWXesco6ICAIBbhQoqrVq10gsvvKBPPvlES5cuVZ8+fSRJe/bsUfXq1Yu0gS4RVAAAcKtQQWXq1Kl68cUXdeWVV2rQoEFq27atJGnBggX+LqGygKACAIBb4YV505VXXqkDBw4oNTVVVatW9U8fMWKEoqKiiqxxrhFUAABwq1AVlePHjysjI8MfUnbs2KFnnnlGmzdvVs2aNYu0gS4RVAAAcKtQQaVfv3567bXXJElHjhxRp06d9Oc//1n9+/fXtGnTirSBLnHWDwAAbhUqqHz99dfq2rWrJGnOnDmqVauWduzYoddee03PPvtskTbQJSoqAAC4VaigcuzYMUVHR0uSlixZouuvv15hYWG67LLLtGPHjiJtoEsEFQAA3CpUUGnSpInmz5+vXbt26b333lOvXr0kScnJyYqJiSnSBrpEUAEAwK1CBZWJEydq3LhxatiwoS699FJ17txZkq2utGvXrkgb6BJBBQAAtwp1evINN9ygK664Qnv37vVfQ0WSevTooQEDBhRZ41xjMC0AAG4VKqhIUkJCghISEvx3Ua5bt26ZutibREUFAADXCtX1k5WVpUceeUSxsbFq0KCBGjRooLi4OD366KPKysoq6jY6Q1ABAMCtQlVUHnzwQb366qt6/PHHdfnll0uSPv30U02ePFknTpzQY489VqSNdIWgAgCAW4UKKv/85z/1yiuv+O+aLElt2rRRnTp1NGrUKIIKAAAoEoXq+jl06JCaN29+zvTmzZvr0KFD592oUMFgWgAA3CpUUGnbtq2ef/75c6Y///zzatOmzXk3KlRQUQEAwK1Cdf088cQTuvbaa7Vs2TL/NVRWrlypXbt2adGiRUXaQJcIKgAAuFWoikr37t31/fffa8CAATpy5IiOHDmi66+/Xhs3btTMmTOLuo3OEFQAAHDLY4wxRbWydevW6ZJLLlFmZmZRrTJXqampio2NVUpKSrFcun/+fGnAAKlzZ+mzz4p89QAAlEsFOX4XqqJSXlBRAQDALYJKLjjrBwAAtwgquaCiAgCAWwU66+f666/Pdf6RI0fOpy0hh6ACAIBbBQoqsbGxec6/7bbbzqtBoYSgAgCAWwUKKtOnTy+udoQkggoAAG4xRiUXDKYFAMAtgkouqKgAAOAWQSUXBBUAANwiqOSCoAIAgFsElVwQVAAAcIugkoszB9MW3R2RAABAfhFUcuGrqEhSCd1nEQAAnIGgkoszgwrdPwAAlDyCSi4IKgAAuEVQyQVBBQAAtwgquahQQfJ47HOCCgAAJY+gkgcuow8AgDsElTxwLRUAANwhqOSBoAIAgDsElTwQVAAAcIegkgeCCgAA7hBU8sBgWgAA3CGo5IGKCgAA7hBU8kBQAQDAHYJKHggqAAC4Q1DJA0EFAAB3CCp5YDAtAADuEFTy4KuonDzpth0AAJRHBJU8eL32b0aG23YAAFAeEVTyEBlp/5444bYdAACURwSVPPiCyvHjbtsBAEB5RFDJQ6VK9i9BBQCAkkdQyQNdPwAAuENQyQNdPwAAuOM0qEyZMkUdO3ZUdHS0atasqf79+2vz5s0um3QOX9cPFRUAAEqe06Dy0UcfafTo0fr888+1dOlSnTp1Sr169dLRo0ddNisAFRUAANwJd/nh7777bsDrGTNmqGbNmvrqq6/UrVs3R60KRFABAMCdkBqjkpKSIkmqVq2a45Zko+sHAAB3nFZUzpSVlaWxY8fq8ssvV+vWrXNcJiMjQxlnXCI2NTW12NtFRQUAAHdCpqIyevRobdiwQbNnzw66zJQpUxQbG+t/1KtXr9jbRVABAMCdkAgqd911lxYuXKgPP/xQdevWDbrc+PHjlZKS4n/s2rWr2NtG1w8AAO447foxxujuu+/WvHnztHz5cjVq1CjX5b1er7y+uwSWECoqAAC44zSojB49Wv/617/09ttvKzo6Wvv27ZMkxcbGKtKXEBwjqAAA4I7Trp9p06YpJSVFV155pRITE/2Pf//73y6bFYCuHwAA3HHe9RPqqKgAAOBOSAymDWUEFQAA3CGo5IGuHwAA3CGo5OHMikop6KkCAKBMIajkwRdUsrKk06fdtgUAgPKGoJIHX9ePxDgVAABKGkElD16v5PHY5wQVAABKFkElDx4PA2oBAHCFoJIPvqBCRQUAgJJFUMkHrqUCAIAbBJV88AUVun4AAChZBJV8oOsHAAA3CCr5QNcPAABuEFTyga4fAADcIKjkA10/AAC4QVDJB7p+AABwg6CSD3T9AADgBkElH+j6AQDADYJKPlBRAQDADYJKPjBGBQAANwgq+UDXDwAAbhBU8oGuHwAA3CCo5EPlyvZvWprbdgAAUN4QVPIhNtb+TUlx2w4AAMobgko+EFQAAHCDoJIPcXH2L0EFAICSRVDJByoqAAC4QVDJB19QOXLEaTMAACh3CCr54AsqaWlSVpbbtgAAUJ4QVPLBF1SM4RRlAABKEkElHypVkrxe+5xxKgAAlByCSj4xTgUAgJJHUMknzvwBAKDkEVTyiaACAEDJI6jkE0EFAICSR1DJJ65OCwBAySOo5BODaQEAKHkElXyi6wcAgJJHUMknggoAACWPoJJPBBUAAEoeQSWffINpGaMCAEDJIajkExUVAABKHkElnwgqAACUPIJKPnF6MgAAJY+gkk/x8fbvzz9LxrhtCwAA5QVBJZ9q1bJ/T52SDh922xYAAMoLgko+eb3ZZ/7s3++0KQAAlBsElQJISLB/9+1z2w4AAMoLgkoB+IIKFRUAAEoGQaUAfONUqKgAAFAyCCoFQNcPAAAli6BSAHT9AABQsggqBUDXDwAAJYugUgBUVAAAKFkElQJgjAoAACWLoFIAvq6f5GQpK8ttWwAAKA8IKgUQHy95PFJmpnTwoOvWAABQ9hFUCqBiRalGDft87163bQEAoDwgqBRQgwb2748/um0HAADlAUGlgJo2tX+3bHHbDgAAygOCSgERVAAAKDkElQIiqAAAUHIIKgVEUAEAoOQQVArIF1R275aOHnXbFgAAyjqCSgFVq2YfkrR1q9u2AABQ1hFUCoHuHwAASgZBpRAuvND+3bzZbTsAACjrCCqF0Lat/bt6tdt2AABQ1jkNKh9//LH69u2r2rVry+PxaP78+S6bk2+dOtm/X3whGeO2LQAAlGVOg8rRo0fVtm1b/e1vf3PZjAK75BKpQgV7v5/du123BgCAsivc5YcnJSUpKSnJZRMKJSpKuugiae1aW1WpW9d1iwAAKJtK1RiVjIwMpaamBjxcObP7BwAAFI9SFVSmTJmi2NhY/6NevXrO2nLppfbvp586awIAAGVeqQoq48ePV0pKiv+xa9cuZ23p1UvyeKSVK6Xt2501AwCAMq1UBRWv16uYmJiAhyt160q/+IV9/q9/OWsGAABlWqkKKqHmllvs35kzOU0ZAIDi4DSopKena+3atVq7dq0kadu2bVq7dq127tzpsln5NnCgVKmS9N130tdfu24NAABlj9Ogsnr1arVr107t2rWTJP3ud79Tu3btNHHiRJfNyreYGKl/f/t85kynTQEAoEzyGFN6Oy1SU1MVGxurlJQUZ+NV3nlHuu46qWZNe/G3cKdXpgEAIPQV5PjNGJXz1KuXFB8vJSdLS5e6bg0AAGULQeU8Vawo3XyzfU73DwAARYugUgRuvdX+nT9fSktz2hQAAMoUgkoR6NBBatZMOn5cmjvXdWsAACg7CCpFwOPJrqo88ojk8BZEAACUKQSVInLXXVLDhtKPP0qjR7tuDQAAZQNBpYjExkqzZklhYdLrr9sHAAA4PwSVItSlizRpkn0+apS0Y4fb9gAAUNoRVIrYAw9Il19uz/4ZPZp7AAEAcD4IKkUsPFx6+WUpIsJetfaZZwgrAAAUFkGlGLRoIU2YYJ//7ne2Gygry22bAAAojQgqxeTBB6UnnrCDa194QbrjDunIEdetAgCgdCGoFBOPR/r976XXXrNhZcYM6YILpA8+cN0yAABKD4JKMRs8WFq4UGrZUjp8WPrlL6Vnn5V++sl1ywAACH0ElRKQlCR9/bW90/LRo9KYMVKrVtKqVa5bBgBAaAt33YDywuuV5s2TnnvOXhhu/Xrp6qul22+3l9yvWVO6/np73yBjbNcRAADlnceY0nvybGpqqmJjY5WSkqKYmBjXzcm3tDSpTx/ps88Cp3s8dtDtvHlSx47SxIm2u+iFF6Rx46Ru3ezZQ5s3S82bS+npUmamFBfnZDMAACiUghy/CSqOnDolzZ8vvfuuVKeOtG6dtGBB8OXr1ZM2bpRuucUud/fd0ttv2zs2r1snJSZKX31lB+62a1dimxHS0tNtGLz6aipUABBKCCqlkDH2zssvvGCvu7Jli62sHD0qRUfb7qFGjaRt28597623SjfcIPXvb9fTtav00kvShRfaUJOcLA0YIMXHS3Pm2JB000021EjSn/4kbdggvfKKFBVl52dl2e6qwmzHrl02WLkOB7/+tfTqq9K0adKdd7ptCwAgG0GljMjIkE6ftpWTwYPttLAwqU0bae1a+9wY+6hY0QYMn0qVpFq1su83VKmSPV360Uft6/bt7fiYX/zCDuw1xnY11aljbwMQGWlPqV6+3FZyKlSQLr7YDgT+4AOpbVsbZF5+WfriC2nEiOwzmsaMsV1VTz6Z83ZlZUmrV9uKx2WXZYejihWD7wvfuJ3Tp+3r8DxGV506ZYNZSop0xRXSsmV2G/J6HwCg+BFUyhhjpOeft908fftKVavasSzXXmtPc54yxS7Xo4etigwfbg/MklSliq3ErF+f87orV7ZVm/yKipKOHbNtiIqSdu+208PDpdmzpbvukvbts9NuvVWqVk2KiZG+/dYOFG7RQnr8cenzz+0yDRva7Zgxww40vuEGW905ccLeMyk5WRo2zIalm2+2laKwMLvcCy/YWxV8/LFd/qqrsoPIBx/Y/eHj9dr9sGSJ9Pe/SwMH2vZINszExmYv++OP0kcfSdddZ8NOQXz3nR0YXa1awd4HAOVJgY7fphRLSUkxkkxKSorrpji1a5cxK1YYc/y4fZ2ZaczXXxuzdKkxycl2eps2tvZSpYoxa9caM3WqMZGRvnqMMbVq2b8xMXbeRRfZ1926GfPss8b8+c/GxMdnL+97NGtmzLXXBk4LCzt3ubMflSsbU7Vq4LSKFY2Jjs5+HRNjTIUKwdcxcqQxCxcGLn/ZZcaMH29MixY5v6dmTfs3IsKYv/zFmFtuMcbjMWbKFGNOnDBm6NDsZVu3NubDD4154AFjrr7amDp1jBk0yJh77zWmZUtjFiyw+/vkSWP27zfm3/+262rRwq5rwwZjLrzQmD/+0S53+rT9m5VlzOrVdt0AUB4V5PhNRaWc+O47W4UZPlwaOtRO+/e/bZWiaVNblZg7154inZBgqzdbtkgXXZQ91mT3bjsAuHdvafJkO27m1VftWUfDhtnTriU75iUiwp6xlJEhHTggNW4s/fe/trvnyitt91JamtSpk11Py5a26iJJtWvbLp7kZPu6Uyd7xtObb0r33GPHv9x0k40TFSrYM5+8XvtZZ+vS5dyzq4Jp0MB2lYWF2a6vvCpNHo/0f/9n1//TT/a177+mP/7Rjg/68svsbfjmG9uFt2qVHQAtSffeK9WoYSs4J0/a5UaNsoOj//pXu08ef9wus2aN9OGHdtD00qW2u+3RR+24pQkT7Nlg/frZgdbvvGOrXt262apaTjIz7W+gZUvbreezfr1td9269to/rscaFbUjR+z3kJRkK4AASh4VFeTb118bs3v3+a8nK8uY55+3VY709Py/b/t2Ww1KSzNm0iRblcjMNObUKWNWrTLmp5/sus/217/aqohkTMeO9jPXrTPmtdeMGTzYVjK6d7cVpVGjjPnvf7MrOKNGGTNtmq0iRUUZM2BAdhUlNtaYJUtsm7xeW+UZNMguv2CBraTUqWPMTTflXLFp0iTvapKvohNsntdrTO3a2a9btDBm2DBbrTl72fvuM6Z588BpjRtnP2/QwJjhw41p186Ytm3tY+JEW9351a+ylxsxwu7z994L/Jzhw4156y1bpXr4YVs5ysiwn9url13H735nzBtvGHP4sDFffmnMXXcZ89ln9ns6fdqYTZuMef99u8yPP577XR47Zsznn9vvdN263H8vhw/n/HvIzQ8/2N+Uz9NPZ2/fihUFWxeAolGQ4zdBBaXW4cM2gBw5kr/l58835vbbjTl0yL7OyDAmJcUexObONeadd7LnGWNDUnJy4DqysrK7cFautAfwZ54xZs8e29V24oQNSr4D4TPPGNOnjzE9exrzz38aM2SInXbwoDEvvmhDxYAB9iD9yivGXHFF9nujo41JSAgMIddcY8z999vHmdPr1DHm5psDA9eZYefsR5cu9u+ZXWsDBmR377Vtm3MwiooyplOnnNcZF2dMpUrZrzt3NuaCCwKXiYy02//QQ8aMHWsD5dndhnffbb/TEyeMefllY2691Zjf/taYp54yJjzcbuezzxqTlGTM99/bkDtihO2e++QTG5g3bbLf0VNP2fX26GHM0aN2Wteu2Z/XtKntuktPN+aXvzTmuuvsOs524oQxM2bY0BPMrl3GjBljfwdny8qyv7877jBm69bAeRkZwdfp+2ygrCGoAI4dPpz7QS2YrCw7tuWjj4z5+Wdjdu60B/UxY4x5993AZf/xDztGKDbWhqbUVPs6MtK+PyXFVjeGDbNB7L337AHdd5AODzfmP/+xjzPHFbVta6sc//mPMR062GAybJitXPmWqVDBjmV66ikbLM4MHK1bBwagqChb9Tk7lJz5iImxwebMkJaYmHdlqlYtYxo2PHd67do2cJxZuQoPt2O0zl524UJbOTpzWrdudvxSz57G3HCDMb172+lVq9rxWsOG2fbOmWO/izlzsit2lSrZitPAgbZiOH68DUS+dV9xRXZV6D//seO1br3VVrTONmGCDYxDhtjfw9l27TJm48bs16tXZ1ezcvptLV0aWEHNrTq1cqWtcvK/VxQHggpQjvgqPMbYgHHgQPBljx+33UAREfYg6TNzpjG/+IUNRWdXkXyysmz3V//+ge/1tWH2bGP+9jf7fPdu+/yVV7K7AjMybCBo1cqYX//aVlRmzrTVJd8Bc+lSG7Z8B/WEBGMefDC7O6tt2+x5Zw7GTkzMrpSc3a122WXnDtzu0MGGP19VRbKB4P/+zwaavAKS7+HxZFenfKHszPlnfq7Xax+S7U7btCkwOF1xhR2ofcMNxkyfbsPmmQHy0kttOGzTxla/+vWzbQ0LsyH0gw9sQPR4jHnpJWMee8x2xW3YYMyiRcb84Q/Z+3T/fhtoqla1gergwezvMiUlsGI3dqyd/sILdlsfftiYvXsDv/tXXrEVwiVLbCXS93sBgiGoAAgqIyPwwBRqTp2yY6feecd26xhjw84XX9iD4sMPG/Pkk7bC8MQTtppx5Ig9MO7YYcynn2Z3W3XsaMy+fbbbZ+tWe3CX7EH1yy8DQ8X999vP2r7dVk2efNJ29/z613ac0Ny5dkzOVVfZM8Fuuy0wsNx/vz2At2xpg1Plytnzn3nGditOmpRdMfKdade6tR0LFSwM9eplq2a5Bab4eGOqVct/wOrZM7AL7MIL7T7csuXcSlb16nZfnRngqlSxQejYMTvm6ezwNnu2DVTdutluUd93OHas7bp7+20bfHbsMObvf7ftmTXLmG++sd/VZ58Zc889trssI8OG748+st//yZP2cbaUFPtdZ2UZs3x5dqUpI8Pu/1mzAsfPvf++HRuVmWk/98xxTOcrLS3n8VjFITMz+D8uQhlBBUC5tmSJrfqcffDJyrLdaVlZ9uHrjurYMeeDX258FaZXXzXm22+zp586ZefNmGEP2qNHZ887etSYK6/MPqi3aWMPrps32wP/yJG2glSnjp1fr54NPy++mP2e3/zGmOees6fXf/BB4Kn4l1ySvf7Wre3nh4cb06iRfX733edWfXyhLikp+7IEjRvb/edrx5ldYmd2AdaokV31ufJKO3hbCuz6q13bdkHmNNC8bt1zL2cQFRU4rX797MpU3brZ46Dq17cVqKFD7XPf8o0a2b+RkfYSAFddFRi6HnvMVs4k+znt29vn111nfwNbtxrz6KO2zYMG2e7NzEwbPtauDRzHdqYdO2xA27vX7qMKFc7thjtwwFa4UlPPfb+vKjl8uB17depU3r/JnTuzu0xHjsweh3W2776z3Yg5fa4rBBUAyIdFi2w3SnH96/fAgXO7QHxnTY0Ykfv4j5Mns4NWZqYx48bZ7puzw9cPP9hqzuuv2wPVyZO2MpKVZUOQb2yL7zpLCxdmV3CGD7fVkjMrOrVqZY9jmTAh8CC/Z4/9/Ndft6HBN+/WW+3yycmBlSTfdYt8j8REO9i8cWMbcnzT27e3XXBnVoWuvjq76nR2+Cnoo0qV7ACT26N58+zuuTMfHTpkf36VKvaMwshIG4JeeSU7+Pjm+5736ZP9Xf7619nTw8Jsde/zz+30lSttVenMClmNGjYoLlliA09Ghu3SGz7cVoZSU22QPbOdffoYs369rR75BmEfOZL9nYwfb0Px739vvzNfJSYrywaxFSts0PrTn2ygK86B3FxHBQAQ1HvvSa+/Lj31lL3Vxrvv2ntinTolTZpkr+cjSYcO2WsXNWhgr+9Tp072OjIy7K0wfvrJ3mfMd2+w8ePttX+6dLFXgp47V1q50t62Y+DA7Lu9r1tnb29RubK9dk98vD3crl5trzTdtau9ntPMmfaK0wMH2us9NW5sr/68dKl91K0rde5sr2SdliYtXGhv8fF//2fb1q6d9I9/SK1b2ytbv/OOVL++vYXI++/bx3XXSX/4g72WkWSv9dSli72e01NP2dt+SPaaROnpOe/TsDB7X7aUlMDpffrYa1L98IN9HRt77jJhYfYzKle215HasuXc9bduba/aLdnrL4WHS1On2qt7P/KIva7SiRPZ13Nq1Mi+Z+tWadMm+77Kle1f3zWi2rWTDh603+XJk/baV2e6+GK7v2rXznmbzweX0AcAOHHypDR9ur33V2Ji7svu22cPuDVqFH079u2zF7rs1i37Bqy52b3bXowxNtYGFd+FDufNs4+RI22AmzNH2r7drveDD6Q33rDb8MorNvTddpsNdMZkXwRTsheR/Pe/7W1QPv3UBpijRwMvVjlrlp0+ZYq93cjSpTboGRN4L7ewMNu+zEx7Ic3rrrMh7De/sfN9tzoJpk0be6HItLTA6dHRdlqjRvbvgQPSJZfYgOgLOUWFoAIAgEOHDtmbtsbF2apHu3a2EuSzbp30ySe2GjVmjJ3/0EPnricry96Etn9/W02Kj5feesvOGzzYVpx8VZRXX7WB6Je/tJWQtDR7r7VLL7U3iZ02zYaqdets2Bo71laeeve21aPu3W2FpVo1aedOG8wOHLCf/Z//5C/w5RdBBQCAMiQz0waFjAzbpXbBBfbO9/m9xcWWLdI119guot//Pn/vWbFCuvpqe0uQ3/62aG+nQVABAADnbe/evLvwCqMgx+8iLOQAAICypDhCSkERVAAAQMgiqAAAgJBFUAEAACGLoAIAAEIWQQUAAIQsggoAAAhZBBUAABCyCCoAACBkEVQAAEDIIqgAAICQRVABAAAhi6ACAABCFkEFAACErHDXDTgfxhhJ9nbRAACgdPAdt33H8dyU6qCSlpYmSapXr57jlgAAgIJKS0tTbGxsrst4TH7iTIjKysrSnj17FB0dLY/HUyTrTE1NVb169bRr1y7FxMQUyTpLG/YB+6C8b7/EPijv2y+xD6Ti2wfGGKWlpal27doKC8t9FEqprqiEhYWpbt26xbLumJiYcvvD9GEfsA/K+/ZL7IPyvv0S+0Aqnn2QVyXFh8G0AAAgZBFUAABAyCKonMXr9WrSpEnyer2um+IM+4B9UN63X2IflPftl9gHUmjsg1I9mBYAAJRtVFQAAEDIIqgAAICQRVABAAAhi6ACAABCFkHlLH/729/UsGFDVapUSZ06ddKXX37puknFYvLkyfJ4PAGP5s2b++efOHFCo0ePVvXq1VWlShUNHDhQ+/fvd9ji8/fxxx+rb9++ql27tjwej+bPnx8w3xijiRMnKjExUZGRkerZs6e2bNkSsMyhQ4c0ePBgxcTEKC4uTnfccYfS09NLcCvOT177YOjQoef8Lvr06ROwTGneB1OmTFHHjh0VHR2tmjVrqn///tq8eXPAMvn57e/cuVPXXnutoqKiVLNmTf3+97/X6dOnS3JTCiU/23/llVee8xu48847A5YprdsvSdOmTVObNm38FzDr3LmzFi9e7J9flr9/n7z2Qcj9Bgz8Zs+ebSIiIsw//vEPs3HjRjN8+HATFxdn9u/f77ppRW7SpEmmVatWZu/evf7Hzz//7J9/5513mnr16pn333/frF692lx22WWmS5cuDlt8/hYtWmQefPBBM3fuXCPJzJs3L2D+448/bmJjY838+fPNunXrzC9/+UvTqFEjc/z4cf8yffr0MW3btjWff/65+eSTT0yTJk3MoEGDSnhLCi+vfTBkyBDTp0+fgN/FoUOHApYpzfugd+/eZvr06WbDhg1m7dq15pprrjH169c36enp/mXy+u2fPn3atG7d2vTs2dOsWbPGLFq0yNSoUcOMHz/exSYVSH62v3v37mb48OEBv4GUlBT//NK8/cYYs2DBAvPOO++Y77//3mzevNk88MADpmLFimbDhg3GmLL9/fvktQ9C7TdAUDnDpZdeakaPHu1/nZmZaWrXrm2mTJnisFXFY9KkSaZt27Y5zjty5IipWLGieeutt/zTNm3aZCSZlStXllALi9fZB+msrCyTkJBgnnzySf+0I0eOGK/Xa9544w1jjDHffvutkWRWrVrlX2bx4sXG4/GY3bt3l1jbi0qwoNKvX7+g7ylr+yA5OdlIMh999JExJn+//UWLFpmwsDCzb98+/zLTpk0zMTExJiMjo2Q34Dydvf3G2IPUmDFjgr6nLG2/T9WqVc0rr7xS7r7/M/n2gTGh9xug6+d/Tp48qa+++ko9e/b0TwsLC1PPnj21cuVKhy0rPlu2bFHt2rV1wQUXaPDgwdq5c6ck6auvvtKpU6cC9kXz5s1Vv379Mrsvtm3bpn379gVsc2xsrDp16uTf5pUrVyouLk4dOnTwL9OzZ0+FhYXpiy++KPE2F5fly5erZs2aatasmUaOHKmDBw/655W1fZCSkiJJqlatmqT8/fZXrlypiy66SLVq1fIv07t3b6Wmpmrjxo0l2Przd/b2+8yaNUs1atRQ69atNX78eB07dsw/ryxtf2ZmpmbPnq2jR4+qc+fO5e77l87dBz6h9Bso1TclLEoHDhxQZmZmwI6XpFq1aum7775z1Kri06lTJ82YMUPNmjXT3r179fDDD6tr167asGGD9u3bp4iICMXFxQW8p1atWtq3b5+bBhcz33bl9P375u3bt081a9YMmB8eHq5q1aqVmf3Sp08fXX/99WrUqJF++OEHPfDAA0pKStLKlStVoUKFMrUPsrKyNHbsWF1++eVq3bq1JOXrt79v374cfye+eaVFTtsvSb/61a/UoEED1a5dW998843uu+8+bd68WXPnzpVUNrZ//fr16ty5s06cOKEqVapo3rx5atmypdauXVtuvv9g+0AKvd8AQaWcSkpK8j9v06aNOnXqpAYNGujNN99UZGSkw5bBpZtvvtn//KKLLlKbNm3UuHFjLV++XD169HDYsqI3evRobdiwQZ9++qnrpjgRbPtHjBjhf37RRRcpMTFRPXr00A8//KDGjRuXdDOLRbNmzbR27VqlpKRozpw5GjJkiD766CPXzSpRwfZBy5YtQ+43QNfP/9SoUUMVKlQ4Z3T3/v37lZCQ4KhVJScuLk4XXnihtm7dqoSEBJ08eVJHjhwJWKYs7wvfduX2/SckJCg5OTlg/unTp3Xo0KEyu18uuOAC1ahRQ1u3bpVUdvbBXXfdpYULF+rDDz9U3bp1/dPz89tPSEjI8Xfim1caBNv+nHTq1EmSAn4DpX37IyIi1KRJE7Vv315TpkxR27Zt9de//rXcfP9S8H2QE9e/AYLK/0RERKh9+/Z6//33/dOysrL0/vvvB/TblVXp6en64YcflJiYqPbt26tixYoB+2Lz5s3auXNnmd0XjRo1UkJCQsA2p6am6osvvvBvc+fOnXXkyBF99dVX/mU++OADZWVl+f9DLmt++uknHTx4UImJiZJK/z4wxuiuu+7SvHnz9MEHH6hRo0YB8/Pz2+/cubPWr18fENiWLl2qmJgYf+k8VOW1/TlZu3atJAX8Bkrr9geTlZWljIyMMv/958a3D3Li/DdQ5MNzS7HZs2cbr9drZsyYYb799lszYsQIExcXFzCyuay49957zfLly822bdvMihUrTM+ePU2NGjVMcnKyMcaeole/fn3zwQcfmNWrV5vOnTubzp07O271+UlLSzNr1qwxa9asMZLM008/bdasWWN27NhhjLGnJ8fFxZm3337bfPPNN6Zfv345np7crl0788UXX5hPP/3UNG3atNScmmtM7vsgLS3NjBs3zqxcudJs27bNLFu2zFxyySWmadOm5sSJE/51lOZ9MHLkSBMbG2uWL18ecOrlsWPH/Mvk9dv3nZrZq1cvs3btWvPuu++a+Pj4UnF6al7bv3XrVvPII4+Y1atXm23btpm3337bXHDBBaZbt27+dZTm7TfGmPvvv9989NFHZtu2beabb74x999/v/F4PGbJkiXGmLL9/fvktg9C8TdAUDnLc889Z+rXr28iIiLMpZdeaj7//HPXTSoWN910k0lMTDQRERGmTp065qabbjJbt271zz9+/LgZNWqUqVq1qomKijIDBgwwe/fuddji8/fhhx8aSec8hgwZYoyxpyhPmDDB1KpVy3i9XtOjRw+zefPmgHUcPHjQDBo0yFSpUsXExMSYYcOGmbS0NAdbUzi57YNjx46ZXr16mfj4eFOxYkXToEEDM3z48HOCemneBzltuyQzffp0/zL5+e1v377dJCUlmcjISFOjRg1z7733mlOnTpXw1hRcXtu/c+dO061bN1OtWjXj9XpNkyZNzO9///uAa2gYU3q33xhjbr/9dtOgQQMTERFh4uPjTY8ePfwhxZiy/f375LYPQvE34DHGmKKv0wAAAJw/xqgAAICQRVABAAAhi6ACAABCFkEFAACELIIKAAAIWQQVAAAQsggqAAAgZBFUAJR6Ho9H8+fPd90MAMWAoALgvAwdOlQej+ecR58+fVw3DUAZEO66AQBKvz59+mj69OkB07xer6PWAChLqKgAOG9er1cJCQkBj6pVq0qy3TLTpk1TUlKSIiMjdcEFF2jOnDkB71+/fr2uuuoqRUZGqnr16hoxYoTS09MDlvnHP/6hVq1ayev1KjExUXfddVfA/AMHDmjAgAGKiopS06ZNtWDBAv+8w4cPa/DgwYqPj1dkZKSaNm16TrACEJoIKgCK3YQJEzRw4ECtW7dOgwcP1s0336xNmzZJko4eParevXuratWqWrVqld566y0tW7YsIIhMmzZNo0eP1ogRI7R+/XotWLBATZo0CfiMhx9+WDfeeKO++eYbXXPNNRo8eLAOHTrk//xvv/1Wixcv1qZNmzRt2jTVqFGj5HYAgMIrllsdAig3hgwZYipUqGAqV64c8HjssceMMfaOvXfeeWfAezp16mRGjhxpjDHmpZdeMlWrVjXp6en++e+8844JCwvz37m5du3a5sEHHwzaBknmoYce8r9OT083kszixYuNMcb07dvXDBs2rGg2GECJYowKgPP2i1/8QtOmTQuYVq1aNf/zzp07B8zr3Lmz1q5dK0natGmT2rZtq8qVK/vnX3755crKytLmzZvl8Xi0Z88e9ejRI9c2tGnTxv+8cuXKiomJUXJysiRp5MiRGjhwoL7++mv16tVL/fv3V5cuXQq1rQBKFkEFwHmrXLnyOV0xRSUyMjJfy1WsWDHgtcfjUVZWliQpKSlJO3bs0KJFi7R06VL16NFDo0eP1lNPPVXk7QVQtBijAqDYff755+e8btGihSSpRYsWWrdunY4ePeqfv2LFCoWFhalZs2aKjo5Ww4YN9f77759XG+Lj4zVkyBC9/vrreuaZZ/TSSy+d1/oAlAwqKgDOW0ZGhvbt2xcwLTw83D9g9a233lKHDh10xRVXaNasWfryyy/16quvSpIGDx6sSZMmaciQIZo8ebJ+/vln3X333br11ltVq1YtSdLkyZN15513qmbNmkpKSlJaWppWrFihu+++O1/tmzhxotq3b69WrVopIyNDCxcu9AclAKGNoALgvL377rtKTEwMmNasWTN99913kuwZObNnz9aoUaOUmJioN954Qy1btpQkRUVF6b333tOYMWPUsWNHRUVFaeDAgXr66af96xoyZIhOnDihv/zlLxo3bpxq1KihG264Id/ti4iI0Pjx47V9+3ZFRkaqa9eumj17dhFsOYDi5jHGGNeNAFB2eTwezZs3T/3793fdFAClEGNUAABAyCKoAACAkMUYFQDFit5lAOeDigoAAAhZBBUAABCyCCoAACBkEVQAAEDIIqgAAICQRVABAAAhi6ACAABCFkEFAACELIIKAAAIWf8P2HJWN2E5DGAAAAAASUVORK5CYII=\n" + }, + "metadata": {} + } + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "3SRXs6V6bNEX", + "outputId": "4ed1f452-e232-41a3-caa9-ea01e89369d0" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "313/313 [==============================] - 1s 3ms/step - loss: 1.5913 - accuracy: 0.7413\n" + ] + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "[1.5913232564926147, 0.7412999868392944]" + ] + }, + "metadata": {}, + "execution_count": 23 + } + ], + "source": [ + "model.evaluate(x_test, to_categorical(y_test, num_classes))" + ] + }, + { + "cell_type": "markdown", + "source": [ + "---" + ], + "metadata": { + "id": "XyPoVUwrRRj5" + } + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": { + "id": "vXA26MZUbOXO" + }, + "outputs": [], + "source": [ + "def cnn(input_shape, num_classes):\n", + " \"\"\"CNN Model from (McMahan et. al., 2017).\n", + "\n", + " Communication-efficient learning of deep networks from decentralized data\n", + " \"\"\"\n", + " input_shape = tuple(input_shape)\n", + "\n", + " weight_decay = 0.004\n", + " model = keras.Sequential(\n", + " [\n", + " keras.layers.Conv2D(\n", + " 64,\n", + " (5, 5),\n", + " padding=\"same\",\n", + " activation=\"relu\",\n", + " input_shape=input_shape,\n", + " ),\n", + " keras.layers.MaxPooling2D((3, 3), strides=(2, 2)),\n", + " keras.layers.BatchNormalization(),\n", + " keras.layers.Conv2D(\n", + " 64,\n", + " (5, 5),\n", + " padding=\"same\",\n", + " activation=\"relu\",\n", + " ),\n", + " keras.layers.BatchNormalization(),\n", + " keras.layers.MaxPooling2D((3, 3), strides=(2, 2)),\n", + " keras.layers.Flatten(),\n", + " keras.layers.Dense(\n", + " 384, activation=\"relu\", kernel_regularizer=l2(weight_decay)\n", + " ),\n", + " keras.layers.Dense(\n", + " 192, activation=\"relu\", kernel_regularizer=l2(weight_decay)\n", + " ),\n", + " keras.layers.Dense(num_classes, activation=\"softmax\"),\n", + " ]\n", + " )\n", + " optimizer = SGD(learning_rate=0.1)\n", + " model.compile(\n", + " loss=\"categorical_crossentropy\", optimizer=optimizer, metrics=[\"accuracy\"]\n", + " )\n", + "\n", + " return model" + ] + }, + { + "cell_type": "code", + "source": [ + "model_cnn = cnn(input_shape, num_classes)" + ], + "metadata": { + "id": "t098yVNYRxPu" + }, + "execution_count": 25, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "history_cnn = model_cnn.fit(x_train, to_categorical(y_train, num_classes), epochs=350, batch_size=100)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "JinRA8quR2mr", + "outputId": "edc6a49c-3fa4-498d-fbb4-21cb439c9b38" + }, + "execution_count": 26, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Epoch 1/350\n", + "500/500 [==============================] - 4s 7ms/step - loss: 4.1634 - accuracy: 0.4622\n", + "Epoch 2/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 2.3282 - accuracy: 0.6234\n", + "Epoch 3/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 1.5241 - accuracy: 0.6978\n", + "Epoch 4/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 1.1409 - accuracy: 0.7442\n", + "Epoch 5/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.9513 - accuracy: 0.7783\n", + "Epoch 6/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.8526 - accuracy: 0.8004\n", + "Epoch 7/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.7955 - accuracy: 0.8228\n", + "Epoch 8/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.7653 - accuracy: 0.8402\n", + "Epoch 9/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.7479 - accuracy: 0.8540\n", + "Epoch 10/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.7359 - accuracy: 0.8678\n", + "Epoch 11/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.7267 - accuracy: 0.8774\n", + "Epoch 12/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.7274 - accuracy: 0.8839\n", + "Epoch 13/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.7191 - accuracy: 0.8918\n", + "Epoch 14/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.7182 - accuracy: 0.8971\n", + "Epoch 15/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.7166 - accuracy: 0.9014\n", + "Epoch 16/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.7239 - accuracy: 0.9033\n", + "Epoch 17/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.7214 - accuracy: 0.9069\n", + "Epoch 18/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.7103 - accuracy: 0.9122\n", + "Epoch 19/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.7023 - accuracy: 0.9168\n", + "Epoch 20/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.7128 - accuracy: 0.9147\n", + "Epoch 21/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.7064 - accuracy: 0.9197\n", + "Epoch 22/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.7090 - accuracy: 0.9177\n", + "Epoch 23/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.7103 - accuracy: 0.9190\n", + "Epoch 24/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6981 - accuracy: 0.9232\n", + "Epoch 25/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.7015 - accuracy: 0.9234\n", + "Epoch 26/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.7026 - accuracy: 0.9253\n", + "Epoch 27/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6889 - accuracy: 0.9264\n", + "Epoch 28/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6924 - accuracy: 0.9275\n", + "Epoch 29/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6818 - accuracy: 0.9303\n", + "Epoch 30/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6961 - accuracy: 0.9273\n", + "Epoch 31/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6967 - accuracy: 0.9277\n", + "Epoch 32/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6932 - accuracy: 0.9318\n", + "Epoch 33/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6812 - accuracy: 0.9331\n", + "Epoch 34/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6779 - accuracy: 0.9321\n", + "Epoch 35/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6898 - accuracy: 0.9312\n", + "Epoch 36/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6800 - accuracy: 0.9328\n", + "Epoch 37/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6785 - accuracy: 0.9340\n", + "Epoch 38/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6713 - accuracy: 0.9370\n", + "Epoch 39/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6832 - accuracy: 0.9345\n", + "Epoch 40/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6778 - accuracy: 0.9349\n", + "Epoch 41/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6610 - accuracy: 0.9378\n", + "Epoch 42/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6612 - accuracy: 0.9385\n", + "Epoch 43/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6545 - accuracy: 0.9393\n", + "Epoch 44/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6609 - accuracy: 0.9369\n", + "Epoch 45/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6648 - accuracy: 0.9382\n", + "Epoch 46/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6587 - accuracy: 0.9385\n", + "Epoch 47/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6492 - accuracy: 0.9420\n", + "Epoch 48/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6523 - accuracy: 0.9404\n", + "Epoch 49/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6648 - accuracy: 0.9378\n", + "Epoch 50/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6571 - accuracy: 0.9397\n", + "Epoch 51/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6493 - accuracy: 0.9413\n", + "Epoch 52/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6590 - accuracy: 0.9388\n", + "Epoch 53/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6542 - accuracy: 0.9412\n", + "Epoch 54/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6526 - accuracy: 0.9427\n", + "Epoch 55/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6311 - accuracy: 0.9462\n", + "Epoch 56/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6459 - accuracy: 0.9412\n", + "Epoch 57/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6436 - accuracy: 0.9438\n", + "Epoch 58/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6429 - accuracy: 0.9440\n", + "Epoch 59/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6459 - accuracy: 0.9421\n", + "Epoch 60/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6418 - accuracy: 0.9432\n", + "Epoch 61/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6357 - accuracy: 0.9444\n", + "Epoch 62/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6316 - accuracy: 0.9452\n", + "Epoch 63/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6348 - accuracy: 0.9451\n", + "Epoch 64/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6293 - accuracy: 0.9447\n", + "Epoch 65/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6339 - accuracy: 0.9453\n", + "Epoch 66/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6223 - accuracy: 0.9482\n", + "Epoch 67/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6169 - accuracy: 0.9483\n", + "Epoch 68/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6217 - accuracy: 0.9456\n", + "Epoch 69/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6262 - accuracy: 0.9456\n", + "Epoch 70/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6168 - accuracy: 0.9488\n", + "Epoch 71/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6166 - accuracy: 0.9465\n", + "Epoch 72/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6248 - accuracy: 0.9458\n", + "Epoch 73/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6089 - accuracy: 0.9510\n", + "Epoch 74/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6155 - accuracy: 0.9472\n", + "Epoch 75/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6207 - accuracy: 0.9480\n", + "Epoch 76/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6123 - accuracy: 0.9502\n", + "Epoch 77/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6173 - accuracy: 0.9474\n", + "Epoch 78/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6020 - accuracy: 0.9510\n", + "Epoch 79/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5970 - accuracy: 0.9512\n", + "Epoch 80/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6211 - accuracy: 0.9454\n", + "Epoch 81/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5945 - accuracy: 0.9522\n", + "Epoch 82/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6178 - accuracy: 0.9460\n", + "Epoch 83/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6082 - accuracy: 0.9504\n", + "Epoch 84/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5934 - accuracy: 0.9522\n", + "Epoch 85/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5979 - accuracy: 0.9512\n", + "Epoch 86/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5985 - accuracy: 0.9506\n", + "Epoch 87/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5924 - accuracy: 0.9520\n", + "Epoch 88/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5885 - accuracy: 0.9514\n", + "Epoch 89/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5934 - accuracy: 0.9515\n", + "Epoch 90/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.6033 - accuracy: 0.9507\n", + "Epoch 91/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5958 - accuracy: 0.9523\n", + "Epoch 92/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5970 - accuracy: 0.9505\n", + "Epoch 93/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5900 - accuracy: 0.9536\n", + "Epoch 94/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5916 - accuracy: 0.9512\n", + "Epoch 95/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5955 - accuracy: 0.9519\n", + "Epoch 96/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5943 - accuracy: 0.9520\n", + "Epoch 97/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5852 - accuracy: 0.9523\n", + "Epoch 98/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5784 - accuracy: 0.9533\n", + "Epoch 99/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5800 - accuracy: 0.9535\n", + "Epoch 100/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5691 - accuracy: 0.9552\n", + "Epoch 101/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5720 - accuracy: 0.9531\n", + "Epoch 102/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5766 - accuracy: 0.9541\n", + "Epoch 103/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5696 - accuracy: 0.9543\n", + "Epoch 104/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5753 - accuracy: 0.9538\n", + "Epoch 105/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5765 - accuracy: 0.9540\n", + "Epoch 106/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5590 - accuracy: 0.9576\n", + "Epoch 107/350\n", + "500/500 [==============================] - 4s 7ms/step - loss: 0.5675 - accuracy: 0.9537\n", + "Epoch 108/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5797 - accuracy: 0.9523\n", + "Epoch 109/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5745 - accuracy: 0.9549\n", + "Epoch 110/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5634 - accuracy: 0.9565\n", + "Epoch 111/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5626 - accuracy: 0.9556\n", + "Epoch 112/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5731 - accuracy: 0.9542\n", + "Epoch 113/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5737 - accuracy: 0.9539\n", + "Epoch 114/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5690 - accuracy: 0.9557\n", + "Epoch 115/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5670 - accuracy: 0.9558\n", + "Epoch 116/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5583 - accuracy: 0.9550\n", + "Epoch 117/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5479 - accuracy: 0.9570\n", + "Epoch 118/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5639 - accuracy: 0.9541\n", + "Epoch 119/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5530 - accuracy: 0.9580\n", + "Epoch 120/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5579 - accuracy: 0.9562\n", + "Epoch 121/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5522 - accuracy: 0.9573\n", + "Epoch 122/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5641 - accuracy: 0.9542\n", + "Epoch 123/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5519 - accuracy: 0.9582\n", + "Epoch 124/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5387 - accuracy: 0.9588\n", + "Epoch 125/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5470 - accuracy: 0.9570\n", + "Epoch 126/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5583 - accuracy: 0.9545\n", + "Epoch 127/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5439 - accuracy: 0.9590\n", + "Epoch 128/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5541 - accuracy: 0.9557\n", + "Epoch 129/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5357 - accuracy: 0.9598\n", + "Epoch 130/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5512 - accuracy: 0.9564\n", + "Epoch 131/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5377 - accuracy: 0.9593\n", + "Epoch 132/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5414 - accuracy: 0.9568\n", + "Epoch 133/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5476 - accuracy: 0.9556\n", + "Epoch 134/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5348 - accuracy: 0.9583\n", + "Epoch 135/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5451 - accuracy: 0.9572\n", + "Epoch 136/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5413 - accuracy: 0.9579\n", + "Epoch 137/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5257 - accuracy: 0.9604\n", + "Epoch 138/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5314 - accuracy: 0.9585\n", + "Epoch 139/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5326 - accuracy: 0.9591\n", + "Epoch 140/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5399 - accuracy: 0.9575\n", + "Epoch 141/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5402 - accuracy: 0.9588\n", + "Epoch 142/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5342 - accuracy: 0.9576\n", + "Epoch 143/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5380 - accuracy: 0.9577\n", + "Epoch 144/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5313 - accuracy: 0.9587\n", + "Epoch 145/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5303 - accuracy: 0.9589\n", + "Epoch 146/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5219 - accuracy: 0.9595\n", + "Epoch 147/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5418 - accuracy: 0.9567\n", + "Epoch 148/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5300 - accuracy: 0.9600\n", + "Epoch 149/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5088 - accuracy: 0.9607\n", + "Epoch 150/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5319 - accuracy: 0.9577\n", + "Epoch 151/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5168 - accuracy: 0.9621\n", + "Epoch 152/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5161 - accuracy: 0.9606\n", + "Epoch 153/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5109 - accuracy: 0.9618\n", + "Epoch 154/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5190 - accuracy: 0.9593\n", + "Epoch 155/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5273 - accuracy: 0.9586\n", + "Epoch 156/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5093 - accuracy: 0.9630\n", + "Epoch 157/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5210 - accuracy: 0.9589\n", + "Epoch 158/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5046 - accuracy: 0.9636\n", + "Epoch 159/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5142 - accuracy: 0.9598\n", + "Epoch 160/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5294 - accuracy: 0.9588\n", + "Epoch 161/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5079 - accuracy: 0.9622\n", + "Epoch 162/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4972 - accuracy: 0.9635\n", + "Epoch 163/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5120 - accuracy: 0.9601\n", + "Epoch 164/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5091 - accuracy: 0.9624\n", + "Epoch 165/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5115 - accuracy: 0.9612\n", + "Epoch 166/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5071 - accuracy: 0.9614\n", + "Epoch 167/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5067 - accuracy: 0.9628\n", + "Epoch 168/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5002 - accuracy: 0.9623\n", + "Epoch 169/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5094 - accuracy: 0.9602\n", + "Epoch 170/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5032 - accuracy: 0.9618\n", + "Epoch 171/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5024 - accuracy: 0.9618\n", + "Epoch 172/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4839 - accuracy: 0.9649\n", + "Epoch 173/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4931 - accuracy: 0.9620\n", + "Epoch 174/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5061 - accuracy: 0.9614\n", + "Epoch 175/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5023 - accuracy: 0.9620\n", + "Epoch 176/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5021 - accuracy: 0.9625\n", + "Epoch 177/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4821 - accuracy: 0.9651\n", + "Epoch 178/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4813 - accuracy: 0.9626\n", + "Epoch 179/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4906 - accuracy: 0.9630\n", + "Epoch 180/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4973 - accuracy: 0.9611\n", + "Epoch 181/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4970 - accuracy: 0.9629\n", + "Epoch 182/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4841 - accuracy: 0.9644\n", + "Epoch 183/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4872 - accuracy: 0.9631\n", + "Epoch 184/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4845 - accuracy: 0.9647\n", + "Epoch 185/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4756 - accuracy: 0.9648\n", + "Epoch 186/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4821 - accuracy: 0.9626\n", + "Epoch 187/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4819 - accuracy: 0.9633\n", + "Epoch 188/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.5000 - accuracy: 0.9617\n", + "Epoch 189/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4783 - accuracy: 0.9652\n", + "Epoch 190/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4778 - accuracy: 0.9641\n", + "Epoch 191/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4815 - accuracy: 0.9623\n", + "Epoch 192/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4892 - accuracy: 0.9640\n", + "Epoch 193/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4850 - accuracy: 0.9637\n", + "Epoch 194/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4871 - accuracy: 0.9641\n", + "Epoch 195/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4739 - accuracy: 0.9651\n", + "Epoch 196/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4779 - accuracy: 0.9636\n", + "Epoch 197/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4658 - accuracy: 0.9663\n", + "Epoch 198/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4821 - accuracy: 0.9623\n", + "Epoch 199/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4826 - accuracy: 0.9635\n", + "Epoch 200/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4732 - accuracy: 0.9656\n", + "Epoch 201/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4790 - accuracy: 0.9648\n", + "Epoch 202/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4675 - accuracy: 0.9658\n", + "Epoch 203/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4743 - accuracy: 0.9633\n", + "Epoch 204/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4667 - accuracy: 0.9653\n", + "Epoch 205/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4760 - accuracy: 0.9624\n", + "Epoch 206/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4736 - accuracy: 0.9651\n", + "Epoch 207/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4744 - accuracy: 0.9636\n", + "Epoch 208/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4623 - accuracy: 0.9664\n", + "Epoch 209/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4591 - accuracy: 0.9670\n", + "Epoch 210/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4700 - accuracy: 0.9645\n", + "Epoch 211/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4690 - accuracy: 0.9653\n", + "Epoch 212/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4686 - accuracy: 0.9649\n", + "Epoch 213/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4546 - accuracy: 0.9667\n", + "Epoch 214/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4710 - accuracy: 0.9645\n", + "Epoch 215/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4677 - accuracy: 0.9653\n", + "Epoch 216/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4796 - accuracy: 0.9629\n", + "Epoch 217/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4601 - accuracy: 0.9673\n", + "Epoch 218/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4571 - accuracy: 0.9667\n", + "Epoch 219/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4652 - accuracy: 0.9648\n", + "Epoch 220/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4613 - accuracy: 0.9658\n", + "Epoch 221/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4510 - accuracy: 0.9679\n", + "Epoch 222/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4689 - accuracy: 0.9653\n", + "Epoch 223/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4490 - accuracy: 0.9677\n", + "Epoch 224/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4579 - accuracy: 0.9645\n", + "Epoch 225/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4465 - accuracy: 0.9682\n", + "Epoch 226/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4486 - accuracy: 0.9673\n", + "Epoch 227/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4628 - accuracy: 0.9638\n", + "Epoch 228/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4438 - accuracy: 0.9689\n", + "Epoch 229/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4528 - accuracy: 0.9650\n", + "Epoch 230/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4560 - accuracy: 0.9656\n", + "Epoch 231/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4532 - accuracy: 0.9670\n", + "Epoch 232/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4497 - accuracy: 0.9671\n", + "Epoch 233/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4474 - accuracy: 0.9675\n", + "Epoch 234/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4517 - accuracy: 0.9672\n", + "Epoch 235/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4531 - accuracy: 0.9660\n", + "Epoch 236/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4524 - accuracy: 0.9662\n", + "Epoch 237/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4423 - accuracy: 0.9669\n", + "Epoch 238/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4500 - accuracy: 0.9658\n", + "Epoch 239/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4564 - accuracy: 0.9653\n", + "Epoch 240/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4247 - accuracy: 0.9709\n", + "Epoch 241/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4395 - accuracy: 0.9670\n", + "Epoch 242/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4506 - accuracy: 0.9656\n", + "Epoch 243/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4332 - accuracy: 0.9697\n", + "Epoch 244/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4385 - accuracy: 0.9674\n", + "Epoch 245/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4414 - accuracy: 0.9672\n", + "Epoch 246/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4494 - accuracy: 0.9664\n", + "Epoch 247/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4429 - accuracy: 0.9677\n", + "Epoch 248/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4385 - accuracy: 0.9683\n", + "Epoch 249/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4283 - accuracy: 0.9697\n", + "Epoch 250/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4362 - accuracy: 0.9677\n", + "Epoch 251/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4360 - accuracy: 0.9678\n", + "Epoch 252/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4337 - accuracy: 0.9684\n", + "Epoch 253/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4485 - accuracy: 0.9664\n", + "Epoch 254/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4364 - accuracy: 0.9686\n", + "Epoch 255/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4394 - accuracy: 0.9681\n", + "Epoch 256/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4211 - accuracy: 0.9692\n", + "Epoch 257/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4226 - accuracy: 0.9694\n", + "Epoch 258/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4358 - accuracy: 0.9669\n", + "Epoch 259/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4260 - accuracy: 0.9696\n", + "Epoch 260/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4276 - accuracy: 0.9690\n", + "Epoch 261/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4286 - accuracy: 0.9683\n", + "Epoch 262/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4297 - accuracy: 0.9690\n", + "Epoch 263/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4210 - accuracy: 0.9696\n", + "Epoch 264/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4301 - accuracy: 0.9681\n", + "Epoch 265/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4283 - accuracy: 0.9687\n", + "Epoch 266/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4182 - accuracy: 0.9713\n", + "Epoch 267/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4202 - accuracy: 0.9681\n", + "Epoch 268/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4292 - accuracy: 0.9686\n", + "Epoch 269/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4213 - accuracy: 0.9699\n", + "Epoch 270/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4239 - accuracy: 0.9688\n", + "Epoch 271/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4213 - accuracy: 0.9686\n", + "Epoch 272/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4120 - accuracy: 0.9706\n", + "Epoch 273/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4196 - accuracy: 0.9697\n", + "Epoch 274/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4181 - accuracy: 0.9694\n", + "Epoch 275/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4176 - accuracy: 0.9692\n", + "Epoch 276/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4199 - accuracy: 0.9693\n", + "Epoch 277/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4298 - accuracy: 0.9686\n", + "Epoch 278/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4201 - accuracy: 0.9696\n", + "Epoch 279/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4046 - accuracy: 0.9715\n", + "Epoch 280/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4024 - accuracy: 0.9707\n", + "Epoch 281/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4200 - accuracy: 0.9685\n", + "Epoch 282/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4064 - accuracy: 0.9710\n", + "Epoch 283/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3937 - accuracy: 0.9725\n", + "Epoch 284/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4107 - accuracy: 0.9690\n", + "Epoch 285/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4096 - accuracy: 0.9709\n", + "Epoch 286/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4103 - accuracy: 0.9696\n", + "Epoch 287/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4256 - accuracy: 0.9673\n", + "Epoch 288/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4131 - accuracy: 0.9715\n", + "Epoch 289/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3974 - accuracy: 0.9726\n", + "Epoch 290/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4112 - accuracy: 0.9703\n", + "Epoch 291/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3966 - accuracy: 0.9720\n", + "Epoch 292/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4136 - accuracy: 0.9687\n", + "Epoch 293/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4141 - accuracy: 0.9709\n", + "Epoch 294/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4057 - accuracy: 0.9714\n", + "Epoch 295/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3946 - accuracy: 0.9728\n", + "Epoch 296/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4176 - accuracy: 0.9682\n", + "Epoch 297/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4128 - accuracy: 0.9701\n", + "Epoch 298/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4058 - accuracy: 0.9712\n", + "Epoch 299/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3820 - accuracy: 0.9740\n", + "Epoch 300/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4019 - accuracy: 0.9694\n", + "Epoch 301/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4020 - accuracy: 0.9713\n", + "Epoch 302/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4151 - accuracy: 0.9681\n", + "Epoch 303/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3961 - accuracy: 0.9724\n", + "Epoch 304/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3941 - accuracy: 0.9709\n", + "Epoch 305/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3992 - accuracy: 0.9710\n", + "Epoch 306/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3982 - accuracy: 0.9722\n", + "Epoch 307/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3842 - accuracy: 0.9727\n", + "Epoch 308/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4010 - accuracy: 0.9696\n", + "Epoch 309/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4085 - accuracy: 0.9690\n", + "Epoch 310/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3894 - accuracy: 0.9731\n", + "Epoch 311/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4014 - accuracy: 0.9696\n", + "Epoch 312/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3918 - accuracy: 0.9729\n", + "Epoch 313/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3906 - accuracy: 0.9708\n", + "Epoch 314/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3815 - accuracy: 0.9746\n", + "Epoch 315/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3944 - accuracy: 0.9701\n", + "Epoch 316/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.4006 - accuracy: 0.9704\n", + "Epoch 317/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3836 - accuracy: 0.9748\n", + "Epoch 318/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3836 - accuracy: 0.9722\n", + "Epoch 319/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3873 - accuracy: 0.9715\n", + "Epoch 320/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3858 - accuracy: 0.9728\n", + "Epoch 321/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3900 - accuracy: 0.9710\n", + "Epoch 322/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3927 - accuracy: 0.9719\n", + "Epoch 323/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3863 - accuracy: 0.9711\n", + "Epoch 324/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3857 - accuracy: 0.9726\n", + "Epoch 325/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3778 - accuracy: 0.9728\n", + "Epoch 326/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3951 - accuracy: 0.9698\n", + "Epoch 327/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3871 - accuracy: 0.9726\n", + "Epoch 328/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3910 - accuracy: 0.9707\n", + "Epoch 329/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3787 - accuracy: 0.9735\n", + "Epoch 330/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3874 - accuracy: 0.9707\n", + "Epoch 331/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3901 - accuracy: 0.9715\n", + "Epoch 332/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3710 - accuracy: 0.9741\n", + "Epoch 333/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3874 - accuracy: 0.9715\n", + "Epoch 334/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3874 - accuracy: 0.9722\n", + "Epoch 335/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3768 - accuracy: 0.9730\n", + "Epoch 336/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3739 - accuracy: 0.9738\n", + "Epoch 337/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3883 - accuracy: 0.9711\n", + "Epoch 338/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3715 - accuracy: 0.9732\n", + "Epoch 339/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3740 - accuracy: 0.9730\n", + "Epoch 340/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3902 - accuracy: 0.9715\n", + "Epoch 341/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3779 - accuracy: 0.9727\n", + "Epoch 342/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3883 - accuracy: 0.9708\n", + "Epoch 343/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3739 - accuracy: 0.9741\n", + "Epoch 344/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3823 - accuracy: 0.9714\n", + "Epoch 345/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3729 - accuracy: 0.9736\n", + "Epoch 346/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3730 - accuracy: 0.9731\n", + "Epoch 347/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3716 - accuracy: 0.9722\n", + "Epoch 348/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3841 - accuracy: 0.9722\n", + "Epoch 349/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3659 - accuracy: 0.9750\n", + "Epoch 350/350\n", + "500/500 [==============================] - 3s 7ms/step - loss: 0.3740 - accuracy: 0.9721\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "model_cnn.evaluate(x_test, to_categorical(y_test, num_classes))" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "eIzsv0QLR_tt", + "outputId": "0a7eb8e7-4d7f-40ef-e1b0-b3979cc65854" + }, + "execution_count": 27, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "313/313 [==============================] - 1s 2ms/step - loss: 1.4919 - accuracy: 0.7581\n" + ] + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "[1.491928219795227, 0.7580999732017517]" + ] + }, + "metadata": {}, + "execution_count": 27 + } + ] + }, + { + "cell_type": "code", + "source": [ + "loss = history_cnn.history['loss']\n", + "epochs = range(1, len(loss) + 1)\n", + "\n", + "plt.plot(epochs, loss, 'b', label='Training Loss')\n", + "plt.title('Training Loss')\n", + "plt.xlabel('Epochs')\n", + "plt.ylabel('Loss')\n", + "plt.legend()\n", + "plt.show()\n" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 472 + }, + "id": "rOXE49XhSGBy", + "outputId": "7bf4879e-632d-4762-977f-522402eee644" + }, + "execution_count": 28, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "
" + ], + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjcAAAHHCAYAAABDUnkqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAABOn0lEQVR4nO3dd3hUZeL28XsSkiGBJISSAoQisPQmNaCAgkJgWUB2VcQlYFsQXLDsKlZA3eBiF0VZlawFWeEFVASRIiC9CApIEUWCkgQRkhBKCMl5/3h+GRhSCTM5Kd/Pdc2VzKnPORmdm6cdh2VZlgAAAMoJH7sLAAAA4EmEGwAAUK4QbgAAQLlCuAEAAOUK4QYAAJQrhBsAAFCuEG4AAEC5QrgBAADlCuEGAACUK4QbAF43cuRINWjQoFj7Tpo0SQ6Hw7MFAlCuEW6ACszhcBTptWrVKruLaouRI0eqatWqdhcDwGVy8GwpoOL64IMP3N6/9957WrZsmd5//3235TfccIPCw8OLfZ7MzExlZ2fL6XRe9r7nz5/X+fPnVbly5WKfv7hGjhypefPmKT09vcTPDaD4KtldAAD2uf32293eb9y4UcuWLcu1/FKnT59WYGBgkc/j5+dXrPJJUqVKlVSpEv+rAlB0NEsBKFCvXr3UqlUrbdu2TT169FBgYKAeffRRSdInn3yiAQMGqHbt2nI6nWrUqJGefvppZWVluR3j0j43P//8sxwOh55//nnNnDlTjRo1ktPpVKdOnbRlyxa3ffPqc+NwODRu3DgtXLhQrVq1ktPpVMuWLfXFF1/kKv+qVavUsWNHVa5cWY0aNdJbb73l8X48c+fOVYcOHRQQEKCaNWvq9ttv16+//uq2TVJSkkaNGqW6devK6XQqMjJSgwYN0s8//+zaZuvWrerbt69q1qypgIAANWzYUHfccYfHyglUFPxzCEChfv/9d8XExOjWW2/V7bff7mqiio+PV9WqVfXAAw+oatWqWrlypZ588kmlpaVp2rRphR539uzZOnnypP72t7/J4XDo3//+t2666Sb99NNPhdb2rF27VvPnz9e9996roKAgvfrqqxo6dKgSEhJUo0YNSdL27dvVr18/RUZGavLkycrKytKUKVNUq1atK78p/yc+Pl6jRo1Sp06dFBcXp+TkZL3yyitat26dtm/frmrVqkmShg4dqt27d+u+++5TgwYNdPToUS1btkwJCQmu9zfeeKNq1aqlRx55RNWqVdPPP/+s+fPne6ysQIVhAcD/GTt2rHXp/xZ69uxpSbLefPPNXNufPn0617K//e1vVmBgoHX27FnXstjYWKt+/fqu9wcPHrQkWTVq1LCOHz/uWv7JJ59YkqzPPvvMteypp57KVSZJlr+/v3XgwAHXsm+//daSZL322muuZQMHDrQCAwOtX3/91bXshx9+sCpVqpTrmHmJjY21qlSpku/6c+fOWWFhYVarVq2sM2fOuJYvWrTIkmQ9+eSTlmVZ1okTJyxJ1rRp0/I91oIFCyxJ1pYtWwotF4CC0SwFoFBOp1OjRo3KtTwgIMD1+8mTJ3Xs2DFde+21On36tPbu3VvocW+55RaFhoa63l977bWSpJ9++qnQffv06aNGjRq53rdp00bBwcGufbOysrR8+XINHjxYtWvXdm3XuHFjxcTEFHr8oti6dauOHj2qe++9163D84ABA9SsWTN9/vnnksx98vf316pVq3TixIk8j5VTw7No0SJlZmZ6pHxARUW4AVCoOnXqyN/fP9fy3bt3a8iQIQoJCVFwcLBq1arl6oycmppa6HHr1avn9j4n6OQXAAraN2f/nH2PHj2qM2fOqHHjxrm2y2tZcRw6dEiS1LRp01zrmjVr5lrvdDr13HPPacmSJQoPD1ePHj3073//W0lJSa7te/bsqaFDh2ry5MmqWbOmBg0apFmzZikjI8MjZQUqEsINgEJdXEOTIyUlRT179tS3336rKVOm6LPPPtOyZcv03HPPSZKys7MLPa6vr2+ey60izFBxJfvaYcKECdq/f7/i4uJUuXJlPfHEE2revLm2b98uyXSSnjdvnjZs2KBx48bp119/1R133KEOHTowFB24TIQbAMWyatUq/f7774qPj9f48eP1xz/+UX369HFrZrJTWFiYKleurAMHDuRal9ey4qhfv74kad++fbnW7du3z7U+R6NGjfTggw/qyy+/1K5du3Tu3Dm98MILbtt07dpVzz77rLZu3aoPP/xQu3fv1pw5czxSXqCiINwAKJacmpOLa0rOnTunN954w64iufH19VWfPn20cOFCHTlyxLX8wIEDWrJkiUfO0bFjR4WFhenNN990az5asmSJ9uzZowEDBkgy8wKdPXvWbd9GjRopKCjItd+JEydy1Tq1a9dOkmiaAi4TQ8EBFEu3bt0UGhqq2NhY/f3vf5fD4dD7779fqpqFJk2apC+//FLdu3fXmDFjlJWVpenTp6tVq1basWNHkY6RmZmpZ555Jtfy6tWr695779Vzzz2nUaNGqWfPnho2bJhrKHiDBg10//33S5L279+v3r176+abb1aLFi1UqVIlLViwQMnJybr11lslSf/973/1xhtvaMiQIWrUqJFOnjyp//znPwoODlb//v09dk+AioBwA6BYatSooUWLFunBBx/U448/rtDQUN1+++3q3bu3+vbta3fxJEkdOnTQkiVL9NBDD+mJJ55QVFSUpkyZoj179hRpNJdkaqOeeOKJXMsbNWqke++9VyNHjlRgYKCmTp2qhx9+WFWqVNGQIUP03HPPuUZARUVFadiwYVqxYoXef/99VapUSc2aNdPHH3+soUOHSjIdijdv3qw5c+YoOTlZISEh6ty5sz788EM1bNjQY/cEqAh4thSACmfw4MHavXu3fvjhB7uLAsAL6HMDoFw7c+aM2/sffvhBixcvVq9evewpEACvo+YGQLkWGRmpkSNH6qqrrtKhQ4c0Y8YMZWRkaPv27WrSpIndxQPgBfS5AVCu9evXTx999JGSkpLkdDoVHR2tf/3rXwQboByj5gYAAJQr9LkBAADlCuEGAACUKxWuz012draOHDmioKAgORwOu4sDAACKwLIsnTx5UrVr15aPT8F1MxUu3Bw5ckRRUVF2FwMAABTD4cOHVbdu3QK3qXDhJigoSJK5OcHBwTaXBgAAFEVaWpqioqJc3+MFqXDhJqcpKjg4mHADAEAZU5QuJXQoBgAA5QrhBgAAlCuEGwAAUK5UuD43AAB7ZGdn69y5c3YXA6WYv79/ocO8i4JwAwDwunPnzungwYPKzs62uygoxXx8fNSwYUP5+/tf0XEINwAAr7IsS4mJifL19VVUVJRH/mWO8idnkt3ExETVq1fviibaJdwAALzq/PnzOn36tGrXrq3AwEC7i4NSrFatWjpy5IjOnz8vPz+/Yh+H+AwA8KqsrCxJuuKmBpR/OZ+RnM9McRFuAAAlguf5oTCe+owQbgAAQLlCuAEAoIQ0aNBAL7/8cpG3X7VqlRwOh1JSUrxWpvKIcAMAwCUcDkeBr0mTJhXruFu2bNE999xT5O27deumxMREhYSEFOt8RVXeQhSjpTwkI0NKTpZ8fKRCnsQOACjlEhMTXb//73//05NPPql9+/a5llWtWtX1u2VZysrKUqVKhX+l1qpV67LK4e/vr4iIiMvaB9TceMw330j160u9etldEgDAlYqIiHC9QkJC5HA4XO/37t2roKAgLVmyRB06dJDT6dTatWv1448/atCgQQoPD1fVqlXVqVMnLV++3O24lzZLORwOvf322xoyZIgCAwPVpEkTffrpp671l9aoxMfHq1q1alq6dKmaN2+uqlWrql+/fm5h7Pz58/r73/+uatWqqUaNGnr44YcVGxurwYMHF/t+nDhxQiNGjFBoaKgCAwMVExOjH374wbX+0KFDGjhwoEJDQ1WlShW1bNlSixcvdu07fPhw1apVSwEBAWrSpIlmzZpV7LIUBeHGQ3LmpGLyTQAomGVJp07Z87Isz13HI488oqlTp2rPnj1q06aN0tPT1b9/f61YsULbt29Xv379NHDgQCUkJBR4nMmTJ+vmm2/Wd999p/79+2v48OE6fvx4vtufPn1azz//vN5//32tWbNGCQkJeuihh1zrn3vuOX344YeaNWuW1q1bp7S0NC1cuPCKrnXkyJHaunWrPv30U23YsEGWZal///7KzMyUJI0dO1YZGRlas2aNdu7cqeeee85Vu/XEE0/o+++/15IlS7Rnzx7NmDFDNWvWvKLyFMqqYFJTUy1JVmpqqkePu3mzZUmWVb++Rw8LAGXemTNnrO+//946c+aMZVmWlZ5u/n9pxys9/fLLP2vWLCskJMT1/quvvrIkWQsXLix035YtW1qvvfaa6339+vWtl156yfVekvX444+73qenp1uSrCVLlrid68SJE66ySLIOHDjg2uf111+3wsPDXe/Dw8OtadOmud6fP3/eqlevnjVo0KB8y3npeS62f/9+S5K1bt0617Jjx45ZAQEB1scff2xZlmW1bt3amjRpUp7HHjhwoDVq1Kh8z32xSz8rF7uc729qbjyEmhsAqFg6duzo9j49PV0PPfSQmjdvrmrVqqlq1aras2dPoTU3bdq0cf1epUoVBQcH6+jRo/luHxgYqEaNGrneR0ZGurZPTU1VcnKyOnfu7Frv6+urDh06XNa1XWzPnj2qVKmSunTp4lpWo0YNNW3aVHv27JEk/f3vf9czzzyj7t2766mnntJ3333n2nbMmDGaM2eO2rVrp3/+859av359sctSVIQbDyHcAEDRBAZK6en2vDz59IcqVaq4vX/ooYe0YMEC/etf/9LXX3+tHTt2qHXr1oU+Cf3Sxww4HI4CHzCa1/aWJ9vbiuGuu+7STz/9pL/+9a/auXOnOnbsqNdee02SFBMTo0OHDun+++/XkSNH1Lt3b7dmNG8g3HgI4QYAisbhkKpUseflzUmS161bp5EjR2rIkCFq3bq1IiIi9PPPP3vvhHkICQlReHi4tmzZ4lqWlZWlb775ptjHbN68uc6fP69Nmza5lv3+++/at2+fWrRo4VoWFRWl0aNHa/78+XrwwQf1n//8x7WuVq1aio2N1QcffKCXX35ZM2fOLHZ5ioKh4B5CuAGAiq1JkyaaP3++Bg4cKIfDoSeeeKLAGhhvue+++xQXF6fGjRurWbNmeu2113TixIkiPdpg586dCgoKcr13OBxq27atBg0apLvvvltvvfWWgoKC9Mgjj6hOnToaNGiQJGnChAmKiYnRH/7wB504cUJfffWVmjdvLkl68skn1aFDB7Vs2VIZGRlatGiRa523EG48hHADABXbiy++qDvuuEPdunVTzZo19fDDDystLa3Ey/Hwww8rKSlJI0aMkK+vr+655x717dtXvr6+he7bo0cPt/e+vr46f/68Zs2apfHjx+uPf/yjzp07px49emjx4sWuJrKsrCyNHTtWv/zyi4KDg9WvXz+99NJLksxcPRMnTtTPP/+sgIAAXXvttZozZ47nL/wiDsvuhroSlpaWppCQEKWmpio4ONhjx92zR2rRQqpRQzp2zGOHBYAy7+zZszp48KAaNmyoypUr212cCic7O1vNmzfXzTffrKefftru4hSooM/K5Xx/U3PjIdTcAABKg0OHDunLL79Uz549lZGRoenTp+vgwYO67bbb7C5aiaFDsYcQbgAApYGPj4/i4+PVqVMnde/eXTt37tTy5cu93s+lNCk14Wbq1KlyOByaMGFCgdvNnTtXzZo1U+XKldW6dWvX9M52I9wAAEqDqKgorVu3TqmpqUpLS9P69etz9aUp70pFuNmyZYveeustt4mM8rJ+/XoNGzZMd955p7Zv367Bgwdr8ODB2rVrVwmVNH+EGwAASgfbw016erqGDx+u//znPwoNDS1w21deeUX9+vXTP/7xDzVv3lxPP/20rr76ak2fPr2ESps/wg0AFKyCjV9BMXjqM2J7uBk7dqwGDBigPn36FLrthg0bcm3Xt29fbdiwId99MjIylJaW5vbyBsINAOQtZwhyYTP1AjmfkaIMWy+IraOl5syZo2+++cZtJsWCJCUlKTw83G1ZeHi4kpKS8t0nLi5OkydPvqJyFgXhBgDyVqlSJQUGBuq3336Tn5+ffHxs/3c1SqHs7Gz99ttvCgwMVKVKVxZPbAs3hw8f1vjx47Vs2TKvznswceJEPfDAA673aWlpioqK8vh5CDcAkDeHw6HIyEgdPHhQhw4dsrs4KMV8fHxUr169Is2mXBDbws22bdt09OhRXX311a5lWVlZWrNmjaZPn66MjIxc1VIRERFKTk52W5acnKyIiIh8z+N0OuV0Oj1b+DwQbgAgf/7+/mrSpAlNUyiQv7+/R2r2bAs3vXv31s6dO92WjRo1Ss2aNdPDDz+cZ3tbdHS0VqxY4TZcfNmyZYqOjvZ2cQuV87ewLPPy5sPZAKAs8vHxYYZilAjbwk1QUJBatWrltqxKlSqqUaOGa/mIESNUp04dxcXFSZLGjx+vnj176oUXXtCAAQM0Z84cbd261etPFy2Ki4NmdrZ0hX2hAABAMZXqXl0JCQlKTEx0ve/WrZtmz56tmTNnqm3btpo3b54WLlyYKyTZ4dJwAwAA7MGDMz12XCkkxPx+9qxUAt18AACoMC7n+7tU19yUJdTcAABQOhBuPIRwAwBA6UC48RDCDQAApQPhxkMINwAAlA6EGw8h3AAAUDoQbjyEcAMAQOlAuPGQi2ckJtwAAGAfwo2HOBwXAg7hBgAA+xBuPIiHZwIAYD/CjQcRbgAAsB/hxoMINwAA2I9w40GEGwAA7Ee48SDCDQAA9iPceBDhBgAA+xFuPIhwAwCA/Qg3HkS4AQDAfoQbDyLcAABgP8KNBxFuAACwH+HGgwg3AADYj3DjQYQbAADsR7jxIMINAAD2I9x4EOEGAAD7EW48iHADAID9CDcelBNusrLsLQcAABUZ4caDqLkBAMB+hBsPItwAAGA/wo0H+fqan4QbAADsQ7jxIGpuAACwH+HGgwg3AADYj3DjQYQbAADsR7jxIMINAAD2I9x4EOEGAAD7EW48iHADAID9bA03M2bMUJs2bRQcHKzg4GBFR0dryZIl+W4fHx8vh8Ph9qpcuXIJlrhghBsAAOxXyc6T161bV1OnTlWTJk1kWZb++9//atCgQdq+fbtatmyZ5z7BwcHat2+f673D4Sip4haKcAMAgP1sDTcDBw50e//ss89qxowZ2rhxY77hxuFwKCIioiSKd9kINwAA2K/U9LnJysrSnDlzdOrUKUVHR+e7XXp6uurXr6+oqCgNGjRIu3fvLsFSFoxwAwCA/WytuZGknTt3Kjo6WmfPnlXVqlW1YMECtWjRIs9tmzZtqnfffVdt2rRRamqqnn/+eXXr1k27d+9W3bp189wnIyNDGRkZrvdpaWleuQ6JcAMAQGlge81N06ZNtWPHDm3atEljxoxRbGysvv/++zy3jY6O1ogRI9SuXTv17NlT8+fPV61atfTWW2/le/y4uDiFhIS4XlFRUd66FMINAAClgO3hxt/fX40bN1aHDh0UFxentm3b6pVXXinSvn5+fmrfvr0OHDiQ7zYTJ05Uamqq63X48GFPFT0Xwg0AAPazPdxcKjs7260ZqSBZWVnauXOnIiMj893G6XS6hprnvLyFcAMAgP1s7XMzceJExcTEqF69ejp58qRmz56tVatWaenSpZKkESNGqE6dOoqLi5MkTZkyRV27dlXjxo2VkpKiadOm6dChQ7rrrrvsvAwXwg0AAPazNdwcPXpUI0aMUGJiokJCQtSmTRstXbpUN9xwgyQpISFBPj4XKpdOnDihu+++W0lJSQoNDVWHDh20fv36fDsglzTCDQAA9nNYlmXZXYiSlJaWppCQEKWmpnq8iWrAAGnxYmnWLGnkSI8eGgCACu1yvr9LXZ+bsoyaGwAA7Ee48SDCDQAA9iPceBDhBgAA+xFuPIhwAwCA/Qg3HkS4AQDAfoQbDyLcAABgP8KNB+WEm6wse8sBAEBFRrjxIGpuAACwH+HGg3x9zU/CDQAA9iHceBA1NwAA2I9w40GEGwAA7Ee48SDCDQAA9iPceBDhBgAA+xFuPIhwAwCA/Qg3HkS4AQDAfoQbDyLcAABgP8KNBxFuAACwH+HGgwg3AADYj3DjQYQbAADsR7jxIMINAAD2I9x4EOEGAAD7EW48iHADAID9CDceRLgBAMB+hBsPItwAAGA/wo0HEW4AALAf4caDCDcAANiPcONBhBsAAOxHuPEgwg0AAPYj3HgQ4QYAAPsRbjyIcAMAgP0INx5EuAEAwH6EGw8i3AAAYD/CjQflhJusLHvLAQBARUa48SBqbgAAsJ+t4WbGjBlq06aNgoODFRwcrOjoaC1ZsqTAfebOnatmzZqpcuXKat26tRYvXlxCpS0c4QYAAPvZGm7q1q2rqVOnatu2bdq6dauuv/56DRo0SLt3785z+/Xr12vYsGG68847tX37dg0ePFiDBw/Wrl27SrjkefP1NT8JNwAA2MdhWZZldyEuVr16dU2bNk133nlnrnW33HKLTp06pUWLFrmWde3aVe3atdObb75ZpOOnpaUpJCREqampCg4O9li5JWnGDOnee6WhQ6V58zx6aAAAKrTL+f4uNX1usrKyNGfOHJ06dUrR0dF5brNhwwb16dPHbVnfvn21YcOGfI+bkZGhtLQ0t5e30CwFAID9bA83O3fuVNWqVeV0OjV69GgtWLBALVq0yHPbpKQkhYeHuy0LDw9XUlJSvsePi4tTSEiI6xUVFeXR8l+McAMAgP1sDzdNmzbVjh07tGnTJo0ZM0axsbH6/vvvPXb8iRMnKjU11fU6fPiwx459KcINAAD2q2R3Afz9/dW4cWNJUocOHbRlyxa98soreuutt3JtGxERoeTkZLdlycnJioiIyPf4TqdTTqfTs4XOB+EGAAD72V5zc6ns7GxlZGTkuS46OlorVqxwW7Zs2bJ8++iUNMINAAD2s7XmZuLEiYqJiVG9evV08uRJzZ49W6tWrdLSpUslSSNGjFCdOnUUFxcnSRo/frx69uypF154QQMGDNCcOXO0detWzZw5087LcCHcAABgP1vDzdGjRzVixAglJiYqJCREbdq00dKlS3XDDTdIkhISEuTjc6FyqVu3bpo9e7Yef/xxPfroo2rSpIkWLlyoVq1a2XUJbgg3AADYr9TNc+Nt3pzn5qOPpNtuk3r3lpYv9+ihAQCo0MrkPDflATU3AADYj3DjQYQbAADsR7jxIMINAAD2I9x4EOEGAAD7EW48iHADAID9CDceRLgBAMB+hBsPItwAAGA/wo0HEW4AALAf4caDCDcAANiPcONBhBsAAOxHuPEgwg0AAPYj3HgQ4QYAAPsRbjyIcAMAgP0INx5EuAEAwH6EGw/KCTdZWfaWAwCAioxw40HU3AAAYD/CjQf5+pqfhBsAAOxDuPEgam4AALAf4caDCDcAANiPcONBhBsAAOxHuPEgwg0AAPYj3HgQ4QYAAPsRbjyIcAMAgP0INx5EuAEAwH6EGw8i3AAAYD/CjQcRbgAAsB/hxoMINwAA2I9w40GEGwAA7Ee48SDCDQAA9iPceBDhBgAA+xFuPIhwAwCA/Qg3HkS4AQDAfoQbDyLcAABgP1vDTVxcnDp16qSgoCCFhYVp8ODB2rdvX4H7xMfHy+FwuL0qV65cQiUumM9Fd9Oy7CsHAAAVma3hZvXq1Ro7dqw2btyoZcuWKTMzUzfeeKNOnTpV4H7BwcFKTEx0vQ4dOlRCJS7YxeGG2hsAAOxRyc6Tf/HFF27v4+PjFRYWpm3btqlHjx757udwOBQREeHt4l22S8ONr699ZQEAoKIqVX1uUlNTJUnVq1cvcLv09HTVr19fUVFRGjRokHbv3p3vthkZGUpLS3N7eQs1NwAA2K/UhJvs7GxNmDBB3bt3V6tWrfLdrmnTpnr33Xf1ySef6IMPPlB2dra6deumX375Jc/t4+LiFBIS4npFRUV56xIINwAAlAIOyyodXV/HjBmjJUuWaO3atapbt26R98vMzFTz5s01bNgwPf3007nWZ2RkKCMjw/U+LS1NUVFRSk1NVXBwsEfKnuP0aalKFfN7evqF3wEAwJVJS0tTSEhIkb6/be1zk2PcuHFatGiR1qxZc1nBRpL8/PzUvn17HThwIM/1TqdTTqfTE8UsFDU3AADYz9ZmKcuyNG7cOC1YsEArV65Uw4YNL/sYWVlZ2rlzpyIjI71QwstzcbjJyrKvHAAAVGS21tyMHTtWs2fP1ieffKKgoCAlJSVJkkJCQhQQECBJGjFihOrUqaO4uDhJ0pQpU9S1a1c1btxYKSkpmjZtmg4dOqS77rrLtuvIQc0NAAD2szXczJgxQ5LUq1cvt+WzZs3SyJEjJUkJCQnyuSg1nDhxQnfffbeSkpIUGhqqDh06aP369WrRokVJFTtfhBsAAOxXajoUl5TL6ZBUHA6H+ZmcLIWFefzwAABUSJfz/V1qhoKXFzxfCgAAexFuPIxwAwCAvYoVbg4fPuw2ad7mzZs1YcIEzZw502MFK6sINwAA2KtY4ea2227TV199JUlKSkrSDTfcoM2bN+uxxx7TlClTPFrAsoZwAwCAvYoVbnbt2qXOnTtLkj7++GO1atVK69ev14cffqj4+HhPlq/MIdwAAGCvYoWbzMxM16y/y5cv15/+9CdJUrNmzZSYmOi50pVBhBsAAOxVrHDTsmVLvfnmm/r666+1bNky9evXT5J05MgR1ahRw6MFLGsINwAA2KtY4ea5557TW2+9pV69emnYsGFq27atJOnTTz91NVdVVIQbAADsVawZinv16qVjx44pLS1NoaGhruX33HOPAgMDPVa4sign3PBsKQAA7FGsmpszZ84oIyPDFWwOHTqkl19+Wfv27VNYBZ+W18/P/Dx/3t5yAABQURUr3AwaNEjvvfeeJCklJUVdunTRCy+8oMGDB7ueF1VRVfq/urDMTHvLAQBARVWscPPNN9/o2muvlSTNmzdP4eHhOnTokN577z29+uqrHi1gWZNTc0O4AQDAHsUKN6dPn1ZQUJAk6csvv9RNN90kHx8fde3aVYcOHfJoAcsamqUAALBXscJN48aNtXDhQh0+fFhLly7VjTfeKEk6evSoV560XZbQLAUAgL2KFW6efPJJPfTQQ2rQoIE6d+6s6OhoSaYWp3379h4tYFlDsxQAAPYq1lDwP//5z7rmmmuUmJjomuNGknr37q0hQ4Z4rHBlEc1SAADYq1jhRpIiIiIUERHhejp43bp1K/wEfhLNUgAA2K1YzVLZ2dmaMmWKQkJCVL9+fdWvX1/VqlXT008/rewKPjUvzVIAANirWDU3jz32mN555x1NnTpV3bt3lyStXbtWkyZN0tmzZ/Xss896tJBlCc1SAADYq1jh5r///a/efvtt19PAJalNmzaqU6eO7r333godbmiWAgDAXsVqljp+/LiaNWuWa3mzZs10/PjxKy5UWUazFAAA9ipWuGnbtq2mT5+ea/n06dPVpk2bKy5UWUazFAAA9ipWs9S///1vDRgwQMuXL3fNcbNhwwYdPnxYixcv9mgByxqapQAAsFexam569uyp/fv3a8iQIUpJSVFKSopuuukm7d69W++//76ny1im0CwFAIC9ij3PTe3atXN1HP7222/1zjvvaObMmVdcsLIqp+aGZikAAOxRrJob5I+aGwAA7EW48TDCDQAA9iLceBjNUgAA2Ouy+tzcdNNNBa5PSUm5krKUC9TcAABgr8sKNyEhIYWuHzFixBUVqKwj3AAAYK/LCjezZs3yVjnKDZqlAACwF31uPIyaGwAA7EW48TDCDQAA9rI13MTFxalTp04KCgpSWFiYBg8erH379hW639y5c9WsWTNVrlxZrVu3LlWPfKBZCgAAe9kablavXq2xY8dq48aNWrZsmTIzM3XjjTfq1KlT+e6zfv16DRs2THfeeae2b9+uwYMHa/Dgwdq1a1cJljx/1NwAAGAvh2VZlt2FyPHbb78pLCxMq1evVo8ePfLc5pZbbtGpU6e0aNEi17KuXbuqXbt2evPNNws9R1pamkJCQpSamqrg4GCPlT3HK69IEyZIw4ZJs2d7/PAAAFRIl/P9Xar63KSmpkqSqlevnu82GzZsUJ8+fdyW9e3bVxs2bMhz+4yMDKWlpbm9vImnggMAYK9SE26ys7M1YcIEde/eXa1atcp3u6SkJIWHh7stCw8PV1JSUp7bx8XFKSQkxPWKioryaLkvRbMUAAD2KjXhZuzYsdq1a5fmzJnj0eNOnDhRqamprtfhw4c9evxL5YQbOhQDAGCPy5rEz1vGjRunRYsWac2aNapbt26B20ZERCg5OdltWXJysiIiIvLc3ul0yul0eqyshaFZCgAAe9lac2NZlsaNG6cFCxZo5cqVatiwYaH7REdHa8WKFW7Lli1bpujoaG8V87LQLAUAgL1srbkZO3asZs+erU8++URBQUGufjMhISEKCAiQJI0YMUJ16tRRXFycJGn8+PHq2bOnXnjhBQ0YMEBz5szR1q1bNXPmTNuu42I0SwEAYC9ba25mzJih1NRU9erVS5GRka7X//73P9c2CQkJSkxMdL3v1q2bZs+erZkzZ6pt27aaN2+eFi5cWGAn5JJEsxQAAPayteamKFPsrFq1Kteyv/zlL/rLX/7ihRJdOZqlAACwV6kZLVVe8PgFAADsRbjxMGpuAACwF+HGwwg3AADYi3DjYTRLAQBgL8KNh1FzAwCAvQg3Hka4AQDAXoQbD6NZCgAAexFuPIyaGwAA7EW48TDCDQAA9iLceBjNUgAA2Itw42EX19wU4ekSAADAwwg3HpYTbiQpK8u+cgAAUFERbjys0kWPIqVpCgCAkke48bCLa27oVAwAQMkj3HjYxeGGmhsAAEoe4cbDfH0v/E7NDQAAJY9w42EOx4V+N4QbAABKHuHGC3KapmiWAgCg5BFuvICaGwAA7EO48QIewQAAgH0IN15AsxQAAPYh3HgBzVIAANiHcOMFNEsBAGAfwo0X8GRwAADsQ7jxAmpuAACwD+HGCwg3AADYh3DjBTRLAQBgH8KNF1BzAwCAfQg3XkC4AQDAPoQbL6BZCgAA+xBuvICaGwAA7EO48QLCDQAA9iHceAHNUgAA2MfWcLNmzRoNHDhQtWvXlsPh0MKFCwvcftWqVXI4HLleSUlJJVPgIqLmBgAA+9gabk6dOqW2bdvq9ddfv6z99u3bp8TERNcrLCzMSyUsHsINAAD2qWTnyWNiYhQTE3PZ+4WFhalatWqeL5CH5ISbc+fsLQcAABVRmexz065dO0VGRuqGG27QunXr7C5OLgEB5ufZs/aWAwCAisjWmpvLFRkZqTfffFMdO3ZURkaG3n77bfXq1UubNm3S1Vdfnec+GRkZysjIcL1PS0vzejkDA83P06e9fioAAHCJMhVumjZtqqZNm7red+vWTT/++KNeeuklvf/++3nuExcXp8mTJ5dUESURbgAAsFOZbJa6WOfOnXXgwIF810+cOFGpqamu1+HDh71eppxwc+aM108FAAAuUaZqbvKyY8cORUZG5rve6XTK6XSWYIku9Lmh5gYAgJJna7hJT093q3U5ePCgduzYoerVq6tevXqaOHGifv31V7333nuSpJdfflkNGzZUy5YtdfbsWb399ttauXKlvvzyS7suIU80SwEAYB9bw83WrVt13XXXud4/8MADkqTY2FjFx8crMTFRCQkJrvXnzp3Tgw8+qF9//VWBgYFq06aNli9f7naM0oBwAwCAfRyWZVl2F6IkpaWlKSQkRKmpqQoODvbKOebNk/7yF6lHD2n1aq+cAgCACuVyvr/LfIfi0og+NwAA2Idw4wU0SwEAYB/CjRcQbgAAsA/hxguY5wYAAPsQbryAPjcAANiHcOMFFzdLVayxaAAA2I9w4wU54SYrS8rMtLcsAABUNIQbL8gJNxL9bgAAKGmEGy/w85N8/u/O0u8GAICSRbjxAoeD4eAAANiFcOMlhBsAAOxBuPES5roBAMAehBsvYa4bAADsQbjxEpqlAACwB+HGSwg3AADYg3DjJfS5AQDAHoQbL6HPDQAA9iDceAnNUgAA2INw4yU0SwEAYA/CjZdQcwMAgD0IN15CnxsAAOxBuPESam4AALAH4cZL6HMDAIA9CDdeQs0NAAD2INx4CX1uAACwB+HGS6pWNT9PnrS3HAAAVDSEGy+pUcP8/P13e8sBAEBFQ7jxkpo1zU/CDQAAJYtw4yU5NTfHjkmWZW9ZAACoSAg3XpITbrKypNRUe8sCAEBFQrjxksqVL3QqpmkKAICSQ7jxooubpgAAQMkg3HhRTqdiwg0AACWHcONFhBsAAEqereFmzZo1GjhwoGrXri2Hw6GFCxcWus+qVat09dVXy+l0qnHjxoqPj/d6OYuL4eAAAJQ8W8PNqVOn1LZtW73++utF2v7gwYMaMGCArrvuOu3YsUMTJkzQXXfdpaVLl3q5pMVDnxsAAEpeJTtPHhMTo5iYmCJv/+abb6phw4Z64YUXJEnNmzfX2rVr9dJLL6lv377eKmax0SwFAEDJK1N9bjZs2KA+ffq4Levbt682bNiQ7z4ZGRlKS0tze5UUwg0AACWvTIWbpKQkhYeHuy0LDw9XWlqazpw5k+c+cXFxCgkJcb2ioqJKoqiS6HMDAIAdylS4KY6JEycqNTXV9Tp8+HCJnZs+NwAAlDxb+9xcroiICCUnJ7stS05OVnBwsAICAvLcx+l0yul0lkTxcqFZCgCAklemam6io6O1YsUKt2XLli1TdHS0TSUq2MXNUjw8EwCAkmFruElPT9eOHTu0Y8cOSWao944dO5SQkCDJNCmNGDHCtf3o0aP1008/6Z///Kf27t2rN954Qx9//LHuv/9+O4pfqFq1JB8f8/DMxES7SwMAQMVga7jZunWr2rdvr/bt20uSHnjgAbVv315PPvmkJCkxMdEVdCSpYcOG+vzzz7Vs2TK1bdtWL7zwgt5+++1SOQxckvz8pIYNze8//GBvWQAAqCgcllWxGkzS0tIUEhKi1NRUBQcHe/18/ftLS5ZIM2dKd9/t9dMBAFAuXc73d5nqc1MW/eEP5uf+/faWAwCAioJw42U54YZmKQAASgbhxsuaNDE/qbkBAKBkEG68LKfm5sABM2oKAAB4F+HGy6KiJKdTysyUDh2yuzQAAJR/hBsv8/G50DS1d6+9ZQEAoCIg3JSA/5vGR+vW2VsOAAAqAsJNCbj+evNz5Up7ywEAQEVAuCkB111nfm7ZIp08aW9ZAAAo7wg3JaB+femqq8xoqa+/trs0AACUb4SbEpJTe7N0qb3lAACgvCPclJAhQ8zPd9+Vjh+3tywAAJRnhJsS0r+/1LatlJ4uvfqq3aUBAKD8ItyUEIdDeuwx8/u//y1t22ZveQAAKK8INyVo6FCpXz/pzBnpT3+Sdu2yu0QAAJQ/hJsS5OMjzZkjtWwpHTkide0qvfOO9Ntv0rlzdpcOAIDyoZLdBahoQkKk1aulW26RVqyQ7rrLLK9eXXr4YaldO7PN//4nffut1KKFNGGC1KiRnaUGAKDscFiWZdldiJKUlpamkJAQpaamKjg42LZynD8vvfKKNHly4RP7BQdL06dLqanSvHnSTTdJlSqZB3GGhEh//7tUtWrJlBsAADtczvc34cZmWVnmFR8vzZ0rJSebZqrmzU3tznvvSevXF3yM5s2lGjWkX36RGjaUZs0yEwdKJhAtXWr6+bRqZWqGfH29fVUAAHgW4aYApS3cFCYzU3r+eelf/zK//+1v0tatphnrqqtM81Vysvs+1aqZ/j19+phHPhw8eGFd27bSggVSdrZpBuvYUXrwQcnPz6w/dszUClWrVlJXCABA4Qg3BShr4SZHaqrpdFyrlvvyX36RPvxQCg83tTX33ivt3eu+TZ06UtOm0ubNZp6dgAATYHKawyIiTCfnHj2kqVNNk9nQodJbb5kmMQAA7Ea4KUBZDTdFlZIiffmlFBRkamb8/KRPPzUB59dfTWjZtMls27atdPhw/jMm33CD9MQTpjkrMNB0gN6yxTRt/fGPpnkrM1N64w0TkP7yF1NjBACApxFuClDew82lLMtMIJgjO1vas0dKSpKuucbUBu3aJS1eLL3/vnTbbVLfvtKAAdKpU2afwEDTYfno0QvHadrUBJ9Zs0zokaQ2bcw8Pu+8Y5rN+vSROnSQYmNNTREAAMVFuClARQs3xbV8uTRxoglBv/xiltWuLV17rVn3++8Xtq1SxdTipKXlfazRo02gSk2VGjQwYWvvXqluXTOZ4e2308cHAFAwwk0BCDeXx7KktWtNX50+fUwzV0qKeZTEmjWmyerRR03wee456YsvpPvuM0PU16yRXn+98HNUrix162aCVM2a5pjbt0v332+aunKkpZnmtpyaqEtrpQAA5RfhpgCEm5J1++2mw3NUlDRtmqnxOX9eatzY1ObMmiXt3p3//sOHm+HtCxea5rP27U1fotWrpQ8+MMFq9Gjp44+l77+Xbr7ZNL0FBZk+RefPS5MmmWMMH26GxV93HR2lAaCsIdwUgHBTss6cMcPVb7zR1O5cyrJMaFm3znR63rvXTE5oWaajclEEBJjzXMzXV/p//09KSDCTHDocpnZo3ToTtP77XxNy8nPy5IVRZQAA+xFuCkC4KTs2bjQ1O2lpUv/+UufO0osvmjAUGmr67Lz1ltm2ZUszmeH/+3+m6ezcOfOzUqXcwUeSnE4zv89PP0kZGVKvXtJf/2o6T7/6quksHREhDRtmjnHbbSYUbdok7d9vht737Wv6GwEAvI9wUwDCTflhWWZCwmrVTC2Mw2FqXCpVku64wzykVDLP59q/3zRR3XefGf6+cGHu4/n4mHCTnl6080dFmQkWhw41s0z7+xe93BL9hQDgchBuCkC4qRgsy3RKXr3a9MP58ktT6/Liiyb8/P3vptbm+utNqHn/fVMjJJnmsaeeks6eNf2BEhPNXEGSaVq7+mrzUNPDh80yX1/zmjDBHG/pUrNu0CBpxw7T1Fatmulcffy4eTZYQoJ5vMakSVJYmDnP55+bGqdrrjFNaEUNSwBQERBuCkC4QX4SEkz4aNMm92SEx4+bUJTzkTl92nSQfvXV/CdBvFTnziZQHTt2YVnt2mZOofh4MyFiDh8f03n6uedMYHr9dRPYxowxASggwHSo3r/fDNkPDDT7WZZpxgsJKfZtAIBSiXBTAMINPOnsWfNsr82bpUceMTU43bqZB5l+9pkZsdW9uxnVldP35+qrpccfN8Pp9+y5cKzrrjP9fJYudQ9MnTub40umlufoUdPfKGf+oU6dzLmyskxt0IYN0rPPSgcOmFFid91lnib/5z+b/TMzzbD7qKiSuUcA4AmEmwIQbmCHzz4zTVC33iqNH2+anNLTTYfpw4fNTM4332z64ViWCS7//veFGhsp96gwHx8zc3Ramgkq6enSiRO5z12zpqktuuoq0z/ogw9MU9tdd0k7d5p1/fubh7NWrWr22bHD9Es6e9aEs3r1zHY9elx4yOru3WZE29/+Zmq7AMCbyly4ef311zVt2jQlJSWpbdu2eu2119S5c+c8t42Pj9eoUaPcljmdTp09e7ZI5yLcoCxZudL0/7nlFvM8r40bTU3O6tXmERg1a0oxMaa5SzJNWV27SjNmmCavI0eKfq5rrjEPZt2/P/+5h+rWNc8Wu/ZaE7wSEkyT2EcfmfXx8SYkNWt2YZ+sLDO5Y6tW5uGuF9u40fSFGj3ajGADgPyUqXDzv//9TyNGjNCbb76pLl266OWXX9bcuXO1b98+hYWF5do+Pj5e48eP1759+1zLHA6HwsPDi3Q+wg3Km99+k55+2jSBjRxpmsYOHTIdox9+WPr6a9M/6I03TBCJiTE1Mg88YOYfGjxYuvNO91FiDoc0ZIhpXps/39QYBQS4P3ZDMv2Qzp+/MCT+1CkpMtKct3Vrs+6hh0wNUViY6ci9cqUZZXbsmAls585J//iHqamyLPPiAawALlWmwk2XLl3UqVMnTZ8+XZKUnZ2tqKgo3XfffXrkkUdybR8fH68JEyYoJSWlWOcj3ADGxY+v+Ppr0w+oWzfT96dZM9MUlbNddrbpq7N8uanZeeEF0xy2fLn0z3+aR21cjtBQE6ZyOlE7HGbo/quvmifP169vmub+8AfTp6hWLTMX0euvm07dt90mNWpk1ktmdupPP5U6djRzEDVrdqH5DED5UGbCzblz5xQYGKh58+Zp8ODBruWxsbFKSUnRJ598kmuf+Ph43XXXXapTp46ys7N19dVX61//+pdatmyZ5zkyMjKUkZHhep+WlqaoqCjCDXAFMjNNbU5wsHTwoGmqysgwnaEXLTLLVq408w6NHSuNGGGeGH/8uOkf9PPP5jh/+Yt5ttj77xevHJMmmUA2dKg5V47mzc1s1KGhpsnu1VfNUP+aNc0+FzebXWz6dOmZZ8wM1n37Fq9MALzjcsKNrZPLHzt2TFlZWbmalMLDw7V3794892natKneffddtWnTRqmpqXr++efVrVs37d69W3Xr1s21fVxcnCZPnuyV8gMVlZ/fhZqRhg3NvD+ZmVKTJlLPnmZ5Zqap8cnpS7NnjxlC7+dnmtF69jTh5uxZ07w1a5ZpBvvoI9MsFRkpffONCUKbNpl5gAYMMM1b69eb+YMmTbpQphYtzLESE825hg83/Xxeesk0j+X48ssL4SY62sxP9M03pp/R44+bmqrYWLPs2DHTf+nDD811BgdL1aubsJTTf+j4cTOZ5Nmzpnx165rr9NS/nX77zYQyJn0Eis7WmpsjR46oTp06Wr9+vaKjo13L//nPf2r16tXatGlTocfIzMxU8+bNNWzYMD399NO51lNzA5QNmzebWp3IyLzXnz1ranlyTJggvfKK+f2uu0yti9Mpbd1qanMunjeob1/zKI0ZM0xQKoifn/u+ealWzdRU9e5tmun27zfLAwJMqDlxQnrySTPkvm5dc7wFC8y1+fiY6QAefdSMmtu2Tfr1V/Ow18hIE7w2bTKj6L76yvSVuusuaebM3AHHssy+efy7Dih3ykzNTc2aNeXr66vk5GS35cnJyYqIiCjSMfz8/NS+fXsdOHAgz/VOp1NOhmEApV4+AyRdLg42kukk3bKlqbHp3v3C8o4dzTD2t94yNTZ33206TUumpujjjy/M/vzSSyYMhYeb87dqJQ0caDpdp6aajtK9epnHdqSkmM7Pzz13YTTZokXmZ+3apgP3li0Xhus//nj+1/LZZ9KyZVLjxtJ777mvGzhQWrzYjDLL8fbbpg9UTkhq1swEnldflZYsMQGoXz8TrsLDTRPd009LXbqYJjvJ9HEKDCxaZ+3sbBPeAgIK3xYojUpFh+LOnTvrtddek2Q6FNerV0/jxo3Ls0PxpbKystSyZUv1799fL774YqHb06EYQI6sLDOnT8uW7uEpK8s0LeX1ZPjUVBNqQkKkceNMCFi50gSO+fPNaLWcTtfXX28eA5Kaaob05zz/bMoU9z5CV19tQtHevRfmNapb19TmtGt3Yah9fvz8TJkDA81Q/NmzzTB7Hx8zr9Lnn5syVa1qpgu4+moTrDZulH780VxrzZpmZNuxY+YYP/0krVhhpgj45BPTafuBBy40M+7YYcJc374XOp9f6vffzb0ZMODCLNpAcZWZDsWSGQoeGxurt956S507d9bLL7+sjz/+WHv37lV4eLhGjBihOnXqKC4uTpI0ZcoUde3aVY0bN1ZKSoqmTZumhQsXatu2bWrRokWh5yPcAPCUrCxTm3JprdLF8hre/sMPptNyQoJ5Gv0NN5jls2ebGpkbb5TmzjWhJTvbDJPPGV9xzz2mCe/NN8371q3NUHtv6NXLhLacc912m5lfae1aU/MkmTB3/fVmGoLBg80IunPnTNPdvfeaa+3QwTThNWpkyi+ZsBQamn8zJHCpMhVuJGn69OmuSfzatWunV199VV26dJEk9erVSw0aNFB8fLwk6f7779f8+fOVlJSk0NBQdejQQc8884zat29fpHMRbgCUZpf2LcrP5s1m2zZtTPNat27m0RtLl5rOz/ffb15795omqj//2Twq5JtvzGv/flMr1LGjqSH69FPTb6lWLRNApk690DTmcJhXdrZ7GS4NVlWrus+XlJcZM0wt1QMPmNqcMWPMeX77zTxapEEDE6ri403zYb9+ZpuDB02NUtOm5lqOHTMd2KtXz30OyzIj8ObNM3Mr9elT+P2UzDE3bjTnvLTGDvYrc+GmJBFuAFQUOWGkqJMiHj5smqcCAqTbbzejxCpVMv2UfvvNTLZ4zTXmcR29epnmvP37zSM9pk41tVh16phwdeiQabp79VVTI/XLL2am6uKoX9+Mgjt3Lve6m24yAWftWhPSqlUznay///7CNo8+aob4r1kjPfigNGqUmaJAMs2DAQHm+D17mhDVt6+55qAgE8QCA03frD17TBNcjx5m33/9ywSioUNNv6+LO3yfO2fCVffuuWfmRvEQbgpAuAGAwh06ZGp+7rzT9JkpzJ49pn/RTTflXfOUnW2atf73PxMaHn/c1BKtW2emAAgLM81Us2aZ2pP77jOdo5955kJtUK9epikvIcHsm5iYf3n8/U1I+ewz8/7OO02fqJznrz34oAkteQyylWQeY+JwSN99Z0bavf32hXV9+pjJLh977MKyzp1NbVNkpHk8yhtvmGkHatc29yUszAQlPz8TGLOyTO3SjBnSn/5k+jv98IMJjJeOiktKMqGtKDV65RnhpgCEGwCwh2WZR3RUqVLwvD0XN80lJppwcNVVpklKMsGgUiXTJPbAAyYwjBljjp+SYsJShw5SRISZLmDChAvHrlfPhKO8NG8uPfusCUJ5PYS2a1dTlotmF1H37mY4/9mzZuReSkruZ7o1bWrCzdq1powDB5opAz78MHe5BgwwtWZdu5owNH686d/Urp0JTcHBF2YXP3rUzPn088+mT1Pr1uZYv/9upiJo2dLURiUnmzCWnS0dOGBq8po0Me/Pnzf3L+fvcf68+d3X17w/eNDU5gUFFb3J1FsINwUg3ABAxTJ3rhnhdu6c6Z/05ZcmwJw/b2qQxo83X+YhIeaLf+NGU+sTHGxGl332mWnKeucd82V/221mLqK2bc2IseRk8/vx4+Z84eGmL9OwYSZ0XTwy7mK+vqYWaOnS3Ov8/EyQuHjfVq1M7dbmzWbUWlqa+z6RkabTdmZm7vmc+vUzk20mJpprfOYZU7v0yy9m5NysWSbADBhgwuHy5SZU3X67CYnt2pn3U6eacOrvbya79PU1YevVV80Iu6lTTfDMzPT8CDnCTQEINwCArVtNM9Ctt+Zdi5SWZgKEv78ZLt+o0YXtzp0zcxFde60JApL07rsmMAUEmMCT80SgX34xw+FPnzaBKTHR1NwcP26aumJjpb/9zdS23Hefmfdo925TGySZjuL//KcJSjlzKF2sdWsTapYvd+/wHRRkwkVKiqnpypml29fXfQ6li+U8CFcyNWUJCe6ze1+qWTNz7jNnTCCUTIg6eNCMnJs6Nf99i4NwUwDCDQDA0yzLBJw//MGEnoIkJ5umq/wG+VqWqWn6/nsz2isgwDQnrVxpanM6dTJBKzTU1BJJponql19MU9eyZaaG6vrrTThZt06aPNlMTjlmjFm+dauZ72j+fPMYk//+15y3QwcT5nKeTX3LLSYQ/fKLCVKvv25qck6dcq9VqlTJ7J8TnKKiTHj05By6hJsCEG4AABXZsWMm1Pz5zxeG0qemSvv2mea1AwekVasu9Nm5uGZr714TXE6eNKPPsrJMLdQ115gaqylTzHPdHnnENPN5EuGmAIQbAADKnsv5/i7i7AcAAABlA+EGAACUK4QbAABQrhBuAABAuUK4AQAA5QrhBgAAlCuEGwAAUK4QbgAAQLlCuAEAAOUK4QYAAJQrhBsAAFCuEG4AAEC5QrgBAADlCuEGAACUK5XsLkBJsyxLknl0OgAAKBtyvrdzvscLUuHCzcmTJyVJUVFRNpcEAABcrpMnTyokJKTAbRxWUSJQOZKdna0jR44oKChIDofDY8dNS0tTVFSUDh8+rODgYI8dt6yo6NcvcQ8q+vVL3AOJe1DRr1/y3j2wLEsnT55U7dq15eNTcK+aCldz4+Pjo7p163rt+MHBwRX2Ay1x/RL3oKJfv8Q9kLgHFf36Je/cg8JqbHLQoRgAAJQrhBsAAFCuEG48xOl06qmnnpLT6bS7KLao6NcvcQ8q+vVL3AOJe1DRr18qHfegwnUoBgAA5Rs1NwAAoFwh3AAAgHKFcAMAAMoVwg0AAChXCDce8Prrr6tBgwaqXLmyunTpos2bN9tdJK+YNGmSHA6H26tZs2au9WfPntXYsWNVo0YNVa1aVUOHDlVycrKNJb5ya9as0cCBA1W7dm05HA4tXLjQbb1lWXryyScVGRmpgIAA9enTRz/88IPbNsePH9fw4cMVHBysatWq6c4771R6enoJXsWVKewejBw5Mtfnol+/fm7blOV7EBcXp06dOikoKEhhYWEaPHiw9u3b57ZNUT77CQkJGjBggAIDAxUWFqZ//OMfOn/+fEleSrEU5fp79eqV6zMwevRot23K6vVL0owZM9SmTRvXpHTR0dFasmSJa315/vvnKOwelLrPgIUrMmfOHMvf39969913rd27d1t33323Va1aNSs5OdnuonncU089ZbVs2dJKTEx0vX777TfX+tGjR1tRUVHWihUrrK1bt1pdu3a1unXrZmOJr9zixYutxx57zJo/f74lyVqwYIHb+qlTp1ohISHWwoULrW+//db605/+ZDVs2NA6c+aMa5t+/fpZbdu2tTZu3Gh9/fXXVuPGja1hw4aV8JUUX2H3IDY21urXr5/b5+L48eNu25Tle9C3b19r1qxZ1q5du6wdO3ZY/fv3t+rVq2elp6e7tinss3/+/HmrVatWVp8+fazt27dbixcvtmrWrGlNnDjRjku6LEW5/p49e1p3332322cgNTXVtb4sX79lWdann35qff7559b+/futffv2WY8++qjl5+dn7dq1y7Ks8v33z1HYPShtnwHCzRXq3LmzNXbsWNf7rKwsq3bt2lZcXJyNpfKOp556ymrbtm2e61JSUiw/Pz9r7ty5rmV79uyxJFkbNmwooRJ616Vf7NnZ2VZERIQ1bdo017KUlBTL6XRaH330kWVZlvX9999bkqwtW7a4tlmyZInlcDisX3/9tcTK7in5hZtBgwblu095uwdHjx61JFmrV6+2LKton/3FixdbPj4+VlJSkmubGTNmWMHBwVZGRkbJXsAVuvT6Lct8sY0fPz7ffcrT9ecIDQ213n777Qr3979Yzj2wrNL3GaBZ6gqcO3dO27ZtU58+fVzLfHx81KdPH23YsMHGknnPDz/8oNq1a+uqq67S8OHDlZCQIEnatm2bMjMz3e5Fs2bNVK9evXJ7Lw4ePKikpCS3aw4JCVGXLl1c17xhwwZVq1ZNHTt2dG3Tp08f+fj4aNOmTSVeZm9ZtWqVwsLC1LRpU40ZM0a///67a115uwepqamSpOrVq0sq2md/w4YNat26tcLDw13b9O3bV2lpadq9e3cJlv7KXXr9OT788EPVrFlTrVq10sSJE3X69GnXuvJ0/VlZWZozZ45OnTql6OjoCvf3l3Lfgxyl6TNQ4R6c6UnHjh1TVlaW2x9LksLDw7V3716bSuU9Xbp0UXx8vJo2barExERNnjxZ1157rXbt2qWkpCT5+/urWrVqbvuEh4crKSnJngJ7Wc515fX3z1mXlJSksLAwt/WVKlVS9erVy8196devn2666SY1bNhQP/74ox599FHFxMRow4YN8vX1LVf3IDs7WxMmTFD37t3VqlUrSSrSZz8pKSnPz0nOurIir+uXpNtuu03169dX7dq19d133+nhhx/Wvn37NH/+fEnl4/p37typ6OhonT17VlWrVtWCBQvUokUL7dixo8L8/fO7B1Lp+wwQblBkMTExrt/btGmjLl26qH79+vr4448VEBBgY8lgp1tvvdX1e+vWrdWmTRs1atRIq1atUu/evW0smeeNHTtWu3bt0tq1a+0uii3yu/577rnH9Xvr1q0VGRmp3r1768cff1SjRo1Kuphe0bRpU+3YsUOpqamaN2+eYmNjtXr1aruLVaLyuwctWrQodZ8BmqWuQM2aNeXr65urV3xycrIiIiJsKlXJqVatmv7whz/owIEDioiI0Llz55SSkuK2TXm+FznXVdDfPyIiQkePHnVbf/78eR0/frzc3perrrpKNWvW1IEDBySVn3swbtw4LVq0SF999ZXq1q3rWl6Uz35ERESen5OcdWVBftefly5dukiS22egrF+/v7+/GjdurA4dOiguLk5t27bVK6+8UmH+/lL+9yAvdn8GCDdXwN/fXx06dNCKFStcy7Kzs7VixQq3dsjyKj09XT/++KMiIyPVoUMH+fn5ud2Lffv2KSEhodzei4YNGyoiIsLtmtPS0rRp0ybXNUdHRyslJUXbtm1zbbNy5UplZ2e7/uMvb3755Rf9/vvvioyMlFT274FlWRo3bpwWLFiglStXqmHDhm7ri/LZj46O1s6dO91C3rJlyxQcHOyq1i+tCrv+vOzYsUOS3D4DZfX685Odna2MjIxy//cvSM49yIvtnwGPd1GuYObMmWM5nU4rPj7e+v7776177rnHqlatmluP8PLiwQcftFatWmUdPHjQWrdundWnTx+rZs2a1tGjRy3LMsMh69WrZ61cudLaunWrFR0dbUVHR9tc6itz8uRJa/v27db27dstSdaLL75obd++3Tp06JBlWWYoeLVq1axPPvnE+u6776xBgwblORS8ffv21qZNm6y1a9daTZo0KTPDoC2r4Htw8uRJ66GHHrI2bNhgHTx40Fq+fLl19dVXW02aNLHOnj3rOkZZvgdjxoyxQkJCrFWrVrkNcz19+rRrm8I++znDYG+88UZrx44d1hdffGHVqlWrTAwFLuz6Dxw4YE2ZMsXaunWrdfDgQeuTTz6xrrrqKqtHjx6uY5Tl67csy3rkkUes1atXWwcPHrS+++4765FHHrEcDof15ZdfWpZVvv/+OQq6B6XxM0C48YDXXnvNqlevnuXv72917tzZ2rhxo91F8opbbrnFioyMtPz9/a06depYt9xyi3XgwAHX+jNnzlj33nuvFRoaagUGBlpDhgyxEhMTbSzxlfvqq68sSblesbGxlmWZ4eBPPPGEFR4ebjmdTqt3797Wvn373I7x+++/W8OGDbOqVq1qBQcHW6NGjbJOnjxpw9UUT0H34PTp09aNN95o1apVy/Lz87Pq169v3X333bnCfVm+B3lduyRr1qxZrm2K8tn/+eefrZiYGCsgIMCqWbOm9eCDD1qZmZklfDWXr7DrT0hIsHr06GFVr17dcjqdVuPGja1//OMfbnOcWFbZvX7Lsqw77rjDql+/vuXv72/VqlXL6t27tyvYWFb5/vvnKOgelMbPgMOyLMvz9UEAAAD2oM8NAAAoVwg3AACgXCHcAACAcoVwAwAAyhXCDQAAKFcINwAAoFwh3AAAgHKFcAOgQnI4HFq4cKHdxQDgBYQbACVu5MiRcjgcuV79+vWzu2gAyoFKdhcAQMXUr18/zZo1y22Z0+m0qTQAyhNqbgDYwul0KiIiwu0VGhoqyTQZzZgxQzExMQoICNBVV12lefPmue2/c+dOXX/99QoICFCNGjV0zz33KD093W2bd999Vy1btpTT6VRkZKTGjRvntv7YsWMaMmSIAgMD1aRJE3366aeudSdOnNDw4cNVq1YtBQQEqEmTJrnCGIDSiXADoFR64oknNHToUH377bcaPny4br31Vu3Zs0eSdOrUKfXt21ehoaHasmWL5s6dq+XLl7uFlxkzZmjs2LG65557tHPnTn366adq3Lix2zkmT56sm2++Wd9995369++v4cOH6/jx467zf//991qyZIn27NmjGTNmqGbNmiV3AwAUn1cexwkABYiNjbV8fX2tKlWquL2effZZy7LMk6hHjx7ttk+XLl2sMWPGWJZlWTNnzrRCQ0Ot9PR01/rPP//c8vHxcT2RvHbt2tZjjz2WbxkkWY8//rjrfXp6uiXJWrJkiWVZljVw4EBr1KhRnrlgACWKPjcAbHHddddpxowZbsuqV6/u+j06OtptXXR0tHbs2CFJ2rNnj9q2basqVaq41nfv3l3Z2dnat2+fHA6Hjhw5ot69exdYhjZt2rh+r1KlioKDg3X06FFJ0pgxYzR06FB98803uvHGGzV48GB169atWNcKoGQRbgDYokqVKrmaiTwlICCgSNv5+fm5vXc4HMrOzpYkxcTE6NChQ1q8eLGWLVum3r17a+zYsXr++ec9Xl4AnkWfGwCl0saNG3O9b968uSSpefPm+vbbb3Xq1CnX+nXr1snHx0dNmzZVUFCQGjRooBUrVlxRGWrVqqXY2Fh98MEHevnllzVz5swrOh6AkkHNDQBbZGRkKCkpyW1ZpUqVXJ12586dq44dO+qaa67Rhx9+qM2bN+udd96RJA0fPlxPPfWUYmNjNWnSJP3222+677779Ne//lXh4eGSpEmTJmn06NEKCwtTTEyMTp48qXXr1um+++4rUvmefPJJdejQQS1btlRGRoYWLVrkClcASjfCDQBbfPHFF4qMjHRb1rRpU+3du1eSGck0Z84c3XvvvYqMjNRHH32kFi1aSJICAwO1dOlSjR8/Xp06dVJgYKCGDh2qF1980XWs2NhYnT17Vi+99JIeeugh1axZU3/+85+LXD5/f39NnDhRP//8swICAnTttddqzpw5HrhyAN7msCzLsrsQAHAxh8OhBQsWaPDgwXYXBUAZRJ8bAABQrhBuAABAuUKfGwClDq3lAK4ENTcAAKBcIdwAAIByhXADAADKFcINAAAoVwg3AACgXCHcAACAcoVwAwAAyhXCDQAAKFcINwAAoFz5/yAGBPoE3TruAAAAAElFTkSuQmCC\n" + }, + "metadata": {} + } + ] + }, + { + "cell_type": "code", + "source": [], + "metadata": { + "id": "dXPARCwlSpqV" + }, + "execution_count": 28, + "outputs": [] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "machine_shape": "hm", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/baselines/fedavgm/_static/Figure6_cifar10_num-rounds=1000_concentration=1.png b/baselines/fedavgm/_static/Figure6_cifar10_num-rounds=1000_concentration=1.png new file mode 100644 index 000000000000..1668474caed6 Binary files /dev/null and b/baselines/fedavgm/_static/Figure6_cifar10_num-rounds=1000_concentration=1.png differ diff --git a/baselines/fedavgm/_static/concentration_cifar10.png b/baselines/fedavgm/_static/concentration_cifar10.png new file mode 100644 index 000000000000..0755ef8d66be Binary files /dev/null and b/baselines/fedavgm/_static/concentration_cifar10.png differ diff --git a/baselines/fedavgm/_static/concentration_cifar10_v2.png b/baselines/fedavgm/_static/concentration_cifar10_v2.png new file mode 100644 index 000000000000..bd3b9db1ff11 Binary files /dev/null and b/baselines/fedavgm/_static/concentration_cifar10_v2.png differ diff --git a/baselines/fedavgm/_static/custom-fedavgm_vs_fedavgm_rounds=1000_fmnist.png b/baselines/fedavgm/_static/custom-fedavgm_vs_fedavgm_rounds=1000_fmnist.png new file mode 100644 index 000000000000..042527a3ac21 Binary files /dev/null and b/baselines/fedavgm/_static/custom-fedavgm_vs_fedavgm_rounds=1000_fmnist.png differ diff --git a/baselines/fedavgm/_static/fedavgm_vs_fedavg_rounds=10000_cifar10.png b/baselines/fedavgm/_static/fedavgm_vs_fedavg_rounds=10000_cifar10.png new file mode 100644 index 000000000000..771e13514363 Binary files /dev/null and b/baselines/fedavgm/_static/fedavgm_vs_fedavg_rounds=10000_cifar10.png differ diff --git a/baselines/fedavgm/_static/fedavgm_vs_fedavg_rounds=10000_cifar10_w_1e-9.png b/baselines/fedavgm/_static/fedavgm_vs_fedavg_rounds=10000_cifar10_w_1e-9.png new file mode 100644 index 000000000000..005aabbf6752 Binary files /dev/null and b/baselines/fedavgm/_static/fedavgm_vs_fedavg_rounds=10000_cifar10_w_1e-9.png differ diff --git a/baselines/fedavgm/_static/fedavgm_vs_fedavg_rounds=1000_fmnist.png b/baselines/fedavgm/_static/fedavgm_vs_fedavg_rounds=1000_fmnist.png new file mode 100644 index 000000000000..313c8299336f Binary files /dev/null and b/baselines/fedavgm/_static/fedavgm_vs_fedavg_rounds=1000_fmnist.png differ diff --git a/baselines/fedavgm/conf-colab.sh b/baselines/fedavgm/conf-colab.sh new file mode 100644 index 000000000000..822fe2f273e1 --- /dev/null +++ b/baselines/fedavgm/conf-colab.sh @@ -0,0 +1,24 @@ +#!/bin/bash +# Shellscript to configure the environment on the Google Colab terminal + +# fix issue with ctypes on Colab instance +apt-get update +apt-get install -y libffi-dev + +# Install pyenv +curl https://pyenv.run | bash +export PYENV_ROOT="$HOME/.pyenv" +command -v pyenv >/dev/null || export PATH="$PYENV_ROOT/bin:$PATH" +eval "$(pyenv init -)" + +# this version is specific to the FedAvgM baseline +pyenv install 3.10.6 +pyenv global 3.10.6 + +# install Poetry +curl -sSL https://install.python-poetry.org | python3 - +export PATH="/root/.local/bin:$PATH" + +# install and set environment with Poetry +poetry install +poetry shell diff --git a/baselines/fedavgm/fedavgm/__init__.py b/baselines/fedavgm/fedavgm/__init__.py new file mode 100644 index 000000000000..a5e567b59135 --- /dev/null +++ b/baselines/fedavgm/fedavgm/__init__.py @@ -0,0 +1 @@ +"""Template baseline package.""" diff --git a/baselines/fedavgm/fedavgm/client.py b/baselines/fedavgm/fedavgm/client.py new file mode 100644 index 000000000000..6500bdc9c737 --- /dev/null +++ b/baselines/fedavgm/fedavgm/client.py @@ -0,0 +1,70 @@ +"""Define the Flower Client and function to instantiate it.""" + +import math + +import flwr as fl +from hydra.utils import instantiate +from keras.utils import to_categorical + + +class FlowerClient(fl.client.NumPyClient): + """Standard Flower client.""" + + # pylint: disable=too-many-arguments + def __init__(self, x_train, y_train, x_val, y_val, model, num_classes) -> None: + # local model + self.model = instantiate(model) + + # local dataset + self.x_train, self.y_train = x_train, to_categorical( + y_train, num_classes=num_classes + ) + self.x_val, self.y_val = x_val, to_categorical(y_val, num_classes=num_classes) + + def get_parameters(self, config): + """Return the parameters of the current local model.""" + return self.model.get_weights() + + def fit(self, parameters, config): + """Implement distributed fit function for a given client.""" + self.model.set_weights(parameters) + + self.model.fit( + self.x_train, + self.y_train, + epochs=config["local_epochs"], + batch_size=config["batch_size"], + verbose=False, + ) + return self.model.get_weights(), len(self.x_train), {} + + def evaluate(self, parameters, config): + """Implement distributed evaluation for a given client.""" + self.model.set_weights(parameters) + loss, acc = self.model.evaluate(self.x_val, self.y_val, verbose=False) + return loss, len(self.x_val), {"accuracy": acc} + + +def generate_client_fn(partitions, model, num_classes): + """Generate the client function that creates the Flower Clients.""" + + def client_fn(cid: str) -> FlowerClient: + """Create a Flower client representing a single organization.""" + full_x_train_cid, full_y_train_cid = partitions[int(cid)] + + # Use 10% of the client's training data for validation + split_idx = math.floor(len(full_x_train_cid) * 0.9) + x_train_cid, y_train_cid = ( + full_x_train_cid[:split_idx], + full_y_train_cid[:split_idx], + ) + x_val_cid, y_val_cid = ( + full_x_train_cid[split_idx:], + full_y_train_cid[split_idx:], + ) + + return FlowerClient( + x_train_cid, y_train_cid, x_val_cid, y_val_cid, model, num_classes + ) + + return client_fn diff --git a/baselines/fedavgm/fedavgm/common.py b/baselines/fedavgm/fedavgm/common.py new file mode 100644 index 000000000000..0ce9d04dc544 --- /dev/null +++ b/baselines/fedavgm/fedavgm/common.py @@ -0,0 +1,494 @@ +# Copyright 2020 Adap GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Commonly used functions for generating partitioned datasets.""" + +# pylint: disable=invalid-name + + +from typing import List, Optional, Tuple, Union + +import numpy as np +from numpy.random import BitGenerator, Generator, SeedSequence + +XY = Tuple[np.ndarray, np.ndarray] +XYList = List[XY] +PartitionedDataset = Tuple[XYList, XYList] + + +def float_to_int(i: float) -> int: + """Return float as int but raise if decimal is dropped.""" + if not i.is_integer(): + raise Exception("Cast would drop decimals") + + return int(i) + + +def sort_by_label(x: np.ndarray, y: np.ndarray) -> XY: + """Sort by label. + + Assuming two labels and four examples the resulting label order would be 1,1,2,2 + """ + idx = np.argsort(y, axis=0).reshape((y.shape[0])) + return (x[idx], y[idx]) + + +def sort_by_label_repeating(x: np.ndarray, y: np.ndarray) -> XY: + """Sort by label in repeating groups. + + Assuming two labels and four examples the resulting label order would be 1,2,1,2. + + Create sorting index which is applied to by label sorted x, y + + .. code-block:: python + + # given: + y = [ + 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9 + ] + + # use: + idx = [ + 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 1, 3, 5, 7, 9, 11, 13, 15, 17, 19 + ] + + # so that y[idx] becomes: + y = [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 + ] + """ + x, y = sort_by_label(x, y) + + num_example = x.shape[0] + num_class = np.unique(y).shape[0] + idx = ( + np.array(range(num_example), np.int64) + .reshape((num_class, num_example // num_class)) + .transpose() + .reshape(num_example) + ) + + return (x[idx], y[idx]) + + +def split_at_fraction(x: np.ndarray, y: np.ndarray, fraction: float) -> Tuple[XY, XY]: + """Split x, y at a certain fraction.""" + splitting_index = float_to_int(x.shape[0] * fraction) + # Take everything BEFORE splitting_index + x_0, y_0 = x[:splitting_index], y[:splitting_index] + # Take everything AFTER splitting_index + x_1, y_1 = x[splitting_index:], y[splitting_index:] + return (x_0, y_0), (x_1, y_1) + + +def shuffle(x: np.ndarray, y: np.ndarray) -> XY: + """Shuffle x and y.""" + idx = np.random.permutation(len(x)) + return x[idx], y[idx] + + +def partition(x: np.ndarray, y: np.ndarray, num_partitions: int) -> List[XY]: + """Return x, y as list of partitions.""" + return list(zip(np.split(x, num_partitions), np.split(y, num_partitions))) + + +def combine_partitions(xy_list_0: XYList, xy_list_1: XYList) -> XYList: + """Combine two lists of ndarray Tuples into one list.""" + return [ + (np.concatenate([x_0, x_1], axis=0), np.concatenate([y_0, y_1], axis=0)) + for (x_0, y_0), (x_1, y_1) in zip(xy_list_0, xy_list_1) + ] + + +def shift(x: np.ndarray, y: np.ndarray) -> XY: + """Shift x_1, y_1. + + so that the first half contains only labels 0 to 4 and the second half 5 to 9. + """ + x, y = sort_by_label(x, y) + + (x_0, y_0), (x_1, y_1) = split_at_fraction(x, y, fraction=0.5) + (x_0, y_0), (x_1, y_1) = shuffle(x_0, y_0), shuffle(x_1, y_1) + x, y = np.concatenate([x_0, x_1], axis=0), np.concatenate([y_0, y_1], axis=0) + return x, y + + +def create_partitions( + unpartitioned_dataset: XY, + iid_fraction: float, + num_partitions: int, +) -> XYList: + """Create partitioned version of a training or test set. + + Currently tested and supported are MNIST, FashionMNIST and CIFAR-10/100 + """ + x, y = unpartitioned_dataset + + x, y = shuffle(x, y) + x, y = sort_by_label_repeating(x, y) + + (x_0, y_0), (x_1, y_1) = split_at_fraction(x, y, fraction=iid_fraction) + + # Shift in second split of dataset the classes into two groups + x_1, y_1 = shift(x_1, y_1) + + xy_0_partitions = partition(x_0, y_0, num_partitions) + xy_1_partitions = partition(x_1, y_1, num_partitions) + + xy_partitions = combine_partitions(xy_0_partitions, xy_1_partitions) + + # Adjust x and y shape + return [adjust_xy_shape(xy) for xy in xy_partitions] + + +def create_partitioned_dataset( + keras_dataset: Tuple[XY, XY], + iid_fraction: float, + num_partitions: int, +) -> Tuple[PartitionedDataset, XY]: + """Create partitioned version of keras dataset. + + Currently tested and supported are MNIST, FashionMNIST and CIFAR-10/100 + """ + xy_train, xy_test = keras_dataset + + xy_train_partitions = create_partitions( + unpartitioned_dataset=xy_train, + iid_fraction=iid_fraction, + num_partitions=num_partitions, + ) + + xy_test_partitions = create_partitions( + unpartitioned_dataset=xy_test, + iid_fraction=iid_fraction, + num_partitions=num_partitions, + ) + + return (xy_train_partitions, xy_test_partitions), adjust_xy_shape(xy_test) + + +def log_distribution(xy_partitions: XYList) -> None: + """Print label distribution for list of paritions.""" + distro = [np.unique(y, return_counts=True) for _, y in xy_partitions] + for d in distro: + print(d) + + +def adjust_xy_shape(xy: XY) -> XY: + """Adjust shape of both x and y.""" + x, y = xy + if x.ndim == 3: + x = adjust_x_shape(x) + if y.ndim == 2: + y = adjust_y_shape(y) + return (x, y) + + +def adjust_x_shape(nda: np.ndarray) -> np.ndarray: + """Turn shape (x, y, z) into (x, y, z, 1).""" + nda_adjusted = np.reshape(nda, (nda.shape[0], nda.shape[1], nda.shape[2], 1)) + return nda_adjusted + + +def adjust_y_shape(nda: np.ndarray) -> np.ndarray: + """Turn shape (x, 1) into (x).""" + nda_adjusted = np.reshape(nda, (nda.shape[0])) + return nda_adjusted + + +def split_array_at_indices( + x: np.ndarray, split_idx: np.ndarray +) -> List[List[np.ndarray]]: + """Split the array `x`. + + into list of elements using starting indices from + `split_idx`. + + This function should be used with `unique_indices` from `np.unique()` after + sorting by label. + + Args: + x (np.ndarray): Original array of dimension (N,a,b,c,...) + split_idx (np.ndarray): 1-D array contaning increasing number of + indices to be used as partitions. Initial value must be zero. Last value + must be less than N. + + Returns + ------- + List[List[np.ndarray]]: List of list of samples. + """ + if split_idx.ndim != 1: + raise ValueError("Variable `split_idx` must be a 1-D numpy array.") + if split_idx.dtype != np.int64: + raise ValueError("Variable `split_idx` must be of type np.int64.") + if split_idx[0] != 0: + raise ValueError("First value of `split_idx` must be 0.") + if split_idx[-1] >= x.shape[0]: + raise ValueError( + """Last value in `split_idx` must be less than + the number of samples in `x`.""" + ) + if not np.all(split_idx[:-1] <= split_idx[1:]): + raise ValueError("Items in `split_idx` must be in increasing order.") + + num_splits: int = len(split_idx) + split_idx = np.append(split_idx, x.shape[0]) + + list_samples_split: List[List[np.ndarray]] = [[] for _ in range(num_splits)] + for j in range(num_splits): + tmp_x = x[split_idx[j] : split_idx[j + 1]] # noqa: E203 + for sample in tmp_x: + list_samples_split[j].append(sample) + + return list_samples_split + + +def exclude_classes_and_normalize( + distribution: np.ndarray, exclude_dims: List[bool], eps: float = 1e-5 +) -> np.ndarray: + """Excludes classes from a distribution. + + This function is particularly useful when sampling without replacement. + Classes for which no sample is available have their probabilities are set to 0. + Classes that had probabilities originally set to 0 are incremented with + `eps` to allow sampling from remaining items. + + Args: + distribution (np.array): Distribution being used. + exclude_dims (List[bool]): Dimensions to be excluded. + eps (float, optional): Small value to be addad to non-excluded dimensions. + Defaults to 1e-5. + + Returns + ------- + np.ndarray: Normalized distributions. + """ + if np.any(distribution < 0) or (not np.isclose(np.sum(distribution), 1.0)): + raise ValueError("distribution must sum to 1 and have only positive values.") + + if distribution.size != len(exclude_dims): + raise ValueError( + """Length of distribution must be equal + to the length `exclude_dims`.""" + ) + if eps < 0: + raise ValueError("""The value of `eps` must be positive and small.""") + + distribution[[not x for x in exclude_dims]] += eps + distribution[exclude_dims] = 0.0 + sum_rows = np.sum(distribution) + np.finfo(float).eps + distribution = distribution / sum_rows + + return distribution + + +def sample_without_replacement( + distribution: np.ndarray, + list_samples: List[List[np.ndarray]], + num_samples: int, + empty_classes: List[bool], +) -> Tuple[XY, List[bool]]: + """Sample from a list without replacement. + + using a given distribution. + + Args: + distribution (np.ndarray): Distribution used for sampling. + list_samples(List[List[np.ndarray]]): List of samples. + num_samples (int): Total number of items to be sampled. + empty_classes (List[bool]): List of booleans indicating which classes are empty. + This is useful to differentiate which classes should still be sampled. + + Returns + ------- + XY: Dataset contaning samples + List[bool]: empty_classes. + """ + if np.sum([len(x) for x in list_samples]) < num_samples: + raise ValueError( + """Number of samples in `list_samples` is less than `num_samples`""" + ) + + # Make sure empty classes are not sampled + # and solves for rare cases where + if not empty_classes: + empty_classes = len(distribution) * [False] + + distribution = exclude_classes_and_normalize( + distribution=distribution, exclude_dims=empty_classes + ) + + data: List[np.ndarray] = [] + target: List[np.ndarray] = [] + + for _ in range(num_samples): + sample_class = np.where(np.random.multinomial(1, distribution) == 1)[0][0] + sample: np.ndarray = list_samples[sample_class].pop() + + data.append(sample) + target.append(sample_class) + + # If last sample of the class was drawn, then set the + # probability density function (PDF) to zero for that class. + if len(list_samples[sample_class]) == 0: + empty_classes[sample_class] = True + # Be careful to distinguish between classes that had zero probability + # and classes that are now empty + distribution = exclude_classes_and_normalize( + distribution=distribution, exclude_dims=empty_classes + ) + data_array: np.ndarray = np.concatenate([data], axis=0) + target_array: np.ndarray = np.array(target, dtype=np.int64) + + return (data_array, target_array), empty_classes + + +def get_partitions_distributions(partitions: XYList) -> Tuple[np.ndarray, List[int]]: + """Evaluate the distribution over classes for a set of partitions. + + Args: + partitions (XYList): Input partitions + + Returns + ------- + np.ndarray: Distributions of size (num_partitions, num_classes) + """ + # Get largest available label + labels = set() + for _, y in partitions: + labels.update(set(y)) + list_labels = sorted(labels) + bin_edges = np.arange(len(list_labels) + 1) + + # Pre-allocate distributions + distributions = np.zeros((len(partitions), len(list_labels)), dtype=np.float32) + for idx, (_, _y) in enumerate(partitions): + hist, _ = np.histogram(_y, bin_edges) + distributions[idx] = hist / hist.sum() + + return distributions, list_labels + + +def create_lda_partitions( + dataset: XY, + dirichlet_dist: Optional[np.ndarray] = None, + num_partitions: int = 100, + concentration: Union[float, np.ndarray, List[float]] = 0.5, + accept_imbalanced: bool = False, + seed: Optional[Union[int, SeedSequence, BitGenerator, Generator]] = None, +) -> Tuple[XYList, np.ndarray]: + r"""Create imbalanced non-iid partitions using Latent Dirichlet Allocation (LDA). + + without resampling. + + Args: + dataset (XY): Dataset containing samples X and labels Y. + dirichlet_dist (numpy.ndarray, optional): previously generated distribution to + be used. This is useful when applying the same distribution for train and + validation sets. + num_partitions (int, optional): Number of partitions to be created. + Defaults to 100. + concentration (float, np.ndarray, List[float]): Dirichlet Concentration + (:math:`\\alpha`) parameter. Set to float('inf') to get uniform partitions. + An :math:`\\alpha \\to \\Inf` generates uniform distributions over classes. + An :math:`\\alpha \\to 0.0` generates one class per client. Defaults to 0.5. + accept_imbalanced (bool): Whether or not to accept imbalanced output classes. + Default False. + seed (None, int, SeedSequence, BitGenerator, Generator): + A seed to initialize the BitGenerator for generating the Dirichlet + distribution. This is defined in Numpy's official documentation as follows: + If None, then fresh, unpredictable entropy will be pulled from the OS. + One may also pass in a SeedSequence instance. + Additionally, when passed a BitGenerator, it will be wrapped by Generator. + If passed a Generator, it will be returned unaltered. + See official Numpy Documentation for further details. + + Returns + ------- + Tuple[XYList, numpy.ndarray]: List of XYList containing partitions + for each dataset and the dirichlet probability density functions. + """ + # pylint: disable=too-many-arguments,too-many-locals + + x, y = dataset + x, y = shuffle(x, y) + x, y = sort_by_label(x, y) + + if (x.shape[0] % num_partitions) and (not accept_imbalanced): + raise ValueError( + """Total number of samples must be a multiple of `num_partitions`. + If imbalanced classes are allowed, set + `accept_imbalanced=True`.""" + ) + + num_samples = num_partitions * [0] + for j in range(x.shape[0]): + num_samples[j % num_partitions] += 1 + + # Get number of classes and verify if they matching with + classes, start_indices = np.unique(y, return_index=True) + + # Make sure that concentration is np.array and + # check if concentration is appropriate + concentration = np.asarray(concentration) + + # Check if concentration is Inf, if so create uniform partitions + partitions: List[XY] = [(_, _) for _ in range(num_partitions)] + if float("inf") in concentration: + partitions = create_partitions( + unpartitioned_dataset=(x, y), + iid_fraction=1.0, + num_partitions=num_partitions, + ) + dirichlet_dist = get_partitions_distributions(partitions)[0] + + return partitions, dirichlet_dist + + if concentration.size == 1: + concentration = np.repeat(concentration, classes.size) + elif concentration.size != classes.size: # Sequence + raise ValueError( + f"The size of the provided concentration ({concentration.size}) ", + f"must be either 1 or equal number of classes {classes.size})", + ) + + # Split into list of list of samples per class + list_samples_per_class: List[List[np.ndarray]] = split_array_at_indices( + x, start_indices + ) + + if dirichlet_dist is None: + dirichlet_dist = np.random.default_rng(seed).dirichlet( + alpha=concentration, size=num_partitions + ) + + if dirichlet_dist.size != 0: + if dirichlet_dist.shape != (num_partitions, classes.size): + raise ValueError( + f"""The shape of the provided dirichlet distribution + ({dirichlet_dist.shape}) must match the provided number + of partitions and classes ({num_partitions},{classes.size})""" + ) + + # Assuming balanced distribution + empty_classes = classes.size * [False] + for partition_id in range(num_partitions): + partitions[partition_id], empty_classes = sample_without_replacement( + distribution=dirichlet_dist[partition_id].copy(), + list_samples=list_samples_per_class, + num_samples=num_samples[partition_id], + empty_classes=empty_classes, + ) + + return partitions, dirichlet_dist diff --git a/baselines/fedavgm/fedavgm/conf/base.yaml b/baselines/fedavgm/fedavgm/conf/base.yaml new file mode 100644 index 000000000000..3c2c281911a3 --- /dev/null +++ b/baselines/fedavgm/fedavgm/conf/base.yaml @@ -0,0 +1,24 @@ +--- +num_clients: 10 +num_rounds: 5 # original experiments (paper) uses 10000 +fraction_evaluate: 0 # fraction of clients usied during validation +num_cpus: 1 +num_gpus: 0 + +noniid: + concentration: 0.1 # concentrations used in the paper [100., 10., 1., 0.5, 0.2, 0.1, 0.05, 0.0] + +server: + momentum: 0.9 + learning_rate: 1.0 + reporting_fraction: 0.05 # values used in the paper 0.05, 0.1, 0.2 (not used for Figure 5), 0.4 + +client: + local_epochs: 1 # in the paper it is used 1 or 5 + batch_size: 64 # in the paper fixed at 64 + lr: 0.01 # client learning rate + +defaults: + - strategy: custom-fedavgm + - model: cnn + - dataset: cifar10 diff --git a/baselines/fedavgm/fedavgm/conf/dataset/cifar10.yaml b/baselines/fedavgm/fedavgm/conf/dataset/cifar10.yaml new file mode 100644 index 000000000000..4894ba5d675f --- /dev/null +++ b/baselines/fedavgm/fedavgm/conf/dataset/cifar10.yaml @@ -0,0 +1,4 @@ +--- +_target_: fedavgm.dataset.cifar10 +num_classes: 10 +input_shape: [32, 32, 3] \ No newline at end of file diff --git a/baselines/fedavgm/fedavgm/conf/dataset/fmnist.yaml b/baselines/fedavgm/fedavgm/conf/dataset/fmnist.yaml new file mode 100644 index 000000000000..2dfa07f1c60a --- /dev/null +++ b/baselines/fedavgm/fedavgm/conf/dataset/fmnist.yaml @@ -0,0 +1,4 @@ +--- +_target_: fedavgm.dataset.fmnist +num_classes: 10 +input_shape: [28, 28, 1] \ No newline at end of file diff --git a/baselines/fedavgm/fedavgm/conf/model/cnn.yaml b/baselines/fedavgm/fedavgm/conf/model/cnn.yaml new file mode 100644 index 000000000000..c25463693c7f --- /dev/null +++ b/baselines/fedavgm/fedavgm/conf/model/cnn.yaml @@ -0,0 +1,5 @@ +--- +_target_: fedavgm.models.cnn +input_shape: ${dataset.input_shape} +num_classes: ${dataset.num_classes} +learning_rate: ${client.lr} \ No newline at end of file diff --git a/baselines/fedavgm/fedavgm/conf/model/tf_example.yaml b/baselines/fedavgm/fedavgm/conf/model/tf_example.yaml new file mode 100644 index 000000000000..8c2a670ee978 --- /dev/null +++ b/baselines/fedavgm/fedavgm/conf/model/tf_example.yaml @@ -0,0 +1,5 @@ +--- +_target_: fedavgm.models.tf_example +input_shape: ${dataset.input_shape} +num_classes: ${dataset.num_classes} +learning_rate: ${client.lr} \ No newline at end of file diff --git a/baselines/fedavgm/fedavgm/conf/strategy/custom-fedavgm.yaml b/baselines/fedavgm/fedavgm/conf/strategy/custom-fedavgm.yaml new file mode 100644 index 000000000000..526c9714ed73 --- /dev/null +++ b/baselines/fedavgm/fedavgm/conf/strategy/custom-fedavgm.yaml @@ -0,0 +1,13 @@ +--- +_target_: fedavgm.strategy.CustomFedAvgM +min_available_clients: ${num_clients} +fraction_fit: ${server.reporting_fraction} +fraction_evaluate: ${fraction_evaluate} +server_learning_rate: ${server.learning_rate} +server_momentum: ${server.momentum} +on_fit_config_fn: + _target_: fedavgm.server.get_on_fit_config + config: ${client} +initial_parameters: + _target_: fedavgm.models.model_to_parameters + model: ${model} \ No newline at end of file diff --git a/baselines/fedavgm/fedavgm/conf/strategy/fedavg.yaml b/baselines/fedavgm/fedavgm/conf/strategy/fedavg.yaml new file mode 100644 index 000000000000..1b2cde85fe6c --- /dev/null +++ b/baselines/fedavgm/fedavgm/conf/strategy/fedavg.yaml @@ -0,0 +1,8 @@ +--- +_target_: flwr.server.strategy.FedAvg +min_available_clients: ${num_clients} +fraction_fit: ${server.reporting_fraction} +fraction_evaluate: ${fraction_evaluate} +on_fit_config_fn: + _target_: fedavgm.server.get_on_fit_config + config: ${client} \ No newline at end of file diff --git a/baselines/fedavgm/fedavgm/conf/strategy/fedavgm.yaml b/baselines/fedavgm/fedavgm/conf/strategy/fedavgm.yaml new file mode 100644 index 000000000000..ce88887c02ab --- /dev/null +++ b/baselines/fedavgm/fedavgm/conf/strategy/fedavgm.yaml @@ -0,0 +1,13 @@ +--- +_target_: flwr.server.strategy.FedAvgM +min_available_clients: ${num_clients} +fraction_fit: ${server.reporting_fraction} +fraction_evaluate: ${fraction_evaluate} +server_learning_rate: ${server.learning_rate} +server_momentum: ${server.momentum} +on_fit_config_fn: + _target_: fedavgm.server.get_on_fit_config + config: ${client} +initial_parameters: + _target_: fedavgm.models.model_to_parameters + model: ${model} \ No newline at end of file diff --git a/baselines/fedavgm/fedavgm/dataset.py b/baselines/fedavgm/fedavgm/dataset.py new file mode 100644 index 000000000000..939a42fda5ae --- /dev/null +++ b/baselines/fedavgm/fedavgm/dataset.py @@ -0,0 +1,57 @@ +"""Dataset utilities for federated learning.""" + +import numpy as np +from tensorflow import keras + +from fedavgm.common import create_lda_partitions + + +def cifar10(num_classes, input_shape): + """Prepare the CIFAR-10. + + This method considers CIFAR-10 for creating both train and test sets. The sets are + already normalized. + """ + print(f">>> [Dataset] Loading CIFAR-10. {num_classes} | {input_shape}.") + (x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data() + x_train = x_train.astype("float32") / 255 + x_test = x_test.astype("float32") / 255 + input_shape = x_train.shape[1:] + num_classes = len(np.unique(y_train)) + + return x_train, y_train, x_test, y_test, input_shape, num_classes + + +def fmnist(num_classes, input_shape): + """Prepare the FMNIST. + + This method considers FMNIST for creating both train and test sets. The sets are + already normalized. + """ + print(f">>> [Dataset] Loading FMNIST. {num_classes} | {input_shape}.") + (x_train, y_train), (x_test, y_test) = keras.datasets.fashion_mnist.load_data() + x_train = x_train.astype("float32") / 255 + x_test = x_test.astype("float32") / 255 + input_shape = x_train.shape[1:] + num_classes = len(np.unique(y_train)) + + return x_train, y_train, x_test, y_test, input_shape, num_classes + + +def partition(x_train, y_train, num_clients, concentration): + """Create non-iid partitions. + + The partitions uses a LDA distribution based on concentration. + """ + print( + f">>> [Dataset] {num_clients} clients, non-iid concentration {concentration}..." + ) + dataset = [x_train, y_train] + partitions, _ = create_lda_partitions( + dataset, + num_partitions=num_clients, + # concentration=concentration * num_classes, + concentration=concentration, + seed=1234, + ) + return partitions diff --git a/baselines/fedavgm/fedavgm/dataset_preparation.py b/baselines/fedavgm/fedavgm/dataset_preparation.py new file mode 100644 index 000000000000..dab1967d8399 --- /dev/null +++ b/baselines/fedavgm/fedavgm/dataset_preparation.py @@ -0,0 +1 @@ +"""Require to download dataset or additional preparation.""" diff --git a/baselines/fedavgm/fedavgm/main.py b/baselines/fedavgm/fedavgm/main.py new file mode 100644 index 000000000000..915cad28f212 --- /dev/null +++ b/baselines/fedavgm/fedavgm/main.py @@ -0,0 +1,100 @@ +"""Create and connect the building blocks for your experiments; start the simulation. + +It includes processioning the dataset, instantiate strategy, specify how the global +model is going to be evaluated, etc. At the end, this script saves the results. +""" + +import pickle +from pathlib import Path + +import flwr as fl +import hydra +import numpy as np +from hydra.core.hydra_config import HydraConfig +from hydra.utils import instantiate +from omegaconf import DictConfig, OmegaConf + +from fedavgm.client import generate_client_fn +from fedavgm.dataset import partition +from fedavgm.server import get_evaluate_fn + + +# pylint: disable=too-many-locals +@hydra.main(config_path="conf", config_name="base", version_base=None) +def main(cfg: DictConfig) -> None: + """Run the baseline. + + Parameters + ---------- + cfg : DictConfig + An omegaconf object that stores the hydra config. + """ + np.random.seed(2020) + + # 1. Print parsed config + print(OmegaConf.to_yaml(cfg)) + + # 2. Prepare your dataset + x_train, y_train, x_test, y_test, input_shape, num_classes = instantiate( + cfg.dataset + ) + + partitions = partition(x_train, y_train, cfg.num_clients, cfg.noniid.concentration) + + print(f">>> [Model]: Num. Classes {num_classes} | Input shape: {input_shape}") + + # 3. Define your clients + client_fn = generate_client_fn(partitions, cfg.model, num_classes) + + # 4. Define your strategy + evaluate_fn = get_evaluate_fn( + instantiate(cfg.model), x_test, y_test, cfg.num_rounds, num_classes + ) + + strategy = instantiate(cfg.strategy, evaluate_fn=evaluate_fn) + + # 5. Start Simulation + history = fl.simulation.start_simulation( + client_fn=client_fn, + num_clients=cfg.num_clients, + config=fl.server.ServerConfig(num_rounds=cfg.num_rounds), + strategy=strategy, + client_resources={"num_cpus": cfg.num_cpus, "num_gpus": cfg.num_gpus}, + ) + + _, final_acc = history.metrics_centralized["accuracy"][-1] + + # 6. Save your results + save_path = HydraConfig.get().runtime.output_dir + + strategy_name = strategy.__class__.__name__ + dataset_type = "cifar10" if cfg.dataset.input_shape == [32, 32, 3] else "fmnist" + + def format_variable(x): + return f"{x!r}" if isinstance(x, bytes) else x + + file_suffix: str = ( + f"_{format_variable(strategy_name)}" + f"_{format_variable(dataset_type)}" + f"_clients={format_variable(cfg.num_clients)}" + f"_rounds={format_variable(cfg.num_rounds)}" + f"_C={format_variable(cfg.server.reporting_fraction)}" + f"_E={format_variable(cfg.client.local_epochs)}" + f"_alpha={format_variable(cfg.noniid.concentration)}" + f"_server-momentum={format_variable(cfg.server.momentum)}" + f"_client-lr={format_variable(cfg.client.lr)}" + f"_acc={format_variable(final_acc):.4f}" + ) + + filename = "results" + file_suffix + ".pkl" + + print(f">>> Saving {filename}...") + results_path = Path(save_path) / filename + results = {"history": history} + + with open(str(results_path), "wb") as hist_file: + pickle.dump(results, hist_file, protocol=pickle.HIGHEST_PROTOCOL) + + +if __name__ == "__main__": + main() diff --git a/baselines/fedavgm/fedavgm/models.py b/baselines/fedavgm/fedavgm/models.py new file mode 100644 index 000000000000..a151c4d9db76 --- /dev/null +++ b/baselines/fedavgm/fedavgm/models.py @@ -0,0 +1,121 @@ +"""CNN model architecture.""" + +from flwr.common import ndarrays_to_parameters +from keras.optimizers import SGD +from keras.regularizers import l2 +from tensorflow import keras +from tensorflow.nn import local_response_normalization # pylint: disable=import-error + + +def cnn(input_shape, num_classes, learning_rate): + """CNN Model from (McMahan et. al., 2017). + + Communication-efficient learning of deep networks from decentralized data + """ + input_shape = tuple(input_shape) + + weight_decay = 0.004 + model = keras.Sequential( + [ + keras.layers.Conv2D( + 64, + (5, 5), + padding="same", + activation="relu", + input_shape=input_shape, + ), + keras.layers.MaxPooling2D((3, 3), strides=(2, 2)), + keras.layers.BatchNormalization(), + keras.layers.Conv2D( + 64, + (5, 5), + padding="same", + activation="relu", + ), + keras.layers.BatchNormalization(), + keras.layers.MaxPooling2D((3, 3), strides=(2, 2)), + keras.layers.Flatten(), + keras.layers.Dense( + 384, activation="relu", kernel_regularizer=l2(weight_decay) + ), + keras.layers.Dense( + 192, activation="relu", kernel_regularizer=l2(weight_decay) + ), + keras.layers.Dense(num_classes, activation="softmax"), + ] + ) + optimizer = SGD(learning_rate=learning_rate) + model.compile( + loss="categorical_crossentropy", optimizer=optimizer, metrics=["accuracy"] + ) + + return model + + +def tf_example(input_shape, num_classes, learning_rate): + """CNN Model from TensorFlow v1.x example. + + This is the model referenced on the FedAvg paper. + + Reference: + https://web.archive.org/web/20170807002954/https://github.com/tensorflow/models/blob/master/tutorials/image/cifar10/cifar10.py + """ + input_shape = tuple(input_shape) + + weight_decay = 0.004 + model = keras.Sequential( + [ + keras.layers.Conv2D( + 64, + (5, 5), + padding="same", + activation="relu", + input_shape=input_shape, + ), + keras.layers.MaxPooling2D((3, 3), strides=(2, 2), padding="same"), + keras.layers.Lambda( + local_response_normalization, + arguments={ + "depth_radius": 4, + "bias": 1.0, + "alpha": 0.001 / 9.0, + "beta": 0.75, + }, + ), + keras.layers.Conv2D( + 64, + (5, 5), + padding="same", + activation="relu", + ), + keras.layers.Lambda( + local_response_normalization, + arguments={ + "depth_radius": 4, + "bias": 1.0, + "alpha": 0.001 / 9.0, + "beta": 0.75, + }, + ), + keras.layers.MaxPooling2D((3, 3), strides=(2, 2), padding="same"), + keras.layers.Flatten(), + keras.layers.Dense( + 384, activation="relu", kernel_regularizer=l2(weight_decay) + ), + keras.layers.Dense( + 192, activation="relu", kernel_regularizer=l2(weight_decay) + ), + keras.layers.Dense(num_classes, activation="softmax"), + ] + ) + optimizer = SGD(learning_rate=learning_rate) + model.compile( + loss="categorical_crossentropy", optimizer=optimizer, metrics=["accuracy"] + ) + + return model + + +def model_to_parameters(model): + """Retrieve model weigths and convert to ndarrays.""" + return ndarrays_to_parameters(model.get_weights()) diff --git a/baselines/fedavgm/fedavgm/server.py b/baselines/fedavgm/fedavgm/server.py new file mode 100644 index 000000000000..c997c035f638 --- /dev/null +++ b/baselines/fedavgm/fedavgm/server.py @@ -0,0 +1,45 @@ +"""Define the Flower Server and function to instantiate it.""" + +from keras.utils import to_categorical +from omegaconf import DictConfig + + +def get_on_fit_config(config: DictConfig): + """Generate the function for config. + + The config dict is sent to the client fit() method. + """ + + def fit_config_fn(server_round: int): # pylint: disable=unused-argument + # option to use scheduling of learning rate based on round + # if server_round > 50: + # lr = config.lr / 10 + return { + "local_epochs": config.local_epochs, + "batch_size": config.batch_size, + } + + return fit_config_fn + + +def get_evaluate_fn(model, x_test, y_test, num_rounds, num_classes): + """Generate the function for server global model evaluation. + + The method evaluate_fn runs after global model aggregation. + """ + + def evaluate_fn( + server_round: int, parameters, config + ): # pylint: disable=unused-argument + if server_round == num_rounds: # evaluates global model just on the last round + # instantiate the model + model.set_weights(parameters) + + y_test_cat = to_categorical(y_test, num_classes=num_classes) + loss, accuracy = model.evaluate(x_test, y_test_cat, verbose=False) + + return loss, {"accuracy": accuracy} + + return None + + return evaluate_fn diff --git a/baselines/fedavgm/fedavgm/strategy.py b/baselines/fedavgm/fedavgm/strategy.py new file mode 100644 index 000000000000..cd0a27254fce --- /dev/null +++ b/baselines/fedavgm/fedavgm/strategy.py @@ -0,0 +1,201 @@ +"""Optionally define a custom strategy. + +Needed only when the strategy is not yet implemented in Flower or because you want to +extend or modify the functionality of an existing strategy. +""" + +from logging import WARNING +from typing import Callable, Dict, List, Optional, Tuple, Union + +from flwr.common import ( + FitRes, + MetricsAggregationFn, + NDArrays, + Parameters, + Scalar, + ndarrays_to_parameters, + parameters_to_ndarrays, +) +from flwr.common.logger import log +from flwr.server.client_manager import ClientManager +from flwr.server.client_proxy import ClientProxy +from flwr.server.strategy import FedAvg +from flwr.server.strategy.aggregate import aggregate + + +class CustomFedAvgM(FedAvg): + """Re-implmentation of FedAvgM. + + This implementation of FedAvgM diverges from original (Flwr v1.5.0) implementation. + Here, the re-implementation introduces the Nesterov Accelerated Gradient (NAG), + same as reported in the original FedAvgM paper: + + https://arxiv.org/pdf/1909.06335.pdf + """ + + def __init__( + self, + *, + fraction_fit: float = 1.0, + fraction_evaluate: float = 1.0, + min_fit_clients: int = 2, + min_evaluate_clients: int = 2, + min_available_clients: int = 2, + evaluate_fn: Optional[ + Callable[ + [int, NDArrays, Dict[str, Scalar]], + Optional[Tuple[float, Dict[str, Scalar]]], + ] + ] = None, + on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, + on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, + accept_failures: bool = True, + initial_parameters: Parameters, + fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, + evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, + server_learning_rate: float = 1.0, + server_momentum: float = 0.9, + ) -> None: + """Federated Averaging with Momentum strategy. + + Implementation based on https://arxiv.org/pdf/1909.06335.pdf + + Parameters + ---------- + fraction_fit : float, optional + Fraction of clients used during training. Defaults to 0.1. + fraction_evaluate : float, optional + Fraction of clients used during validation. Defaults to 0.1. + min_fit_clients : int, optional + Minimum number of clients used during training. Defaults to 2. + min_evaluate_clients : int, optional + Minimum number of clients used during validation. Defaults to 2. + min_available_clients : int, optional + Minimum number of total clients in the system. Defaults to 2. + evaluate_fn : Optional[Callable[[int, NDArrays, Dict[str, Scalar]], + Optional[Tuple[float, Dict[str, Scalar]]]]] + Optional function used for validation. Defaults to None. + on_fit_config_fn : Callable[[int], Dict[str, Scalar]], optional + Function used to configure training. Defaults to None. + on_evaluate_config_fn : Callable[[int], Dict[str, Scalar]], optional + Function used to configure validation. Defaults to None. + accept_failures : bool, optional + Whether or not accept rounds containing failures. Defaults to True. + initial_parameters : Parameters + Initial global model parameters. + server_learning_rate: float + Server-side learning rate used in server-side optimization. + Defaults to 1.0. + server_momentum: float + Server-side momentum factor used for FedAvgM. Defaults to 0.9. + """ + super().__init__( + fraction_fit=fraction_fit, + fraction_evaluate=fraction_evaluate, + min_fit_clients=min_fit_clients, + min_evaluate_clients=min_evaluate_clients, + min_available_clients=min_available_clients, + evaluate_fn=evaluate_fn, + on_fit_config_fn=on_fit_config_fn, + on_evaluate_config_fn=on_evaluate_config_fn, + accept_failures=accept_failures, + initial_parameters=initial_parameters, + fit_metrics_aggregation_fn=fit_metrics_aggregation_fn, + evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn, + ) + self.server_learning_rate = server_learning_rate + self.server_momentum = server_momentum + self.momentum_vector: Optional[NDArrays] = None + + def __repr__(self) -> str: + """Compute a string representation of the strategy.""" + rep = f"FedAvgM(accept_failures={self.accept_failures})" + return rep + + def initialize_parameters( + self, client_manager: ClientManager + ) -> Optional[Parameters]: + """Initialize global model parameters.""" + return self.initial_parameters + + def aggregate_fit( + self, + server_round: int, + results: List[Tuple[ClientProxy, FitRes]], + failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], + ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + """Aggregate fit results using weighted average.""" + if not results: + return None, {} + + # Do not aggregate if there are failures and failures are not accepted + if not self.accept_failures and failures: + return None, {} + + # Convert results + weights_results = [ + (parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples) + for _, fit_res in results + ] + + fedavg_result = aggregate(weights_results) # parameters_aggregated from FedAvg + + # original implementation follows convention described in + # https://pytorch.org/docs/stable/generated/torch.optim.SGD.html + + # do the check for self.initial_parameters being set + assert ( + self.initial_parameters is not None + ), "Initial parameters must be set for CustomFedAvgM strategy" + + # remember that updates are the opposite of gradients + pseudo_gradient: NDArrays = [ + x - y + for x, y in zip( + parameters_to_ndarrays(self.initial_parameters), fedavg_result + ) + ] + + if server_round > 1: + assert self.momentum_vector, "Momentum should have been created on round 1." + + self.momentum_vector = [ + self.server_momentum * v + w + for w, v in zip(pseudo_gradient, self.momentum_vector) + ] + else: # Round 1 + # Initialize server-side model + assert ( + self.initial_parameters is not None + ), "When using server-side optimization, model needs to be initialized." + # Initialize momentum vector + self.momentum_vector = pseudo_gradient + + # Applying Nesterov + pseudo_gradient = [ + g + self.server_momentum * v + for g, v in zip(pseudo_gradient, self.momentum_vector) + ] + + # Federated Averaging with Server Momentum + fedavgm_result = [ + w - self.server_learning_rate * v + for w, v in zip( + parameters_to_ndarrays(self.initial_parameters), pseudo_gradient + ) + ] + + # Update current weights + self.initial_parameters = ndarrays_to_parameters(fedavgm_result) + + parameters_aggregated = ndarrays_to_parameters(fedavgm_result) + + # Aggregate custom metrics if aggregation fn was provided + metrics_aggregated = {} + if self.fit_metrics_aggregation_fn: + fit_metrics = [(res.num_examples, res.metrics) for _, res in results] + metrics_aggregated = self.fit_metrics_aggregation_fn(fit_metrics) + elif server_round == 1: # Only log this warning once + log(WARNING, "No fit_metrics_aggregation_fn provided") + + return parameters_aggregated, metrics_aggregated diff --git a/baselines/fedavgm/fedavgm/utils.py b/baselines/fedavgm/fedavgm/utils.py new file mode 100644 index 000000000000..42a3f372e6ad --- /dev/null +++ b/baselines/fedavgm/fedavgm/utils.py @@ -0,0 +1,61 @@ +"""Define any utility function. + +They are not directly relevant to the other (more FL specific) python modules. For +example, you may define here things like: loading a model from a checkpoint, saving +results, plotting. +""" + +import matplotlib.pyplot as plt +import numpy as np + +from fedavgm.dataset import cifar10, partition + +# pylint: disable=too-many-locals + + +def plot_concentrations_cifar10(): + """Create a plot with different concentrations for dataset using LDA.""" + x_train, y_train, x_test, y_test, _, num_classes = cifar10(10, (32, 32, 3)) + x = np.concatenate((x_train, x_test), axis=0) + y = np.concatenate((y_train, y_test), axis=0) + num_clients = 30 + + # Simulated different concentrations for partitioning + concentration_values = [np.inf, 100, 1, 0.1, 0.01, 1e-10] + color = plt.get_cmap("RdYlGn")(np.linspace(0.15, 0.85, num_classes)) + num_plots = len(concentration_values) + fig, axs = plt.subplots(1, num_plots, figsize=(15, 5), sharey=True) + + pos = axs[0].get_position() + pos.x0 += 0.1 + axs[0].set_position(pos) + + for i, concentration in enumerate(concentration_values): + partitions = partition(x, y, num_clients, concentration) + + for client in range(num_clients): + _, y_client = partitions[client] + lefts = [0] + axis = axs[i] + class_counts = np.bincount(y_client, minlength=num_classes) + np.sum(class_counts > 0) + + class_distribution = class_counts.astype(np.float16) / len(y_client) + + for idx, val in enumerate(class_distribution[:-1]): + lefts.append(lefts[idx] + val) + + axis.barh(client, class_distribution, left=lefts, color=color) + axis.set_xticks([]) + axis.set_yticks([]) + axis.set_xlabel("Class distribution") + axis.set_title(f"Concentration = {concentration}") + + fig.text(0, 0.5, "Client", va="center", rotation="vertical") + plt.tight_layout() + plt.savefig("../_static/concentration_cifar10_v2.png") + print(">>> Concentration plot created") + + +if __name__ == "__main__": + plot_concentrations_cifar10() diff --git a/baselines/fedavgm/pyproject.toml b/baselines/fedavgm/pyproject.toml new file mode 100644 index 000000000000..298deafd8932 --- /dev/null +++ b/baselines/fedavgm/pyproject.toml @@ -0,0 +1,139 @@ +[build-system] +requires = ["poetry-core>=1.4.0"] +build-backend = "poetry.masonry.api" + +[tool.poetry] +name = "fedavgm" +version = "1.0.0" +description = "FedAvgM: Measuring the effects of non-identical data distribution for federated visual classification" +license = "Apache-2.0" +authors = ["Gustavo Bertoli"] +readme = "README.md" +homepage = "https://flower.dev" +repository = "https://github.com/adap/flower" +documentation = "https://flower.dev" +classifiers = [ + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: Apache Software License", + "Operating System :: MacOS :: MacOS X", + "Operating System :: POSIX :: Linux", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: Implementation :: CPython", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Scientific/Engineering :: Mathematics", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + "Typing :: Typed", +] + +[tool.poetry.dependencies] +python = ">=3.9, <3.12.0" # changed! original baseline template uses >= 3.8.15 +flwr = "1.5.0" +ray = "2.6.3" +hydra-core = "1.3.2" # don't change this +cython = "^3.0.0" +tensorflow = "2.10" +numpy = "1.25.2" +matplotlib = "^3.7.2" + +[tool.poetry.dev-dependencies] +isort = "==5.11.5" +black = "==23.1.0" +docformatter = "==1.5.1" +mypy = "==1.4.1" +pylint = "==2.8.2" +flake8 = "==3.9.2" +pytest = "==6.2.4" +pytest-watch = "==4.2.0" +ruff = "==0.0.272" +types-requests = "==2.27.7" + +[tool.isort] +line_length = 88 +indent = " " +multi_line_output = 3 +include_trailing_comma = true +force_grid_wrap = 0 +use_parentheses = true + +[tool.black] +line-length = 88 +target-version = ["py38", "py39", "py310", "py311"] + +[tool.pytest.ini_options] +minversion = "6.2" +addopts = "-qq" +testpaths = [ + "flwr_baselines", +] + +[tool.mypy] +ignore_missing_imports = true +strict = false +plugins = "numpy.typing.mypy_plugin" + +[tool.pylint."MESSAGES CONTROL"] +disable = "bad-continuation,duplicate-code,too-few-public-methods,useless-import-alias" +good-names = "i,j,k,_,x,y,X,Y" +signature-mutators="hydra.main.main" + +[[tool.mypy.overrides]] +module = [ + "importlib.metadata.*", + "importlib_metadata.*", +] +follow_imports = "skip" +follow_imports_for_stubs = true +disallow_untyped_calls = false + +[[tool.mypy.overrides]] +module = "torch.*" +follow_imports = "skip" +follow_imports_for_stubs = true + +[tool.docformatter] +wrap-summaries = 88 +wrap-descriptions = 88 + +[tool.ruff] +target-version = "py38" +line-length = 88 +select = ["D", "E", "F", "W", "B", "ISC", "C4"] +fixable = ["D", "E", "F", "W", "B", "ISC", "C4"] +ignore = ["B024", "B027"] +exclude = [ + ".bzr", + ".direnv", + ".eggs", + ".git", + ".hg", + ".mypy_cache", + ".nox", + ".pants.d", + ".pytype", + ".ruff_cache", + ".svn", + ".tox", + ".venv", + "__pypackages__", + "_build", + "buck-out", + "build", + "dist", + "node_modules", + "venv", + "proto", +] + +[tool.ruff.pydocstyle] +convention = "numpy" diff --git a/baselines/heterofl/LICENSE b/baselines/heterofl/LICENSE new file mode 100644 index 000000000000..d64569567334 --- /dev/null +++ b/baselines/heterofl/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/baselines/heterofl/README.md b/baselines/heterofl/README.md new file mode 100644 index 000000000000..6e9c32077e9b --- /dev/null +++ b/baselines/heterofl/README.md @@ -0,0 +1,200 @@ +--- +title: "HeteroFL: Computation And Communication Efficient Federated Learning For Heterogeneous Clients" +url: https://openreview.net/forum?id=TNkPBBYFkXg +labels: [system heterogeneity, image classification] +dataset: [MNIST, CIFAR-10] +--- + +# HeteroFL: Computation And Communication Efficient Federated Learning For Heterogeneous Clients + +**Paper:** [openreview.net/forum?id=TNkPBBYFkXg](https://openreview.net/forum?id=TNkPBBYFkXg) + +**Authors:** Enmao Diao, Jie Ding, Vahid Tarokh + +**Abstract:** Federated Learning (FL) is a method of training machine learning models on private data distributed over a large number of possibly heterogeneous clients such as mobile phones and IoT devices. In this work, we propose a new federated learning framework named HeteroFL to address heterogeneous clients equipped with very different computation and communication capabilities. Our solution can enable the training of heterogeneous local models with varying computation complexities and still produce a single global inference model. For the first time, our method challenges the underlying assumption of existing work that local models have to share the same architecture as the global model. We demonstrate several strategies to enhance FL training and conduct extensive empirical evaluations, including five computation complexity levels of three model architecture on three datasets. We show that adaptively distributing subnetworks according to clients’ capabilities is both computation and communication efficient. + + +## About this baseline + +**What’s implemented:** The code in this directory is an implementation of HeteroFL in PyTorch using Flower. The code incorporates references from the authors' implementation. Implementation of custom model split and aggregation as suggested by [@negedng](https://github.com/negedng), is available [here](https://github.com/msck72/heterofl_custom_aggregation). By modifying the configuration in the `base.yaml`, the results in the paper can be replicated, with both fixed and dynamic computational complexities among clients. + +**Key Terminology:** ++ *Model rate* defines the computational complexity of a client. Authors have defined five different computation complexity levels {a, b, c, d, e} with the hidden channel shrinkage ratio r = 0.5. + ++ *Model split mode* specifies whether the computational complexities of clients are fixed (throughout the experiment), or whether they are dynamic (change their mode_rate/computational-complexity every-round). + ++ *Model mode* determines the proportionality of clients with various computation complexity levels, for example, a4-b2-e4 determines at each round, proportion of clients with computational complexity level a = 4 / (4 + 2 + 4) * num_clients, similarly, proportion of clients with computational complexity level b = 2 / (4 + 2 + 4) * num_clients and so on. + +**Implementation Insights:** +*ModelRateManager* manages the model rate of client in simulation, which changes the model rate based on the model mode of the setup and *ClientManagerHeterofl* keeps track of model rates of the clients, so configure fit knows which/how-much subset of the model that needs to be sent to the client. + +**Datasets:** The code utilized benchmark MNIST and CIFAR-10 datasets from Pytorch's torchvision for its experimentation. + +**Hardware Setup:** The experiments were run on Google colab pro with 50GB RAM and T4 TPU. For MNIST dataset & CNN model, it approximately takes 1.5 hours to complete 200 rounds while for CIFAR10 dataset & ResNet18 model it takes around 3-4 hours to complete 400 rounds (may vary based on the model-mode of the setup). + +**Contributors:** M S Chaitanya Kumar [(github.com/msck72)](https://github.com/msck72) + + +## Experimental Setup + +**Task:** Image Classification. +**Model:** This baseline uses two models: ++ Convolutional Neural Network(CNN) model is used for MNIST dataset. ++ PreResNet (preactivated ResNet) model is used for CIFAR10 dataset. + +These models use static batch normalization (sBN) and they incorporate a Scaler module following each convolutional layer. + +**Dataset:** This baseline includes MNIST and CIFAR10 datasets. + +| Dataset | #Classes | IID Partition | non-IID Partition | +| :---: | :---: | :---: | :---: | +| MNIST
CIFAR10 | 10| Distribution of equal number of data examples among n clients | Distribution of data examples such that each client has at most 2 (customizable) classes | + + +**Training Hyperparameters:** + +| Description | Data Setting | MNIST | CIFAR-10 | +| :---: | :---: | :---:| :---: | +Total Clients | both | 100 | 100 | +Clients Per Round | both | 100 | 100 +Local Epcohs | both | 5 | 5 +Num. ROunds | IID
non-IID| 200
400 | 400
800 +Optimizer | both | SGD | SGD +Momentum | both | 0.9 | 0.9 +Weight-decay | both | 5.00e-04 | 5.00e-04 +Learning Rate | both | 0.01 | 0.1 +Decay Schedule | IID
non-IID| [100]
[150, 250] | [200]
[300,500] +Hidden Layers | both | [64 , 128 , 256 , 512] | [64 , 128 , 256 , 512] + + +The hyperparameters of Fedavg baseline are available in [Liang et al (2020)](https://arxiv.org/abs/2001.01523). + +## Environment Setup + +To construct the Python environment, simply run: + +```bash +# Set python version +pyenv install 3.10.6 +pyenv local 3.10.6 + +# Tell poetry to use python 3.10 +poetry env use 3.10.6 + +# install the base Poetry environment +poetry install + +# activate the environment +poetry shell +``` + + +## Running the Experiments +To run HeteroFL experiments in poetry activated environment: +```bash +# The main experiment implemented in your baseline using default hyperparameters (that should be setup in the Hydra configs) +# should run (including dataset download and necessary partitioning) by executing the command: + +python -m heterofl.main # Which runs the heterofl with arguments availbale in heterfl/conf/base.yaml + +# We could override the settings that were specified in base.yaml using the command-line-arguments +# Here's an example for changing the dataset name, non-iid and model +python -m heterofl.main dataset.dataset_name='CIFAR10' dataset.iid=False model.model_name='resnet18' + +# Similarly, another example for changing num_rounds, model_split_mode, and model_mode +python -m heterofl.main num_rounds=400 control.model_split_mode='dynamic' control.model_mode='a1-b1' + +# Similarly, another example for changing num_rounds, model_split_mode, and model_mode +python -m heterofl.main num_rounds=400 control.model_split_mode='dynamic' control.model_mode='a1-b1' + +``` +To run FedAvg experiments: +```bash +python -m heterofl.main --config-name fedavg +# Similarly to the commands illustrated above, we can modify the default settings in the fedavg.yaml file. +``` + +## Expected Results + +```bash +# running the multirun for IID-MNIST with various model-modes using default config +python -m heterofl.main --multirun control.model_mode='a1','a1-e1','a1-b1-c1-d1-e1' + +# running the multirun for IID-CIFAR10 dataset with various model-modes by modifying default config +python -m heterofl.main --multirun control.model_mode='a1','a1-e1','a1-b1-c1-d1-e1' dataset.dataset_name='CIFAR10' model.model_name='resnet18' num_rounds=400 optim_scheduler.lr=0.1 strategy.milestones=[150, 250] + +# running the multirun for non-IID-MNIST with various model-modes by modifying default config +python -m heterofl.main --multirun control.model_mode='a1','a1-e1','a1-b1-c1-d1-e1' dataset.iid=False num_rounds=400 optim_scheduler.milestones=[200] + +# similarly, we can perform for various model-modes, datasets. But we cannot multirun with both non-iid and iid at once for reproducing the tables below, since the number of rounds and milestones for MultiStepLR are different for non-iid and iid. The tables below are the reproduced results of various multiruns. + +#To reproduce the fedavg results +#for MNIST dataset +python -m heterofl.main --config-name fedavg --multirun dataset.iid=True,False +# for CIFAR10 dataset +python -m heterofl.main --config-name fedavg --multirun num_rounds=1800 dataset.dataset_name='CIFAR10' dataset.iid=True,False dataset.batch_size.train=50 dataset.batch_size.test=128 model.model_name='CNNCifar' optim_scheduler.lr=0.1 +``` +
+ +Results of the combination of various computation complexity levels for **MNIST** dataset with **dynamic** scenario(where a client does not belong to a fixed computational complexity level): + +| Model | Ratio | Parameters | FLOPS | Space(MB) | IID-accuracy | non-IId local-acc | non-IID global-acc | +| :--: | :----: | :-----: | :-------: | :-------: | :----------: | :---------------: | :----------------: | +| a | 1 | 1556.874 K | 80.504 M | 5.939 | 99.47 | 99.82 | 98.87 | +| a-e | 0.502 | 781.734 K | 40.452 M | 2.982 | 99.49 | 99.86 | 98.9 | +| a-b-c-d-e | 0.267 | 415.807 K | 21.625 M | 1.586 | 99.23 | 99.84 | 98.5 | +| b | 1 | 391.37 K | 20.493 M | 1.493 | 99.54 | 99.81 | 98.81 | +| b-e | 0.508 | 198.982 K | 10.447 M | 0.759 | 99.48 | 99.87 | 98.98 | +| b-c-d-e | 0.334 | 130.54 K | 6.905 M | 0.498 | 99.34 | 99.81 | 98.73 | +| c | 1 | 98.922 K | 5.307 M | 0.377 | 99.37 | 99.64 | 97.14 | +| c-e | 0.628 | 62.098 K | 3.363 M | 0.237 | 99.16 | 99.72 | 97.68 | +| c-d-e | 0.441 | 43.5965 K | 2.375 M | 0.166 | 99.28 | 99.69 | 97.27 | +| d | 1 | 25.274 K | 1.418 M | 0.096 | 99.07 | 99.77 | 97.58 | +| d-e | 0.63 | 15.934 K | 0.909 M | 0.0608 | 99.12 | 99.65 | 97.33 | +| e | 1 | 6.594 K | 0.4005 M | 0.025 | 98.46 | 99.53 | 96.5 | +| FedAvg | 1 | 633.226 K | 1.264128 M | 2.416 | 97.85 | 97.76 | 97.74 | + + +
+ +Results of the combination of various computation complexity levels for **CIFAR10** dataset with **dynamic** scenario(where a client does not belong to a fixed computational complexity level): +> *The HeteroFL paper reports a model with 1.8M parameters for their FedAvg baseline. However, as stated by the paper authors, those results are borrowed from [Liang et al (2020)](https://arxiv.org/abs/2001.01523), which uses a small CNN with fewer parameters (~64K as shown in this table below). We believe the HeteroFL authors made a mistake when reporting the number of parameters. We borrowed the model from Liang et al (2020)'s [repo](https://github.com/pliang279/LG-FedAvg/blob/master/models/Nets.py). As in the paper, FedAvg was run for 1800 rounds.* + + +| Model | Ratio | Parameters | FLOPS | Space(MB) | IID-acc | non-IId local-acc
Final   Best| non-IID global-acc
Final    Best| +| :--: | :----: | :-----: | :-------: | :-------: | :----------: | :-----: | :------: | + a | 1 | 9622 K | 330.2 M | 36.705 | 90.83 | 89.04    92.41 | 48.72    59.29 | + a-e | 0.502 | 4830 K | 165.9 M | 18.426 | 89.98 | 87.98    91.25 | 50.16    57.66 | + a-b-c-d-e | 0.267 | 2565 K | 88.4 M | 9.785 | 87.46 | 89.75    91.19 | 46.96    55.6 | + b | 1 | 2409 K | 83.3 M | 9.189 | 88.59 | 89.31    92.07 | 49.85    60.79 | + b-e | 0.508 | 1224 K | 42.4 M | 4.667 | 89.23 | 90.93    92.3 | 55.46    61.98 | + b-c-d-e | 0.332 | 801 K | 27.9 M | 3.054 | 87.61 | 89.23    91.83 | 51.59    59.4 | + c | 1 | 604 K | 21.2 M | 2.303 | 85.74 | 89.83    91.75 | 44.03    58.26 | + c-e | 0.532 | 321 K | 11.4 M | 1.225 | 87.32 | 89.28    91.56 | 53.43    59.5 | + c-d-e | 0.438 | 265 K | 9.4 M | 1.010 | 85.59 | 91.48    92.05 | 58.26    61.79 | + d | 1 | 152 K | 5.5 M | 0.579 | 82.91 | 90.81    91.47 | 55.95    58.34 | + d-e | 0.626 | 95 K | 3.5 M | 0.363 | 82.77 | 88.79    90.13 | 48.49    54.18 | + e | 1 | 38 K | 1.5 M | 0.146 | 76.53 | 90.05    90.91 | 54.68    57.05 | +|FedAvg | 1 | 64 K| 1.3 M | 0.2446 | 70.65 | 53.12    58.6 | 52.93    58.47 | + + + diff --git a/baselines/heterofl/heterofl/__init__.py b/baselines/heterofl/heterofl/__init__.py new file mode 100644 index 000000000000..a5e567b59135 --- /dev/null +++ b/baselines/heterofl/heterofl/__init__.py @@ -0,0 +1 @@ +"""Template baseline package.""" diff --git a/baselines/heterofl/heterofl/client.py b/baselines/heterofl/heterofl/client.py new file mode 100644 index 000000000000..cf325cb7e85b --- /dev/null +++ b/baselines/heterofl/heterofl/client.py @@ -0,0 +1,133 @@ +"""Defines the MNIST Flower Client and a function to instantiate it.""" + +from typing import Callable, Dict, List, Optional, Tuple + +import flwr as fl +import torch +from flwr.common.typing import NDArrays + +from heterofl.models import create_model, get_parameters, set_parameters, test, train + +# from torch.utils.data import DataLoader + + +class FlowerNumPyClient(fl.client.NumPyClient): + """Standard Flower client for training.""" + + def __init__( + self, + # cid: str, + net: torch.nn.Module, + dataloader, + model_rate: Optional[float], + client_train_settings: Dict, + ): + # self.cid = cid + self.net = net + self.trainloader = dataloader["trainloader"] + self.label_split = dataloader["label_split"] + self.valloader = dataloader["valloader"] + self.model_rate = model_rate + self.client_train_settings = client_train_settings + self.client_train_settings["device"] = torch.device( + "cuda:0" if torch.cuda.is_available() else "cpu" + ) + # print( + # "Client_with model rate = {} , cid of client = {}".format( + # self.model_rate, self.cid + # ) + # ) + + def get_parameters(self, config) -> NDArrays: + """Return the parameters of the current net.""" + # print(f"[Client {self.cid}] get_parameters") + return get_parameters(self.net) + + def fit(self, parameters, config) -> Tuple[NDArrays, int, Dict]: + """Implement distributed fit function for a given client.""" + # print(f"cid = {self.cid}") + set_parameters(self.net, parameters) + if "lr" in config: + self.client_train_settings["lr"] = config["lr"] + train( + self.net, + self.trainloader, + self.label_split, + self.client_train_settings, + ) + return get_parameters(self.net), len(self.trainloader), {} + + def evaluate(self, parameters, config) -> Tuple[float, int, Dict]: + """Implement distributed evaluation for a given client.""" + set_parameters(self.net, parameters) + loss, accuracy = test( + self.net, self.valloader, device=self.client_train_settings["device"] + ) + return float(loss), len(self.valloader), {"accuracy": float(accuracy)} + + +def gen_client_fn( + model_config: Dict, + client_to_model_rate_mapping: Optional[List[float]], + client_train_settings: Dict, + data_loaders, +) -> Callable[[str], FlowerNumPyClient]: # pylint: disable=too-many-arguments + """Generate the client function that creates the Flower Clients. + + Parameters + ---------- + model_config : Dict + Dict that contains all the information required to + create a model (data_shape , hidden_layers , classes_size...) + client_to_model_rate: List[float] + List tha contains model_rates of clients. + model_rate of client with cid i = client_to_model_rate_mapping[i] + client_train_settings : Dict + Dict that contains information regarding optimizer , lr , + momentum , device required by the client to train + trainloaders: List[DataLoader] + A list of DataLoaders, each pointing to the dataset training partition + belonging to a particular client. + label_split: torch.tensor + A Tensor of tensors that conatins the labels of the partitioned dataset. + label_split of client with cid i = label_split[i] + valloaders: List[DataLoader] + A list of DataLoaders, each pointing to the dataset validation partition + belonging to a particular client. + + Returns + ------- + Callable[[str], FlowerClient] + A tuple containing the client function that creates Flower Clients + """ + + def client_fn(cid: str) -> FlowerNumPyClient: + """Create a Flower client representing a single organization.""" + # Note: each client gets a different trainloader/valloader, so each client + # will train and evaluate on their own unique data + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + client_dataloader = { + "trainloader": data_loaders["trainloaders"][int(cid)], + "valloader": data_loaders["valloaders"][int(cid)], + "label_split": data_loaders["label_split"][int(cid)], + } + # trainloader = data_loaders["trainloaders"][int(cid)] + # valloader = data_loaders["valloaders"][int(cid)] + model_rate = None + if client_to_model_rate_mapping is not None: + model_rate = client_to_model_rate_mapping[int(cid)] + + return FlowerNumPyClient( + # cid=cid, + net=create_model( + model_config, + model_rate=model_rate, + device=device, + ), + dataloader=client_dataloader, + model_rate=model_rate, + client_train_settings=client_train_settings, + ) + + return client_fn diff --git a/baselines/heterofl/heterofl/client_manager_heterofl.py b/baselines/heterofl/heterofl/client_manager_heterofl.py new file mode 100644 index 000000000000..be5b2227a159 --- /dev/null +++ b/baselines/heterofl/heterofl/client_manager_heterofl.py @@ -0,0 +1,207 @@ +"""HeteroFL ClientManager.""" + +import random +import threading +from logging import INFO +from typing import Dict, List, Optional + +import flwr as fl +import torch +from flwr.common.logger import log +from flwr.server.client_proxy import ClientProxy +from flwr.server.criterion import Criterion + +# from heterofl.utils import ModelRateManager + + +class ClientManagerHeteroFL(fl.server.ClientManager): + """Provides a pool of available clients.""" + + def __init__( + self, + model_rate_manager=None, + clients_to_model_rate_mapping=None, + client_label_split: Optional[list[torch.tensor]] = None, + ) -> None: + super().__init__() + self.clients: Dict[str, ClientProxy] = {} + + self.is_simulation = False + if model_rate_manager is not None and clients_to_model_rate_mapping is not None: + self.is_simulation = True + + self.model_rate_manager = model_rate_manager + + # have a common array in simulation to access in the client_fn and server side + if self.is_simulation is True: + self.clients_to_model_rate_mapping = clients_to_model_rate_mapping + ans = self.model_rate_manager.create_model_rate_mapping( + len(clients_to_model_rate_mapping) + ) + # copy self.clients_to_model_rate_mapping , ans + for i, model_rate in enumerate(ans): + self.clients_to_model_rate_mapping[i] = model_rate + + # shall handle in case of not_simulation... + self.client_label_split = client_label_split + + self._cv = threading.Condition() + + def __len__(self) -> int: + """Return the length of clients Dict. + + Returns + ------- + len : int + Length of Dict (self.clients). + """ + return len(self.clients) + + def num_available(self) -> int: + """Return the number of available clients. + + Returns + ------- + num_available : int + The number of currently available clients. + """ + return len(self) + + def wait_for(self, num_clients: int, timeout: int = 86400) -> bool: + """Wait until at least `num_clients` are available. + + Blocks until the requested number of clients is available or until a + timeout is reached. Current timeout default: 1 day. + + Parameters + ---------- + num_clients : int + The number of clients to wait for. + timeout : int + The time in seconds to wait for, defaults to 86400 (24h). + + Returns + ------- + success : bool + """ + with self._cv: + return self._cv.wait_for( + lambda: len(self.clients) >= num_clients, timeout=timeout + ) + + def register(self, client: ClientProxy) -> bool: + """Register Flower ClientProxy instance. + + Parameters + ---------- + client : flwr.server.client_proxy.ClientProxy + + Returns + ------- + success : bool + Indicating if registration was successful. False if ClientProxy is + already registered or can not be registered for any reason. + """ + if client.cid in self.clients: + return False + + self.clients[client.cid] = client + + # in case of not a simulation, this type of method can be used + # if self.is_simulation is False: + # prop = client.get_properties(None, timeout=86400) + # self.clients_to_model_rate_mapping[int(client.cid)] = prop["model_rate"] + # self.client_label_split[int(client.cid)] = prop["label_split"] + + with self._cv: + self._cv.notify_all() + + return True + + def unregister(self, client: ClientProxy) -> None: + """Unregister Flower ClientProxy instance. + + This method is idempotent. + + Parameters + ---------- + client : flwr.server.client_proxy.ClientProxy + """ + if client.cid in self.clients: + del self.clients[client.cid] + + with self._cv: + self._cv.notify_all() + + def all(self) -> Dict[str, ClientProxy]: + """Return all available clients.""" + return self.clients + + def get_client_to_model_mapping(self, cid) -> float: + """Return model rate of client with cid.""" + return self.clients_to_model_rate_mapping[int(cid)] + + def get_all_clients_to_model_mapping(self) -> List[float]: + """Return all available clients to model rate mapping.""" + return self.clients_to_model_rate_mapping.copy() + + def update(self, server_round: int) -> None: + """Update the client to model rate mapping.""" + if self.is_simulation is True: + if ( + server_round == 1 and self.model_rate_manager.model_split_mode == "fix" + ) or (self.model_rate_manager.model_split_mode == "dynamic"): + ans = self.model_rate_manager.create_model_rate_mapping( + self.num_available() + ) + # copy self.clients_to_model_rate_mapping , ans + for i, model_rate in enumerate(ans): + self.clients_to_model_rate_mapping[i] = model_rate + print( + "clients to model rate mapping ", self.clients_to_model_rate_mapping + ) + return + + # to be handled in case of not a simulation, i.e. to get the properties + # again from the clients as they can change the model_rate + # for i in range(self.num_available): + # # need to test this , accumilates the + # # changing model rate of the client + # self.clients_to_model_rate_mapping[i] = + # self.clients[str(i)].get_properties['model_rate'] + # return + + def sample( + self, + num_clients: int, + min_num_clients: Optional[int] = None, + criterion: Optional[Criterion] = None, + ) -> List[ClientProxy]: + """Sample a number of Flower ClientProxy instances.""" + # Block until at least num_clients are connected. + if min_num_clients is None: + min_num_clients = num_clients + self.wait_for(min_num_clients) + # Sample clients which meet the criterion + available_cids = list(self.clients) + if criterion is not None: + available_cids = [ + cid for cid in available_cids if criterion.select(self.clients[cid]) + ] + + if num_clients > len(available_cids): + log( + INFO, + "Sampling failed: number of available clients" + " (%s) is less than number of requested clients (%s).", + len(available_cids), + num_clients, + ) + return [] + + random_indices = torch.randperm(len(available_cids))[:num_clients] + # Use the random indices to select clients + sampled_cids = [available_cids[i] for i in random_indices] + sampled_cids = random.sample(available_cids, num_clients) + print(f"Sampled CIDS = {sampled_cids}") + return [self.clients[cid] for cid in sampled_cids] diff --git a/baselines/heterofl/heterofl/conf/base.yaml b/baselines/heterofl/heterofl/conf/base.yaml new file mode 100644 index 000000000000..42edf419cc38 --- /dev/null +++ b/baselines/heterofl/heterofl/conf/base.yaml @@ -0,0 +1,47 @@ +num_clients: 100 +num_epochs: 5 +num_rounds: 800 +seed: 0 +client_resources: + num_cpus: 1 + num_gpus: 0.08 + +control: + model_split_mode: 'dynamic' + model_mode: 'a1-b1-c1-d1-e1' + +dataset: + dataset_name: 'CIFAR10' + iid: False + shard_per_user : 2 # only used in case of non-iid (i.e. iid = false) + balance: false + batch_size: + train: 10 + test: 50 + shuffle: + train: true + test: false + + +model: + model_name: resnet18 # use 'conv' for MNIST + hidden_layers: [64 , 128 , 256 , 512] + norm: bn + scale: 1 + mask: 1 + + +optim_scheduler: + optimizer: SGD + lr: 0.1 + momentum: 0.9 + weight_decay: 5.00e-04 + scheduler: MultiStepLR + milestones: [300, 500] + +strategy: + _target_: heterofl.strategy.HeteroFL + fraction_fit: 0.1 + fraction_evaluate: 0.1 + min_fit_clients: 10 + min_evaluate_clients: 10 diff --git a/baselines/heterofl/heterofl/conf/fedavg.yaml b/baselines/heterofl/heterofl/conf/fedavg.yaml new file mode 100644 index 000000000000..d67d0950654a --- /dev/null +++ b/baselines/heterofl/heterofl/conf/fedavg.yaml @@ -0,0 +1,41 @@ +num_clients: 100 +num_epochs: 1 +num_rounds: 800 +seed: 0 +clip: False +enable_train_on_train_data_while_testing: False +client_resources: + num_cpus: 1 + num_gpus: 0.4 + +dataset: + dataset_name: 'MNIST' + iid: False + shard_per_user : 2 + balance: False + batch_size: + train: 10 + test: 10 + shuffle: + train: true + test: false + + +model: + model_name: MLP #use CNNCifar for CIFAR10 + +optim_scheduler: + optimizer: SGD + lr: 0.05 + lr_decay_rate: 1.0 + momentum: 0.5 + weight_decay: 0 + scheduler: MultiStepLR + milestones: [] + +strategy: + _target_: flwr.server.strategy.FedAvg + fraction_fit: 0.1 + fraction_evaluate: 0.1 + min_fit_clients: 10 + min_evaluate_clients: 10 diff --git a/baselines/heterofl/heterofl/dataset.py b/baselines/heterofl/heterofl/dataset.py new file mode 100644 index 000000000000..0e0f4b726842 --- /dev/null +++ b/baselines/heterofl/heterofl/dataset.py @@ -0,0 +1,83 @@ +"""Utilities for creation of DataLoaders for clients and server.""" + +from typing import List, Optional, Tuple + +import torch +from omegaconf import DictConfig +from torch.utils.data import DataLoader + +from heterofl.dataset_preparation import _partition_data + + +def load_datasets( # pylint: disable=too-many-arguments + strategy_name: str, + config: DictConfig, + num_clients: int, + seed: Optional[int] = 42, +) -> Tuple[ + DataLoader, List[DataLoader], List[torch.tensor], List[DataLoader], DataLoader +]: + """Create the dataloaders to be fed into the model. + + Parameters + ---------- + config: DictConfig + Parameterises the dataset partitioning process + num_clients : int + The number of clients that hold a part of the data + seed : int, optional + Used to set a fix seed to replicate experiments, by default 42 + + Returns + ------- + Tuple[DataLoader, DataLoader, DataLoader, DataLoader] + The entire trainset Dataloader for testing purposes, + The DataLoader for training, the DataLoader for validation, + the DataLoader for testing. + """ + print(f"Dataset partitioning config: {config}") + trainset, datasets, label_split, client_testsets, testset = _partition_data( + num_clients, + dataset_name=config.dataset_name, + strategy_name=strategy_name, + iid=config.iid, + dataset_division={ + "shard_per_user": config.shard_per_user, + "balance": config.balance, + }, + seed=seed, + ) + # Split each partition into train/val and create DataLoader + entire_trainloader = DataLoader( + trainset, batch_size=config.batch_size.train, shuffle=config.shuffle.train + ) + + trainloaders = [] + valloaders = [] + for dataset in datasets: + trainloaders.append( + DataLoader( + dataset, + batch_size=config.batch_size.train, + shuffle=config.shuffle.train, + ) + ) + + for client_testset in client_testsets: + valloaders.append( + DataLoader( + client_testset, + batch_size=config.batch_size.test, + shuffle=config.shuffle.test, + ) + ) + + return ( + entire_trainloader, + trainloaders, + label_split, + valloaders, + DataLoader( + testset, batch_size=config.batch_size.test, shuffle=config.shuffle.test + ), + ) diff --git a/baselines/heterofl/heterofl/dataset_preparation.py b/baselines/heterofl/heterofl/dataset_preparation.py new file mode 100644 index 000000000000..525e815e9e98 --- /dev/null +++ b/baselines/heterofl/heterofl/dataset_preparation.py @@ -0,0 +1,357 @@ +"""Functions for dataset download and processing.""" + +from typing import Dict, List, Optional, Tuple + +import numpy as np +import torch +from torch.utils.data import ConcatDataset, Dataset, Subset, random_split +from torchvision import transforms + +import heterofl.datasets as dt + + +def _download_data(dataset_name: str, strategy_name: str) -> Tuple[Dataset, Dataset]: + root = "./data/{}".format(dataset_name) + if dataset_name == "MNIST": + trainset = dt.MNIST( + root=root, + split="train", + subset="label", + transform=dt.Compose( + [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] + ), + ) + testset = dt.MNIST( + root=root, + split="test", + subset="label", + transform=dt.Compose( + [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] + ), + ) + elif dataset_name == "CIFAR10": + normalize = transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + if strategy_name == "heterofl": + normalize = transforms.Normalize( + (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010) + ) + trainset = dt.CIFAR10( + root=root, + split="train", + subset="label", + transform=dt.Compose( + [ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ] + ), + ) + testset = dt.CIFAR10( + root=root, + split="test", + subset="label", + transform=dt.Compose( + [ + transforms.ToTensor(), + normalize, + ] + ), + ) + else: + raise ValueError(f"{dataset_name} is not valid") + + return trainset, testset + + +# pylint: disable=too-many-arguments +def _partition_data( + num_clients: int, + dataset_name: str, + strategy_name: str, + iid: Optional[bool] = False, + dataset_division=None, + seed: Optional[int] = 42, +) -> Tuple[Dataset, List[Dataset], List[torch.tensor], List[Dataset], Dataset]: + trainset, testset = _download_data(dataset_name, strategy_name) + + if dataset_name in ("MNIST", "CIFAR10"): + classes_size = 10 + + if dataset_division["balance"]: + trainset = _balance_classes(trainset, seed) + + if iid: + datasets, label_split = iid_partition(trainset, num_clients, seed=seed) + client_testsets, _ = iid_partition(testset, num_clients, seed=seed) + else: + datasets, label_split = non_iid( + {"dataset": trainset, "classes_size": classes_size}, + num_clients, + dataset_division["shard_per_user"], + ) + client_testsets, _ = non_iid( + { + "dataset": testset, + "classes_size": classes_size, + }, + num_clients, + dataset_division["shard_per_user"], + label_split, + ) + + tensor_label_split = [] + for i in label_split: + tensor_label_split.append(torch.Tensor(i)) + label_split = tensor_label_split + + return trainset, datasets, label_split, client_testsets, testset + + +def iid_partition( + dataset: Dataset, num_clients: int, seed: Optional[int] = 42 +) -> Tuple[List[Dataset], List[torch.tensor]]: + """IID partition of dataset among clients.""" + partition_size = int(len(dataset) / num_clients) + lengths = [partition_size] * num_clients + + divided_dataset = random_split( + dataset, lengths, torch.Generator().manual_seed(seed) + ) + label_split = [] + for i in range(num_clients): + label_split.append( + torch.unique(torch.Tensor([target for _, target in divided_dataset[i]])) + ) + + return divided_dataset, label_split + + +def non_iid( + dataset_info, + num_clients: int, + shard_per_user: int, + label_split=None, + seed=42, +) -> Tuple[List[Dataset], List]: + """Non-IID partition of dataset among clients. + + Adopted from authors (of heterofl) implementation. + """ + data_split: Dict[int, List] = {i: [] for i in range(num_clients)} + + label_idx_split, shard_per_class = _split_dataset_targets_idx( + dataset_info["dataset"], + shard_per_user, + num_clients, + dataset_info["classes_size"], + ) + + if label_split is None: + label_split = list(range(dataset_info["classes_size"])) * shard_per_class + label_split = torch.tensor(label_split)[ + torch.randperm( + len(label_split), generator=torch.Generator().manual_seed(seed) + ) + ].tolist() + label_split = np.array(label_split).reshape((num_clients, -1)).tolist() + + for i, _ in enumerate(label_split): + label_split[i] = np.unique(label_split[i]).tolist() + + for i in range(num_clients): + for label_i in label_split[i]: + idx = torch.arange(len(label_idx_split[label_i]))[ + torch.randperm( + len(label_idx_split[label_i]), + generator=torch.Generator().manual_seed(seed), + )[0] + ].item() + data_split[i].extend(label_idx_split[label_i].pop(idx)) + + return ( + _get_dataset_from_idx(dataset_info["dataset"], data_split, num_clients), + label_split, + ) + + +def _split_dataset_targets_idx(dataset, shard_per_user, num_clients, classes_size): + label = np.array(dataset.target) if hasattr(dataset, "target") else dataset.targets + label_idx_split: Dict = {} + for i, _ in enumerate(label): + label_i = label[i].item() + if label_i not in label_idx_split: + label_idx_split[label_i] = [] + label_idx_split[label_i].append(i) + + shard_per_class = int(shard_per_user * num_clients / classes_size) + + for label_i in label_idx_split: + label_idx = label_idx_split[label_i] + num_leftover = len(label_idx) % shard_per_class + leftover = label_idx[-num_leftover:] if num_leftover > 0 else [] + new_label_idx = ( + np.array(label_idx[:-num_leftover]) + if num_leftover > 0 + else np.array(label_idx) + ) + new_label_idx = new_label_idx.reshape((shard_per_class, -1)).tolist() + + for i, leftover_label_idx in enumerate(leftover): + new_label_idx[i] = np.concatenate([new_label_idx[i], [leftover_label_idx]]) + label_idx_split[label_i] = new_label_idx + return label_idx_split, shard_per_class + + +def _get_dataset_from_idx(dataset, data_split, num_clients): + divided_dataset = [None for i in range(num_clients)] + for i in range(num_clients): + divided_dataset[i] = Subset(dataset, data_split[i]) + return divided_dataset + + +def _balance_classes( + trainset: Dataset, + seed: Optional[int] = 42, +) -> Dataset: + class_counts = np.bincount(trainset.target) + targets = torch.Tensor(trainset.target) + smallest = np.min(class_counts) + idxs = targets.argsort() + tmp = [Subset(trainset, idxs[: int(smallest)])] + tmp_targets = [targets[idxs[: int(smallest)]]] + for count in np.cumsum(class_counts): + tmp.append(Subset(trainset, idxs[int(count) : int(count + smallest)])) + tmp_targets.append(targets[idxs[int(count) : int(count + smallest)]]) + unshuffled = ConcatDataset(tmp) + unshuffled_targets = torch.cat(tmp_targets) + shuffled_idxs = torch.randperm( + len(unshuffled), generator=torch.Generator().manual_seed(seed) + ) + shuffled = Subset(unshuffled, shuffled_idxs) + shuffled.targets = unshuffled_targets[shuffled_idxs] + + return shuffled + + +def _sort_by_class( + trainset: Dataset, +) -> Dataset: + class_counts = np.bincount(trainset.targets) + idxs = trainset.targets.argsort() # sort targets in ascending order + + tmp = [] # create subset of smallest class + tmp_targets = [] # same for targets + + start = 0 + for count in np.cumsum(class_counts): + tmp.append( + Subset(trainset, idxs[start : int(count + start)]) + ) # add rest of classes + tmp_targets.append(trainset.targets[idxs[start : int(count + start)]]) + start += count + sorted_dataset = ConcatDataset(tmp) # concat dataset + sorted_dataset.targets = torch.cat(tmp_targets) # concat targets + return sorted_dataset + + +# pylint: disable=too-many-locals, too-many-arguments +def _power_law_split( + sorted_trainset: Dataset, + num_partitions: int, + num_labels_per_partition: int = 2, + min_data_per_partition: int = 10, + mean: float = 0.0, + sigma: float = 2.0, +) -> Dataset: + """Partition the dataset following a power-law distribution. It follows the. + + implementation of Li et al 2020: https://arxiv.org/abs/1812.06127 with default + values set accordingly. + + Parameters + ---------- + sorted_trainset : Dataset + The training dataset sorted by label/class. + num_partitions: int + Number of partitions to create + num_labels_per_partition: int + Number of labels to have in each dataset partition. For + example if set to two, this means all training examples in + a given partition will be long to the same two classes. default 2 + min_data_per_partition: int + Minimum number of datapoints included in each partition, default 10 + mean: float + Mean value for LogNormal distribution to construct power-law, default 0.0 + sigma: float + Sigma value for LogNormal distribution to construct power-law, default 2.0 + + Returns + ------- + Dataset + The partitioned training dataset. + """ + targets = sorted_trainset.targets + full_idx = list(range(len(targets))) + + class_counts = np.bincount(sorted_trainset.targets) + labels_cs = np.cumsum(class_counts) + labels_cs = [0] + labels_cs[:-1].tolist() + + partitions_idx: List[List[int]] = [] + num_classes = len(np.bincount(targets)) + hist = np.zeros(num_classes, dtype=np.int32) + + # assign min_data_per_partition + min_data_per_class = int(min_data_per_partition / num_labels_per_partition) + for u_id in range(num_partitions): + partitions_idx.append([]) + for cls_idx in range(num_labels_per_partition): + # label for the u_id-th client + cls = (u_id + cls_idx) % num_classes + # record minimum data + indices = list( + full_idx[ + labels_cs[cls] + + hist[cls] : labels_cs[cls] + + hist[cls] + + min_data_per_class + ] + ) + partitions_idx[-1].extend(indices) + hist[cls] += min_data_per_class + + # add remaining images following power-law + probs = np.random.lognormal( + mean, + sigma, + (num_classes, int(num_partitions / num_classes), num_labels_per_partition), + ) + remaining_per_class = class_counts - hist + # obtain how many samples each partition should be assigned for each of the + # labels it contains + # pylint: disable=too-many-function-args + probs = ( + remaining_per_class.reshape(-1, 1, 1) + * probs + / np.sum(probs, (1, 2), keepdims=True) + ) + + for u_id in range(num_partitions): + for cls_idx in range(num_labels_per_partition): + cls = (u_id + cls_idx) % num_classes + count = int(probs[cls, u_id // num_classes, cls_idx]) + + # add count of specific class to partition + indices = full_idx[ + labels_cs[cls] + hist[cls] : labels_cs[cls] + hist[cls] + count + ] + partitions_idx[u_id].extend(indices) + hist[cls] += count + + # construct subsets + partitions = [Subset(sorted_trainset, p) for p in partitions_idx] + return partitions diff --git a/baselines/heterofl/heterofl/datasets/__init__.py b/baselines/heterofl/heterofl/datasets/__init__.py new file mode 100644 index 000000000000..91251db77302 --- /dev/null +++ b/baselines/heterofl/heterofl/datasets/__init__.py @@ -0,0 +1,9 @@ +"""Dataset module. + +The entire datasets module is adopted from authors implementation. +""" +from .cifar import CIFAR10 +from .mnist import MNIST +from .utils import Compose + +__all__ = ("MNIST", "CIFAR10", "Compose") diff --git a/baselines/heterofl/heterofl/datasets/cifar.py b/baselines/heterofl/heterofl/datasets/cifar.py new file mode 100644 index 000000000000..c75194bc8ee7 --- /dev/null +++ b/baselines/heterofl/heterofl/datasets/cifar.py @@ -0,0 +1,150 @@ +"""CIFAR10 dataset class, adopted from authors implementation.""" +import os +import pickle + +import anytree +import numpy as np +import torch +from PIL import Image +from torch.utils.data import Dataset + +from heterofl.datasets.utils import ( + download_url, + extract_file, + make_classes_counts, + make_flat_index, + make_tree, +) +from heterofl.utils import check_exists, load, makedir_exist_ok, save + + +# pylint: disable=too-many-instance-attributes +class CIFAR10(Dataset): + """CIFAR10 dataset.""" + + data_name = "CIFAR10" + file = [ + ( + "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz", + "c58f30108f718f92721af3b95e74349a", + ) + ] + + def __init__(self, root, split, subset, transform=None): + self.root = os.path.expanduser(root) + self.split = split + self.subset = subset + self.transform = transform + if not check_exists(self.processed_folder): + self.process() + self.img, self.target = load( + os.path.join(self.processed_folder, "{}.pt".format(self.split)) + ) + self.target = self.target[self.subset] + self.classes_counts = make_classes_counts(self.target) + self.classes_to_labels, self.classes_size = load( + os.path.join(self.processed_folder, "meta.pt") + ) + self.classes_to_labels, self.classes_size = ( + self.classes_to_labels[self.subset], + self.classes_size[self.subset], + ) + + def __getitem__(self, index): + """Get the item with index.""" + img, target = Image.fromarray(self.img[index]), torch.tensor(self.target[index]) + inp = {"img": img, self.subset: target} + if self.transform is not None: + inp = self.transform(inp) + return inp["img"], inp["label"] + + def __len__(self): + """Length of the dataset.""" + return len(self.img) + + @property + def processed_folder(self): + """Return path of processed folder.""" + return os.path.join(self.root, "processed") + + @property + def raw_folder(self): + """Return path of raw folder.""" + return os.path.join(self.root, "raw") + + def process(self): + """Save the dataset accordingly.""" + if not check_exists(self.raw_folder): + self.download() + train_set, test_set, meta = self.make_data() + save(train_set, os.path.join(self.processed_folder, "train.pt")) + save(test_set, os.path.join(self.processed_folder, "test.pt")) + save(meta, os.path.join(self.processed_folder, "meta.pt")) + + def download(self): + """Download dataset from the url.""" + makedir_exist_ok(self.raw_folder) + for url, md5 in self.file: + filename = os.path.basename(url) + download_url(url, self.raw_folder, filename, md5) + extract_file(os.path.join(self.raw_folder, filename)) + + def __repr__(self): + """Represent CIFAR10 as string.""" + fmt_str = ( + f"Dataset {self.__class__.__name__}\nSize: {self.__len__()}\n" + f"Root: {self.root}\nSplit: {self.split}\nSubset: {self.subset}\n" + f"Transforms: {self.transform.__repr__()}" + ) + return fmt_str + + def make_data(self): + """Make data.""" + train_filenames = [ + "data_batch_1", + "data_batch_2", + "data_batch_3", + "data_batch_4", + "data_batch_5", + ] + test_filenames = ["test_batch"] + train_img, train_label = _read_pickle_file( + os.path.join(self.raw_folder, "cifar-10-batches-py"), train_filenames + ) + test_img, test_label = _read_pickle_file( + os.path.join(self.raw_folder, "cifar-10-batches-py"), test_filenames + ) + train_target, test_target = {"label": train_label}, {"label": test_label} + with open( + os.path.join(self.raw_folder, "cifar-10-batches-py", "batches.meta"), "rb" + ) as fle: + data = pickle.load(fle, encoding="latin1") + classes = data["label_names"] + classes_to_labels = {"label": anytree.Node("U", index=[])} + for cls in classes: + make_tree(classes_to_labels["label"], [cls]) + classes_size = {"label": make_flat_index(classes_to_labels["label"])} + return ( + (train_img, train_target), + (test_img, test_target), + (classes_to_labels, classes_size), + ) + + +def _read_pickle_file(path, filenames): + img, label = [], [] + for filename in filenames: + file_path = os.path.join(path, filename) + with open(file_path, "rb") as file: + entry = pickle.load(file, encoding="latin1") + img.append(entry["data"]) + if "labels" in entry: + label.extend(entry["labels"]) + else: + label.extend(entry["fine_labels"]) + # label.extend(entry["labels"]) if "labels" in entry else label.extend( + # entry["fine_labels"] + # ) + img = np.vstack(img).reshape(-1, 3, 32, 32) + img = img.transpose((0, 2, 3, 1)) + return img, label diff --git a/baselines/heterofl/heterofl/datasets/mnist.py b/baselines/heterofl/heterofl/datasets/mnist.py new file mode 100644 index 000000000000..feae2ea987b4 --- /dev/null +++ b/baselines/heterofl/heterofl/datasets/mnist.py @@ -0,0 +1,167 @@ +"""MNIST dataset class, adopted from authors implementation.""" +import codecs +import os + +import anytree +import numpy as np +import torch +from PIL import Image +from torch.utils.data import Dataset + +from heterofl.datasets.utils import ( + download_url, + extract_file, + make_classes_counts, + make_flat_index, + make_tree, +) +from heterofl.utils import check_exists, load, makedir_exist_ok, save + + +# pylint: disable=too-many-instance-attributes +class MNIST(Dataset): + """MNIST dataset.""" + + data_name = "MNIST" + file = [ + ( + "http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz", + "f68b3c2dcbeaaa9fbdd348bbdeb94873", + ), + ( + "http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz", + "9fb629c4189551a2d022fa330f9573f3", + ), + ( + "http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz", + "d53e105ee54ea40749a09fcbcd1e9432", + ), + ( + "http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz", + "ec29112dd5afa0611ce80d1b7f02629c", + ), + ] + + def __init__(self, root, split, subset, transform=None): + self.root = os.path.expanduser(root) + self.split = split + self.subset = subset + self.transform = transform + if not check_exists(self.processed_folder): + self.process() + self.img, self.target = load( + os.path.join(self.processed_folder, "{}.pt".format(self.split)) + ) + self.target = self.target[self.subset] + self.classes_counts = make_classes_counts(self.target) + self.classes_to_labels, self.classes_size = load( + os.path.join(self.processed_folder, "meta.pt") + ) + self.classes_to_labels, self.classes_size = ( + self.classes_to_labels[self.subset], + self.classes_size[self.subset], + ) + + def __getitem__(self, index): + """Get the item with index.""" + img, target = Image.fromarray(self.img[index]), torch.tensor(self.target[index]) + inp = {"img": img, self.subset: target} + if self.transform is not None: + inp = self.transform(inp) + return inp["img"], inp["label"] + + def __len__(self): + """Length of the dataset.""" + return len(self.img) + + @property + def processed_folder(self): + """Return path of processed folder.""" + return os.path.join(self.root, "processed") + + @property + def raw_folder(self): + """Return path of raw folder.""" + return os.path.join(self.root, "raw") + + def process(self): + """Save the dataset accordingly.""" + if not check_exists(self.raw_folder): + self.download() + train_set, test_set, meta = self.make_data() + save(train_set, os.path.join(self.processed_folder, "train.pt")) + save(test_set, os.path.join(self.processed_folder, "test.pt")) + save(meta, os.path.join(self.processed_folder, "meta.pt")) + + def download(self): + """Download and save the dataset accordingly.""" + makedir_exist_ok(self.raw_folder) + for url, md5 in self.file: + filename = os.path.basename(url) + download_url(url, self.raw_folder, filename, md5) + extract_file(os.path.join(self.raw_folder, filename)) + + def __repr__(self): + """Represent CIFAR10 as string.""" + fmt_str = ( + f"Dataset {self.__class__.__name__}\nSize: {self.__len__()}\n" + f"Root: {self.root}\nSplit: {self.split}\nSubset: {self.subset}\n" + f"Transforms: {self.transform.__repr__()}" + ) + return fmt_str + + def make_data(self): + """Make data.""" + train_img = _read_image_file( + os.path.join(self.raw_folder, "train-images-idx3-ubyte") + ) + test_img = _read_image_file( + os.path.join(self.raw_folder, "t10k-images-idx3-ubyte") + ) + train_label = _read_label_file( + os.path.join(self.raw_folder, "train-labels-idx1-ubyte") + ) + test_label = _read_label_file( + os.path.join(self.raw_folder, "t10k-labels-idx1-ubyte") + ) + train_target, test_target = {"label": train_label}, {"label": test_label} + classes_to_labels = {"label": anytree.Node("U", index=[])} + classes = list(map(str, list(range(10)))) + for cls in classes: + make_tree(classes_to_labels["label"], [cls]) + classes_size = {"label": make_flat_index(classes_to_labels["label"])} + return ( + (train_img, train_target), + (test_img, test_target), + (classes_to_labels, classes_size), + ) + + +def _get_int(num): + return int(codecs.encode(num, "hex"), 16) + + +def _read_image_file(path): + with open(path, "rb") as file: + data = file.read() + assert _get_int(data[:4]) == 2051 + length = _get_int(data[4:8]) + num_rows = _get_int(data[8:12]) + num_cols = _get_int(data[12:16]) + parsed = np.frombuffer(data, dtype=np.uint8, offset=16).reshape( + (length, num_rows, num_cols) + ) + return parsed + + +def _read_label_file(path): + with open(path, "rb") as file: + data = file.read() + assert _get_int(data[:4]) == 2049 + length = _get_int(data[4:8]) + parsed = ( + np.frombuffer(data, dtype=np.uint8, offset=8) + .reshape(length) + .astype(np.int64) + ) + return parsed diff --git a/baselines/heterofl/heterofl/datasets/utils.py b/baselines/heterofl/heterofl/datasets/utils.py new file mode 100644 index 000000000000..6b71811ed50d --- /dev/null +++ b/baselines/heterofl/heterofl/datasets/utils.py @@ -0,0 +1,244 @@ +"""Contains utility functions required for datasests. + +Adopted from authors implementation. +""" +import glob +import gzip +import hashlib +import os +import tarfile +import zipfile +from collections import Counter + +import anytree +import numpy as np +from PIL import Image +from six.moves import urllib +from tqdm import tqdm + +from heterofl.utils import makedir_exist_ok + +IMG_EXTENSIONS = [".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif"] + + +def find_classes(drctry): + """Find the classes in a directory.""" + classes = [d.name for d in os.scandir(drctry) if d.is_dir()] + classes.sort() + classes_to_labels = {classes[i]: i for i in range(len(classes))} + return classes_to_labels + + +def pil_loader(path): + """Load image from path using PIL.""" + with open(path, "rb") as file: + img = Image.open(file) + return img.convert("RGB") + + +# def accimage_loader(path): +# """Load image from path using accimage_loader.""" +# import accimage + +# try: +# return accimage.Image(path) +# except IOError: +# return pil_loader(path) + + +def default_loader(path): + """Load image from path using default loader.""" + # if get_image_backend() == "accimage": + # return accimage_loader(path) + + return pil_loader(path) + + +def has_file_allowed_extension(filename, extensions): + """Check whether file possesses any of the extensions listed.""" + filename_lower = filename.lower() + return any(filename_lower.endswith(ext) for ext in extensions) + + +def make_classes_counts(label): + """Count number of classes.""" + label = np.array(label) + if label.ndim > 1: + label = label.sum(axis=tuple(range(1, label.ndim))) + classes_counts = Counter(label) + return classes_counts + + +def _make_bar_updater(pbar): + def bar_update(count, block_size, total_size): + if pbar.total is None and total_size: + pbar.total = total_size + progress_bytes = count * block_size + pbar.update(progress_bytes - pbar.n) + + return bar_update + + +def _calculate_md5(path, chunk_size=1024 * 1024): + md5 = hashlib.md5() + with open(path, "rb") as file: + for chunk in iter(lambda: file.read(chunk_size), b""): + md5.update(chunk) + return md5.hexdigest() + + +def _check_md5(path, md5, **kwargs): + return md5 == _calculate_md5(path, **kwargs) + + +def _check_integrity(path, md5=None): + if not os.path.isfile(path): + return False + if md5 is None: + return True + return _check_md5(path, md5) + + +def download_url(url, root, filename, md5): + """Download files from the url.""" + path = os.path.join(root, filename) + makedir_exist_ok(root) + if os.path.isfile(path) and _check_integrity(path, md5): + print("Using downloaded and verified file: " + path) + else: + try: + print("Downloading " + url + " to " + path) + urllib.request.urlretrieve( + url, path, reporthook=_make_bar_updater(tqdm(unit="B", unit_scale=True)) + ) + except OSError: + if url[:5] == "https": + url = url.replace("https:", "http:") + print( + "Failed download. Trying https -> http instead." + " Downloading " + url + " to " + path + ) + urllib.request.urlretrieve( + url, + path, + reporthook=_make_bar_updater(tqdm(unit="B", unit_scale=True)), + ) + if not _check_integrity(path, md5): + raise RuntimeError("Not valid downloaded file") + + +def extract_file(src, dest=None, delete=False): + """Extract the file.""" + print("Extracting {}".format(src)) + dest = os.path.dirname(src) if dest is None else dest + filename = os.path.basename(src) + if filename.endswith(".zip"): + with zipfile.ZipFile(src, "r") as zip_f: + zip_f.extractall(dest) + elif filename.endswith(".tar"): + with tarfile.open(src) as tar_f: + tar_f.extractall(dest) + elif filename.endswith(".tar.gz") or filename.endswith(".tgz"): + with tarfile.open(src, "r:gz") as tar_f: + tar_f.extractall(dest) + elif filename.endswith(".gz"): + with open(src.replace(".gz", ""), "wb") as out_f, gzip.GzipFile(src) as zip_f: + out_f.write(zip_f.read()) + if delete: + os.remove(src) + + +def make_data(root, extensions): + """Get all the files in the root directory that follows the given extensions.""" + path = [] + files = glob.glob("{}/**/*".format(root), recursive=True) + for file in files: + if has_file_allowed_extension(file, extensions): + path.append(os.path.normpath(file)) + return path + + +# pylint: disable=dangerous-default-value +def make_img(path, classes_to_labels, extensions=IMG_EXTENSIONS): + """Make image.""" + img, label = [], [] + classes = [] + leaf_nodes = classes_to_labels.leaves + for node in leaf_nodes: + classes.append(node.name) + for cls in sorted(classes): + folder = os.path.join(path, cls) + if not os.path.isdir(folder): + continue + for root, _, filenames in sorted(os.walk(folder)): + for filename in sorted(filenames): + if has_file_allowed_extension(filename, extensions): + cur_path = os.path.join(root, filename) + img.append(cur_path) + label.append( + anytree.find_by_attr(classes_to_labels, cls).flat_index + ) + return img, label + + +def make_tree(root, name, attribute=None): + """Create a tree of name.""" + if len(name) == 0: + return + if attribute is None: + attribute = {} + this_name = name[0] + next_name = name[1:] + this_attribute = {k: attribute[k][0] for k in attribute} + next_attribute = {k: attribute[k][1:] for k in attribute} + this_node = anytree.find_by_attr(root, this_name) + this_index = root.index + [len(root.children)] + if this_node is None: + this_node = anytree.Node( + this_name, parent=root, index=this_index, **this_attribute + ) + make_tree(this_node, next_name, next_attribute) + return + + +def make_flat_index(root, given=None): + """Make flat index for each leaf node in the tree.""" + if given: + classes_size = 0 + for node in anytree.PreOrderIter(root): + if len(node.children) == 0: + node.flat_index = given.index(node.name) + classes_size = ( + given.index(node.name) + 1 + if given.index(node.name) + 1 > classes_size + else classes_size + ) + else: + classes_size = 0 + for node in anytree.PreOrderIter(root): + if len(node.children) == 0: + node.flat_index = classes_size + classes_size += 1 + return classes_size + + +class Compose: + """Custom Compose class.""" + + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, inp): + """Apply transforms when called.""" + for transform in self.transforms: + inp["img"] = transform(inp["img"]) + return inp + + def __repr__(self): + """Represent Compose as string.""" + format_string = self.__class__.__name__ + "(" + for transform in self.transforms: + format_string += "\n" + format_string += " {0}".format(transform) + format_string += "\n)" + return format_string diff --git a/baselines/heterofl/heterofl/main.py b/baselines/heterofl/heterofl/main.py new file mode 100644 index 000000000000..3973841cb60e --- /dev/null +++ b/baselines/heterofl/heterofl/main.py @@ -0,0 +1,204 @@ +"""Runs federated learning for given configuration in base.yaml.""" +import pickle +from pathlib import Path + +import flwr as fl +import hydra +import torch +from hydra.core.hydra_config import HydraConfig +from hydra.utils import instantiate +from omegaconf import DictConfig, OmegaConf + +from heterofl import client, models, server +from heterofl.client_manager_heterofl import ClientManagerHeteroFL +from heterofl.dataset import load_datasets +from heterofl.model_properties import get_model_properties +from heterofl.utils import ModelRateManager, get_global_model_rate, preprocess_input + + +# pylint: disable=too-many-locals,protected-access +@hydra.main(config_path="conf", config_name="base.yaml", version_base=None) +def main(cfg: DictConfig) -> None: + """Run the baseline. + + Parameters + ---------- + cfg : DictConfig + An omegaconf object that stores the hydra config. + """ + # print config structured as YAML + print(OmegaConf.to_yaml(cfg)) + torch.manual_seed(cfg.seed) + + data_loaders = {} + + ( + data_loaders["entire_trainloader"], + data_loaders["trainloaders"], + data_loaders["label_split"], + data_loaders["valloaders"], + data_loaders["testloader"], + ) = load_datasets( + "heterofl" if "heterofl" in cfg.strategy._target_ else "fedavg", + config=cfg.dataset, + num_clients=cfg.num_clients, + seed=cfg.seed, + ) + + model_config = preprocess_input(cfg.model, cfg.dataset) + + model_split_rate = None + model_mode = None + client_to_model_rate_mapping = None + model_rate_manager = None + history = None + + if "HeteroFL" in cfg.strategy._target_: + # send this array(client_model_rate_mapping) as + # an argument to client_manager and client + model_split_rate = {"a": 1, "b": 0.5, "c": 0.25, "d": 0.125, "e": 0.0625} + # model_split_mode = cfg.control.model_split_mode + model_mode = cfg.control.model_mode + + client_to_model_rate_mapping = [float(0) for _ in range(cfg.num_clients)] + model_rate_manager = ModelRateManager( + cfg.control.model_split_mode, model_split_rate, model_mode + ) + + model_config["global_model_rate"] = model_split_rate[ + get_global_model_rate(model_mode) + ] + + test_model = models.create_model( + model_config, + model_rate=model_split_rate[get_global_model_rate(model_mode)] + if model_split_rate is not None + else None, + track=True, + device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"), + ) + + get_model_properties( + model_config, + model_split_rate, + model_mode + "" if model_mode is not None else None, + data_loaders["entire_trainloader"], + cfg.dataset.batch_size.train, + ) + + # prepare function that will be used to spawn each client + client_train_settings = { + "epochs": cfg.num_epochs, + "optimizer": cfg.optim_scheduler.optimizer, + "lr": cfg.optim_scheduler.lr, + "momentum": cfg.optim_scheduler.momentum, + "weight_decay": cfg.optim_scheduler.weight_decay, + "scheduler": cfg.optim_scheduler.scheduler, + "milestones": cfg.optim_scheduler.milestones, + } + + if "clip" in cfg: + client_train_settings["clip"] = cfg.clip + + optim_scheduler_settings = { + "optimizer": cfg.optim_scheduler.optimizer, + "lr": cfg.optim_scheduler.lr, + "momentum": cfg.optim_scheduler.momentum, + "weight_decay": cfg.optim_scheduler.weight_decay, + "scheduler": cfg.optim_scheduler.scheduler, + "milestones": cfg.optim_scheduler.milestones, + } + + client_fn = client.gen_client_fn( + model_config=model_config, + client_to_model_rate_mapping=client_to_model_rate_mapping, + client_train_settings=client_train_settings, + data_loaders=data_loaders, + ) + + evaluate_fn = server.gen_evaluate_fn( + data_loaders, + torch.device("cuda:0" if torch.cuda.is_available() else "cpu"), + test_model, + models.create_model( + model_config, + model_rate=model_split_rate[get_global_model_rate(model_mode)] + if model_split_rate is not None + else None, + track=False, + device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"), + ) + .state_dict() + .keys(), + enable_train_on_train_data=cfg.enable_train_on_train_data_while_testing + if "enable_train_on_train_data_while_testing" in cfg + else True, + ) + client_resources = { + "num_cpus": cfg.client_resources.num_cpus, + "num_gpus": cfg.client_resources.num_gpus if torch.cuda.is_available() else 0, + } + + if "HeteroFL" in cfg.strategy._target_: + strategy_heterofl = instantiate( + cfg.strategy, + model_name=cfg.model.model_name, + net=models.create_model( + model_config, + model_rate=model_split_rate[get_global_model_rate(model_mode)] + if model_split_rate is not None + else None, + device="cpu", + ), + optim_scheduler_settings=optim_scheduler_settings, + global_model_rate=model_split_rate[get_global_model_rate(model_mode)] + if model_split_rate is not None + else 1.0, + evaluate_fn=evaluate_fn, + min_available_clients=cfg.num_clients, + ) + + history = fl.simulation.start_simulation( + client_fn=client_fn, + num_clients=cfg.num_clients, + config=fl.server.ServerConfig(num_rounds=cfg.num_rounds), + client_resources=client_resources, + client_manager=ClientManagerHeteroFL( + model_rate_manager, + client_to_model_rate_mapping, + client_label_split=data_loaders["label_split"], + ), + strategy=strategy_heterofl, + ) + else: + strategy_fedavg = instantiate( + cfg.strategy, + # on_fit_config_fn=lambda server_round: { + # "lr": cfg.optim_scheduler.lr + # * pow(cfg.optim_scheduler.lr_decay_rate, server_round) + # }, + evaluate_fn=evaluate_fn, + min_available_clients=cfg.num_clients, + ) + + history = fl.simulation.start_simulation( + client_fn=client_fn, + num_clients=cfg.num_clients, + config=fl.server.ServerConfig(num_rounds=cfg.num_rounds), + client_resources=client_resources, + strategy=strategy_fedavg, + ) + + # save the results + save_path = HydraConfig.get().runtime.output_dir + + # save the results as a python pickle + with open(str(Path(save_path) / "results.pkl"), "wb") as file_handle: + pickle.dump({"history": history}, file_handle, protocol=pickle.HIGHEST_PROTOCOL) + + # save the model + torch.save(test_model.state_dict(), str(Path(save_path) / "model.pth")) + + +if __name__ == "__main__": + main() diff --git a/baselines/heterofl/heterofl/model_properties.py b/baselines/heterofl/heterofl/model_properties.py new file mode 100644 index 000000000000..0739fe4fde22 --- /dev/null +++ b/baselines/heterofl/heterofl/model_properties.py @@ -0,0 +1,123 @@ +"""Determine number of model parameters, space it requires.""" +import numpy as np +import torch +import torch.nn as nn + +from heterofl.models import create_model + + +def get_model_properties( + model_config, model_split_rate, model_mode, data_loader, batch_size +): + """Calculate space occupied & number of parameters of model.""" + model_mode = model_mode.split("-") if model_mode is not None else None + # model = create_model(model_config, model_rate=model_split_rate(i[0])) + + total_flops = 0 + total_model_parameters = 0 + ttl_prcntg = 0 + if model_mode is None: + total_flops = _calculate_model_memory(create_model(model_config), data_loader) + total_model_parameters = _count_parameters(create_model(model_config)) + else: + for i in model_mode: + total_flops += _calculate_model_memory( + create_model(model_config, model_rate=model_split_rate[i[0]]), + data_loader, + ) * int(i[1]) + total_model_parameters += _count_parameters( + create_model(model_config, model_rate=model_split_rate[i[0]]) + ) * int(i[1]) + ttl_prcntg += int(i[1]) + + total_flops = total_flops / ttl_prcntg if ttl_prcntg != 0 else total_flops + total_flops /= batch_size + total_model_parameters = ( + total_model_parameters / ttl_prcntg + if ttl_prcntg != 0 + else total_model_parameters + ) + + space = total_model_parameters * 32.0 / 8 / (1024**2.0) + print("num_of_parameters = ", total_model_parameters / 1000, " K") + print("total_flops = ", total_flops / 1000000, " M") + print("space = ", space) + + return total_model_parameters, total_flops, space + + +def _calculate_model_memory(model, data_loader): + def register_hook(module): + def hook(module, inp, output): + # temp = _make_flops(module, inp, output) + # print(temp) + for _ in module.named_parameters(): + flops.append(_make_flops(module, inp, output)) + + if ( + not isinstance(module, nn.Sequential) + and not isinstance(module, nn.ModuleList) + and not isinstance(module, nn.ModuleDict) + and module != model + ): + hooks.append(module.register_forward_hook(hook)) + + hooks = [] + flops = [] + model.apply(register_hook) + + one_dl = next(iter(data_loader)) + input_dict = {"img": one_dl[0], "label": one_dl[1]} + with torch.no_grad(): + model(input_dict) + + for hook in hooks: + hook.remove() + + return sum(fl for fl in flops) + + +def _count_parameters(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + +def _make_flops(module, inp, output): + if isinstance(inp, tuple): + return _make_flops(module, inp[0], output) + if isinstance(output, tuple): + return _make_flops(module, inp, output[0]) + flops = _compute_flops(module, inp, output) + return flops + + +def _compute_flops(module, inp, out): + flops = 0 + if isinstance(module, nn.Conv2d): + flops = _compute_conv2d_flops(module, inp, out) + elif isinstance(module, (nn.BatchNorm2d, nn.InstanceNorm2d)): + flops = np.prod(inp.shape).item() + if isinstance(module, (nn.BatchNorm2d, nn.InstanceNorm2d)) and module.affine: + flops *= 2 + elif isinstance(module, nn.Linear): + flops = np.prod(inp.size()[:-1]).item() * inp.size()[-1] * out.size()[-1] + # else: + # print(f"[Flops]: {type(module).__name__} is not supported!") + return flops + + +def _compute_conv2d_flops(module, inp, out): + batch_size = inp.size()[0] + in_c = inp.size()[1] + out_c, out_h, out_w = out.size()[1:] + groups = module.groups + filters_per_channel = out_c // groups + conv_per_position_flops = ( + module.kernel_size[0] * module.kernel_size[1] * in_c * filters_per_channel + ) + active_elements_count = batch_size * out_h * out_w + total_conv_flops = conv_per_position_flops * active_elements_count + bias_flops = 0 + if module.bias is not None: + bias_flops = out_c * active_elements_count + total_flops = total_conv_flops + bias_flops + return total_flops diff --git a/baselines/heterofl/heterofl/models.py b/baselines/heterofl/heterofl/models.py new file mode 100644 index 000000000000..9426ee8b2789 --- /dev/null +++ b/baselines/heterofl/heterofl/models.py @@ -0,0 +1,839 @@ +"""Conv & resnet18 model architecture, training, testing functions. + +Classes Conv, Block, Resnet18 are adopted from authors implementation. +""" +import copy +from typing import List, OrderedDict + +import numpy as np +import torch +import torch.nn.functional as F +from flwr.common import parameters_to_ndarrays +from torch import nn + +from heterofl.utils import make_optimizer + + +class Conv(nn.Module): + """Convolutional Neural Network architecture with sBN.""" + + def __init__( + self, + model_config, + ): + super().__init__() + self.model_config = model_config + + blocks = [ + nn.Conv2d( + model_config["data_shape"][0], model_config["hidden_size"][0], 3, 1, 1 + ), + self._get_scale(), + self._get_norm(0), + nn.ReLU(inplace=True), + nn.MaxPool2d(2), + ] + for i in range(len(model_config["hidden_size"]) - 1): + blocks.extend( + [ + nn.Conv2d( + model_config["hidden_size"][i], + model_config["hidden_size"][i + 1], + 3, + 1, + 1, + ), + self._get_scale(), + self._get_norm(i + 1), + nn.ReLU(inplace=True), + nn.MaxPool2d(2), + ] + ) + blocks = blocks[:-1] + blocks.extend( + [ + nn.AdaptiveAvgPool2d(1), + nn.Flatten(), + nn.Linear( + model_config["hidden_size"][-1], model_config["classes_size"] + ), + ] + ) + self.blocks = nn.Sequential(*blocks) + + def _get_norm(self, j: int): + """Return the relavant norm.""" + if self.model_config["norm"] == "bn": + norm = nn.BatchNorm2d( + self.model_config["hidden_size"][j], + momentum=None, + track_running_stats=self.model_config["track"], + ) + elif self.model_config["norm"] == "in": + norm = nn.GroupNorm( + self.model_config["hidden_size"][j], self.model_config["hidden_size"][j] + ) + elif self.model_config["norm"] == "ln": + norm = nn.GroupNorm(1, self.model_config["hidden_size"][j]) + elif self.model_config["norm"] == "gn": + norm = nn.GroupNorm(4, self.model_config["hidden_size"][j]) + elif self.model_config["norm"] == "none": + norm = nn.Identity() + else: + raise ValueError("Not valid norm") + + return norm + + def _get_scale(self): + """Return the relavant scaler.""" + if self.model_config["scale"]: + scaler = _Scaler(self.model_config["rate"]) + else: + scaler = nn.Identity() + return scaler + + def forward(self, input_dict): + """Forward pass of the Conv. + + Parameters + ---------- + input_dict : Dict + Conatins input Tensor that will pass through the network. + label of that input to calculate loss. + label_split if masking is required. + + Returns + ------- + Dict + The resulting Tensor after it has passed through the network and the loss. + """ + # output = {"loss": torch.tensor(0, device=self.device, dtype=torch.float32)} + output = {} + out = self.blocks(input_dict["img"]) + if "label_split" in input_dict and self.model_config["mask"]: + label_mask = torch.zeros( + self.model_config["classes_size"], device=out.device + ) + label_mask[input_dict["label_split"]] = 1 + out = out.masked_fill(label_mask == 0, 0) + output["score"] = out + output["loss"] = F.cross_entropy(out, input_dict["label"], reduction="mean") + return output + + +def conv( + model_rate, + model_config, + device="cpu", +): + """Create the Conv model.""" + model_config["hidden_size"] = [ + int(np.ceil(model_rate * x)) for x in model_config["hidden_layers"] + ] + scaler_rate = model_rate / model_config["global_model_rate"] + model_config["rate"] = scaler_rate + model = Conv(model_config) + model.apply(_init_param) + return model.to(device) + + +class Block(nn.Module): + """Block.""" + + expansion = 1 + + def __init__(self, in_planes, planes, stride, model_config): + super().__init__() + if model_config["norm"] == "bn": + n_1 = nn.BatchNorm2d( + in_planes, momentum=None, track_running_stats=model_config["track"] + ) + n_2 = nn.BatchNorm2d( + planes, momentum=None, track_running_stats=model_config["track"] + ) + elif model_config["norm"] == "in": + n_1 = nn.GroupNorm(in_planes, in_planes) + n_2 = nn.GroupNorm(planes, planes) + elif model_config["norm"] == "ln": + n_1 = nn.GroupNorm(1, in_planes) + n_2 = nn.GroupNorm(1, planes) + elif model_config["norm"] == "gn": + n_1 = nn.GroupNorm(4, in_planes) + n_2 = nn.GroupNorm(4, planes) + elif model_config["norm"] == "none": + n_1 = nn.Identity() + n_2 = nn.Identity() + else: + raise ValueError("Not valid norm") + self.n_1 = n_1 + self.conv1 = nn.Conv2d( + in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False + ) + self.n_2 = n_2 + self.conv2 = nn.Conv2d( + planes, planes, kernel_size=3, stride=1, padding=1, bias=False + ) + if model_config["scale"]: + self.scaler = _Scaler(model_config["rate"]) + else: + self.scaler = nn.Identity() + + if stride != 1 or in_planes != self.expansion * planes: + self.shortcut = nn.Conv2d( + in_planes, + self.expansion * planes, + kernel_size=1, + stride=stride, + bias=False, + ) + + def forward(self, x): + """Forward pass of the Block. + + Parameters + ---------- + x : Dict + Dict that contains Input Tensor that will pass through the network. + label of that input to calculate loss. + label_split if masking is required. + + Returns + ------- + Dict + The resulting Tensor after it has passed through the network and the loss. + """ + out = F.relu(self.n_1(self.scaler(x))) + shortcut = self.shortcut(out) if hasattr(self, "shortcut") else x + out = self.conv1(out) + out = self.conv2(F.relu(self.n_2(self.scaler(out)))) + out += shortcut + return out + + +# pylint: disable=too-many-instance-attributes +class ResNet(nn.Module): + """Implementation of a Residual Neural Network (ResNet) model with sBN.""" + + def __init__( + self, + model_config, + block, + num_blocks, + ): + self.model_config = model_config + super().__init__() + self.in_planes = model_config["hidden_size"][0] + self.conv1 = nn.Conv2d( + model_config["data_shape"][0], + model_config["hidden_size"][0], + kernel_size=3, + stride=1, + padding=1, + bias=False, + ) + + self.layer1 = self._make_layer( + block, + model_config["hidden_size"][0], + num_blocks[0], + stride=1, + ) + self.layer2 = self._make_layer( + block, + model_config["hidden_size"][1], + num_blocks[1], + stride=2, + ) + self.layer3 = self._make_layer( + block, + model_config["hidden_size"][2], + num_blocks[2], + stride=2, + ) + self.layer4 = self._make_layer( + block, + model_config["hidden_size"][3], + num_blocks[3], + stride=2, + ) + + # self.layers = [layer1, layer2, layer3, layer4] + + if model_config["norm"] == "bn": + n_4 = nn.BatchNorm2d( + model_config["hidden_size"][3] * block.expansion, + momentum=None, + track_running_stats=model_config["track"], + ) + elif model_config["norm"] == "in": + n_4 = nn.GroupNorm( + model_config["hidden_size"][3] * block.expansion, + model_config["hidden_size"][3] * block.expansion, + ) + elif model_config["norm"] == "ln": + n_4 = nn.GroupNorm(1, model_config["hidden_size"][3] * block.expansion) + elif model_config["norm"] == "gn": + n_4 = nn.GroupNorm(4, model_config["hidden_size"][3] * block.expansion) + elif model_config["norm"] == "none": + n_4 = nn.Identity() + else: + raise ValueError("Not valid norm") + self.n_4 = n_4 + if model_config["scale"]: + self.scaler = _Scaler(model_config["rate"]) + else: + self.scaler = nn.Identity() + self.linear = nn.Linear( + model_config["hidden_size"][3] * block.expansion, + model_config["classes_size"], + ) + + def _make_layer(self, block, planes, num_blocks, stride): + strides = [stride] + [1] * (num_blocks - 1) + layers = [] + for strd in strides: + layers.append(block(self.in_planes, planes, strd, self.model_config.copy())) + self.in_planes = planes * block.expansion + return nn.Sequential(*layers) + + def forward(self, input_dict): + """Forward pass of the ResNet. + + Parameters + ---------- + input_dict : Dict + Dict that contains Input Tensor that will pass through the network. + label of that input to calculate loss. + label_split if masking is required. + + Returns + ------- + Dict + The resulting Tensor after it has passed through the network and the loss. + """ + output = {} + x = input_dict["img"] + out = self.conv1(x) + # for layer in self.layers: + # out = layer(out) + out = self.layer1(out) + out = self.layer2(out) + out = self.layer3(out) + out = self.layer4(out) + out = F.relu(self.n_4(self.scaler(out))) + out = F.adaptive_avg_pool2d(out, 1) + out = out.view(out.size(0), -1) + out = self.linear(out) + if "label_split" in input_dict and self.model_config["mask"]: + label_mask = torch.zeros( + self.model_config["classes_size"], device=out.device + ) + label_mask[input_dict["label_split"]] = 1 + out = out.masked_fill(label_mask == 0, 0) + output["score"] = out + output["loss"] = F.cross_entropy(output["score"], input_dict["label"]) + return output + + +def resnet18( + model_rate, + model_config, + device="cpu", +): + """Create the ResNet18 model.""" + model_config["hidden_size"] = [ + int(np.ceil(model_rate * x)) for x in model_config["hidden_layers"] + ] + scaler_rate = model_rate / model_config["global_model_rate"] + model_config["rate"] = scaler_rate + model = ResNet(model_config, block=Block, num_blocks=[1, 1, 1, 2]) + model.apply(_init_param) + return model.to(device) + + +class MLP(nn.Module): + """Multi Layer Perceptron.""" + + def __init__(self): + super().__init__() + self.layer_input = nn.Linear(784, 512) + self.relu = nn.ReLU() + self.dropout = nn.Dropout() + self.layer_hidden1 = nn.Linear(512, 256) + self.layer_hidden2 = nn.Linear(256, 256) + self.layer_hidden3 = nn.Linear(256, 128) + self.layer_out = nn.Linear(128, 10) + self.softmax = nn.Softmax(dim=1) + self.weight_keys = [ + ["layer_input.weight", "layer_input.bias"], + ["layer_hidden1.weight", "layer_hidden1.bias"], + ["layer_hidden2.weight", "layer_hidden2.bias"], + ["layer_hidden3.weight", "layer_hidden3.bias"], + ["layer_out.weight", "layer_out.bias"], + ] + + def forward(self, input_dict): + """Forward pass of the Conv. + + Parameters + ---------- + input_dict : Dict + Conatins input Tensor that will pass through the network. + label of that input to calculate loss. + label_split if masking is required. + + Returns + ------- + Dict + The resulting Tensor after it has passed through the network and the loss. + """ + output = {} + x = input_dict["img"] + x = x.view(-1, x.shape[1] * x.shape[-2] * x.shape[-1]) + x = self.layer_input(x) + x = self.relu(x) + + x = self.layer_hidden1(x) + x = self.relu(x) + + x = self.layer_hidden2(x) + x = self.relu(x) + + x = self.layer_hidden3(x) + x = self.relu(x) + + x = self.layer_out(x) + out = self.softmax(x) + output["score"] = out + output["loss"] = F.cross_entropy(out, input_dict["label"], reduction="mean") + return output + + +class CNNCifar(nn.Module): + """Convolutional Neural Network architecture for cifar dataset.""" + + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(3, 6, 5) + self.pool = nn.MaxPool2d(2, 2) + self.conv2 = nn.Conv2d(6, 16, 5) + self.fc1 = nn.Linear(16 * 5 * 5, 120) + self.fc2 = nn.Linear(120, 100) + self.fc3 = nn.Linear(100, 10) + + self.weight_keys = [ + ["fc1.weight", "fc1.bias"], + ["fc2.weight", "fc2.bias"], + ["fc3.weight", "fc3.bias"], + ["conv2.weight", "conv2.bias"], + ["conv1.weight", "conv1.bias"], + ] + + def forward(self, input_dict): + """Forward pass of the Conv. + + Parameters + ---------- + input_dict : Dict + Conatins input Tensor that will pass through the network. + label of that input to calculate loss. + label_split if masking is required. + + Returns + ------- + Dict + The resulting Tensor after it has passed through the network and the loss. + """ + output = {} + x = input_dict["img"] + x = self.pool(F.relu(self.conv1(x))) + x = self.pool(F.relu(self.conv2(x))) + x = x.view(-1, 16 * 5 * 5) + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + out = F.log_softmax(x, dim=1) + output["score"] = out + output["loss"] = F.cross_entropy(out, input_dict["label"], reduction="mean") + return output + + +def create_model(model_config, model_rate=None, track=False, device="cpu"): + """Create the model based on the configuration given in hydra.""" + model = None + model_config = model_config.copy() + model_config["track"] = track + + if model_config["model"] == "MLP": + model = MLP() + model.to(device) + elif model_config["model"] == "CNNCifar": + model = CNNCifar() + model.to(device) + elif model_config["model"] == "conv": + model = conv(model_rate=model_rate, model_config=model_config, device=device) + elif model_config["model"] == "resnet18": + model = resnet18( + model_rate=model_rate, model_config=model_config, device=device + ) + return model + + +def _init_param(m_param): + if isinstance(m_param, (nn.BatchNorm2d, nn.InstanceNorm2d)): + m_param.weight.data.fill_(1) + m_param.bias.data.zero_() + elif isinstance(m_param, nn.Linear): + m_param.bias.data.zero_() + return m_param + + +class _Scaler(nn.Module): + def __init__(self, rate): + super().__init__() + self.rate = rate + + def forward(self, inp): + """Forward of Scalar nn.Module.""" + output = inp / self.rate if self.training else inp + return output + + +def get_parameters(net) -> List[np.ndarray]: + """Return the parameters of model as numpy.NDArrays.""" + return [val.cpu().numpy() for _, val in net.state_dict().items()] + + +def set_parameters(net, parameters: List[np.ndarray]): + """Set the model parameters with given parameters.""" + params_dict = zip(net.state_dict().keys(), parameters) + state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict}) + net.load_state_dict(state_dict, strict=True) + + +def train(model, train_loader, label_split, settings): + """Train a model with given settings. + + Parameters + ---------- + model : nn.Module + The neural network to train. + train_loader : DataLoader + The DataLoader containing the data to train the network on. + label_split : torch.tensor + Tensor containing the labels of the data. + settings: Dict + Dictionary conatining the information about eopchs, optimizer, + lr, momentum, weight_decay, device to train on. + """ + # criterion = torch.nn.CrossEntropyLoss() + optimizer = make_optimizer( + settings["optimizer"], + model.parameters(), + learning_rate=settings["lr"], + momentum=settings["momentum"], + weight_decay=settings["weight_decay"], + ) + + model.train() + for _ in range(settings["epochs"]): + for images, labels in train_loader: + input_dict = {} + input_dict["img"] = images.to(settings["device"]) + input_dict["label"] = labels.to(settings["device"]) + input_dict["label_split"] = label_split.type(torch.int).to( + settings["device"] + ) + optimizer.zero_grad() + output = model(input_dict) + output["loss"].backward() + if ("clip" not in settings) or ( + "clip" in settings and settings["clip"] is True + ): + torch.nn.utils.clip_grad_norm_(model.parameters(), 1) + optimizer.step() + + +def test(model, test_loader, label_split=None, device="cpu"): + """Evaluate the network on the test set. + + Parameters + ---------- + model : nn.Module + The neural network to test. + test_loader : DataLoader + The DataLoader containing the data to test the network on. + device : torch.device + The device on which the model should be tested, either 'cpu' or 'cuda'. + + Returns + ------- + Tuple[float, float] + The loss and the accuracy of the input model on the given data. + """ + model.eval() + size = len(test_loader.dataset) + num_batches = len(test_loader) + test_loss, correct = 0, 0 + + with torch.no_grad(): + model.train(False) + for images, labels in test_loader: + input_dict = {} + input_dict["img"] = images.to(device) + input_dict["label"] = labels.to(device) + if label_split is not None: + input_dict["label_split"] = label_split.type(torch.int).to(device) + output = model(input_dict) + test_loss += output["loss"].item() + correct += ( + (output["score"].argmax(1) == input_dict["label"]) + .type(torch.float) + .sum() + .item() + ) + + test_loss /= num_batches + correct /= size + return test_loss, correct + + +def param_model_rate_mapping( + model_name, parameters, clients_model_rate, global_model_rate=1 +): + """Map the model rate to subset of global parameters(as list of indices). + + Parameters + ---------- + model_name : str + The name of the neural network of global model. + parameters : Dict + state_dict of the global model. + client_model_rate : List[float] + List of model rates of active clients. + global_model_rate: float + Model rate of the global model. + + Returns + ------- + Dict + model rate to parameters indices relative to global model mapping. + """ + unique_client_model_rate = list(set(clients_model_rate)) + print(unique_client_model_rate) + + if "conv" in model_name: + idx = _mr_to_param_idx_conv( + parameters, unique_client_model_rate, global_model_rate + ) + elif "resnet" in model_name: + idx = _mr_to_param_idx_resnet18( + parameters, unique_client_model_rate, global_model_rate + ) + else: + raise ValueError("Not valid model name") + + # add model rate as key to the params calculated + param_idx_model_rate_mapping = OrderedDict() + for i, _ in enumerate(unique_client_model_rate): + param_idx_model_rate_mapping[unique_client_model_rate[i]] = idx[i] + + return param_idx_model_rate_mapping + + +def _mr_to_param_idx_conv(parameters, unique_client_model_rate, global_model_rate): + idx_i = [None for _ in range(len(unique_client_model_rate))] + idx = [OrderedDict() for _ in range(len(unique_client_model_rate))] + output_weight_name = [k for k in parameters.keys() if "weight" in k][-1] + output_bias_name = [k for k in parameters.keys() if "bias" in k][-1] + for k, val in parameters.items(): + parameter_type = k.split(".")[-1] + for index, _ in enumerate(unique_client_model_rate): + if "weight" in parameter_type or "bias" in parameter_type: + scaler_rate = unique_client_model_rate[index] / global_model_rate + _get_key_k_idx_conv( + idx, + idx_i, + { + "index": index, + "parameter_type": parameter_type, + "k": k, + "val": val, + }, + output_names={ + "output_weight_name": output_weight_name, + "output_bias_name": output_bias_name, + }, + scaler_rate=scaler_rate, + ) + else: + pass + return idx + + +def _get_key_k_idx_conv( + idx, + idx_i, + param_info, + output_names, + scaler_rate, +): + if param_info["parameter_type"] == "weight": + if param_info["val"].dim() > 1: + input_size = param_info["val"].size(1) + output_size = param_info["val"].size(0) + if idx_i[param_info["index"]] is None: + idx_i[param_info["index"]] = torch.arange( + input_size, device=param_info["val"].device + ) + input_idx_i_m = idx_i[param_info["index"]] + if param_info["k"] == output_names["output_weight_name"]: + output_idx_i_m = torch.arange( + output_size, device=param_info["val"].device + ) + else: + local_output_size = int(np.ceil(output_size * (scaler_rate))) + output_idx_i_m = torch.arange( + output_size, device=param_info["val"].device + )[:local_output_size] + idx[param_info["index"]][param_info["k"]] = output_idx_i_m, input_idx_i_m + idx_i[param_info["index"]] = output_idx_i_m + else: + input_idx_i_m = idx_i[param_info["index"]] + idx[param_info["index"]][param_info["k"]] = input_idx_i_m + else: + if param_info["k"] == output_names["output_bias_name"]: + input_idx_i_m = idx_i[param_info["index"]] + idx[param_info["index"]][param_info["k"]] = input_idx_i_m + else: + input_idx_i_m = idx_i[param_info["index"]] + idx[param_info["index"]][param_info["k"]] = input_idx_i_m + + +def _mr_to_param_idx_resnet18(parameters, unique_client_model_rate, global_model_rate): + idx_i = [None for _ in range(len(unique_client_model_rate))] + idx = [OrderedDict() for _ in range(len(unique_client_model_rate))] + for k, val in parameters.items(): + parameter_type = k.split(".")[-1] + for index, _ in enumerate(unique_client_model_rate): + if "weight" in parameter_type or "bias" in parameter_type: + scaler_rate = unique_client_model_rate[index] / global_model_rate + _get_key_k_idx_resnet18( + idx, + idx_i, + { + "index": index, + "parameter_type": parameter_type, + "k": k, + "val": val, + }, + scaler_rate=scaler_rate, + ) + else: + pass + return idx + + +def _get_key_k_idx_resnet18( + idx, + idx_i, + param_info, + scaler_rate, +): + if param_info["parameter_type"] == "weight": + if param_info["val"].dim() > 1: + input_size = param_info["val"].size(1) + output_size = param_info["val"].size(0) + if "conv1" in param_info["k"] or "conv2" in param_info["k"]: + if idx_i[param_info["index"]] is None: + idx_i[param_info["index"]] = torch.arange( + input_size, device=param_info["val"].device + ) + input_idx_i_m = idx_i[param_info["index"]] + local_output_size = int(np.ceil(output_size * (scaler_rate))) + output_idx_i_m = torch.arange( + output_size, device=param_info["val"].device + )[:local_output_size] + idx_i[param_info["index"]] = output_idx_i_m + elif "shortcut" in param_info["k"]: + input_idx_i_m = idx[param_info["index"]][ + param_info["k"].replace("shortcut", "conv1") + ][1] + output_idx_i_m = idx_i[param_info["index"]] + elif "linear" in param_info["k"]: + input_idx_i_m = idx_i[param_info["index"]] + output_idx_i_m = torch.arange( + output_size, device=param_info["val"].device + ) + else: + raise ValueError("Not valid k") + idx[param_info["index"]][param_info["k"]] = (output_idx_i_m, input_idx_i_m) + else: + input_idx_i_m = idx_i[param_info["index"]] + idx[param_info["index"]][param_info["k"]] = input_idx_i_m + else: + input_size = param_info["val"].size(0) + if "linear" in param_info["k"]: + input_idx_i_m = torch.arange(input_size, device=param_info["val"].device) + idx[param_info["index"]][param_info["k"]] = input_idx_i_m + else: + input_idx_i_m = idx_i[param_info["index"]] + idx[param_info["index"]][param_info["k"]] = input_idx_i_m + + +def param_idx_to_local_params(global_parameters, client_param_idx): + """Get the local parameters from the list of param indices. + + Parameters + ---------- + global_parameters : Dict + The state_dict of global model. + client_param_idx : List + Local parameters indices with respect to global model. + + Returns + ------- + Dict + state dict of local model. + """ + local_parameters = OrderedDict() + for k, val in global_parameters.items(): + parameter_type = k.split(".")[-1] + if "weight" in parameter_type or "bias" in parameter_type: + if "weight" in parameter_type: + if val.dim() > 1: + local_parameters[k] = copy.deepcopy( + val[torch.meshgrid(client_param_idx[k])] + ) + else: + local_parameters[k] = copy.deepcopy(val[client_param_idx[k]]) + else: + local_parameters[k] = copy.deepcopy(val[client_param_idx[k]]) + else: + local_parameters[k] = copy.deepcopy(val) + return local_parameters + + +def get_state_dict_from_param(model, parameters): + """Get the state dict from model & parameters as np.NDarrays. + + Parameters + ---------- + model : nn.Module + The neural network. + parameters : np.NDarray + Parameters of the model as np.NDarrays. + + Returns + ------- + Dict + state dict of model. + """ + # Load the parameters into the model + for param_tensor, param_ndarray in zip( + model.state_dict(), parameters_to_ndarrays(parameters) + ): + model.state_dict()[param_tensor].copy_(torch.from_numpy(param_ndarray)) + # Step 3: Obtain the state_dict of the model + state_dict = model.state_dict() + return state_dict diff --git a/baselines/heterofl/heterofl/server.py b/baselines/heterofl/heterofl/server.py new file mode 100644 index 000000000000..f82db0a59fff --- /dev/null +++ b/baselines/heterofl/heterofl/server.py @@ -0,0 +1,101 @@ +"""Flower Server.""" +import time +from collections import OrderedDict +from typing import Callable, Dict, Optional, Tuple + +import torch +from flwr.common.typing import NDArrays, Scalar +from torch import nn + +from heterofl.models import test +from heterofl.utils import save_model + + +def gen_evaluate_fn( + data_loaders, + device: torch.device, + model: nn.Module, + keys, + enable_train_on_train_data: bool, +) -> Callable[ + [int, NDArrays, Dict[str, Scalar]], Optional[Tuple[float, Dict[str, Scalar]]] +]: + """Generate the function for centralized evaluation. + + Parameters + ---------- + data_loaders : + A dictionary containing dataloaders for testing and + label split of each client. + device : torch.device + The device to test the model on. + model : + Model for testing. + keys : + keys of the model that it is trained on. + + Returns + ------- + Callable[ [int, NDArrays, Dict[str, Scalar]], + Optional[Tuple[float, Dict[str, Scalar]]] ] + The centralized evaluation function. + """ + intermediate_keys = keys + + def evaluate( + server_round: int, parameters_ndarrays: NDArrays, config: Dict[str, Scalar] + ) -> Optional[Tuple[float, Dict[str, Scalar]]]: + # pylint: disable=unused-argument + """Use the entire test set for evaluation.""" + # if server_round % 5 != 0 and server_round < 395: + # return 1, {} + + net = model + params_dict = zip(intermediate_keys, parameters_ndarrays) + state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict}) + net.load_state_dict(state_dict, strict=False) + net.to(device) + + if server_round % 100 == 0: + save_model(net, f"model_after_round_{server_round}.pth") + + if enable_train_on_train_data is True: + print("start of testing") + start_time = time.time() + with torch.no_grad(): + net.train(True) + for images, labels in data_loaders["entire_trainloader"]: + input_dict = {} + input_dict["img"] = images.to(device) + input_dict["label"] = labels.to(device) + net(input_dict) + print(f"end of stat, time taken = {time.time() - start_time}") + + local_metrics = {} + local_metrics["loss"] = 0 + local_metrics["accuracy"] = 0 + for i, clnt_tstldr in enumerate(data_loaders["valloaders"]): + client_test_res = test( + net, + clnt_tstldr, + data_loaders["label_split"][i].type(torch.int), + device=device, + ) + local_metrics["loss"] += client_test_res[0] + local_metrics["accuracy"] += client_test_res[1] + + global_metrics = {} + global_metrics["loss"], global_metrics["accuracy"] = test( + net, data_loaders["testloader"], device=device + ) + + # return statistics + print(f"global accuracy = {global_metrics['accuracy']}") + print(f"local_accuracy = {local_metrics['accuracy']}") + return global_metrics["loss"], { + "global_accuracy": global_metrics["accuracy"], + "local_loss": local_metrics["loss"], + "local_accuracy": local_metrics["accuracy"], + } + + return evaluate diff --git a/baselines/heterofl/heterofl/strategy.py b/baselines/heterofl/heterofl/strategy.py new file mode 100644 index 000000000000..70dbd19594df --- /dev/null +++ b/baselines/heterofl/heterofl/strategy.py @@ -0,0 +1,467 @@ +"""Flower strategy for HeteroFL.""" +import copy +from collections import OrderedDict +from typing import Dict, List, Optional, Tuple, Union + +import flwr as fl +import torch +from flwr.common import ( + EvaluateIns, + EvaluateRes, + FitIns, + FitRes, + Parameters, + Scalar, + ndarrays_to_parameters, + parameters_to_ndarrays, +) +from flwr.server.client_manager import ClientManager +from flwr.server.client_proxy import ClientProxy +from torch import nn + +from heterofl.client_manager_heterofl import ClientManagerHeteroFL +from heterofl.models import ( + get_parameters, + get_state_dict_from_param, + param_idx_to_local_params, + param_model_rate_mapping, +) +from heterofl.utils import make_optimizer, make_scheduler + + +# pylint: disable=too-many-instance-attributes +class HeteroFL(fl.server.strategy.Strategy): + """HeteroFL strategy. + + Distribute subsets of a global model to clients according to their + + computational complexity and aggregate received models from clients. + """ + + # pylint: disable=too-many-arguments + def __init__( + self, + model_name: str, + net: nn.Module, + optim_scheduler_settings: Dict, + global_model_rate: float = 1.0, + evaluate_fn=None, + fraction_fit: float = 1.0, + fraction_evaluate: float = 1.0, + min_fit_clients: int = 2, + min_evaluate_clients: int = 2, + min_available_clients: int = 2, + ) -> None: + super().__init__() + self.fraction_fit = fraction_fit + self.fraction_evaluate = fraction_evaluate + self.min_fit_clients = min_fit_clients + self.min_evaluate_clients = min_evaluate_clients + self.min_available_clients = min_available_clients + self.evaluate_fn = evaluate_fn + # # created client_to_model_mapping + # self.client_to_model_rate_mapping: Dict[str, ClientProxy] = {} + + self.model_name = model_name + self.net = net + self.global_model_rate = global_model_rate + # info required for configure and aggregate + # to be filled in initialize + self.local_param_model_rate: OrderedDict = OrderedDict() + # to be filled in initialize + self.active_cl_labels: List[torch.tensor] = [] + # to be filled in configure + self.active_cl_mr: OrderedDict = OrderedDict() + # required for scheduling the lr + self.optimizer = make_optimizer( + optim_scheduler_settings["optimizer"], + self.net.parameters(), + learning_rate=optim_scheduler_settings["lr"], + momentum=optim_scheduler_settings["momentum"], + weight_decay=optim_scheduler_settings["weight_decay"], + ) + self.scheduler = make_scheduler( + optim_scheduler_settings["scheduler"], + self.optimizer, + milestones=optim_scheduler_settings["milestones"], + ) + + def __repr__(self) -> str: + """Return a string representation of the HeteroFL object.""" + return "HeteroFL" + + def initialize_parameters( + self, client_manager: ClientManager + ) -> Optional[Parameters]: + """Initialize global model parameters.""" + # self.make_client_to_model_rate_mapping(client_manager) + # net = conv(model_rate = 1) + if not isinstance(client_manager, ClientManagerHeteroFL): + raise ValueError( + "Not valid client manager, use ClientManagerHeterFL instead" + ) + clnt_mngr_heterofl: ClientManagerHeteroFL = client_manager + + ndarrays = get_parameters(self.net) + self.local_param_model_rate = param_model_rate_mapping( + self.model_name, + self.net.state_dict(), + clnt_mngr_heterofl.get_all_clients_to_model_mapping(), + self.global_model_rate, + ) + + if clnt_mngr_heterofl.client_label_split is not None: + self.active_cl_labels = clnt_mngr_heterofl.client_label_split.copy() + + return fl.common.ndarrays_to_parameters(ndarrays) + + def configure_fit( + self, + server_round: int, + parameters: Parameters, + client_manager: ClientManager, + ) -> List[Tuple[ClientProxy, FitIns]]: + """Configure the next round of training.""" + print(f"in configure fit , server round no. = {server_round}") + if not isinstance(client_manager, ClientManagerHeteroFL): + raise ValueError( + "Not valid client manager, use ClientManagerHeterFL instead" + ) + clnt_mngr_heterofl: ClientManagerHeteroFL = client_manager + # Sample clients + # no need to change this + clientts_selection_config = {} + ( + clientts_selection_config["sample_size"], + clientts_selection_config["min_num_clients"], + ) = self.num_fit_clients(clnt_mngr_heterofl.num_available()) + + # for sampling we pass the criterion to select the required clients + clients = clnt_mngr_heterofl.sample( + num_clients=clientts_selection_config["sample_size"], + min_num_clients=clientts_selection_config["min_num_clients"], + ) + + # update client model rate mapping + clnt_mngr_heterofl.update(server_round) + + global_parameters = get_state_dict_from_param(self.net, parameters) + + self.active_cl_mr = OrderedDict() + + # Create custom configs + fit_configurations = [] + learning_rate = self.optimizer.param_groups[0]["lr"] + print(f"lr = {learning_rate}") + for client in clients: + model_rate = clnt_mngr_heterofl.get_client_to_model_mapping(client.cid) + client_param_idx = self.local_param_model_rate[model_rate] + local_param = param_idx_to_local_params( + global_parameters=global_parameters, client_param_idx=client_param_idx + ) + self.active_cl_mr[client.cid] = model_rate + # local param are in the form of state_dict, + # so converting them only to values of tensors + local_param_fitres = [val.cpu() for val in local_param.values()] + fit_configurations.append( + ( + client, + FitIns( + ndarrays_to_parameters(local_param_fitres), + {"lr": learning_rate}, + ), + ) + ) + + self.scheduler.step() + return fit_configurations + + def aggregate_fit( + self, + server_round: int, + results: List[Tuple[ClientProxy, FitRes]], + failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], + ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + """Aggregate fit results using weighted average. + + Adopted from authors implementation. + """ + print("in aggregate fit") + gl_model = self.net.state_dict() + + param_idx = [] + for res in results: + param_idx.append( + copy.deepcopy( + self.local_param_model_rate[self.active_cl_mr[res[0].cid]] + ) + ) + + local_param_as_parameters = [fit_res.parameters for _, fit_res in results] + local_parameters_as_ndarrays = [ + parameters_to_ndarrays(local_param_as_parameters[i]) + for i in range(len(local_param_as_parameters)) + ] + local_parameters: List[OrderedDict] = [ + OrderedDict() for _ in range(len(local_param_as_parameters)) + ] + for i in range(len(results)): + j = 0 + for k, _ in gl_model.items(): + local_parameters[i][k] = local_parameters_as_ndarrays[i][j] + j += 1 + + if "conv" in self.model_name: + self._aggregate_conv(param_idx, local_parameters, results) + + elif "resnet" in self.model_name: + self._aggregate_resnet18(param_idx, local_parameters, results) + else: + raise ValueError("Not valid model name") + + return ndarrays_to_parameters([v for k, v in gl_model.items()]), {} + + def _aggregate_conv(self, param_idx, local_parameters, results): + gl_model = self.net.state_dict() + count = OrderedDict() + output_bias_name = [k for k in gl_model.keys() if "bias" in k][-1] + output_weight_name = [k for k in gl_model.keys() if "weight" in k][-1] + for k, val in gl_model.items(): + parameter_type = k.split(".")[-1] + count[k] = val.new_zeros(val.size(), dtype=torch.float32) + tmp_v = val.new_zeros(val.size(), dtype=torch.float32) + for clnt, _ in enumerate(local_parameters): + if "weight" in parameter_type or "bias" in parameter_type: + self._agg_layer_conv( + { + "cid": int(results[clnt][0].cid), + "param_idx": param_idx, + "local_parameters": local_parameters, + }, + { + "tmp_v": tmp_v, + "count": count, + }, + { + "clnt": clnt, + "parameter_type": parameter_type, + "k": k, + "val": val, + }, + { + "output_weight_name": output_weight_name, + "output_bias_name": output_bias_name, + }, + ) + else: + tmp_v += local_parameters[clnt][k] + count[k] += 1 + tmp_v[count[k] > 0] = tmp_v[count[k] > 0].div_(count[k][count[k] > 0]) + val[count[k] > 0] = tmp_v[count[k] > 0].to(val.dtype) + + def _agg_layer_conv( + self, + clnt_params, + tmp_v_count, + param_info, + output_names, + ): + # pi = param_info + param_idx = clnt_params["param_idx"] + clnt = param_info["clnt"] + k = param_info["k"] + tmp_v = tmp_v_count["tmp_v"] + count = tmp_v_count["count"] + + if param_info["parameter_type"] == "weight": + if param_info["val"].dim() > 1: + if k == output_names["output_weight_name"]: + label_split = self.active_cl_labels[clnt_params["cid"]] + label_split = label_split.type(torch.int) + param_idx[clnt][k] = list(param_idx[clnt][k]) + param_idx[clnt][k][0] = param_idx[clnt][k][0][label_split] + tmp_v[torch.meshgrid(param_idx[clnt][k])] += clnt_params[ + "local_parameters" + ][clnt][k][label_split] + count[k][torch.meshgrid(param_idx[clnt][k])] += 1 + else: + tmp_v[torch.meshgrid(param_idx[clnt][k])] += clnt_params[ + "local_parameters" + ][clnt][k] + count[k][torch.meshgrid(param_idx[clnt][k])] += 1 + else: + tmp_v[param_idx[clnt][k]] += clnt_params["local_parameters"][clnt][k] + count[k][param_idx[clnt][k]] += 1 + else: + if k == output_names["output_bias_name"]: + label_split = self.active_cl_labels[clnt_params["cid"]] + label_split = label_split.type(torch.int) + param_idx[clnt][k] = param_idx[clnt][k][label_split] + tmp_v[param_idx[clnt][k]] += clnt_params["local_parameters"][clnt][k][ + label_split + ] + count[k][param_idx[clnt][k]] += 1 + else: + tmp_v[param_idx[clnt][k]] += clnt_params["local_parameters"][clnt][k] + count[k][param_idx[clnt][k]] += 1 + + def _aggregate_resnet18(self, param_idx, local_parameters, results): + gl_model = self.net.state_dict() + count = OrderedDict() + for k, val in gl_model.items(): + parameter_type = k.split(".")[-1] + count[k] = val.new_zeros(val.size(), dtype=torch.float32) + tmp_v = val.new_zeros(val.size(), dtype=torch.float32) + for clnt, _ in enumerate(local_parameters): + if "weight" in parameter_type or "bias" in parameter_type: + self._agg_layer_resnet18( + { + "cid": int(results[clnt][0].cid), + "param_idx": param_idx, + "local_parameters": local_parameters, + }, + tmp_v, + count, + { + "clnt": clnt, + "parameter_type": parameter_type, + "k": k, + "val": val, + }, + ) + else: + tmp_v += local_parameters[clnt][k] + count[k] += 1 + tmp_v[count[k] > 0] = tmp_v[count[k] > 0].div_(count[k][count[k] > 0]) + val[count[k] > 0] = tmp_v[count[k] > 0].to(val.dtype) + + def _agg_layer_resnet18(self, clnt_params, tmp_v, count, param_info): + param_idx = clnt_params["param_idx"] + k = param_info["k"] + clnt = param_info["clnt"] + + if param_info["parameter_type"] == "weight": + if param_info["val"].dim() > 1: + if "linear" in k: + label_split = self.active_cl_labels[clnt_params["cid"]] + label_split = label_split.type(torch.int) + param_idx[clnt][k] = list(param_idx[clnt][k]) + param_idx[clnt][k][0] = param_idx[clnt][k][0][label_split] + tmp_v[torch.meshgrid(param_idx[clnt][k])] += clnt_params[ + "local_parameters" + ][clnt][k][label_split] + count[k][torch.meshgrid(param_idx[clnt][k])] += 1 + else: + tmp_v[torch.meshgrid(param_idx[clnt][k])] += clnt_params[ + "local_parameters" + ][clnt][k] + count[k][torch.meshgrid(param_idx[clnt][k])] += 1 + else: + tmp_v[param_idx[clnt][k]] += clnt_params["local_parameters"][clnt][k] + count[k][param_idx[clnt][k]] += 1 + else: + if "linear" in k: + label_split = self.active_cl_labels[clnt_params["cid"]] + label_split = label_split.type(torch.int) + param_idx[clnt][k] = param_idx[clnt][k][label_split] + tmp_v[param_idx[clnt][k]] += clnt_params["local_parameters"][clnt][k][ + label_split + ] + count[k][param_idx[clnt][k]] += 1 + else: + tmp_v[param_idx[clnt][k]] += clnt_params["local_parameters"][clnt][k] + count[k][param_idx[clnt][k]] += 1 + + def configure_evaluate( + self, server_round: int, parameters: Parameters, client_manager: ClientManager + ) -> List[Tuple[ClientProxy, EvaluateIns]]: + """Configure the next round of evaluation.""" + # if self.fraction_evaluate == 0.0: + # return [] + # config = {} + # evaluate_ins = EvaluateIns(parameters, config) + + # # Sample clients + # sample_size, min_num_clients = self.num_evaluation_clients( + # client_manager.num_available() + # ) + # clients = client_manager.sample( + # num_clients=sample_size, min_num_clients=min_num_clients + # ) + + # global_parameters = get_state_dict_from_param(self.net, parameters) + + # self.active_cl_mr = OrderedDict() + + # # Create custom configs + # evaluate_configurations = [] + # for idx, client in enumerate(clients): + # model_rate = client_manager.get_client_to_model_mapping(client.cid) + # client_param_idx = self.local_param_model_rate[model_rate] + # local_param = + # param_idx_to_local_params(global_parameters, client_param_idx) + # self.active_cl_mr[client.cid] = model_rate + # # local param are in the form of state_dict, + # # so converting them only to values of tensors + # local_param_fitres = [v.cpu() for v in local_param.values()] + # evaluate_configurations.append( + # (client, EvaluateIns(ndarrays_to_parameters(local_param_fitres), {})) + # ) + # return evaluate_configurations + + return [] + + # return self.configure_fit(server_round , parameters , client_manager) + + def aggregate_evaluate( + self, + server_round: int, + results: List[Tuple[ClientProxy, EvaluateRes]], + failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]], + ) -> Tuple[Optional[float], Dict[str, Scalar]]: + """Aggregate evaluation losses using weighted average.""" + # if not results: + # return None, {} + + # loss_aggregated = weighted_loss_avg( + # [ + # (evaluate_res.num_examples, evaluate_res.loss) + # for _, evaluate_res in results + # ] + # ) + + # accuracy_aggregated = 0 + # for cp, y in results: + # print(f"{cp.cid}-->{y.metrics['accuracy']}", end=" ") + # accuracy_aggregated += y.metrics["accuracy"] + # accuracy_aggregated /= len(results) + + # metrics_aggregated = {"accuracy": accuracy_aggregated} + # print(f"\npaneer lababdar {metrics_aggregated}") + # return loss_aggregated, metrics_aggregated + + return None, {} + + def evaluate( + self, server_round: int, parameters: Parameters + ) -> Optional[Tuple[float, Dict[str, Scalar]]]: + """Evaluate model parameters using an evaluation function.""" + if self.evaluate_fn is None: + # No evaluation function provided + return None + parameters_ndarrays = parameters_to_ndarrays(parameters) + eval_res = self.evaluate_fn(server_round, parameters_ndarrays, {}) + if eval_res is None: + return None + loss, metrics = eval_res + return loss, metrics + + def num_fit_clients(self, num_available_clients: int) -> Tuple[int, int]: + """Return sample size and required number of clients.""" + num_clients = int(num_available_clients * self.fraction_fit) + return max(num_clients, self.min_fit_clients), self.min_available_clients + + def num_evaluation_clients(self, num_available_clients: int) -> Tuple[int, int]: + """Use a fraction of available clients for evaluation.""" + num_clients = int(num_available_clients * self.fraction_evaluate) + return max(num_clients, self.min_evaluate_clients), self.min_available_clients diff --git a/baselines/heterofl/heterofl/utils.py b/baselines/heterofl/heterofl/utils.py new file mode 100644 index 000000000000..3bcb7f3d8ea7 --- /dev/null +++ b/baselines/heterofl/heterofl/utils.py @@ -0,0 +1,218 @@ +"""Contains utility functions.""" +import errno +import os +from pathlib import Path + +import numpy as np +import torch +from hydra.core.hydra_config import HydraConfig + + +def preprocess_input(cfg_model, cfg_data): + """Preprocess the input to get input shape, other derivables. + + Parameters + ---------- + cfg_model : DictConfig + Retrieve model-related information from the base.yaml configuration in Hydra. + cfg_data : DictConfig + Retrieve data-related information required to construct the model. + + Returns + ------- + Dict + Dictionary contained derived information from config. + """ + model_config = {} + # if cfg_model.model_name == "conv": + # model_config["model_name"] = + # elif for others... + model_config["model"] = cfg_model.model_name + if cfg_data.dataset_name == "MNIST": + model_config["data_shape"] = [1, 28, 28] + model_config["classes_size"] = 10 + elif cfg_data.dataset_name == "CIFAR10": + model_config["data_shape"] = [3, 32, 32] + model_config["classes_size"] = 10 + + if "hidden_layers" in cfg_model: + model_config["hidden_layers"] = cfg_model.hidden_layers + if "norm" in cfg_model: + model_config["norm"] = cfg_model.norm + if "scale" in cfg_model: + model_config["scale"] = cfg_model.scale + if "mask" in cfg_model: + model_config["mask"] = cfg_model.mask + + return model_config + + +def make_optimizer(optimizer_name, parameters, learning_rate, weight_decay, momentum): + """Make the optimizer with given config. + + Parameters + ---------- + optimizer_name : str + Name of the optimizer. + parameters : Dict + Parameters of the model. + learning_rate: float + Learning rate of the optimizer. + weight_decay: float + weight_decay of the optimizer. + + Returns + ------- + torch.optim.Optimizer + Optimizer. + """ + optimizer = None + if optimizer_name == "SGD": + optimizer = torch.optim.SGD( + parameters, lr=learning_rate, momentum=momentum, weight_decay=weight_decay + ) + return optimizer + + +def make_scheduler(scheduler_name, optimizer, milestones): + """Make the scheduler with given config. + + Parameters + ---------- + scheduler_name : str + Name of the scheduler. + optimizer : torch.optim.Optimizer + Parameters of the model. + milestones: List[int] + List of epoch indices. Must be increasing. + + Returns + ------- + torch.optim.lr_scheduler.Scheduler + scheduler. + """ + scheduler = None + if scheduler_name == "MultiStepLR": + scheduler = torch.optim.lr_scheduler.MultiStepLR( + optimizer, milestones=milestones + ) + return scheduler + + +def get_global_model_rate(model_mode): + """Give the global model rate from string(cfg.control.model_mode) . + + Parameters + ---------- + model_mode : str + Contains the division of computational complexties among clients. + + Returns + ------- + str + global model computational complexity. + """ + model_mode = "" + model_mode + model_mode = model_mode.split("-")[0][0] + return model_mode + + +class ModelRateManager: + """Control the model rate of clients in case of simulation.""" + + def __init__(self, model_split_mode, model_split_rate, model_mode): + self.model_split_mode = model_split_mode + self.model_split_rate = model_split_rate + self.model_mode = model_mode + self.model_mode = self.model_mode.split("-") + + def create_model_rate_mapping(self, num_users): + """Change the client to model rate mapping accordingly.""" + client_model_rate = [] + + if self.model_split_mode == "fix": + mode_rate, proportion = [], [] + for comp_level_prop in self.model_mode: + mode_rate.append(self.model_split_rate[comp_level_prop[0]]) + proportion.append(int(comp_level_prop[1:])) + num_users_proportion = num_users // sum(proportion) + for i, comp_level in enumerate(mode_rate): + client_model_rate += np.repeat( + comp_level, num_users_proportion * proportion[i] + ).tolist() + client_model_rate = client_model_rate + [ + client_model_rate[-1] for _ in range(num_users - len(client_model_rate)) + ] + # return client_model_rate + + elif self.model_split_mode == "dynamic": + mode_rate, proportion = [], [] + + for comp_level_prop in self.model_mode: + mode_rate.append(self.model_split_rate[comp_level_prop[0]]) + proportion.append(int(comp_level_prop[1:])) + + proportion = (np.array(proportion) / sum(proportion)).tolist() + + rate_idx = torch.multinomial( + torch.tensor(proportion), num_samples=num_users, replacement=True + ).tolist() + client_model_rate = np.array(mode_rate)[rate_idx] + + # return client_model_rate + + else: + raise ValueError("Not valid model split mode") + + return client_model_rate + + +def save_model(model, path): + """To save the model in the given path.""" + # print('in save model') + current_path = HydraConfig.get().runtime.output_dir + model_save_path = Path(current_path) / path + torch.save(model.state_dict(), model_save_path) + + +# """ The following functions(check_exists, makedir_exit_ok, save, load) +# are adopted from authors (of heterofl) implementation.""" + + +def check_exists(path): + """Check if the given path exists.""" + return os.path.exists(path) + + +def makedir_exist_ok(path): + """Create a directory.""" + try: + os.makedirs(path) + except OSError as os_err: + if os_err.errno == errno.EEXIST: + pass + else: + raise + + +def save(inp, path, protocol=2, mode="torch"): + """Save the inp in a given path.""" + dirname = os.path.dirname(path) + makedir_exist_ok(dirname) + if mode == "torch": + torch.save(inp, path, pickle_protocol=protocol) + elif mode == "numpy": + np.save(path, inp, allow_pickle=True) + else: + raise ValueError("Not valid save mode") + + +# pylint: disable=no-else-return +def load(path, mode="torch"): + """Load the file from given path.""" + if mode == "torch": + return torch.load(path, map_location=lambda storage, loc: storage) + elif mode == "numpy": + return np.load(path, allow_pickle=True) + else: + raise ValueError("Not valid save mode") diff --git a/baselines/heterofl/pyproject.toml b/baselines/heterofl/pyproject.toml new file mode 100644 index 000000000000..0f72edf20345 --- /dev/null +++ b/baselines/heterofl/pyproject.toml @@ -0,0 +1,145 @@ +[build-system] +requires = ["poetry-core>=1.4.0"] +build-backend = "poetry.masonry.api" + +[tool.poetry] +name = "heterofl" # <----- Ensure it matches the name of your baseline directory containing all the source code +version = "1.0.0" +description = "HeteroFL : Computation And Communication Efficient Federated Learning For Heterogeneous Clients" +license = "Apache-2.0" +authors = ["M S Chaitanya Kumar ", "The Flower Authors "] +readme = "README.md" +homepage = "https://flower.dev" +repository = "https://github.com/adap/flower" +documentation = "https://flower.dev" +classifiers = [ + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: Apache Software License", + "Operating System :: MacOS :: MacOS X", + "Operating System :: POSIX :: Linux", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: Implementation :: CPython", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Scientific/Engineering :: Mathematics", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + "Typing :: Typed", +] + +[tool.poetry.dependencies] +python = ">=3.10.0, <3.11.0" +flwr = { extras = ["simulation"], version = "1.5.0" } +hydra-core = "1.3.2" # don't change this +torch = { url = "https://download.pytorch.org/whl/cu118/torch-2.1.0%2Bcu118-cp310-cp310-linux_x86_64.whl"} +torchvision = { url = "https://download.pytorch.org/whl/cu118/torchvision-0.16.0%2Bcu118-cp310-cp310-linux_x86_64.whl"} +anytree = "^2.12.1" +types-six = "^1.16.21.9" +tqdm = "4.66.1" + +[tool.poetry.dev-dependencies] +isort = "==5.11.5" +black = "==23.1.0" +docformatter = "==1.5.1" +mypy = "==1.4.1" +pylint = "==2.8.2" +flake8 = "==3.9.2" +pytest = "==6.2.4" +pytest-watch = "==4.2.0" +ruff = "==0.0.272" +types-requests = "==2.27.7" +virtualenv = "20.21.0" + +[tool.isort] +line_length = 88 +indent = " " +multi_line_output = 3 +include_trailing_comma = true +force_grid_wrap = 0 +use_parentheses = true + +[tool.black] +line-length = 88 +target-version = ["py38", "py39", "py310", "py311"] + +[tool.pytest.ini_options] +minversion = "6.2" +addopts = "-qq" +testpaths = [ + "flwr_baselines", +] + +[tool.mypy] +ignore_missing_imports = true +strict = false +plugins = "numpy.typing.mypy_plugin" + +[tool.pylint."MESSAGES CONTROL"] +disable = "bad-continuation,duplicate-code,too-few-public-methods,useless-import-alias" +good-names = "i,j,k,_,x,y,X,Y" +signature-mutators="hydra.main.main" + + +[tool.pylint.typecheck] +generated-members="numpy.*, torch.*, tensorflow.*" + + +[[tool.mypy.overrides]] +module = [ + "importlib.metadata.*", + "importlib_metadata.*", +] +follow_imports = "skip" +follow_imports_for_stubs = true +disallow_untyped_calls = false + +[[tool.mypy.overrides]] +module = "torch.*" +follow_imports = "skip" +follow_imports_for_stubs = true + +[tool.docformatter] +wrap-summaries = 88 +wrap-descriptions = 88 + +[tool.ruff] +target-version = "py38" +line-length = 88 +select = ["D", "E", "F", "W", "B", "ISC", "C4"] +fixable = ["D", "E", "F", "W", "B", "ISC", "C4"] +ignore = ["B024", "B027"] +exclude = [ + ".bzr", + ".direnv", + ".eggs", + ".git", + ".hg", + ".mypy_cache", + ".nox", + ".pants.d", + ".pytype", + ".ruff_cache", + ".svn", + ".tox", + ".venv", + "__pypackages__", + "_build", + "buck-out", + "build", + "dist", + "node_modules", + "venv", + "proto", +] + +[tool.ruff.pydocstyle] +convention = "numpy" diff --git a/baselines/hfedxgboost/README.md b/baselines/hfedxgboost/README.md index 29702496370b..2f31e2c4c584 100644 --- a/baselines/hfedxgboost/README.md +++ b/baselines/hfedxgboost/README.md @@ -11,7 +11,7 @@ dataset: [a9a, cod-rna, ijcnn1, space_ga, cpusmall, YearPredictionMSD] **Paper:** [arxiv.org/abs/2304.07537](https://arxiv.org/abs/2304.07537) -**Authors:** Chenyang Ma, Xinchi Qiu, Daniel J. Beutel, Nicholas D. Laneearly_stop_patience_rounds: 100 +**Authors:** Chenyang Ma, Xinchi Qiu, Daniel J. Beutel, Nicholas D. Lane **Abstract:** The privacy-sensitive nature of decentralized datasets and the robustness of eXtreme Gradient Boosting (XGBoost) on tabular data raise the need to train XGBoost in the context of federated learning (FL). Existing works on federated XGBoost in the horizontal setting rely on the sharing of gradients, which induce per-node level communication frequency and serious privacy concerns. To alleviate these problems, we develop an innovative framework for horizontal federated XGBoost which does not depend on the sharing of gradients and simultaneously boosts privacy and communication efficiency by making the learning rates of the aggregated tree ensembles are learnable. We conduct extensive evaluations on various classification and regression datasets, showing our approach achieve performance comparable to the state-of-the-art method and effectively improves communication efficiency by lowering both communication rounds and communication overhead by factors ranging from 25x to 700x. diff --git a/datasets/e2e/tensorflow/pyproject.toml b/datasets/e2e/tensorflow/pyproject.toml index 9c5c72c46400..4d7b5f60e856 100644 --- a/datasets/e2e/tensorflow/pyproject.toml +++ b/datasets/e2e/tensorflow/pyproject.toml @@ -9,7 +9,7 @@ description = "Flower Datasets with TensorFlow" authors = ["The Flower Authors "] [tool.poetry.dependencies] -python = "^3.8" +python = ">=3.8,<3.11" flwr-datasets = { path = "./../../", extras = ["vision"] } tensorflow-cpu = "^2.9.1, !=2.11.1" parameterized = "==0.9.0" diff --git a/dev/publish.sh b/dev/publish.sh deleted file mode 100755 index fb4df1694530..000000000000 --- a/dev/publish.sh +++ /dev/null @@ -1,21 +0,0 @@ -#!/bin/bash - -# Copyright 2022 Flower Labs GmbH. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -set -e -cd "$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"/../ - -python -m poetry publish diff --git a/doc/source/how-to-install-flower.rst b/doc/source/how-to-install-flower.rst index 1107f6798b23..ff3dbb605846 100644 --- a/doc/source/how-to-install-flower.rst +++ b/doc/source/how-to-install-flower.rst @@ -11,6 +11,9 @@ Flower requires at least `Python 3.8 `_, but `Pyth Install stable release ---------------------- +Using pip +~~~~~~~~~ + Stable releases are available on `PyPI `_:: python -m pip install flwr @@ -20,6 +23,25 @@ For simulations that use the Virtual Client Engine, ``flwr`` should be installed python -m pip install flwr[simulation] +Using conda (or mamba) +~~~~~~~~~~~~~~~~~~~~~~ + +Flower can also be installed from the ``conda-forge`` channel. + +If you have not added ``conda-forge`` to your channels, you will first need to run the following:: + + conda config --add channels conda-forge + conda config --set channel_priority strict + +Once the ``conda-forge`` channel has been enabled, ``flwr`` can be installed with ``conda``:: + + conda install flwr + +or with ``mamba``:: + + mamba install flwr + + Verify installation ------------------- diff --git a/doc/source/ref-changelog.md b/doc/source/ref-changelog.md index 7168386eaf0a..5f323bc80baa 100644 --- a/doc/source/ref-changelog.md +++ b/doc/source/ref-changelog.md @@ -6,6 +6,8 @@ - **General updates to Flower Examples** ([#2381](https://github.com/adap/flower/pull/2381)) +- **Retiring MXNet examples** The development of the MXNet fremework has ended and the project is now [archived on GitHub](https://github.com/apache/mxnet). Existing MXNet examples won't receive updates [#2724](https://github.com/adap/flower/pull/2724) + - **Update Flower Baselines** - HFedXGBoost [#2226](https://github.com/adap/flower/pull/2226) @@ -14,6 +16,10 @@ - FedNova [#2179](https://github.com/adap/flower/pull/2179) + - HeteroFL [#2439](https://github.com/adap/flower/pull/2439) + + - FedAvgM [#2246](https://github.com/adap/flower/pull/2246) + ## v1.6.0 (2023-11-28) ### Thanks to our contributors diff --git a/doc/source/tutorial-quickstart-mxnet.rst b/doc/source/tutorial-quickstart-mxnet.rst index 149d060e4c00..ff8d4b2087dd 100644 --- a/doc/source/tutorial-quickstart-mxnet.rst +++ b/doc/source/tutorial-quickstart-mxnet.rst @@ -4,6 +4,8 @@ Quickstart MXNet ================ +.. warning:: MXNet is no longer maintained and has been moved into `Attic `_. As a result, we would encourage you to use other ML frameworks alongise Flower, for example, PyTorch. This tutorial might be removed in future versions of Flower. + .. meta:: :description: Check out this Federated Learning quickstart tutorial for using Flower with MXNet to train a Sequential model on MNIST. diff --git a/doc/source/tutorial-series-get-started-with-flower-pytorch.ipynb b/doc/source/tutorial-series-get-started-with-flower-pytorch.ipynb index ce4c2bb63606..bbd916b32375 100644 --- a/doc/source/tutorial-series-get-started-with-flower-pytorch.ipynb +++ b/doc/source/tutorial-series-get-started-with-flower-pytorch.ipynb @@ -484,7 +484,7 @@ " min_available_clients=10, # Wait until all 10 clients are available\n", ")\n", "\n", - "# Specify the resources each of your clients need. By default, each \n", + "# Specify the resources each of your clients need. By default, each\n", "# client will be allocated 1x CPU and 0x CPUs\n", "client_resources = {\"num_cpus\": 1, \"num_gpus\": 0.0}\n", "if DEVICE.type == \"cuda\":\n", diff --git a/e2e/test_driver.sh b/e2e/test_driver.sh index ca54dbf4852f..32314bd22533 100755 --- a/e2e/test_driver.sh +++ b/e2e/test_driver.sh @@ -16,10 +16,10 @@ esac timeout 2m flower-server $server_arg & sleep 3 -timeout 2m flower-client $client_arg --callable client:flower --server 127.0.0.1:9092 & +timeout 2m flower-client client:flower $client_arg --server 127.0.0.1:9092 & sleep 3 -timeout 2m flower-client $client_arg --callable client:flower --server 127.0.0.1:9092 & +timeout 2m flower-client client:flower $client_arg --server 127.0.0.1:9092 & sleep 3 timeout 2m python driver.py & diff --git a/examples/android/README.md b/examples/android/README.md index 7931aa96b0c5..f9f2bb93b8dc 100644 --- a/examples/android/README.md +++ b/examples/android/README.md @@ -54,4 +54,4 @@ poetry run ./run.sh Download and install the `flwr_android_client.apk` on each Android device/emulator. The server currently expects a minimum of 4 Android clients, but it can be changed in the `server.py`. -When the Android app runs, add the client ID (between 1-10), the IP and port of your server, and press `Load Dataset`. This will load the local CIFAR10 dataset in memory. Then press `Setup Connection Channel` which will establish connection with the server. Finally, press `Train Federated!` which will start the federated training. +When the Android app runs, add the client ID (between 1-10), the IP and port of your server, and press `Start`. This will load the local CIFAR10 dataset in memory, establish connection with the server, and start the federated training. To abort the federated learning process, press `Stop`. You can clear and refresh the log messages by pressing `Clear` and `Refresh` buttons respectively. diff --git a/examples/flower-simulation-step-by-step-pytorch/Part-I/main.py b/examples/flower-simulation-step-by-step-pytorch/Part-I/main.py index f5c76ab6dc99..f8124b9353f7 100644 --- a/examples/flower-simulation-step-by-step-pytorch/Part-I/main.py +++ b/examples/flower-simulation-step-by-step-pytorch/Part-I/main.py @@ -24,7 +24,7 @@ def main(cfg: DictConfig): save_path = HydraConfig.get().runtime.output_dir ## 2. Prepare your dataset - # When simulating FL workloads we have a lot of freedom on how the FL clients behave, + # When simulating FL runs we have a lot of freedom on how the FL clients behave, # what data they have, how much data, etc. This is not possible in real FL settings. # In simulation you'd often encounter two types of dataset: # * naturally partitioned, that come pre-partitioned by user id (e.g. FEMNIST, @@ -91,7 +91,7 @@ def main(cfg: DictConfig): "num_gpus": 0.0, }, # (optional) controls the degree of parallelism of your simulation. # Lower resources per client allow for more clients to run concurrently - # (but need to be set taking into account the compute/memory footprint of your workload) + # (but need to be set taking into account the compute/memory footprint of your run) # `num_cpus` is an absolute number (integer) indicating the number of threads a client should be allocated # `num_gpus` is a ratio indicating the portion of gpu memory that a client needs. ) diff --git a/examples/mt-pytorch-callable/README.md b/examples/mt-pytorch-callable/README.md index 65ef000c26f2..120e28098344 100644 --- a/examples/mt-pytorch-callable/README.md +++ b/examples/mt-pytorch-callable/README.md @@ -33,13 +33,13 @@ flower-server --insecure In a new terminal window, start the first long-running Flower client: ```bash -flower-client --callable client:flower +flower-client --insecure client:flower ``` In yet another new terminal window, start the second long-running Flower client: ```bash -flower-client --callable client:flower +flower-client --insecure client:flower ``` ## Start the Driver script diff --git a/examples/mt-pytorch-callable/client.py b/examples/mt-pytorch-callable/client.py index 6f9747784ae0..4195a714ca89 100644 --- a/examples/mt-pytorch-callable/client.py +++ b/examples/mt-pytorch-callable/client.py @@ -108,7 +108,7 @@ def client_fn(cid: str): return FlowerClient().to_client() -# To run this: `flower-client --callable client:flower` +# To run this: `flower-client client:flower` flower = fl.flower.Flower( client_fn=client_fn, ) diff --git a/examples/mt-pytorch/driver.py b/examples/mt-pytorch/driver.py index fed760f021af..ad4d5e1caabe 100644 --- a/examples/mt-pytorch/driver.py +++ b/examples/mt-pytorch/driver.py @@ -54,13 +54,13 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: # -------------------------------------------------------------------------- Driver SDK driver.connect() -create_workload_res: driver_pb2.CreateWorkloadResponse = driver.create_workload( - req=driver_pb2.CreateWorkloadRequest() +create_run_res: driver_pb2.CreateRunResponse = driver.create_run( + req=driver_pb2.CreateRunRequest() ) # -------------------------------------------------------------------------- Driver SDK -workload_id = create_workload_res.workload_id -print(f"Created workload id {workload_id}") +run_id = create_run_res.run_id +print(f"Created run id {run_id}") history = History() for server_round in range(num_rounds): @@ -93,7 +93,7 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: # loop and wait until enough client nodes are available. while True: # Get a list of node ID's from the server - get_nodes_req = driver_pb2.GetNodesRequest(workload_id=workload_id) + get_nodes_req = driver_pb2.GetNodesRequest(run_id=run_id) # ---------------------------------------------------------------------- Driver SDK get_nodes_res: driver_pb2.GetNodesResponse = driver.get_nodes( @@ -125,7 +125,7 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: new_task_ins = task_pb2.TaskIns( task_id="", # Do not set, will be created and set by the DriverAPI group_id="", - workload_id=workload_id, + run_id=run_id, task=task_pb2.Task( producer=node_pb2.Node( node_id=0, diff --git a/examples/mxnet-from-centralized-to-federated/README.md b/examples/mxnet-from-centralized-to-federated/README.md index 839d3b16a1cf..2c3f240d8978 100644 --- a/examples/mxnet-from-centralized-to-federated/README.md +++ b/examples/mxnet-from-centralized-to-federated/README.md @@ -1,5 +1,7 @@ # MXNet: From Centralized To Federated +> Note the MXNet project has ended, and is now in [Attic](https://attic.apache.org/projects/mxnet.html). The MXNet GitHub has also [been archived](https://github.com/apache/mxnet). As a result, this example won't be receiving more updates. Using MXNet is no longer recommnended. + This example demonstrates how an already existing centralized MXNet-based machine learning project can be federated with Flower. This introductory example for Flower uses MXNet, but you're not required to be a MXNet expert to run the example. The example will help you to understand how Flower can be used to build federated learning use cases based on an existing MXNet project. diff --git a/examples/mxnet-from-centralized-to-federated/pyproject.toml b/examples/mxnet-from-centralized-to-federated/pyproject.toml index a0d31f76ebdd..952683eb90f6 100644 --- a/examples/mxnet-from-centralized-to-federated/pyproject.toml +++ b/examples/mxnet-from-centralized-to-federated/pyproject.toml @@ -10,7 +10,6 @@ authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" -flwr = ">=1.0,<2.0" -# flwr = { path = "../../", develop = true } # Development -mxnet = "1.6.0" +flwr = "1.6.0" +mxnet = "1.9.1" numpy = "1.23.1" diff --git a/examples/mxnet-from-centralized-to-federated/requirements.txt b/examples/mxnet-from-centralized-to-federated/requirements.txt index 73060e27c70c..8dd6f7150dfd 100644 --- a/examples/mxnet-from-centralized-to-federated/requirements.txt +++ b/examples/mxnet-from-centralized-to-federated/requirements.txt @@ -1,3 +1,3 @@ -flwr>=1.0,<2.0 -mxnet==1.6.0 +flwr==1.6.0 +mxnet==1.9.1 numpy==1.23.1 diff --git a/examples/pytorch-from-centralized-to-federated/cifar.py b/examples/pytorch-from-centralized-to-federated/cifar.py index a374909c33b2..e8f3ec3fd724 100644 --- a/examples/pytorch-from-centralized-to-federated/cifar.py +++ b/examples/pytorch-from-centralized-to-federated/cifar.py @@ -73,10 +73,10 @@ def apply_transforms(batch): def train( - net: Net, - trainloader: torch.utils.data.DataLoader, - epochs: int, - device: torch.device, # pylint: disable=no-member + net: Net, + trainloader: torch.utils.data.DataLoader, + epochs: int, + device: torch.device, # pylint: disable=no-member ) -> None: """Train the network.""" # Define loss and optimizer @@ -110,9 +110,9 @@ def train( def test( - net: Net, - testloader: torch.utils.data.DataLoader, - device: torch.device, # pylint: disable=no-member + net: Net, + testloader: torch.utils.data.DataLoader, + device: torch.device, # pylint: disable=no-member ) -> Tuple[float, float]: """Validate the network on the entire test set.""" # Define loss and metrics diff --git a/examples/pytorch-from-centralized-to-federated/client.py b/examples/pytorch-from-centralized-to-federated/client.py index df4da7c11cff..61c7e7f762b3 100644 --- a/examples/pytorch-from-centralized-to-federated/client.py +++ b/examples/pytorch-from-centralized-to-federated/client.py @@ -24,10 +24,10 @@ class CifarClient(fl.client.NumPyClient): """Flower client implementing CIFAR-10 image classification using PyTorch.""" def __init__( - self, - model: cifar.Net, - trainloader: DataLoader, - testloader: DataLoader, + self, + model: cifar.Net, + trainloader: DataLoader, + testloader: DataLoader, ) -> None: self.model = model self.trainloader = trainloader @@ -61,7 +61,7 @@ def set_parameters(self, parameters: List[np.ndarray]) -> None: self.model.load_state_dict(state_dict, strict=True) def fit( - self, parameters: List[np.ndarray], config: Dict[str, str] + self, parameters: List[np.ndarray], config: Dict[str, str] ) -> Tuple[List[np.ndarray], int, Dict]: # Set model parameters, train model, return updated model parameters self.set_parameters(parameters) @@ -69,7 +69,7 @@ def fit( return self.get_parameters(config={}), len(self.trainloader.dataset), {} def evaluate( - self, parameters: List[np.ndarray], config: Dict[str, str] + self, parameters: List[np.ndarray], config: Dict[str, str] ) -> Tuple[float, int, Dict]: # Set model parameters, evaluate model on local test dataset, return result self.set_parameters(parameters) diff --git a/examples/quickstart-mxnet/README.md b/examples/quickstart-mxnet/README.md index 930cec5acdfd..37e01ef2707c 100644 --- a/examples/quickstart-mxnet/README.md +++ b/examples/quickstart-mxnet/README.md @@ -1,5 +1,7 @@ # Flower Example using MXNet +> Note the MXNet project has ended, and is now in [Attic](https://attic.apache.org/projects/mxnet.html). The MXNet GitHub has also [been archived](https://github.com/apache/mxnet). As a result, this example won't be receiving more updates. Using MXNet is no longer recommnended. + This example demonstrates how to run a MXNet machine learning project federated with Flower. This introductory example for Flower uses MXNet, but you're not required to be a MXNet expert to run the example. The example will help you to understand how Flower can be used to build federated learning use cases based on an existing MXNet projects. diff --git a/examples/quickstart-mxnet/pyproject.toml b/examples/quickstart-mxnet/pyproject.toml index a0d31f76ebdd..952683eb90f6 100644 --- a/examples/quickstart-mxnet/pyproject.toml +++ b/examples/quickstart-mxnet/pyproject.toml @@ -10,7 +10,6 @@ authors = ["The Flower Authors "] [tool.poetry.dependencies] python = ">=3.8,<3.11" -flwr = ">=1.0,<2.0" -# flwr = { path = "../../", develop = true } # Development -mxnet = "1.6.0" +flwr = "1.6.0" +mxnet = "1.9.1" numpy = "1.23.1" diff --git a/examples/quickstart-mxnet/requirements.txt b/examples/quickstart-mxnet/requirements.txt index 73060e27c70c..8dd6f7150dfd 100644 --- a/examples/quickstart-mxnet/requirements.txt +++ b/examples/quickstart-mxnet/requirements.txt @@ -1,3 +1,3 @@ -flwr>=1.0,<2.0 -mxnet==1.6.0 +flwr==1.6.0 +mxnet==1.9.1 numpy==1.23.1 diff --git a/examples/quickstart-pytorch-lightning/client.py b/examples/quickstart-pytorch-lightning/client.py index 8e07494b6492..1dabd5732b9b 100644 --- a/examples/quickstart-pytorch-lightning/client.py +++ b/examples/quickstart-pytorch-lightning/client.py @@ -10,6 +10,7 @@ disable_progress_bar() + class FlowerClient(fl.client.NumPyClient): def __init__(self, model, train_loader, val_loader, test_loader): self.model = model @@ -55,7 +56,6 @@ def _set_parameters(model, parameters): def main() -> None: - parser = argparse.ArgumentParser(description="Flower") parser.add_argument( "--node-id", diff --git a/examples/quickstart-pytorch-lightning/mnist.py b/examples/quickstart-pytorch-lightning/mnist.py index d32a0afe2d1e..95342f4fb9b3 100644 --- a/examples/quickstart-pytorch-lightning/mnist.py +++ b/examples/quickstart-pytorch-lightning/mnist.py @@ -86,16 +86,20 @@ def load_data(partition): # 60 % for the federated train and 20 % for the federated validation (both in fit) partition_train_valid = partition_full["train"].train_test_split(train_size=0.75) trainloader = DataLoader( - partition_train_valid["train"], batch_size=32, - shuffle=True, collate_fn=collate_fn, num_workers=1 + partition_train_valid["train"], + batch_size=32, + shuffle=True, + collate_fn=collate_fn, + num_workers=1, ) valloader = DataLoader( - partition_train_valid["test"], batch_size=32, - collate_fn=collate_fn, num_workers=1 + partition_train_valid["test"], + batch_size=32, + collate_fn=collate_fn, + num_workers=1, ) testloader = DataLoader( - partition_full["test"], batch_size=32, - collate_fn=collate_fn, num_workers=1 + partition_full["test"], batch_size=32, collate_fn=collate_fn, num_workers=1 ) return trainloader, valloader, testloader diff --git a/examples/quickstart-sklearn-tabular/client.py b/examples/quickstart-sklearn-tabular/client.py index 88f654d4398e..5dc0e88b3c75 100644 --- a/examples/quickstart-sklearn-tabular/client.py +++ b/examples/quickstart-sklearn-tabular/client.py @@ -68,4 +68,6 @@ def evaluate(self, parameters, config): # type: ignore return loss, len(X_test), {"test_accuracy": accuracy} # Start Flower client - fl.client.start_client(server_address="0.0.0.0:8080", client=IrisClient().to_client()) + fl.client.start_client( + server_address="0.0.0.0:8080", client=IrisClient().to_client() + ) diff --git a/examples/secaggplus-mt/driver.py b/examples/secaggplus-mt/driver.py index d9f795766f6d..f5871f1b44e4 100644 --- a/examples/secaggplus-mt/driver.py +++ b/examples/secaggplus-mt/driver.py @@ -23,7 +23,8 @@ def merge(_task: task_pb2.Task, _merge_task: task_pb2.Task) -> task_pb2.Task: task_pb2.TaskIns( task_id="", # Do not set, will be created and set by the DriverAPI group_id="", - workload_id=workload_id, + run_id=run_id, + run_id=run_id, task=merge( task, task_pb2.Task( @@ -84,13 +85,13 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: # -------------------------------------------------------------------------- Driver SDK driver.connect() -create_workload_res: driver_pb2.CreateWorkloadResponse = driver.create_workload( - req=driver_pb2.CreateWorkloadRequest() +create_run_res: driver_pb2.CreateRunResponse = driver.create_run( + req=driver_pb2.CreateRunRequest() ) # -------------------------------------------------------------------------- Driver SDK -workload_id = create_workload_res.workload_id -print(f"Created workload id {workload_id}") +run_id = create_run_res.run_id +print(f"Created run id {run_id}") history = History() for server_round in range(num_rounds): @@ -119,7 +120,7 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: # loop and wait until enough client nodes are available. while True: # Get a list of node ID's from the server - get_nodes_req = driver_pb2.GetNodesRequest(workload_id=workload_id) + get_nodes_req = driver_pb2.GetNodesRequest(run_id=run_id) # ---------------------------------------------------------------------- Driver SDK get_nodes_res: driver_pb2.GetNodesResponse = driver.get_nodes( diff --git a/pyproject.toml b/pyproject.toml index adadba711787..cab083b32325 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -81,21 +81,21 @@ rest = ["requests", "starlette", "uvicorn"] [tool.poetry.group.dev.dependencies] types-dataclasses = "==0.6.6" types-protobuf = "==3.19.18" -types-requests = "==2.31.0.2" -types-setuptools = "==68.2.0.0" +types-requests = "==2.31.0.10" +types-setuptools = "==69.0.0.20240115" clang-format = "==17.0.4" -isort = "==5.12.0" +isort = "==5.13.2" black = { version = "==23.10.1", extras = ["jupyter"] } docformatter = "==1.7.5" mypy = "==1.6.1" -pylint = "==2.13.9" +pylint = "==3.0.3" flake8 = "==5.0.4" pytest = "==7.4.3" pytest-cov = "==4.1.0" pytest-watch = "==4.2.0" grpcio-tools = "==1.48.2" mypy-protobuf = "==3.2.0" -jupyterlab = "==4.0.8" +jupyterlab = "==4.0.9" rope = "==1.11.0" semver = "==3.0.2" sphinx = "==6.2.1" @@ -109,7 +109,7 @@ furo = "==2023.9.10" sphinx-reredirects = "==0.1.3" nbsphinx = "==0.9.3" nbstripout = "==0.6.1" -ruff = "==0.1.4" +ruff = "==0.1.9" sphinx-argparse = "==0.4.0" pipreqs = "==0.4.13" mdformat-gfm = "==0.3.5" @@ -137,7 +137,7 @@ line-length = 88 target-version = ["py38", "py39", "py310", "py311"] [tool.pylint."MESSAGES CONTROL"] -disable = "bad-continuation,duplicate-code,too-few-public-methods,useless-import-alias" +disable = "duplicate-code,too-few-public-methods,useless-import-alias" [tool.pytest.ini_options] minversion = "6.2" @@ -184,7 +184,7 @@ target-version = "py38" line-length = 88 select = ["D", "E", "F", "W", "B", "ISC", "C4", "UP"] fixable = ["D", "E", "F", "W", "B", "ISC", "C4", "UP"] -ignore = ["B024", "B027"] +ignore = ["B024", "B027", "D205", "D209"] exclude = [ ".bzr", ".direnv", diff --git a/src/docker/client/Dockerfile b/src/docker/client/Dockerfile new file mode 100644 index 000000000000..0755a7989281 --- /dev/null +++ b/src/docker/client/Dockerfile @@ -0,0 +1,8 @@ +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. + +ARG BASE_REPOSITORY=flwr/base +ARG BASE_IMAGE_TAG +FROM $BASE_REPOSITORY:$BASE_IMAGE_TAG + +ARG FLWR_VERSION +RUN python -m pip install -U --no-cache-dir flwr[rest]==${FLWR_VERSION} diff --git a/src/kotlin/flwr/src/main/AndroidManifest.xml b/src/kotlin/flwr/src/main/AndroidManifest.xml index 8bdb7e14b389..3cb3262db448 100644 --- a/src/kotlin/flwr/src/main/AndroidManifest.xml +++ b/src/kotlin/flwr/src/main/AndroidManifest.xml @@ -1,4 +1,5 @@ - + + diff --git a/src/proto/flwr/proto/driver.proto b/src/proto/flwr/proto/driver.proto index eb948217a4de..bc0062c4a51f 100644 --- a/src/proto/flwr/proto/driver.proto +++ b/src/proto/flwr/proto/driver.proto @@ -21,8 +21,8 @@ import "flwr/proto/node.proto"; import "flwr/proto/task.proto"; service Driver { - // Request workload_id - rpc CreateWorkload(CreateWorkloadRequest) returns (CreateWorkloadResponse) {} + // Request run_id + rpc CreateRun(CreateRunRequest) returns (CreateRunResponse) {} // Return a set of nodes rpc GetNodes(GetNodesRequest) returns (GetNodesResponse) {} @@ -34,12 +34,12 @@ service Driver { rpc PullTaskRes(PullTaskResRequest) returns (PullTaskResResponse) {} } -// CreateWorkload -message CreateWorkloadRequest {} -message CreateWorkloadResponse { sint64 workload_id = 1; } +// CreateRun +message CreateRunRequest {} +message CreateRunResponse { sint64 run_id = 1; } // GetNodes messages -message GetNodesRequest { sint64 workload_id = 1; } +message GetNodesRequest { sint64 run_id = 1; } message GetNodesResponse { repeated Node nodes = 1; } // PushTaskIns messages diff --git a/src/proto/flwr/proto/task.proto b/src/proto/flwr/proto/task.proto index 2205ef2815c8..ad71d7ea3811 100644 --- a/src/proto/flwr/proto/task.proto +++ b/src/proto/flwr/proto/task.proto @@ -36,14 +36,14 @@ message Task { message TaskIns { string task_id = 1; string group_id = 2; - sint64 workload_id = 3; + sint64 run_id = 3; Task task = 4; } message TaskRes { string task_id = 1; string group_id = 2; - sint64 workload_id = 3; + sint64 run_id = 3; Task task = 4; } diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index 3448e18e20c5..91fa5468ae75 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -74,10 +74,10 @@ def run_client() -> None: print(args.root_certificates) print(args.server) - print(args.callable_dir) + print(args.dir) print(args.callable) - callable_dir = args.callable_dir + callable_dir = args.dir if callable_dir is not None: sys.path.insert(0, callable_dir) @@ -101,6 +101,10 @@ def _parse_args_client() -> argparse.ArgumentParser: description="Start a long-running Flower client", ) + parser.add_argument( + "callable", + help="For example: `client:flower` or `project.package.module:wrapper.flower`", + ) parser.add_argument( "--insecure", action="store_true", @@ -120,13 +124,10 @@ def _parse_args_client() -> argparse.ArgumentParser: help="Server address", ) parser.add_argument( - "--callable", - help="For example: `client:flower` or `project.package.module:wrapper.flower`", - ) - parser.add_argument( - "--callable-dir", + "--dir", default="", - help="Add specified directory to the PYTHONPATH and load callable from there." + help="Add specified directory to the PYTHONPATH and load Flower " + "callable from there." " Default: current working directory.", ) @@ -137,10 +138,12 @@ def _check_actionable_client( client: Optional[Client], client_fn: Optional[ClientFn] ) -> None: if client_fn is None and client is None: - raise Exception("Both `client_fn` and `client` are `None`, but one is required") + raise ValueError( + "Both `client_fn` and `client` are `None`, but one is required" + ) if client_fn is not None and client is not None: - raise Exception( + raise ValueError( "Both `client_fn` and `client` are provided, but only one is allowed" ) @@ -149,6 +152,7 @@ def _check_actionable_client( # pylint: disable=too-many-branches # pylint: disable=too-many-locals # pylint: disable=too-many-statements +# pylint: disable=too-many-arguments def start_client( *, server_address: str, @@ -298,7 +302,7 @@ def single_client_factory( cid: str, # pylint: disable=unused-argument ) -> Client: if client is None: # Added this to keep mypy happy - raise Exception( + raise ValueError( "Both `client_fn` and `client` are `None`, but one is required" ) return client # Always return the same instance @@ -348,7 +352,7 @@ def _load_app() -> Flower: break # Register state - node_state.register_workloadstate(workload_id=task_ins.workload_id) + node_state.register_runstate(run_id=task_ins.run_id) # Load app app: Flower = load_flower_callable_fn() @@ -356,16 +360,14 @@ def _load_app() -> Flower: # Handle task message fwd_msg: Fwd = Fwd( task_ins=task_ins, - state=node_state.retrieve_workloadstate( - workload_id=task_ins.workload_id - ), + state=node_state.retrieve_runstate(run_id=task_ins.run_id), ) bwd_msg: Bwd = app(fwd=fwd_msg) # Update node state - node_state.update_workloadstate( - workload_id=bwd_msg.task_res.workload_id, - workload_state=bwd_msg.state, + node_state.update_runstate( + run_id=bwd_msg.task_res.run_id, + run_state=bwd_msg.state, ) # Send diff --git a/src/py/flwr/client/app_test.py b/src/py/flwr/client/app_test.py index 7ef6410debad..56d6308a0fe2 100644 --- a/src/py/flwr/client/app_test.py +++ b/src/py/flwr/client/app_test.py @@ -41,19 +41,19 @@ class PlainClient(Client): def get_properties(self, ins: GetPropertiesIns) -> GetPropertiesRes: """Raise an Exception because this method is not expected to be called.""" - raise Exception() + raise NotImplementedError() def get_parameters(self, ins: GetParametersIns) -> GetParametersRes: """Raise an Exception because this method is not expected to be called.""" - raise Exception() + raise NotImplementedError() def fit(self, ins: FitIns) -> FitRes: """Raise an Exception because this method is not expected to be called.""" - raise Exception() + raise NotImplementedError() def evaluate(self, ins: EvaluateIns) -> EvaluateRes: """Raise an Exception because this method is not expected to be called.""" - raise Exception() + raise NotImplementedError() class NeedsWrappingClient(NumPyClient): @@ -61,23 +61,23 @@ class NeedsWrappingClient(NumPyClient): def get_properties(self, config: Config) -> Dict[str, Scalar]: """Raise an Exception because this method is not expected to be called.""" - raise Exception() + raise NotImplementedError() def get_parameters(self, config: Config) -> NDArrays: """Raise an Exception because this method is not expected to be called.""" - raise Exception() + raise NotImplementedError() def fit( self, parameters: NDArrays, config: Config ) -> Tuple[NDArrays, int, Dict[str, Scalar]]: """Raise an Exception because this method is not expected to be called.""" - raise Exception() + raise NotImplementedError() def evaluate( self, parameters: NDArrays, config: Config ) -> Tuple[float, int, Dict[str, Scalar]]: """Raise an Exception because this method is not expected to be called.""" - raise Exception() + raise NotImplementedError() def test_to_client_with_client() -> None: diff --git a/src/py/flwr/client/client.py b/src/py/flwr/client/client.py index 280e0a8ca989..54b53296fd2f 100644 --- a/src/py/flwr/client/client.py +++ b/src/py/flwr/client/client.py @@ -19,7 +19,7 @@ from abc import ABC -from flwr.client.workload_state import WorkloadState +from flwr.client.run_state import RunState from flwr.common import ( Code, EvaluateIns, @@ -38,7 +38,7 @@ class Client(ABC): """Abstract base class for Flower clients.""" - state: WorkloadState + state: RunState def get_properties(self, ins: GetPropertiesIns) -> GetPropertiesRes: """Return set of client's properties. @@ -141,12 +141,12 @@ def evaluate(self, ins: EvaluateIns) -> EvaluateRes: metrics={}, ) - def get_state(self) -> WorkloadState: - """Get the workload state from this client.""" + def get_state(self) -> RunState: + """Get the run state from this client.""" return self.state - def set_state(self, state: WorkloadState) -> None: - """Apply a workload state to this client.""" + def set_state(self, state: RunState) -> None: + """Apply a run state to this client.""" self.state = state def to_client(self) -> Client: diff --git a/src/py/flwr/client/dpfedavg_numpy_client.py b/src/py/flwr/client/dpfedavg_numpy_client.py index 41b4d676df43..c39b89b31da3 100644 --- a/src/py/flwr/client/dpfedavg_numpy_client.py +++ b/src/py/flwr/client/dpfedavg_numpy_client.py @@ -117,16 +117,16 @@ def fit( update = [np.subtract(x, y) for (x, y) in zip(updated_params, original_params)] if "dpfedavg_clip_norm" not in config: - raise Exception("Clipping threshold not supplied by the server.") + raise KeyError("Clipping threshold not supplied by the server.") if not isinstance(config["dpfedavg_clip_norm"], float): - raise Exception("Clipping threshold should be a floating point value.") + raise TypeError("Clipping threshold should be a floating point value.") # Clipping update, clipped = clip_by_l2(update, config["dpfedavg_clip_norm"]) if "dpfedavg_noise_stddev" in config: if not isinstance(config["dpfedavg_noise_stddev"], float): - raise Exception( + raise TypeError( "Scale of noise to be added should be a floating point value." ) # Noising @@ -138,7 +138,7 @@ def fit( # Calculating value of norm indicator bit, required for adaptive clipping if "dpfedavg_adaptive_clip_enabled" in config: if not isinstance(config["dpfedavg_adaptive_clip_enabled"], bool): - raise Exception( + raise TypeError( "dpfedavg_adaptive_clip_enabled should be a boolean-valued flag." ) metrics["dpfedavg_norm_bit"] = not clipped diff --git a/src/py/flwr/client/grpc_client/connection.py b/src/py/flwr/client/grpc_client/connection.py index 335d28e72828..481f32c77859 100644 --- a/src/py/flwr/client/grpc_client/connection.py +++ b/src/py/flwr/client/grpc_client/connection.py @@ -119,7 +119,7 @@ def receive() -> TaskIns: return TaskIns( task_id=str(uuid.uuid4()), group_id="", - workload_id=0, + run_id=0, task=Task( producer=Node(node_id=0, anonymous=True), consumer=Node(node_id=0, anonymous=True), diff --git a/src/py/flwr/client/message_handler/message_handler.py b/src/py/flwr/client/message_handler/message_handler.py index 0f3070cfb01a..3f30db2a4ea2 100644 --- a/src/py/flwr/client/message_handler/message_handler.py +++ b/src/py/flwr/client/message_handler/message_handler.py @@ -28,9 +28,9 @@ get_server_message_from_task_ins, wrap_client_message_in_task_res, ) +from flwr.client.run_state import RunState from flwr.client.secure_aggregation import SecureAggregationHandler from flwr.client.typing import ClientFn -from flwr.client.workload_state import WorkloadState from flwr.common import serde from flwr.proto.task_pb2 import SecureAggregation, Task, TaskIns, TaskRes from flwr.proto.transport_pb2 import ClientMessage, Reason, ServerMessage @@ -79,16 +79,16 @@ def handle_control_message(task_ins: TaskIns) -> Tuple[Optional[TaskRes], int]: def handle( - client_fn: ClientFn, state: WorkloadState, task_ins: TaskIns -) -> Tuple[TaskRes, WorkloadState]: + client_fn: ClientFn, state: RunState, task_ins: TaskIns +) -> Tuple[TaskRes, RunState]: """Handle incoming TaskIns from the server. Parameters ---------- client_fn : ClientFn A callable that instantiates a Client. - state : WorkloadState - A dataclass storing the state for the workload being executed by the client. + state : RunState + A dataclass storing the state for the run being executed by the client. task_ins: TaskIns The task instruction coming from the server, to be processed by the client. @@ -112,7 +112,7 @@ def handle( task_res = TaskRes( task_id="", group_id="", - workload_id=0, + run_id=0, task=Task( ancestry=[], sa=SecureAggregation(named_values=serde.named_values_to_proto(res)), @@ -126,16 +126,16 @@ def handle( def handle_legacy_message( - client_fn: ClientFn, state: WorkloadState, server_msg: ServerMessage -) -> Tuple[ClientMessage, WorkloadState]: + client_fn: ClientFn, state: RunState, server_msg: ServerMessage +) -> Tuple[ClientMessage, RunState]: """Handle incoming messages from the server. Parameters ---------- client_fn : ClientFn A callable that instantiates a Client. - state : WorkloadState - A dataclass storing the state for the workload being executed by the client. + state : RunState + A dataclass storing the state for the run being executed by the client. server_msg: ServerMessage The message coming from the server, to be processed by the client. diff --git a/src/py/flwr/client/message_handler/message_handler_test.py b/src/py/flwr/client/message_handler/message_handler_test.py index d7f410d81fc0..cd810ae220e9 100644 --- a/src/py/flwr/client/message_handler/message_handler_test.py +++ b/src/py/flwr/client/message_handler/message_handler_test.py @@ -18,8 +18,8 @@ import uuid from flwr.client import Client +from flwr.client.run_state import RunState from flwr.client.typing import ClientFn -from flwr.client.workload_state import WorkloadState from flwr.common import ( EvaluateIns, EvaluateRes, @@ -121,7 +121,7 @@ def test_client_without_get_properties() -> None: task_ins: TaskIns = TaskIns( task_id=str(uuid.uuid4()), group_id="", - workload_id=0, + run_id=0, task=Task( producer=Node(node_id=0, anonymous=True), consumer=Node(node_id=0, anonymous=True), @@ -136,7 +136,7 @@ def test_client_without_get_properties() -> None: ) task_res, _ = handle( client_fn=_get_client_fn(client), - state=WorkloadState(state={}), + state=RunState(state={}), task_ins=task_ins, ) @@ -152,7 +152,7 @@ def test_client_without_get_properties() -> None: TaskRes( task_id=str(uuid.uuid4()), group_id="", - workload_id=0, + run_id=0, ) ) # pylint: disable=no-member @@ -189,7 +189,7 @@ def test_client_with_get_properties() -> None: task_ins = TaskIns( task_id=str(uuid.uuid4()), group_id="", - workload_id=0, + run_id=0, task=Task( producer=Node(node_id=0, anonymous=True), consumer=Node(node_id=0, anonymous=True), @@ -204,7 +204,7 @@ def test_client_with_get_properties() -> None: ) task_res, _ = handle( client_fn=_get_client_fn(client), - state=WorkloadState(state={}), + state=RunState(state={}), task_ins=task_ins, ) @@ -220,7 +220,7 @@ def test_client_with_get_properties() -> None: TaskRes( task_id=str(uuid.uuid4()), group_id="", - workload_id=0, + run_id=0, ) ) # pylint: disable=no-member diff --git a/src/py/flwr/client/message_handler/task_handler.py b/src/py/flwr/client/message_handler/task_handler.py index fc24539998c0..3599e1dfb254 100644 --- a/src/py/flwr/client/message_handler/task_handler.py +++ b/src/py/flwr/client/message_handler/task_handler.py @@ -70,7 +70,7 @@ def validate_task_res(task_res: TaskRes) -> bool: Returns ------- is_valid: bool - True if the `task_id`, `group_id`, and `workload_id` fields in TaskRes + True if the `task_id`, `group_id`, and `run_id` fields in TaskRes and the `producer`, `consumer`, and `ancestry` fields in its sub-message Task are not initialized accidentally elsewhere, False otherwise. @@ -80,11 +80,10 @@ def validate_task_res(task_res: TaskRes) -> bool: initialized_fields_in_task = {field.name for field, _ in task_res.task.ListFields()} # Check if certain fields are already initialized - # pylint: disable-next=too-many-boolean-expressions - if ( + if ( # pylint: disable-next=too-many-boolean-expressions "task_id" in initialized_fields_in_task_res or "group_id" in initialized_fields_in_task_res - or "workload_id" in initialized_fields_in_task_res + or "run_id" in initialized_fields_in_task_res or "producer" in initialized_fields_in_task or "consumer" in initialized_fields_in_task or "ancestry" in initialized_fields_in_task @@ -129,7 +128,7 @@ def wrap_client_message_in_task_res(client_message: ClientMessage) -> TaskRes: return TaskRes( task_id="", group_id="", - workload_id=0, + run_id=0, task=Task(ancestry=[], legacy_client_message=client_message), ) @@ -139,7 +138,7 @@ def configure_task_res( ) -> TaskRes: """Set the metadata of a TaskRes. - Fill `group_id` and `workload_id` in TaskRes + Fill `group_id` and `run_id` in TaskRes and `producer`, `consumer`, and `ancestry` in Task in TaskRes. `producer` in Task in TaskRes will remain unchanged/unset. @@ -152,7 +151,7 @@ def configure_task_res( task_res = TaskRes( task_id="", # This will be generated by the server group_id=ref_task_ins.group_id, - workload_id=ref_task_ins.workload_id, + run_id=ref_task_ins.run_id, task=task_res.task, ) # pylint: disable-next=no-member diff --git a/src/py/flwr/client/message_handler/task_handler_test.py b/src/py/flwr/client/message_handler/task_handler_test.py index 21f3a2ead98a..748ef63e72ef 100644 --- a/src/py/flwr/client/message_handler/task_handler_test.py +++ b/src/py/flwr/client/message_handler/task_handler_test.py @@ -92,7 +92,7 @@ def test_validate_task_res() -> None: assert not validate_task_res(task_res) task_res.Clear() - task_res.workload_id = 61016 + task_res.run_id = 61016 assert not validate_task_res(task_res) task_res.Clear() diff --git a/src/py/flwr/client/middleware/utils_test.py b/src/py/flwr/client/middleware/utils_test.py index 9a2d888a5ecd..aa4358be5a51 100644 --- a/src/py/flwr/client/middleware/utils_test.py +++ b/src/py/flwr/client/middleware/utils_test.py @@ -18,8 +18,8 @@ import unittest from typing import List +from flwr.client.run_state import RunState from flwr.client.typing import Bwd, FlowerCallable, Fwd, Layer -from flwr.client.workload_state import WorkloadState from flwr.proto.task_pb2 import TaskIns, TaskRes from .utils import make_ffn @@ -45,7 +45,7 @@ def make_mock_app(name: str, footprint: List[str]) -> FlowerCallable: def app(fwd: Fwd) -> Bwd: footprint.append(name) fwd.task_ins.task_id += f"{name}" - return Bwd(task_res=TaskRes(task_id=name), state=WorkloadState({})) + return Bwd(task_res=TaskRes(task_id=name), state=RunState({})) return app @@ -66,7 +66,7 @@ def test_multiple_middlewares(self) -> None: # Execute wrapped_app = make_ffn(mock_app, mock_middleware_layers) - task_res = wrapped_app(Fwd(task_ins=task_ins, state=WorkloadState({}))).task_res + task_res = wrapped_app(Fwd(task_ins=task_ins, state=RunState({}))).task_res # Assert trace = mock_middleware_names + ["app"] @@ -86,11 +86,11 @@ def filter_layer(fwd: Fwd, _: FlowerCallable) -> Bwd: footprint.append("filter") fwd.task_ins.task_id += "filter" # Skip calling app - return Bwd(task_res=TaskRes(task_id="filter"), state=WorkloadState({})) + return Bwd(task_res=TaskRes(task_id="filter"), state=RunState({})) # Execute wrapped_app = make_ffn(mock_app, [filter_layer]) - task_res = wrapped_app(Fwd(task_ins=task_ins, state=WorkloadState({}))).task_res + task_res = wrapped_app(Fwd(task_ins=task_ins, state=RunState({}))).task_res # Assert self.assertEqual(footprint, ["filter"]) diff --git a/src/py/flwr/client/node_state.py b/src/py/flwr/client/node_state.py index ee4f70dc4dca..0a29be511806 100644 --- a/src/py/flwr/client/node_state.py +++ b/src/py/flwr/client/node_state.py @@ -17,34 +17,32 @@ from typing import Any, Dict -from flwr.client.workload_state import WorkloadState +from flwr.client.run_state import RunState class NodeState: - """State of a node where client nodes execute workloads.""" + """State of a node where client nodes execute runs.""" def __init__(self) -> None: self._meta: Dict[str, Any] = {} # holds metadata about the node - self.workload_states: Dict[int, WorkloadState] = {} + self.run_states: Dict[int, RunState] = {} - def register_workloadstate(self, workload_id: int) -> None: - """Register new workload state for this node.""" - if workload_id not in self.workload_states: - self.workload_states[workload_id] = WorkloadState({}) + def register_runstate(self, run_id: int) -> None: + """Register new run state for this node.""" + if run_id not in self.run_states: + self.run_states[run_id] = RunState({}) - def retrieve_workloadstate(self, workload_id: int) -> WorkloadState: - """Get workload state given a workload_id.""" - if workload_id in self.workload_states: - return self.workload_states[workload_id] + def retrieve_runstate(self, run_id: int) -> RunState: + """Get run state given a run_id.""" + if run_id in self.run_states: + return self.run_states[run_id] raise RuntimeError( - f"WorkloadState for workload_id={workload_id} doesn't exist." - " A workload must be registered before it can be retrieved or updated " + f"RunState for run_id={run_id} doesn't exist." + " A run must be registered before it can be retrieved or updated " " by a client." ) - def update_workloadstate( - self, workload_id: int, workload_state: WorkloadState - ) -> None: - """Update workload state.""" - self.workload_states[workload_id] = workload_state + def update_runstate(self, run_id: int, run_state: RunState) -> None: + """Update run state.""" + self.run_states[run_id] = run_state diff --git a/src/py/flwr/client/node_state_tests.py b/src/py/flwr/client/node_state_tests.py index d9f9ae7db3b0..7a6bfcd31f08 100644 --- a/src/py/flwr/client/node_state_tests.py +++ b/src/py/flwr/client/node_state_tests.py @@ -16,11 +16,11 @@ from flwr.client.node_state import NodeState -from flwr.client.workload_state import WorkloadState +from flwr.client.run_state import RunState from flwr.proto.task_pb2 import TaskIns -def _run_dummy_task(state: WorkloadState) -> WorkloadState: +def _run_dummy_task(state: RunState) -> RunState: if "counter" in state.state: state.state["counter"] += "1" else: @@ -29,31 +29,31 @@ def _run_dummy_task(state: WorkloadState) -> WorkloadState: return state -def test_multiworkload_in_node_state() -> None: +def test_multirun_in_node_state() -> None: """Test basic NodeState logic.""" # Tasks to perform - tasks = [TaskIns(workload_id=w_id) for w_id in [0, 1, 1, 2, 3, 2, 1, 5]] - # the "tasks" is to count how many times each workload is executed + tasks = [TaskIns(run_id=run_id) for run_id in [0, 1, 1, 2, 3, 2, 1, 5]] + # the "tasks" is to count how many times each run is executed expected_values = {0: "1", 1: "1" * 3, 2: "1" * 2, 3: "1", 5: "1"} # NodeState node_state = NodeState() for task in tasks: - w_id = task.workload_id + run_id = task.run_id # Register - node_state.register_workloadstate(workload_id=w_id) + node_state.register_runstate(run_id=run_id) - # Get workload state - state = node_state.retrieve_workloadstate(workload_id=w_id) + # Get run state + state = node_state.retrieve_runstate(run_id=run_id) # Run "task" updated_state = _run_dummy_task(state) - # Update workload state - node_state.update_workloadstate(workload_id=w_id, workload_state=updated_state) + # Update run state + node_state.update_runstate(run_id=run_id, run_state=updated_state) # Verify values - for w_id, state in node_state.workload_states.items(): - assert state.state["counter"] == expected_values[w_id] + for run_id, state in node_state.run_states.items(): + assert state.state["counter"] == expected_values[run_id] diff --git a/src/py/flwr/client/numpy_client.py b/src/py/flwr/client/numpy_client.py index 8b0893ea30aa..d67fb90512d4 100644 --- a/src/py/flwr/client/numpy_client.py +++ b/src/py/flwr/client/numpy_client.py @@ -19,7 +19,7 @@ from typing import Callable, Dict, Tuple from flwr.client.client import Client -from flwr.client.workload_state import WorkloadState +from flwr.client.run_state import RunState from flwr.common import ( Config, NDArrays, @@ -70,7 +70,7 @@ class NumPyClient(ABC): """Abstract base class for Flower clients using NumPy.""" - state: WorkloadState + state: RunState def get_properties(self, config: Config) -> Dict[str, Scalar]: """Return a client's set of properties. @@ -174,12 +174,12 @@ def evaluate( _ = (self, parameters, config) return 0.0, 0, {} - def get_state(self) -> WorkloadState: - """Get the workload state from this client.""" + def get_state(self) -> RunState: + """Get the run state from this client.""" return self.state - def set_state(self, state: WorkloadState) -> None: - """Apply a workload state to this client.""" + def set_state(self, state: RunState) -> None: + """Apply a run state to this client.""" self.state = state def to_client(self) -> Client: @@ -242,7 +242,7 @@ def _fit(self: Client, ins: FitIns) -> FitRes: and isinstance(results[1], int) and isinstance(results[2], dict) ): - raise Exception(EXCEPTION_MESSAGE_WRONG_RETURN_TYPE_FIT) + raise TypeError(EXCEPTION_MESSAGE_WRONG_RETURN_TYPE_FIT) # Return FitRes parameters_prime, num_examples, metrics = results @@ -266,7 +266,7 @@ def _evaluate(self: Client, ins: EvaluateIns) -> EvaluateRes: and isinstance(results[1], int) and isinstance(results[2], dict) ): - raise Exception(EXCEPTION_MESSAGE_WRONG_RETURN_TYPE_EVALUATE) + raise TypeError(EXCEPTION_MESSAGE_WRONG_RETURN_TYPE_EVALUATE) # Return EvaluateRes loss, num_examples, metrics = results @@ -278,12 +278,12 @@ def _evaluate(self: Client, ins: EvaluateIns) -> EvaluateRes: ) -def _get_state(self: Client) -> WorkloadState: +def _get_state(self: Client) -> RunState: """Return state of underlying NumPyClient.""" return self.numpy_client.get_state() # type: ignore -def _set_state(self: Client, state: WorkloadState) -> None: +def _set_state(self: Client, state: RunState) -> None: """Apply state to underlying NumPyClient.""" self.numpy_client.set_state(state) # type: ignore diff --git a/src/py/flwr/client/rest_client/connection.py b/src/py/flwr/client/rest_client/connection.py index d22b246dbd61..87b06dd0be4e 100644 --- a/src/py/flwr/client/rest_client/connection.py +++ b/src/py/flwr/client/rest_client/connection.py @@ -143,6 +143,7 @@ def create_node() -> None: }, data=create_node_req_bytes, verify=verify, + timeout=None, ) # Check status code and headers @@ -185,6 +186,7 @@ def delete_node() -> None: }, data=delete_node_req_req_bytes, verify=verify, + timeout=None, ) # Check status code and headers @@ -225,6 +227,7 @@ def receive() -> Optional[TaskIns]: }, data=pull_task_ins_req_bytes, verify=verify, + timeout=None, ) # Check status code and headers @@ -303,6 +306,7 @@ def send(task_res: TaskRes) -> None: }, data=push_task_res_request_bytes, verify=verify, + timeout=None, ) state[KEY_TASK_INS] = None diff --git a/src/py/flwr/client/workload_state.py b/src/py/flwr/client/run_state.py similarity index 88% rename from src/py/flwr/client/workload_state.py rename to src/py/flwr/client/run_state.py index 42ae2a925f47..c2755eb995eb 100644 --- a/src/py/flwr/client/workload_state.py +++ b/src/py/flwr/client/run_state.py @@ -12,14 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Workload state.""" +"""Run state.""" from dataclasses import dataclass from typing import Dict @dataclass -class WorkloadState: - """State of a workload executed by a client node.""" +class RunState: + """State of a run executed by a client node.""" state: Dict[str, str] diff --git a/src/py/flwr/client/secure_aggregation/secaggplus_handler.py b/src/py/flwr/client/secure_aggregation/secaggplus_handler.py index efbb00a9d916..4b74c1ace3de 100644 --- a/src/py/flwr/client/secure_aggregation/secaggplus_handler.py +++ b/src/py/flwr/client/secure_aggregation/secaggplus_handler.py @@ -333,7 +333,7 @@ def _share_keys( # Check if the size is larger than threshold if len(state.public_keys_dict) < state.threshold: - raise Exception("Available neighbours number smaller than threshold") + raise ValueError("Available neighbours number smaller than threshold") # Check if all public keys are unique pk_list: List[bytes] = [] @@ -341,14 +341,14 @@ def _share_keys( pk_list.append(pk1) pk_list.append(pk2) if len(set(pk_list)) != len(pk_list): - raise Exception("Some public keys are identical") + raise ValueError("Some public keys are identical") # Check if public keys of this client are correct in the dictionary if ( state.public_keys_dict[state.sid][0] != state.pk1 or state.public_keys_dict[state.sid][1] != state.pk2 ): - raise Exception( + raise ValueError( "Own public keys are displayed in dict incorrectly, should not happen!" ) @@ -393,7 +393,7 @@ def _collect_masked_input( ciphertexts = cast(List[bytes], named_values[KEY_CIPHERTEXT_LIST]) srcs = cast(List[int], named_values[KEY_SOURCE_LIST]) if len(ciphertexts) + 1 < state.threshold: - raise Exception("Not enough available neighbour clients.") + raise ValueError("Not enough available neighbour clients.") # Decrypt ciphertexts, verify their sources, and store shares. for src, ciphertext in zip(srcs, ciphertexts): @@ -409,7 +409,7 @@ def _collect_masked_input( f"from {actual_src} instead of {src}." ) if dst != state.sid: - ValueError( + raise ValueError( f"Client {state.sid}: received an encrypted message" f"for Client {dst} from Client {src}." ) @@ -476,7 +476,7 @@ def _unmask(state: SecAggPlusState, named_values: Dict[str, Value]) -> Dict[str, # Send private mask seed share for every avaliable client (including itclient) # Send first private key share for building pairwise mask for every dropped client if len(active_sids) < state.threshold: - raise Exception("Available neighbours number smaller than threshold") + raise ValueError("Available neighbours number smaller than threshold") sids, shares = [], [] sids += active_sids diff --git a/src/py/flwr/client/typing.py b/src/py/flwr/client/typing.py index 2dd368bf6d08..1652ee57674a 100644 --- a/src/py/flwr/client/typing.py +++ b/src/py/flwr/client/typing.py @@ -17,7 +17,7 @@ from dataclasses import dataclass from typing import Callable -from flwr.client.workload_state import WorkloadState +from flwr.client.run_state import RunState from flwr.proto.task_pb2 import TaskIns, TaskRes from .client import Client as Client @@ -28,7 +28,7 @@ class Fwd: """.""" task_ins: TaskIns - state: WorkloadState + state: RunState @dataclass @@ -36,7 +36,7 @@ class Bwd: """.""" task_res: TaskRes - state: WorkloadState + state: RunState FlowerCallable = Callable[[Fwd], Bwd] diff --git a/src/py/flwr/common/parametersrecord.py b/src/py/flwr/common/parametersrecord.py new file mode 100644 index 000000000000..3d40c0488baa --- /dev/null +++ b/src/py/flwr/common/parametersrecord.py @@ -0,0 +1,110 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""ParametersRecord and Array.""" + + +from dataclasses import dataclass, field +from typing import List, Optional, OrderedDict + + +@dataclass +class Array: + """Array type. + + A dataclass containing serialized data from an array-like or tensor-like object + along with some metadata about it. + + Parameters + ---------- + dtype : str + A string representing the data type of the serialised object (e.g. `np.float32`) + + shape : List[int] + A list representing the shape of the unserialized array-like object. This is + used to deserialize the data (depending on the serialization method) or simply + as a metadata field. + + stype : str + A string indicating the type of serialisation mechanism used to generate the + bytes in `data` from an array-like or tensor-like object. + + data: bytes + A buffer of bytes containing the data. + """ + + dtype: str + shape: List[int] + stype: str + data: bytes + + +@dataclass +class ParametersRecord: + """Parameters record. + + A dataclass storing named Arrays in order. This means that it holds entries as an + OrderedDict[str, Array]. ParametersRecord objects can be viewed as an equivalent to + PyTorch's state_dict, but holding serialised tensors instead. + """ + + keep_input: bool + data: OrderedDict[str, Array] = field(default_factory=OrderedDict[str, Array]) + + def __init__( + self, + array_dict: Optional[OrderedDict[str, Array]] = None, + keep_input: bool = False, + ) -> None: + """Construct a ParametersRecord object. + + Parameters + ---------- + array_dict : Optional[OrderedDict[str, Array]] + A dictionary that stores serialized array-like or tensor-like objects. + keep_input : bool (default: False) + A boolean indicating whether parameters should be deleted from the input + dictionary immediately after adding them to the record. If False, the + dictionary passed to `set_parameters()` will be empty once exiting from that + function. This is the desired behaviour when working with very large + models/tensors/arrays. However, if you plan to continue working with your + parameters after adding it to the record, set this flag to True. When set + to True, the data is duplicated in memory. + """ + self.keep_input = keep_input + self.data = OrderedDict() + if array_dict: + self.set_parameters(array_dict) + + def set_parameters(self, array_dict: OrderedDict[str, Array]) -> None: + """Add parameters to record. + + Parameters + ---------- + array_dict : OrderedDict[str, Array] + A dictionary that stores serialized array-like or tensor-like objects. + """ + if any(not isinstance(k, str) for k in array_dict.keys()): + raise TypeError(f"Not all keys are of valid type. Expected {str}") + if any(not isinstance(v, Array) for v in array_dict.values()): + raise TypeError(f"Not all values are of valid type. Expected {Array}") + + if self.keep_input: + # Copy + self.data = OrderedDict(array_dict) + else: + # Add entries to dataclass without duplicating memory + for key in list(array_dict.keys()): + self.data[key] = array_dict[key] + del array_dict[key] diff --git a/src/py/flwr/common/recordset.py b/src/py/flwr/common/recordset.py new file mode 100644 index 000000000000..dc723a2cea86 --- /dev/null +++ b/src/py/flwr/common/recordset.py @@ -0,0 +1,75 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""RecordSet.""" + +from dataclasses import dataclass, field +from typing import Dict + +from .parametersrecord import ParametersRecord + + +@dataclass +class MetricsRecord: + """Metrics record.""" + + +@dataclass +class ConfigsRecord: + """Configs record.""" + + +@dataclass +class RecordSet: + """Definition of RecordSet.""" + + parameters: Dict[str, ParametersRecord] = field(default_factory=dict) + metrics: Dict[str, MetricsRecord] = field(default_factory=dict) + configs: Dict[str, ConfigsRecord] = field(default_factory=dict) + + def set_parameters(self, name: str, record: ParametersRecord) -> None: + """Add a ParametersRecord.""" + self.parameters[name] = record + + def get_parameters(self, name: str) -> ParametersRecord: + """Get a ParametesRecord.""" + return self.parameters[name] + + def del_parameters(self, name: str) -> None: + """Delete a ParametersRecord.""" + del self.parameters[name] + + def set_metrics(self, name: str, record: MetricsRecord) -> None: + """Add a MetricsRecord.""" + self.metrics[name] = record + + def get_metrics(self, name: str) -> MetricsRecord: + """Get a MetricsRecord.""" + return self.metrics[name] + + def del_metrics(self, name: str) -> None: + """Delete a MetricsRecord.""" + del self.metrics[name] + + def set_configs(self, name: str, record: ConfigsRecord) -> None: + """Add a ConfigsRecord.""" + self.configs[name] = record + + def get_configs(self, name: str) -> ConfigsRecord: + """Get a ConfigsRecord.""" + return self.configs[name] + + def del_configs(self, name: str) -> None: + """Delete a ConfigsRecord.""" + del self.configs[name] diff --git a/src/py/flwr/common/recordset_test.py b/src/py/flwr/common/recordset_test.py new file mode 100644 index 000000000000..90c06dcdb109 --- /dev/null +++ b/src/py/flwr/common/recordset_test.py @@ -0,0 +1,147 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""RecordSet tests.""" + + +from typing import Callable, List, OrderedDict, Type, Union + +import numpy as np +import pytest + +from .parameter import ndarrays_to_parameters, parameters_to_ndarrays +from .parametersrecord import Array, ParametersRecord +from .recordset_utils import ( + parameters_to_parametersrecord, + parametersrecord_to_parameters, +) +from .typing import NDArray, NDArrays, Parameters + + +def get_ndarrays() -> NDArrays: + """Return list of NumPy arrays.""" + arr1 = np.array([[1.0, 2.0], [3.0, 4], [5.0, 6.0]]) + arr2 = np.eye(2, 7, 3) + + return [arr1, arr2] + + +def ndarray_to_array(ndarray: NDArray) -> Array: + """Represent NumPy ndarray as Array.""" + return Array( + data=ndarray.tobytes(), + dtype=str(ndarray.dtype), + stype="numpy.ndarray.tobytes", + shape=list(ndarray.shape), + ) + + +def test_ndarray_to_array() -> None: + """Test creation of Array object from NumPy ndarray.""" + shape = (2, 7, 9) + arr = np.eye(*shape) + + array = ndarray_to_array(arr) + + arr_ = np.frombuffer(buffer=array.data, dtype=array.dtype).reshape(array.shape) + + assert np.array_equal(arr, arr_) + + +def test_parameters_to_array_and_back() -> None: + """Test conversion between legacy Parameters and Array.""" + ndarrays = get_ndarrays() + + # Array represents a single array, unlike Paramters, which represent a + # list of arrays + ndarray = ndarrays[0] + + parameters = ndarrays_to_parameters([ndarray]) + + array = Array( + data=parameters.tensors[0], dtype="", stype=parameters.tensor_type, shape=[] + ) + + parameters = Parameters(tensors=[array.data], tensor_type=array.stype) + + ndarray_ = parameters_to_ndarrays(parameters=parameters)[0] + + assert np.array_equal(ndarray, ndarray_) + + +def test_parameters_to_parametersrecord_and_back() -> None: + """Test conversion between legacy Parameters and ParametersRecords.""" + ndarrays = get_ndarrays() + + parameters = ndarrays_to_parameters(ndarrays) + + params_record = parameters_to_parametersrecord(parameters=parameters) + + parameters_ = parametersrecord_to_parameters(params_record) + + ndarrays_ = parameters_to_ndarrays(parameters=parameters_) + + for arr, arr_ in zip(ndarrays, ndarrays_): + assert np.array_equal(arr, arr_) + + +def test_set_parameters_while_keeping_intputs() -> None: + """Tests keep_input functionality in ParametersRecord.""" + # Adding parameters to a record that doesn't erase entries in the input `array_dict` + p_record = ParametersRecord(keep_input=True) + array_dict = OrderedDict( + {str(i): ndarray_to_array(ndarray) for i, ndarray in enumerate(get_ndarrays())} + ) + p_record.set_parameters(array_dict) + + # Creating a second parametersrecord passing the same `array_dict` (not erased) + p_record_2 = ParametersRecord(array_dict) + assert p_record.data == p_record_2.data + + # Now it should be empty (the second ParametersRecord wasn't flagged to keep it) + assert len(array_dict) == 0 + + +def test_set_parameters_with_correct_types() -> None: + """Test adding dictionary of Arrays to ParametersRecord.""" + p_record = ParametersRecord() + array_dict = OrderedDict( + {str(i): ndarray_to_array(ndarray) for i, ndarray in enumerate(get_ndarrays())} + ) + p_record.set_parameters(array_dict) + + +@pytest.mark.parametrize( + "key_type, value_fn", + [ + (str, lambda x: x), # correct key, incorrect value + (str, lambda x: x.tolist()), # correct key, incorrect value + (int, ndarray_to_array), # incorrect key, correct value + (int, lambda x: x), # incorrect key, incorrect value + (int, lambda x: x.tolist()), # incorrect key, incorrect value + ], +) +def test_set_parameters_with_incorrect_types( + key_type: Type[Union[int, str]], + value_fn: Callable[[NDArray], Union[NDArray, List[float]]], +) -> None: + """Test adding dictionary of unsupported types to ParametersRecord.""" + p_record = ParametersRecord() + + array_dict = { + key_type(i): value_fn(ndarray) for i, ndarray in enumerate(get_ndarrays()) + } + + with pytest.raises(TypeError): + p_record.set_parameters(array_dict) # type: ignore diff --git a/src/py/flwr/common/recordset_utils.py b/src/py/flwr/common/recordset_utils.py new file mode 100644 index 000000000000..c1e724fa2758 --- /dev/null +++ b/src/py/flwr/common/recordset_utils.py @@ -0,0 +1,87 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""RecordSet utilities.""" + + +from typing import OrderedDict + +from .parametersrecord import Array, ParametersRecord +from .typing import Parameters + + +def parametersrecord_to_parameters( + record: ParametersRecord, keep_input: bool = False +) -> Parameters: + """Convert ParameterRecord to legacy Parameters. + + Warning: Because `Arrays` in `ParametersRecord` encode more information of the + array-like or tensor-like data (e.g their datatype, shape) than `Parameters` it + might not be possible to reconstruct such data structures from `Parameters` objects + alone. Additional information or metadta must be provided from elsewhere. + + Parameters + ---------- + record : ParametersRecord + The record to be conveted into Parameters. + keep_input : bool (default: False) + A boolean indicating whether entries in the record should be deleted from the + input dictionary immediately after adding them to the record. + """ + parameters = Parameters(tensors=[], tensor_type="") + + for key in list(record.data.keys()): + parameters.tensors.append(record.data[key].data) + + if not keep_input: + del record.data[key] + + return parameters + + +def parameters_to_parametersrecord( + parameters: Parameters, keep_input: bool = False +) -> ParametersRecord: + """Convert legacy Parameters into a single ParametersRecord. + + Because there is no concept of names in the legacy Parameters, arbitrary keys will + be used when constructing the ParametersRecord. Similarly, the shape and data type + won't be recorded in the Array objects. + + Parameters + ---------- + parameters : Parameters + Parameters object to be represented as a ParametersRecord. + keep_input : bool (default: False) + A boolean indicating whether parameters should be deleted from the input + Parameters object (i.e. a list of serialized NumPy arrays) immediately after + adding them to the record. + """ + tensor_type = parameters.tensor_type + + p_record = ParametersRecord() + + num_arrays = len(parameters.tensors) + for idx in range(num_arrays): + if keep_input: + tensor = parameters.tensors[idx] + else: + tensor = parameters.tensors.pop(0) + p_record.set_parameters( + OrderedDict( + {str(idx): Array(data=tensor, dtype="", stype=tensor_type, shape=[])} + ) + ) + + return p_record diff --git a/src/py/flwr/common/retry_invoker.py b/src/py/flwr/common/retry_invoker.py index a60fff57e7bf..5441e766983a 100644 --- a/src/py/flwr/common/retry_invoker.py +++ b/src/py/flwr/common/retry_invoker.py @@ -156,6 +156,7 @@ class RetryInvoker: >>> invoker.invoke(my_func, arg1, arg2, kw1=kwarg1) """ + # pylint: disable-next=too-many-arguments def __init__( self, wait_factory: Callable[[], Generator[float, None, None]], diff --git a/src/py/flwr/common/serde.py b/src/py/flwr/common/serde.py index c8c73e87e04a..59f5387b0a07 100644 --- a/src/py/flwr/common/serde.py +++ b/src/py/flwr/common/serde.py @@ -59,7 +59,9 @@ def server_message_to_proto(server_message: typing.ServerMessage) -> ServerMessa server_message.evaluate_ins, ) ) - raise Exception("No instruction set in ServerMessage, cannot serialize to ProtoBuf") + raise ValueError( + "No instruction set in ServerMessage, cannot serialize to ProtoBuf" + ) def server_message_from_proto( @@ -91,7 +93,7 @@ def server_message_from_proto( server_message_proto.evaluate_ins, ) ) - raise Exception( + raise ValueError( "Unsupported instruction in ServerMessage, cannot deserialize from ProtoBuf" ) @@ -125,7 +127,9 @@ def client_message_to_proto(client_message: typing.ClientMessage) -> ClientMessa client_message.evaluate_res, ) ) - raise Exception("No instruction set in ClientMessage, cannot serialize to ProtoBuf") + raise ValueError( + "No instruction set in ClientMessage, cannot serialize to ProtoBuf" + ) def client_message_from_proto( @@ -157,7 +161,7 @@ def client_message_from_proto( client_message_proto.evaluate_res, ) ) - raise Exception( + raise ValueError( "Unsupported instruction in ClientMessage, cannot deserialize from ProtoBuf" ) @@ -474,7 +478,7 @@ def scalar_to_proto(scalar: typing.Scalar) -> Scalar: if isinstance(scalar, str): return Scalar(string=scalar) - raise Exception( + raise ValueError( f"Accepted types: {bool, bytes, float, int, str} (but not {type(scalar)})" ) @@ -518,7 +522,7 @@ def _check_value(value: typing.Value) -> None: for element in value: if isinstance(element, data_type): continue - raise Exception( + raise TypeError( f"Inconsistent type: the types of elements in the list must " f"be the same (expected {data_type}, but got {type(element)})." ) diff --git a/src/py/flwr/driver/app.py b/src/py/flwr/driver/app.py index 3cb8652365d8..987b4a31981b 100644 --- a/src/py/flwr/driver/app.py +++ b/src/py/flwr/driver/app.py @@ -170,8 +170,8 @@ def update_client_manager( and dead nodes will be removed from the ClientManager via `client_manager.unregister()`. """ - # Request for workload_id - workload_id = driver.create_workload(driver_pb2.CreateWorkloadRequest()).workload_id + # Request for run_id + run_id = driver.create_run(driver_pb2.CreateRunRequest()).run_id # Loop until the driver is disconnected registered_nodes: Dict[int, DriverClientProxy] = {} @@ -181,7 +181,7 @@ def update_client_manager( if driver.stub is None: break get_nodes_res = driver.get_nodes( - req=driver_pb2.GetNodesRequest(workload_id=workload_id) + req=driver_pb2.GetNodesRequest(run_id=run_id) ) all_node_ids = {node.node_id for node in get_nodes_res.nodes} dead_nodes = set(registered_nodes).difference(all_node_ids) @@ -199,7 +199,7 @@ def update_client_manager( node_id=node_id, driver=driver, anonymous=False, - workload_id=workload_id, + run_id=run_id, ) if client_manager.register(client_proxy): registered_nodes[node_id] = client_proxy diff --git a/src/py/flwr/driver/app_test.py b/src/py/flwr/driver/app_test.py index 91b4fd30bc4b..82747e5afb2c 100644 --- a/src/py/flwr/driver/app_test.py +++ b/src/py/flwr/driver/app_test.py @@ -13,7 +13,6 @@ # limitations under the License. # ============================================================================== """Flower Driver app tests.""" -# pylint: disable=no-self-use import threading @@ -22,7 +21,7 @@ from unittest.mock import MagicMock from flwr.driver.app import update_client_manager -from flwr.proto.driver_pb2 import CreateWorkloadResponse, GetNodesResponse +from flwr.proto.driver_pb2 import CreateRunResponse, GetNodesResponse from flwr.proto.node_pb2 import Node from flwr.server.client_manager import SimpleClientManager @@ -43,7 +42,7 @@ def test_simple_client_manager_update(self) -> None: ] driver = MagicMock() driver.stub = "driver stub" - driver.create_workload.return_value = CreateWorkloadResponse(workload_id=1) + driver.create_run.return_value = CreateRunResponse(run_id=1) driver.get_nodes.return_value = GetNodesResponse(nodes=expected_nodes) client_manager = SimpleClientManager() lock = threading.Lock() @@ -76,7 +75,7 @@ def test_simple_client_manager_update(self) -> None: driver.stub = None # Assert - driver.create_workload.assert_called_once() + driver.create_run.assert_called_once() assert node_ids == {node.node_id for node in expected_nodes} assert updated_node_ids == {node.node_id for node in expected_updated_nodes} diff --git a/src/py/flwr/driver/driver.py b/src/py/flwr/driver/driver.py index f1a7c6663c11..9f96cc46ce1e 100644 --- a/src/py/flwr/driver/driver.py +++ b/src/py/flwr/driver/driver.py @@ -19,7 +19,7 @@ from flwr.driver.grpc_driver import DEFAULT_SERVER_ADDRESS_DRIVER, GrpcDriver from flwr.proto.driver_pb2 import ( - CreateWorkloadRequest, + CreateRunRequest, GetNodesRequest, PullTaskResRequest, PushTaskInsRequest, @@ -54,37 +54,37 @@ def __init__( self.addr = driver_service_address self.certificates = certificates self.grpc_driver: Optional[GrpcDriver] = None - self.workload_id: Optional[int] = None + self.run_id: Optional[int] = None self.node = Node(node_id=0, anonymous=True) - def _get_grpc_driver_and_workload_id(self) -> Tuple[GrpcDriver, int]: + def _get_grpc_driver_and_run_id(self) -> Tuple[GrpcDriver, int]: # Check if the GrpcDriver is initialized - if self.grpc_driver is None or self.workload_id is None: - # Connect and create workload + if self.grpc_driver is None or self.run_id is None: + # Connect and create run self.grpc_driver = GrpcDriver( driver_service_address=self.addr, certificates=self.certificates ) self.grpc_driver.connect() - res = self.grpc_driver.create_workload(CreateWorkloadRequest()) - self.workload_id = res.workload_id + res = self.grpc_driver.create_run(CreateRunRequest()) + self.run_id = res.run_id - return self.grpc_driver, self.workload_id + return self.grpc_driver, self.run_id def get_nodes(self) -> List[Node]: """Get node IDs.""" - grpc_driver, workload_id = self._get_grpc_driver_and_workload_id() + grpc_driver, run_id = self._get_grpc_driver_and_run_id() # Call GrpcDriver method - res = grpc_driver.get_nodes(GetNodesRequest(workload_id=workload_id)) + res = grpc_driver.get_nodes(GetNodesRequest(run_id=run_id)) return list(res.nodes) def push_task_ins(self, task_ins_list: List[TaskIns]) -> List[str]: """Schedule tasks.""" - grpc_driver, workload_id = self._get_grpc_driver_and_workload_id() + grpc_driver, run_id = self._get_grpc_driver_and_run_id() - # Set workload_id + # Set run_id for task_ins in task_ins_list: - task_ins.workload_id = workload_id + task_ins.run_id = run_id # Call GrpcDriver method res = grpc_driver.push_task_ins(PushTaskInsRequest(task_ins_list=task_ins_list)) @@ -92,7 +92,7 @@ def push_task_ins(self, task_ins_list: List[TaskIns]) -> List[str]: def pull_task_res(self, task_ids: Iterable[str]) -> List[TaskRes]: """Get task results.""" - grpc_driver, _ = self._get_grpc_driver_and_workload_id() + grpc_driver, _ = self._get_grpc_driver_and_run_id() # Call GrpcDriver method res = grpc_driver.pull_task_res( diff --git a/src/py/flwr/driver/driver_client_proxy.py b/src/py/flwr/driver/driver_client_proxy.py index 6d60fc49159b..6c15acb9ebde 100644 --- a/src/py/flwr/driver/driver_client_proxy.py +++ b/src/py/flwr/driver/driver_client_proxy.py @@ -31,13 +31,11 @@ class DriverClientProxy(ClientProxy): """Flower client proxy which delegates work using the Driver API.""" - def __init__( - self, node_id: int, driver: GrpcDriver, anonymous: bool, workload_id: int - ): + def __init__(self, node_id: int, driver: GrpcDriver, anonymous: bool, run_id: int): super().__init__(str(node_id)) self.node_id = node_id self.driver = driver - self.workload_id = workload_id + self.run_id = run_id self.anonymous = anonymous def get_properties( @@ -106,7 +104,7 @@ def _send_receive_msg( task_ins = task_pb2.TaskIns( task_id="", group_id="", - workload_id=self.workload_id, + run_id=self.run_id, task=task_pb2.Task( producer=node_pb2.Node( node_id=0, diff --git a/src/py/flwr/driver/driver_client_proxy_test.py b/src/py/flwr/driver/driver_client_proxy_test.py index 82b5b46d7810..e7fb088dbf57 100644 --- a/src/py/flwr/driver/driver_client_proxy_test.py +++ b/src/py/flwr/driver/driver_client_proxy_test.py @@ -52,7 +52,7 @@ def test_get_properties(self) -> None: task_pb2.TaskRes( task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012", group_id="", - workload_id=0, + run_id=0, task=task_pb2.Task( legacy_client_message=ClientMessage( get_properties_res=ClientMessage.GetPropertiesRes( @@ -64,7 +64,7 @@ def test_get_properties(self) -> None: ] ) client = DriverClientProxy( - node_id=1, driver=self.driver, anonymous=True, workload_id=0 + node_id=1, driver=self.driver, anonymous=True, run_id=0 ) request_properties: Config = {"tensor_type": "str"} ins: flwr.common.GetPropertiesIns = flwr.common.GetPropertiesIns( @@ -88,7 +88,7 @@ def test_get_parameters(self) -> None: task_pb2.TaskRes( task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012", group_id="", - workload_id=0, + run_id=0, task=task_pb2.Task( legacy_client_message=ClientMessage( get_parameters_res=ClientMessage.GetParametersRes( @@ -100,7 +100,7 @@ def test_get_parameters(self) -> None: ] ) client = DriverClientProxy( - node_id=1, driver=self.driver, anonymous=True, workload_id=0 + node_id=1, driver=self.driver, anonymous=True, run_id=0 ) get_parameters_ins = GetParametersIns(config={}) @@ -123,7 +123,7 @@ def test_fit(self) -> None: task_pb2.TaskRes( task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012", group_id="", - workload_id=0, + run_id=0, task=task_pb2.Task( legacy_client_message=ClientMessage( fit_res=ClientMessage.FitRes( @@ -136,7 +136,7 @@ def test_fit(self) -> None: ] ) client = DriverClientProxy( - node_id=1, driver=self.driver, anonymous=True, workload_id=0 + node_id=1, driver=self.driver, anonymous=True, run_id=0 ) parameters = flwr.common.ndarrays_to_parameters([np.ones((2, 2))]) ins: flwr.common.FitIns = flwr.common.FitIns(parameters, {}) @@ -160,7 +160,7 @@ def test_evaluate(self) -> None: task_pb2.TaskRes( task_id="554bd3c8-8474-4b93-a7db-c7bec1bf0012", group_id="", - workload_id=0, + run_id=0, task=task_pb2.Task( legacy_client_message=ClientMessage( evaluate_res=ClientMessage.EvaluateRes( @@ -172,7 +172,7 @@ def test_evaluate(self) -> None: ] ) client = DriverClientProxy( - node_id=1, driver=self.driver, anonymous=True, workload_id=0 + node_id=1, driver=self.driver, anonymous=True, run_id=0 ) parameters = flwr.common.Parameters(tensors=[], tensor_type="np") evaluate_ins: flwr.common.EvaluateIns = flwr.common.EvaluateIns(parameters, {}) diff --git a/src/py/flwr/driver/driver_test.py b/src/py/flwr/driver/driver_test.py index 820018788a8f..8f75bbf78362 100644 --- a/src/py/flwr/driver/driver_test.py +++ b/src/py/flwr/driver/driver_test.py @@ -33,9 +33,9 @@ class TestDriver(unittest.TestCase): def setUp(self) -> None: """Initialize mock GrpcDriver and Driver instance before each test.""" mock_response = Mock() - mock_response.workload_id = 61016 + mock_response.run_id = 61016 self.mock_grpc_driver = Mock() - self.mock_grpc_driver.create_workload.return_value = mock_response + self.mock_grpc_driver.create_run.return_value = mock_response self.patcher = patch( "flwr.driver.driver.GrpcDriver", return_value=self.mock_grpc_driver ) @@ -47,27 +47,27 @@ def tearDown(self) -> None: self.patcher.stop() def test_check_and_init_grpc_driver_already_initialized(self) -> None: - """Test that GrpcDriver doesn't initialize if workload is created.""" + """Test that GrpcDriver doesn't initialize if run is created.""" # Prepare self.driver.grpc_driver = self.mock_grpc_driver - self.driver.workload_id = 61016 + self.driver.run_id = 61016 # Execute # pylint: disable-next=protected-access - self.driver._get_grpc_driver_and_workload_id() + self.driver._get_grpc_driver_and_run_id() # Assert self.mock_grpc_driver.connect.assert_not_called() def test_check_and_init_grpc_driver_needs_initialization(self) -> None: - """Test GrpcDriver initialization when workload is not created.""" + """Test GrpcDriver initialization when run is not created.""" # Execute # pylint: disable-next=protected-access - self.driver._get_grpc_driver_and_workload_id() + self.driver._get_grpc_driver_and_run_id() # Assert self.mock_grpc_driver.connect.assert_called_once() - self.assertEqual(self.driver.workload_id, 61016) + self.assertEqual(self.driver.run_id, 61016) def test_get_nodes(self) -> None: """Test retrieval of nodes.""" @@ -85,7 +85,7 @@ def test_get_nodes(self) -> None: self.assertEqual(len(args), 1) self.assertEqual(len(kwargs), 0) self.assertIsInstance(args[0], GetNodesRequest) - self.assertEqual(args[0].workload_id, 61016) + self.assertEqual(args[0].run_id, 61016) self.assertEqual(nodes, mock_response.nodes) def test_push_task_ins(self) -> None: @@ -107,7 +107,7 @@ def test_push_task_ins(self) -> None: self.assertIsInstance(args[0], PushTaskInsRequest) self.assertEqual(task_ids, mock_response.task_ids) for task_ins in args[0].task_ins_list: - self.assertEqual(task_ins.workload_id, 61016) + self.assertEqual(task_ins.run_id, 61016) def test_pull_task_res_with_given_task_ids(self) -> None: """Test pulling task results with specific task IDs.""" @@ -136,9 +136,10 @@ def test_del_with_initialized_driver(self) -> None: """Test cleanup behavior when Driver is initialized.""" # Prepare # pylint: disable-next=protected-access - self.driver._get_grpc_driver_and_workload_id() + self.driver._get_grpc_driver_and_run_id() # Execute + # pylint: disable-next=unnecessary-dunder-call self.driver.__del__() # Assert @@ -147,6 +148,7 @@ def test_del_with_initialized_driver(self) -> None: def test_del_with_uninitialized_driver(self) -> None: """Test cleanup behavior when Driver is not initialized.""" # Execute + # pylint: disable-next=unnecessary-dunder-call self.driver.__del__() # Assert diff --git a/src/py/flwr/driver/grpc_driver.py b/src/py/flwr/driver/grpc_driver.py index 7dd0a0f501c5..627b95cdb1b4 100644 --- a/src/py/flwr/driver/grpc_driver.py +++ b/src/py/flwr/driver/grpc_driver.py @@ -24,8 +24,8 @@ from flwr.common.grpc import create_channel from flwr.common.logger import log from flwr.proto.driver_pb2 import ( - CreateWorkloadRequest, - CreateWorkloadResponse, + CreateRunRequest, + CreateRunResponse, GetNodesRequest, GetNodesResponse, PullTaskResRequest, @@ -84,15 +84,15 @@ def disconnect(self) -> None: channel.close() log(INFO, "[Driver] Disconnected") - def create_workload(self, req: CreateWorkloadRequest) -> CreateWorkloadResponse: - """Request for workload ID.""" + def create_run(self, req: CreateRunRequest) -> CreateRunResponse: + """Request for run ID.""" # Check if channel is open if self.stub is None: log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED) - raise Exception("`GrpcDriver` instance not connected") + raise ConnectionError("`GrpcDriver` instance not connected") # Call Driver API - res: CreateWorkloadResponse = self.stub.CreateWorkload(request=req) + res: CreateRunResponse = self.stub.CreateRun(request=req) return res def get_nodes(self, req: GetNodesRequest) -> GetNodesResponse: @@ -100,7 +100,7 @@ def get_nodes(self, req: GetNodesRequest) -> GetNodesResponse: # Check if channel is open if self.stub is None: log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED) - raise Exception("`GrpcDriver` instance not connected") + raise ConnectionError("`GrpcDriver` instance not connected") # Call gRPC Driver API res: GetNodesResponse = self.stub.GetNodes(request=req) @@ -111,7 +111,7 @@ def push_task_ins(self, req: PushTaskInsRequest) -> PushTaskInsResponse: # Check if channel is open if self.stub is None: log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED) - raise Exception("`GrpcDriver` instance not connected") + raise ConnectionError("`GrpcDriver` instance not connected") # Call gRPC Driver API res: PushTaskInsResponse = self.stub.PushTaskIns(request=req) @@ -122,7 +122,7 @@ def pull_task_res(self, req: PullTaskResRequest) -> PullTaskResResponse: # Check if channel is open if self.stub is None: log(ERROR, ERROR_MESSAGE_DRIVER_NOT_CONNECTED) - raise Exception("`GrpcDriver` instance not connected") + raise ConnectionError("`GrpcDriver` instance not connected") # Call Driver API res: PullTaskResResponse = self.stub.PullTaskRes(request=req) diff --git a/src/py/flwr/proto/driver_pb2.py b/src/py/flwr/proto/driver_pb2.py index c138507e03e9..615bf4672afa 100644 --- a/src/py/flwr/proto/driver_pb2.py +++ b/src/py/flwr/proto/driver_pb2.py @@ -16,31 +16,31 @@ from flwr.proto import task_pb2 as flwr_dot_proto_dot_task__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x17\x66lwr/proto/driver.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x15\x66lwr/proto/task.proto\"\x17\n\x15\x43reateWorkloadRequest\"-\n\x16\x43reateWorkloadResponse\x12\x13\n\x0bworkload_id\x18\x01 \x01(\x12\"&\n\x0fGetNodesRequest\x12\x13\n\x0bworkload_id\x18\x01 \x01(\x12\"3\n\x10GetNodesResponse\x12\x1f\n\x05nodes\x18\x01 \x03(\x0b\x32\x10.flwr.proto.Node\"@\n\x12PushTaskInsRequest\x12*\n\rtask_ins_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskIns\"\'\n\x13PushTaskInsResponse\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"F\n\x12PullTaskResRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"A\n\x13PullTaskResResponse\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes2\xd0\x02\n\x06\x44river\x12Y\n\x0e\x43reateWorkload\x12!.flwr.proto.CreateWorkloadRequest\x1a\".flwr.proto.CreateWorkloadResponse\"\x00\x12G\n\x08GetNodes\x12\x1b.flwr.proto.GetNodesRequest\x1a\x1c.flwr.proto.GetNodesResponse\"\x00\x12P\n\x0bPushTaskIns\x12\x1e.flwr.proto.PushTaskInsRequest\x1a\x1f.flwr.proto.PushTaskInsResponse\"\x00\x12P\n\x0bPullTaskRes\x12\x1e.flwr.proto.PullTaskResRequest\x1a\x1f.flwr.proto.PullTaskResResponse\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x17\x66lwr/proto/driver.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x15\x66lwr/proto/task.proto\"\x12\n\x10\x43reateRunRequest\"#\n\x11\x43reateRunResponse\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\"!\n\x0fGetNodesRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\x12\"3\n\x10GetNodesResponse\x12\x1f\n\x05nodes\x18\x01 \x03(\x0b\x32\x10.flwr.proto.Node\"@\n\x12PushTaskInsRequest\x12*\n\rtask_ins_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskIns\"\'\n\x13PushTaskInsResponse\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"F\n\x12PullTaskResRequest\x12\x1e\n\x04node\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x10\n\x08task_ids\x18\x02 \x03(\t\"A\n\x13PullTaskResResponse\x12*\n\rtask_res_list\x18\x01 \x03(\x0b\x32\x13.flwr.proto.TaskRes2\xc1\x02\n\x06\x44river\x12J\n\tCreateRun\x12\x1c.flwr.proto.CreateRunRequest\x1a\x1d.flwr.proto.CreateRunResponse\"\x00\x12G\n\x08GetNodes\x12\x1b.flwr.proto.GetNodesRequest\x1a\x1c.flwr.proto.GetNodesResponse\"\x00\x12P\n\x0bPushTaskIns\x12\x1e.flwr.proto.PushTaskInsRequest\x1a\x1f.flwr.proto.PushTaskInsResponse\"\x00\x12P\n\x0bPullTaskRes\x12\x1e.flwr.proto.PullTaskResRequest\x1a\x1f.flwr.proto.PullTaskResResponse\"\x00\x62\x06proto3') -_CREATEWORKLOADREQUEST = DESCRIPTOR.message_types_by_name['CreateWorkloadRequest'] -_CREATEWORKLOADRESPONSE = DESCRIPTOR.message_types_by_name['CreateWorkloadResponse'] +_CREATERUNREQUEST = DESCRIPTOR.message_types_by_name['CreateRunRequest'] +_CREATERUNRESPONSE = DESCRIPTOR.message_types_by_name['CreateRunResponse'] _GETNODESREQUEST = DESCRIPTOR.message_types_by_name['GetNodesRequest'] _GETNODESRESPONSE = DESCRIPTOR.message_types_by_name['GetNodesResponse'] _PUSHTASKINSREQUEST = DESCRIPTOR.message_types_by_name['PushTaskInsRequest'] _PUSHTASKINSRESPONSE = DESCRIPTOR.message_types_by_name['PushTaskInsResponse'] _PULLTASKRESREQUEST = DESCRIPTOR.message_types_by_name['PullTaskResRequest'] _PULLTASKRESRESPONSE = DESCRIPTOR.message_types_by_name['PullTaskResResponse'] -CreateWorkloadRequest = _reflection.GeneratedProtocolMessageType('CreateWorkloadRequest', (_message.Message,), { - 'DESCRIPTOR' : _CREATEWORKLOADREQUEST, +CreateRunRequest = _reflection.GeneratedProtocolMessageType('CreateRunRequest', (_message.Message,), { + 'DESCRIPTOR' : _CREATERUNREQUEST, '__module__' : 'flwr.proto.driver_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.CreateWorkloadRequest) + # @@protoc_insertion_point(class_scope:flwr.proto.CreateRunRequest) }) -_sym_db.RegisterMessage(CreateWorkloadRequest) +_sym_db.RegisterMessage(CreateRunRequest) -CreateWorkloadResponse = _reflection.GeneratedProtocolMessageType('CreateWorkloadResponse', (_message.Message,), { - 'DESCRIPTOR' : _CREATEWORKLOADRESPONSE, +CreateRunResponse = _reflection.GeneratedProtocolMessageType('CreateRunResponse', (_message.Message,), { + 'DESCRIPTOR' : _CREATERUNRESPONSE, '__module__' : 'flwr.proto.driver_pb2' - # @@protoc_insertion_point(class_scope:flwr.proto.CreateWorkloadResponse) + # @@protoc_insertion_point(class_scope:flwr.proto.CreateRunResponse) }) -_sym_db.RegisterMessage(CreateWorkloadResponse) +_sym_db.RegisterMessage(CreateRunResponse) GetNodesRequest = _reflection.GeneratedProtocolMessageType('GetNodesRequest', (_message.Message,), { 'DESCRIPTOR' : _GETNODESREQUEST, @@ -88,22 +88,22 @@ if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None - _CREATEWORKLOADREQUEST._serialized_start=85 - _CREATEWORKLOADREQUEST._serialized_end=108 - _CREATEWORKLOADRESPONSE._serialized_start=110 - _CREATEWORKLOADRESPONSE._serialized_end=155 - _GETNODESREQUEST._serialized_start=157 - _GETNODESREQUEST._serialized_end=195 - _GETNODESRESPONSE._serialized_start=197 - _GETNODESRESPONSE._serialized_end=248 - _PUSHTASKINSREQUEST._serialized_start=250 - _PUSHTASKINSREQUEST._serialized_end=314 - _PUSHTASKINSRESPONSE._serialized_start=316 - _PUSHTASKINSRESPONSE._serialized_end=355 - _PULLTASKRESREQUEST._serialized_start=357 - _PULLTASKRESREQUEST._serialized_end=427 - _PULLTASKRESRESPONSE._serialized_start=429 - _PULLTASKRESRESPONSE._serialized_end=494 - _DRIVER._serialized_start=497 - _DRIVER._serialized_end=833 + _CREATERUNREQUEST._serialized_start=85 + _CREATERUNREQUEST._serialized_end=103 + _CREATERUNRESPONSE._serialized_start=105 + _CREATERUNRESPONSE._serialized_end=140 + _GETNODESREQUEST._serialized_start=142 + _GETNODESREQUEST._serialized_end=175 + _GETNODESRESPONSE._serialized_start=177 + _GETNODESRESPONSE._serialized_end=228 + _PUSHTASKINSREQUEST._serialized_start=230 + _PUSHTASKINSREQUEST._serialized_end=294 + _PUSHTASKINSRESPONSE._serialized_start=296 + _PUSHTASKINSRESPONSE._serialized_end=335 + _PULLTASKRESREQUEST._serialized_start=337 + _PULLTASKRESREQUEST._serialized_end=407 + _PULLTASKRESRESPONSE._serialized_start=409 + _PULLTASKRESRESPONSE._serialized_end=474 + _DRIVER._serialized_start=477 + _DRIVER._serialized_end=798 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/driver_pb2.pyi b/src/py/flwr/proto/driver_pb2.pyi index 8b940972cb6d..8dc254a55e8c 100644 --- a/src/py/flwr/proto/driver_pb2.pyi +++ b/src/py/flwr/proto/driver_pb2.pyi @@ -13,34 +13,34 @@ import typing_extensions DESCRIPTOR: google.protobuf.descriptor.FileDescriptor -class CreateWorkloadRequest(google.protobuf.message.Message): - """CreateWorkload""" +class CreateRunRequest(google.protobuf.message.Message): + """CreateRun""" DESCRIPTOR: google.protobuf.descriptor.Descriptor def __init__(self, ) -> None: ... -global___CreateWorkloadRequest = CreateWorkloadRequest +global___CreateRunRequest = CreateRunRequest -class CreateWorkloadResponse(google.protobuf.message.Message): +class CreateRunResponse(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor - WORKLOAD_ID_FIELD_NUMBER: builtins.int - workload_id: builtins.int + RUN_ID_FIELD_NUMBER: builtins.int + run_id: builtins.int def __init__(self, *, - workload_id: builtins.int = ..., + run_id: builtins.int = ..., ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["workload_id",b"workload_id"]) -> None: ... -global___CreateWorkloadResponse = CreateWorkloadResponse + def ClearField(self, field_name: typing_extensions.Literal["run_id",b"run_id"]) -> None: ... +global___CreateRunResponse = CreateRunResponse class GetNodesRequest(google.protobuf.message.Message): """GetNodes messages""" DESCRIPTOR: google.protobuf.descriptor.Descriptor - WORKLOAD_ID_FIELD_NUMBER: builtins.int - workload_id: builtins.int + RUN_ID_FIELD_NUMBER: builtins.int + run_id: builtins.int def __init__(self, *, - workload_id: builtins.int = ..., + run_id: builtins.int = ..., ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["workload_id",b"workload_id"]) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["run_id",b"run_id"]) -> None: ... global___GetNodesRequest = GetNodesRequest class GetNodesResponse(google.protobuf.message.Message): diff --git a/src/py/flwr/proto/driver_pb2_grpc.py b/src/py/flwr/proto/driver_pb2_grpc.py index ea33b843d945..ac6815023ebd 100644 --- a/src/py/flwr/proto/driver_pb2_grpc.py +++ b/src/py/flwr/proto/driver_pb2_grpc.py @@ -14,10 +14,10 @@ def __init__(self, channel): Args: channel: A grpc.Channel. """ - self.CreateWorkload = channel.unary_unary( - '/flwr.proto.Driver/CreateWorkload', - request_serializer=flwr_dot_proto_dot_driver__pb2.CreateWorkloadRequest.SerializeToString, - response_deserializer=flwr_dot_proto_dot_driver__pb2.CreateWorkloadResponse.FromString, + self.CreateRun = channel.unary_unary( + '/flwr.proto.Driver/CreateRun', + request_serializer=flwr_dot_proto_dot_driver__pb2.CreateRunRequest.SerializeToString, + response_deserializer=flwr_dot_proto_dot_driver__pb2.CreateRunResponse.FromString, ) self.GetNodes = channel.unary_unary( '/flwr.proto.Driver/GetNodes', @@ -39,8 +39,8 @@ def __init__(self, channel): class DriverServicer(object): """Missing associated documentation comment in .proto file.""" - def CreateWorkload(self, request, context): - """Request workload_id + def CreateRun(self, request, context): + """Request run_id """ context.set_code(grpc.StatusCode.UNIMPLEMENTED) context.set_details('Method not implemented!') @@ -70,10 +70,10 @@ def PullTaskRes(self, request, context): def add_DriverServicer_to_server(servicer, server): rpc_method_handlers = { - 'CreateWorkload': grpc.unary_unary_rpc_method_handler( - servicer.CreateWorkload, - request_deserializer=flwr_dot_proto_dot_driver__pb2.CreateWorkloadRequest.FromString, - response_serializer=flwr_dot_proto_dot_driver__pb2.CreateWorkloadResponse.SerializeToString, + 'CreateRun': grpc.unary_unary_rpc_method_handler( + servicer.CreateRun, + request_deserializer=flwr_dot_proto_dot_driver__pb2.CreateRunRequest.FromString, + response_serializer=flwr_dot_proto_dot_driver__pb2.CreateRunResponse.SerializeToString, ), 'GetNodes': grpc.unary_unary_rpc_method_handler( servicer.GetNodes, @@ -101,7 +101,7 @@ class Driver(object): """Missing associated documentation comment in .proto file.""" @staticmethod - def CreateWorkload(request, + def CreateRun(request, target, options=(), channel_credentials=None, @@ -111,9 +111,9 @@ def CreateWorkload(request, wait_for_ready=None, timeout=None, metadata=None): - return grpc.experimental.unary_unary(request, target, '/flwr.proto.Driver/CreateWorkload', - flwr_dot_proto_dot_driver__pb2.CreateWorkloadRequest.SerializeToString, - flwr_dot_proto_dot_driver__pb2.CreateWorkloadResponse.FromString, + return grpc.experimental.unary_unary(request, target, '/flwr.proto.Driver/CreateRun', + flwr_dot_proto_dot_driver__pb2.CreateRunRequest.SerializeToString, + flwr_dot_proto_dot_driver__pb2.CreateRunResponse.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/src/py/flwr/proto/driver_pb2_grpc.pyi b/src/py/flwr/proto/driver_pb2_grpc.pyi index 1b10d71e943d..43cf45f39b25 100644 --- a/src/py/flwr/proto/driver_pb2_grpc.pyi +++ b/src/py/flwr/proto/driver_pb2_grpc.pyi @@ -8,10 +8,10 @@ import grpc class DriverStub: def __init__(self, channel: grpc.Channel) -> None: ... - CreateWorkload: grpc.UnaryUnaryMultiCallable[ - flwr.proto.driver_pb2.CreateWorkloadRequest, - flwr.proto.driver_pb2.CreateWorkloadResponse] - """Request workload_id""" + CreateRun: grpc.UnaryUnaryMultiCallable[ + flwr.proto.driver_pb2.CreateRunRequest, + flwr.proto.driver_pb2.CreateRunResponse] + """Request run_id""" GetNodes: grpc.UnaryUnaryMultiCallable[ flwr.proto.driver_pb2.GetNodesRequest, @@ -31,11 +31,11 @@ class DriverStub: class DriverServicer(metaclass=abc.ABCMeta): @abc.abstractmethod - def CreateWorkload(self, - request: flwr.proto.driver_pb2.CreateWorkloadRequest, + def CreateRun(self, + request: flwr.proto.driver_pb2.CreateRunRequest, context: grpc.ServicerContext, - ) -> flwr.proto.driver_pb2.CreateWorkloadResponse: - """Request workload_id""" + ) -> flwr.proto.driver_pb2.CreateRunResponse: + """Request run_id""" pass @abc.abstractmethod diff --git a/src/py/flwr/proto/task_pb2.py b/src/py/flwr/proto/task_pb2.py index 6d8cf8fd3656..ba0e2e3f5218 100644 --- a/src/py/flwr/proto/task_pb2.py +++ b/src/py/flwr/proto/task_pb2.py @@ -16,7 +16,7 @@ from flwr.proto import transport_pb2 as flwr_dot_proto_dot_transport__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/task.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x1a\x66lwr/proto/transport.proto\"\xbe\x02\n\x04Task\x12\"\n\x08producer\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\"\n\x08\x63onsumer\x18\x02 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x12\n\ncreated_at\x18\x03 \x01(\t\x12\x14\n\x0c\x64\x65livered_at\x18\x04 \x01(\t\x12\x0b\n\x03ttl\x18\x05 \x01(\t\x12\x10\n\x08\x61ncestry\x18\x06 \x03(\t\x12)\n\x02sa\x18\x07 \x01(\x0b\x32\x1d.flwr.proto.SecureAggregation\x12<\n\x15legacy_server_message\x18\x65 \x01(\x0b\x32\x19.flwr.proto.ServerMessageB\x02\x18\x01\x12<\n\x15legacy_client_message\x18\x66 \x01(\x0b\x32\x19.flwr.proto.ClientMessageB\x02\x18\x01\"a\n\x07TaskIns\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x13\n\x0bworkload_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Task\"a\n\x07TaskRes\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x13\n\x0bworkload_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Task\"\xf3\x03\n\x05Value\x12\x10\n\x06\x64ouble\x18\x01 \x01(\x01H\x00\x12\x10\n\x06sint64\x18\x02 \x01(\x12H\x00\x12\x0e\n\x04\x62ool\x18\x03 \x01(\x08H\x00\x12\x10\n\x06string\x18\x04 \x01(\tH\x00\x12\x0f\n\x05\x62ytes\x18\x05 \x01(\x0cH\x00\x12\x33\n\x0b\x64ouble_list\x18\x15 \x01(\x0b\x32\x1c.flwr.proto.Value.DoubleListH\x00\x12\x33\n\x0bsint64_list\x18\x16 \x01(\x0b\x32\x1c.flwr.proto.Value.Sint64ListH\x00\x12/\n\tbool_list\x18\x17 \x01(\x0b\x32\x1a.flwr.proto.Value.BoolListH\x00\x12\x33\n\x0bstring_list\x18\x18 \x01(\x0b\x32\x1c.flwr.proto.Value.StringListH\x00\x12\x31\n\nbytes_list\x18\x19 \x01(\x0b\x32\x1b.flwr.proto.Value.BytesListH\x00\x1a\x1a\n\nDoubleList\x12\x0c\n\x04vals\x18\x01 \x03(\x01\x1a\x1a\n\nSint64List\x12\x0c\n\x04vals\x18\x01 \x03(\x12\x1a\x18\n\x08\x42oolList\x12\x0c\n\x04vals\x18\x01 \x03(\x08\x1a\x1a\n\nStringList\x12\x0c\n\x04vals\x18\x01 \x03(\t\x1a\x19\n\tBytesList\x12\x0c\n\x04vals\x18\x01 \x03(\x0c\x42\x07\n\x05value\"\xa0\x01\n\x11SecureAggregation\x12\x44\n\x0cnamed_values\x18\x01 \x03(\x0b\x32..flwr.proto.SecureAggregation.NamedValuesEntry\x1a\x45\n\x10NamedValuesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12 \n\x05value\x18\x02 \x01(\x0b\x32\x11.flwr.proto.Value:\x02\x38\x01\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/task.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x1a\x66lwr/proto/transport.proto\"\xbe\x02\n\x04Task\x12\"\n\x08producer\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\"\n\x08\x63onsumer\x18\x02 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x12\n\ncreated_at\x18\x03 \x01(\t\x12\x14\n\x0c\x64\x65livered_at\x18\x04 \x01(\t\x12\x0b\n\x03ttl\x18\x05 \x01(\t\x12\x10\n\x08\x61ncestry\x18\x06 \x03(\t\x12)\n\x02sa\x18\x07 \x01(\x0b\x32\x1d.flwr.proto.SecureAggregation\x12<\n\x15legacy_server_message\x18\x65 \x01(\x0b\x32\x19.flwr.proto.ServerMessageB\x02\x18\x01\x12<\n\x15legacy_client_message\x18\x66 \x01(\x0b\x32\x19.flwr.proto.ClientMessageB\x02\x18\x01\"\\\n\x07TaskIns\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Task\"\\\n\x07TaskRes\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Task\"\xf3\x03\n\x05Value\x12\x10\n\x06\x64ouble\x18\x01 \x01(\x01H\x00\x12\x10\n\x06sint64\x18\x02 \x01(\x12H\x00\x12\x0e\n\x04\x62ool\x18\x03 \x01(\x08H\x00\x12\x10\n\x06string\x18\x04 \x01(\tH\x00\x12\x0f\n\x05\x62ytes\x18\x05 \x01(\x0cH\x00\x12\x33\n\x0b\x64ouble_list\x18\x15 \x01(\x0b\x32\x1c.flwr.proto.Value.DoubleListH\x00\x12\x33\n\x0bsint64_list\x18\x16 \x01(\x0b\x32\x1c.flwr.proto.Value.Sint64ListH\x00\x12/\n\tbool_list\x18\x17 \x01(\x0b\x32\x1a.flwr.proto.Value.BoolListH\x00\x12\x33\n\x0bstring_list\x18\x18 \x01(\x0b\x32\x1c.flwr.proto.Value.StringListH\x00\x12\x31\n\nbytes_list\x18\x19 \x01(\x0b\x32\x1b.flwr.proto.Value.BytesListH\x00\x1a\x1a\n\nDoubleList\x12\x0c\n\x04vals\x18\x01 \x03(\x01\x1a\x1a\n\nSint64List\x12\x0c\n\x04vals\x18\x01 \x03(\x12\x1a\x18\n\x08\x42oolList\x12\x0c\n\x04vals\x18\x01 \x03(\x08\x1a\x1a\n\nStringList\x12\x0c\n\x04vals\x18\x01 \x03(\t\x1a\x19\n\tBytesList\x12\x0c\n\x04vals\x18\x01 \x03(\x0c\x42\x07\n\x05value\"\xa0\x01\n\x11SecureAggregation\x12\x44\n\x0cnamed_values\x18\x01 \x03(\x0b\x32..flwr.proto.SecureAggregation.NamedValuesEntry\x1a\x45\n\x10NamedValuesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12 \n\x05value\x18\x02 \x01(\x0b\x32\x11.flwr.proto.Value:\x02\x38\x01\x62\x06proto3') @@ -126,23 +126,23 @@ _TASK._serialized_start=89 _TASK._serialized_end=407 _TASKINS._serialized_start=409 - _TASKINS._serialized_end=506 - _TASKRES._serialized_start=508 - _TASKRES._serialized_end=605 - _VALUE._serialized_start=608 - _VALUE._serialized_end=1107 - _VALUE_DOUBLELIST._serialized_start=963 - _VALUE_DOUBLELIST._serialized_end=989 - _VALUE_SINT64LIST._serialized_start=991 - _VALUE_SINT64LIST._serialized_end=1017 - _VALUE_BOOLLIST._serialized_start=1019 - _VALUE_BOOLLIST._serialized_end=1043 - _VALUE_STRINGLIST._serialized_start=1045 - _VALUE_STRINGLIST._serialized_end=1071 - _VALUE_BYTESLIST._serialized_start=1073 - _VALUE_BYTESLIST._serialized_end=1098 - _SECUREAGGREGATION._serialized_start=1110 - _SECUREAGGREGATION._serialized_end=1270 - _SECUREAGGREGATION_NAMEDVALUESENTRY._serialized_start=1201 - _SECUREAGGREGATION_NAMEDVALUESENTRY._serialized_end=1270 + _TASKINS._serialized_end=501 + _TASKRES._serialized_start=503 + _TASKRES._serialized_end=595 + _VALUE._serialized_start=598 + _VALUE._serialized_end=1097 + _VALUE_DOUBLELIST._serialized_start=953 + _VALUE_DOUBLELIST._serialized_end=979 + _VALUE_SINT64LIST._serialized_start=981 + _VALUE_SINT64LIST._serialized_end=1007 + _VALUE_BOOLLIST._serialized_start=1009 + _VALUE_BOOLLIST._serialized_end=1033 + _VALUE_STRINGLIST._serialized_start=1035 + _VALUE_STRINGLIST._serialized_end=1061 + _VALUE_BYTESLIST._serialized_start=1063 + _VALUE_BYTESLIST._serialized_end=1088 + _SECUREAGGREGATION._serialized_start=1100 + _SECUREAGGREGATION._serialized_end=1260 + _SECUREAGGREGATION_NAMEDVALUESENTRY._serialized_start=1191 + _SECUREAGGREGATION_NAMEDVALUESENTRY._serialized_end=1260 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/task_pb2.pyi b/src/py/flwr/proto/task_pb2.pyi index 7cf96cb61edf..f40a66ef98d1 100644 --- a/src/py/flwr/proto/task_pb2.pyi +++ b/src/py/flwr/proto/task_pb2.pyi @@ -59,44 +59,44 @@ class TaskIns(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor TASK_ID_FIELD_NUMBER: builtins.int GROUP_ID_FIELD_NUMBER: builtins.int - WORKLOAD_ID_FIELD_NUMBER: builtins.int + RUN_ID_FIELD_NUMBER: builtins.int TASK_FIELD_NUMBER: builtins.int task_id: typing.Text group_id: typing.Text - workload_id: builtins.int + run_id: builtins.int @property def task(self) -> global___Task: ... def __init__(self, *, task_id: typing.Text = ..., group_id: typing.Text = ..., - workload_id: builtins.int = ..., + run_id: builtins.int = ..., task: typing.Optional[global___Task] = ..., ) -> None: ... def HasField(self, field_name: typing_extensions.Literal["task",b"task"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["group_id",b"group_id","task",b"task","task_id",b"task_id","workload_id",b"workload_id"]) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["group_id",b"group_id","run_id",b"run_id","task",b"task","task_id",b"task_id"]) -> None: ... global___TaskIns = TaskIns class TaskRes(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor TASK_ID_FIELD_NUMBER: builtins.int GROUP_ID_FIELD_NUMBER: builtins.int - WORKLOAD_ID_FIELD_NUMBER: builtins.int + RUN_ID_FIELD_NUMBER: builtins.int TASK_FIELD_NUMBER: builtins.int task_id: typing.Text group_id: typing.Text - workload_id: builtins.int + run_id: builtins.int @property def task(self) -> global___Task: ... def __init__(self, *, task_id: typing.Text = ..., group_id: typing.Text = ..., - workload_id: builtins.int = ..., + run_id: builtins.int = ..., task: typing.Optional[global___Task] = ..., ) -> None: ... def HasField(self, field_name: typing_extensions.Literal["task",b"task"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["group_id",b"group_id","task",b"task","task_id",b"task_id","workload_id",b"workload_id"]) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["group_id",b"group_id","run_id",b"run_id","task",b"task","task_id",b"task_id"]) -> None: ... global___TaskRes = TaskRes class Value(google.protobuf.message.Message): diff --git a/src/py/flwr/server/driver/driver_servicer.py b/src/py/flwr/server/driver/driver_servicer.py index f96b3b1262ac..546ebd884ca9 100644 --- a/src/py/flwr/server/driver/driver_servicer.py +++ b/src/py/flwr/server/driver/driver_servicer.py @@ -24,8 +24,8 @@ from flwr.common.logger import log from flwr.proto import driver_pb2_grpc from flwr.proto.driver_pb2 import ( - CreateWorkloadRequest, - CreateWorkloadResponse, + CreateRunRequest, + CreateRunResponse, GetNodesRequest, GetNodesResponse, PullTaskResRequest, @@ -51,20 +51,20 @@ def GetNodes( """Get available nodes.""" log(INFO, "DriverServicer.GetNodes") state: State = self.state_factory.state() - all_ids: Set[int] = state.get_nodes(request.workload_id) + all_ids: Set[int] = state.get_nodes(request.run_id) nodes: List[Node] = [ Node(node_id=node_id, anonymous=False) for node_id in all_ids ] return GetNodesResponse(nodes=nodes) - def CreateWorkload( - self, request: CreateWorkloadRequest, context: grpc.ServicerContext - ) -> CreateWorkloadResponse: - """Create workload ID.""" - log(INFO, "DriverServicer.CreateWorkload") + def CreateRun( + self, request: CreateRunRequest, context: grpc.ServicerContext + ) -> CreateRunResponse: + """Create run ID.""" + log(INFO, "DriverServicer.CreateRun") state: State = self.state_factory.state() - workload_id = state.create_workload() - return CreateWorkloadResponse(workload_id=workload_id) + run_id = state.create_run() + return CreateRunResponse(run_id=run_id) def PushTaskIns( self, request: PushTaskInsRequest, context: grpc.ServicerContext diff --git a/src/py/flwr/server/fleet/grpc_bidi/grpc_bridge.py b/src/py/flwr/server/fleet/grpc_bidi/grpc_bridge.py index 6ae38ea3d805..4e68499f018d 100644 --- a/src/py/flwr/server/fleet/grpc_bidi/grpc_bridge.py +++ b/src/py/flwr/server/fleet/grpc_bidi/grpc_bridge.py @@ -113,7 +113,7 @@ def _transition(self, next_status: Status) -> None: ): self._status = next_status else: - raise Exception(f"Invalid transition: {self._status} to {next_status}") + raise ValueError(f"Invalid transition: {self._status} to {next_status}") self._cv.notify_all() @@ -129,7 +129,7 @@ def request(self, ins_wrapper: InsWrapper) -> ResWrapper: self._raise_if_closed() if self._status != Status.AWAITING_INS_WRAPPER: - raise Exception("This should not happen") + raise ValueError("This should not happen") self._ins_wrapper = ins_wrapper # Write self._transition(Status.INS_WRAPPER_AVAILABLE) @@ -146,7 +146,7 @@ def request(self, ins_wrapper: InsWrapper) -> ResWrapper: self._transition(Status.AWAITING_INS_WRAPPER) if res_wrapper is None: - raise Exception("ResWrapper can not be None") + raise ValueError("ResWrapper can not be None") return res_wrapper @@ -170,7 +170,7 @@ def ins_wrapper_iterator(self) -> Iterator[InsWrapper]: self._transition(Status.AWAITING_RES_WRAPPER) if ins_wrapper is None: - raise Exception("InsWrapper can not be None") + raise ValueError("InsWrapper can not be None") yield ins_wrapper @@ -180,7 +180,7 @@ def set_res_wrapper(self, res_wrapper: ResWrapper) -> None: self._raise_if_closed() if self._status != Status.AWAITING_RES_WRAPPER: - raise Exception("This should not happen") + raise ValueError("This should not happen") self._res_wrapper = res_wrapper # Write self._transition(Status.RES_WRAPPER_AVAILABLE) diff --git a/src/py/flwr/server/fleet/grpc_bidi/grpc_bridge_test.py b/src/py/flwr/server/fleet/grpc_bidi/grpc_bridge_test.py index 18a2144072ed..bcfbe6e6fac8 100644 --- a/src/py/flwr/server/fleet/grpc_bidi/grpc_bridge_test.py +++ b/src/py/flwr/server/fleet/grpc_bidi/grpc_bridge_test.py @@ -70,6 +70,7 @@ def test_workflow_successful() -> None: _ = next(ins_wrapper_iterator) bridge.set_res_wrapper(ResWrapper(client_message=ClientMessage())) except Exception as exception: + # pylint: disable-next=broad-exception-raised raise Exception from exception # Wait until worker_thread is finished diff --git a/src/py/flwr/server/fleet/grpc_bidi/ins_scheduler.py b/src/py/flwr/server/fleet/grpc_bidi/ins_scheduler.py index 1c737d31c7fc..0fa6f82a89b5 100644 --- a/src/py/flwr/server/fleet/grpc_bidi/ins_scheduler.py +++ b/src/py/flwr/server/fleet/grpc_bidi/ins_scheduler.py @@ -166,6 +166,6 @@ def _call_client_proxy( evaluate_res_proto = serde.evaluate_res_to_proto(res=evaluate_res) return ClientMessage(evaluate_res=evaluate_res_proto) - raise Exception( + raise ValueError( "Unsupported instruction in ServerMessage, cannot deserialize from ProtoBuf" ) diff --git a/src/py/flwr/server/fleet/message_handler/message_handler_test.py b/src/py/flwr/server/fleet/message_handler/message_handler_test.py index 25fd822492f2..bb2205e26b18 100644 --- a/src/py/flwr/server/fleet/message_handler/message_handler_test.py +++ b/src/py/flwr/server/fleet/message_handler/message_handler_test.py @@ -109,7 +109,7 @@ def test_push_task_res() -> None: TaskRes( task_id="", group_id="", - workload_id=0, + run_id=0, task=Task(), ), ], diff --git a/src/py/flwr/server/server_test.py b/src/py/flwr/server/server_test.py index 63ec1021ff5c..9b5c03aeeaf9 100644 --- a/src/py/flwr/server/server_test.py +++ b/src/py/flwr/server/server_test.py @@ -47,14 +47,14 @@ class SuccessClient(ClientProxy): def get_properties( self, ins: GetPropertiesIns, timeout: Optional[float] ) -> GetPropertiesRes: - """Raise an Exception because this method is not expected to be called.""" - raise Exception() + """Raise an error because this method is not expected to be called.""" + raise NotImplementedError() def get_parameters( self, ins: GetParametersIns, timeout: Optional[float] ) -> GetParametersRes: - """Raise an Exception because this method is not expected to be called.""" - raise Exception() + """Raise a error because this method is not expected to be called.""" + raise NotImplementedError() def fit(self, ins: FitIns, timeout: Optional[float]) -> FitRes: """Simulate fit by returning a success FitRes with simple set of weights.""" @@ -87,26 +87,26 @@ class FailingClient(ClientProxy): def get_properties( self, ins: GetPropertiesIns, timeout: Optional[float] ) -> GetPropertiesRes: - """Raise an Exception to simulate failure in the client.""" - raise Exception() + """Raise a NotImplementedError to simulate failure in the client.""" + raise NotImplementedError() def get_parameters( self, ins: GetParametersIns, timeout: Optional[float] ) -> GetParametersRes: - """Raise an Exception to simulate failure in the client.""" - raise Exception() + """Raise a NotImplementedError to simulate failure in the client.""" + raise NotImplementedError() def fit(self, ins: FitIns, timeout: Optional[float]) -> FitRes: - """Raise an Exception to simulate failure in the client.""" - raise Exception() + """Raise a NotImplementedError to simulate failure in the client.""" + raise NotImplementedError() def evaluate(self, ins: EvaluateIns, timeout: Optional[float]) -> EvaluateRes: - """Raise an Exception to simulate failure in the client.""" - raise Exception() + """Raise a NotImplementedError to simulate failure in the client.""" + raise NotImplementedError() def reconnect(self, ins: ReconnectIns, timeout: Optional[float]) -> DisconnectRes: - """Raise an Exception to simulate failure in the client.""" - raise Exception() + """Raise a NotImplementedError to simulate failure in the client.""" + raise NotImplementedError() def test_fit_clients() -> None: diff --git a/src/py/flwr/server/state/in_memory_state.py b/src/py/flwr/server/state/in_memory_state.py index 384839b7461f..f8352fcfb091 100644 --- a/src/py/flwr/server/state/in_memory_state.py +++ b/src/py/flwr/server/state/in_memory_state.py @@ -32,7 +32,7 @@ class InMemoryState(State): def __init__(self) -> None: self.node_ids: Set[int] = set() - self.workload_ids: Set[int] = set() + self.run_ids: Set[int] = set() self.task_ins_store: Dict[UUID, TaskIns] = {} self.task_res_store: Dict[UUID, TaskRes] = {} @@ -43,9 +43,9 @@ def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]: if any(errors): log(ERROR, errors) return None - # Validate workload_id - if task_ins.workload_id not in self.workload_ids: - log(ERROR, "`workload_id` is invalid") + # Validate run_id + if task_ins.run_id not in self.run_ids: + log(ERROR, "`run_id` is invalid") return None # Create task_id, created_at and ttl @@ -104,9 +104,9 @@ def store_task_res(self, task_res: TaskRes) -> Optional[UUID]: log(ERROR, errors) return None - # Validate workload_id - if task_res.workload_id not in self.workload_ids: - log(ERROR, "`workload_id` is invalid") + # Validate run_id + if task_res.run_id not in self.run_ids: + log(ERROR, "`run_id` is invalid") return None # Create task_id, created_at and ttl @@ -199,25 +199,25 @@ def delete_node(self, node_id: int) -> None: raise ValueError(f"Node {node_id} not found") self.node_ids.remove(node_id) - def get_nodes(self, workload_id: int) -> Set[int]: + def get_nodes(self, run_id: int) -> Set[int]: """Return all available client nodes. Constraints ----------- - If the provided `workload_id` does not exist or has no matching nodes, + If the provided `run_id` does not exist or has no matching nodes, an empty `Set` MUST be returned. """ - if workload_id not in self.workload_ids: + if run_id not in self.run_ids: return set() return self.node_ids - def create_workload(self) -> int: - """Create one workload.""" - # Sample a random int64 as workload_id - workload_id: int = int.from_bytes(os.urandom(8), "little", signed=True) + def create_run(self) -> int: + """Create one run.""" + # Sample a random int64 as run_id + run_id: int = int.from_bytes(os.urandom(8), "little", signed=True) - if workload_id not in self.workload_ids: - self.workload_ids.add(workload_id) - return workload_id - log(ERROR, "Unexpected workload creation failure.") + if run_id not in self.run_ids: + self.run_ids.add(run_id) + return run_id + log(ERROR, "Unexpected run creation failure.") return 0 diff --git a/src/py/flwr/server/state/sqlite_state.py b/src/py/flwr/server/state/sqlite_state.py index f3ff60f370e9..26f326819971 100644 --- a/src/py/flwr/server/state/sqlite_state.py +++ b/src/py/flwr/server/state/sqlite_state.py @@ -37,9 +37,9 @@ ); """ -SQL_CREATE_TABLE_WORKLOAD = """ -CREATE TABLE IF NOT EXISTS workload( - workload_id INTEGER UNIQUE +SQL_CREATE_TABLE_RUN = """ +CREATE TABLE IF NOT EXISTS run( + run_id INTEGER UNIQUE ); """ @@ -47,7 +47,7 @@ CREATE TABLE IF NOT EXISTS task_ins( task_id TEXT UNIQUE, group_id TEXT, - workload_id INTEGER, + run_id INTEGER, producer_anonymous BOOLEAN, producer_node_id INTEGER, consumer_anonymous BOOLEAN, @@ -58,7 +58,7 @@ ancestry TEXT, legacy_server_message BLOB, legacy_client_message BLOB, - FOREIGN KEY(workload_id) REFERENCES workload(workload_id) + FOREIGN KEY(run_id) REFERENCES run(run_id) ); """ @@ -67,7 +67,7 @@ CREATE TABLE IF NOT EXISTS task_res( task_id TEXT UNIQUE, group_id TEXT, - workload_id INTEGER, + run_id INTEGER, producer_anonymous BOOLEAN, producer_node_id INTEGER, consumer_anonymous BOOLEAN, @@ -78,7 +78,7 @@ ancestry TEXT, legacy_server_message BLOB, legacy_client_message BLOB, - FOREIGN KEY(workload_id) REFERENCES workload(workload_id) + FOREIGN KEY(run_id) REFERENCES run(run_id) ); """ @@ -119,7 +119,7 @@ def initialize(self, log_queries: bool = False) -> List[Tuple[str]]: cur = self.conn.cursor() # Create each table if not exists queries - cur.execute(SQL_CREATE_TABLE_WORKLOAD) + cur.execute(SQL_CREATE_TABLE_RUN) cur.execute(SQL_CREATE_TABLE_TASK_INS) cur.execute(SQL_CREATE_TABLE_TASK_RES) cur.execute(SQL_CREATE_TABLE_NODE) @@ -134,7 +134,7 @@ def query( ) -> List[Dict[str, Any]]: """Execute a SQL query.""" if self.conn is None: - raise Exception("State is not initialized.") + raise AttributeError("State is not initialized.") if data is None: data = [] @@ -198,12 +198,12 @@ def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]: columns = ", ".join([f":{key}" for key in data[0]]) query = f"INSERT INTO task_ins VALUES({columns});" - # Only invalid workload_id can trigger IntegrityError. + # Only invalid run_id can trigger IntegrityError. # This may need to be changed in the future version with more integrity checks. try: self.query(query, data) except sqlite3.IntegrityError: - log(ERROR, "`workload` is invalid") + log(ERROR, "`run` is invalid") return None return task_id @@ -333,12 +333,12 @@ def store_task_res(self, task_res: TaskRes) -> Optional[UUID]: columns = ", ".join([f":{key}" for key in data[0]]) query = f"INSERT INTO task_res VALUES({columns});" - # Only invalid workload_id can trigger IntegrityError. + # Only invalid run_id can trigger IntegrityError. # This may need to be changed in the future version with more integrity checks. try: self.query(query, data) except sqlite3.IntegrityError: - log(ERROR, "`workload` is invalid") + log(ERROR, "`run` is invalid") return None return task_id @@ -459,7 +459,7 @@ def delete_tasks(self, task_ids: Set[UUID]) -> None: """ if self.conn is None: - raise Exception("State not intitialized") + raise AttributeError("State not intitialized") with self.conn: self.conn.execute(query_1, data) @@ -485,17 +485,17 @@ def delete_node(self, node_id: int) -> None: query = "DELETE FROM node WHERE node_id = :node_id;" self.query(query, {"node_id": node_id}) - def get_nodes(self, workload_id: int) -> Set[int]: + def get_nodes(self, run_id: int) -> Set[int]: """Retrieve all currently stored node IDs as a set. Constraints ----------- - If the provided `workload_id` does not exist or has no matching nodes, + If the provided `run_id` does not exist or has no matching nodes, an empty `Set` MUST be returned. """ - # Validate workload ID - query = "SELECT COUNT(*) FROM workload WHERE workload_id = ?;" - if self.query(query, (workload_id,))[0]["COUNT(*)"] == 0: + # Validate run ID + query = "SELECT COUNT(*) FROM run WHERE run_id = ?;" + if self.query(query, (run_id,))[0]["COUNT(*)"] == 0: return set() # Get nodes @@ -504,19 +504,19 @@ def get_nodes(self, workload_id: int) -> Set[int]: result: Set[int] = {row["node_id"] for row in rows} return result - def create_workload(self) -> int: - """Create one workload and store it in state.""" - # Sample a random int64 as workload_id - workload_id: int = int.from_bytes(os.urandom(8), "little", signed=True) + def create_run(self) -> int: + """Create one run and store it in state.""" + # Sample a random int64 as run_id + run_id: int = int.from_bytes(os.urandom(8), "little", signed=True) # Check conflicts - query = "SELECT COUNT(*) FROM workload WHERE workload_id = ?;" - # If workload_id does not exist - if self.query(query, (workload_id,))[0]["COUNT(*)"] == 0: - query = "INSERT INTO workload VALUES(:workload_id);" - self.query(query, {"workload_id": workload_id}) - return workload_id - log(ERROR, "Unexpected workload creation failure.") + query = "SELECT COUNT(*) FROM run WHERE run_id = ?;" + # If run_id does not exist + if self.query(query, (run_id,))[0]["COUNT(*)"] == 0: + query = "INSERT INTO run VALUES(:run_id);" + self.query(query, {"run_id": run_id}) + return run_id + log(ERROR, "Unexpected run creation failure.") return 0 @@ -537,7 +537,7 @@ def task_ins_to_dict(task_msg: TaskIns) -> Dict[str, Any]: result = { "task_id": task_msg.task_id, "group_id": task_msg.group_id, - "workload_id": task_msg.workload_id, + "run_id": task_msg.run_id, "producer_anonymous": task_msg.task.producer.anonymous, "producer_node_id": task_msg.task.producer.node_id, "consumer_anonymous": task_msg.task.consumer.anonymous, @@ -559,7 +559,7 @@ def task_res_to_dict(task_msg: TaskRes) -> Dict[str, Any]: result = { "task_id": task_msg.task_id, "group_id": task_msg.group_id, - "workload_id": task_msg.workload_id, + "run_id": task_msg.run_id, "producer_anonymous": task_msg.task.producer.anonymous, "producer_node_id": task_msg.task.producer.node_id, "consumer_anonymous": task_msg.task.consumer.anonymous, @@ -584,7 +584,7 @@ def dict_to_task_ins(task_dict: Dict[str, Any]) -> TaskIns: result = TaskIns( task_id=task_dict["task_id"], group_id=task_dict["group_id"], - workload_id=task_dict["workload_id"], + run_id=task_dict["run_id"], task=Task( producer=Node( node_id=task_dict["producer_node_id"], @@ -612,7 +612,7 @@ def dict_to_task_res(task_dict: Dict[str, Any]) -> TaskRes: result = TaskRes( task_id=task_dict["task_id"], group_id=task_dict["group_id"], - workload_id=task_dict["workload_id"], + run_id=task_dict["run_id"], task=Task( producer=Node( node_id=task_dict["producer_node_id"], diff --git a/src/py/flwr/server/state/sqlite_state_test.py b/src/py/flwr/server/state/sqlite_state_test.py index da8fead1438e..a3f899386011 100644 --- a/src/py/flwr/server/state/sqlite_state_test.py +++ b/src/py/flwr/server/state/sqlite_state_test.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================== """Test for utility functions.""" -# pylint: disable=no-self-use, invalid-name, disable=R0904 +# pylint: disable=invalid-name, disable=R0904 import unittest @@ -27,11 +27,11 @@ class SqliteStateTest(unittest.TestCase): def test_ins_res_to_dict(self) -> None: """Check if all required keys are included in return value.""" # Prepare - ins_res = create_task_ins(consumer_node_id=1, anonymous=True, workload_id=0) + ins_res = create_task_ins(consumer_node_id=1, anonymous=True, run_id=0) expected_keys = [ "task_id", "group_id", - "workload_id", + "run_id", "producer_anonymous", "producer_node_id", "consumer_anonymous", diff --git a/src/py/flwr/server/state/state.py b/src/py/flwr/server/state/state.py index fd8bbc8e8e25..7ab3b6bc0848 100644 --- a/src/py/flwr/server/state/state.py +++ b/src/py/flwr/server/state/state.py @@ -43,7 +43,7 @@ def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]: If `task_ins.task.consumer.anonymous` is `False`, then `task_ins.task.consumer.node_id` MUST be set (not 0) - If `task_ins.workload_id` is invalid, then + If `task_ins.run_id` is invalid, then storing the `task_ins` MUST fail. """ @@ -92,7 +92,7 @@ def store_task_res(self, task_res: TaskRes) -> Optional[UUID]: If `task_res.task.consumer.anonymous` is `False`, then `task_res.task.consumer.node_id` MUST be set (not 0) - If `task_res.workload_id` is invalid, then + If `task_res.run_id` is invalid, then storing the `task_res` MUST fail. """ @@ -140,15 +140,15 @@ def delete_node(self, node_id: int) -> None: """Remove `node_id` from state.""" @abc.abstractmethod - def get_nodes(self, workload_id: int) -> Set[int]: + def get_nodes(self, run_id: int) -> Set[int]: """Retrieve all currently stored node IDs as a set. Constraints ----------- - If the provided `workload_id` does not exist or has no matching nodes, + If the provided `run_id` does not exist or has no matching nodes, an empty `Set` MUST be returned. """ @abc.abstractmethod - def create_workload(self) -> int: - """Create one workload.""" + def create_run(self) -> int: + """Create one run.""" diff --git a/src/py/flwr/server/state/state_test.py b/src/py/flwr/server/state/state_test.py index 59299451c3d8..204b4ba97b5f 100644 --- a/src/py/flwr/server/state/state_test.py +++ b/src/py/flwr/server/state/state_test.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================== """Tests all state implemenations have to conform to.""" -# pylint: disable=no-self-use, invalid-name, disable=R0904 +# pylint: disable=invalid-name, disable=R0904 import tempfile import unittest @@ -66,9 +66,9 @@ def test_store_task_ins_one(self) -> None: # Prepare consumer_node_id = 1 state = self.state_factory() - workload_id = state.create_workload() + run_id = state.create_run() task_ins = create_task_ins( - consumer_node_id=consumer_node_id, anonymous=False, workload_id=workload_id + consumer_node_id=consumer_node_id, anonymous=False, run_id=run_id ) assert task_ins.task.created_at == "" # pylint: disable=no-member @@ -108,15 +108,15 @@ def test_store_and_delete_tasks(self) -> None: # Prepare consumer_node_id = 1 state = self.state_factory() - workload_id = state.create_workload() + run_id = state.create_run() task_ins_0 = create_task_ins( - consumer_node_id=consumer_node_id, anonymous=False, workload_id=workload_id + consumer_node_id=consumer_node_id, anonymous=False, run_id=run_id ) task_ins_1 = create_task_ins( - consumer_node_id=consumer_node_id, anonymous=False, workload_id=workload_id + consumer_node_id=consumer_node_id, anonymous=False, run_id=run_id ) task_ins_2 = create_task_ins( - consumer_node_id=consumer_node_id, anonymous=False, workload_id=workload_id + consumer_node_id=consumer_node_id, anonymous=False, run_id=run_id ) # Insert three TaskIns @@ -136,7 +136,7 @@ def test_store_and_delete_tasks(self) -> None: producer_node_id=100, anonymous=False, ancestry=[str(task_id_0)], - workload_id=workload_id, + run_id=run_id, ) _ = state.store_task_res(task_res=task_res_0) @@ -147,7 +147,7 @@ def test_store_and_delete_tasks(self) -> None: producer_node_id=100, anonymous=False, ancestry=[str(task_id_1)], - workload_id=workload_id, + run_id=run_id, ) _ = state.store_task_res(task_res=task_res_1) @@ -182,10 +182,8 @@ def test_task_ins_store_anonymous_and_retrieve_anonymous(self) -> None: """ # Prepare state: State = self.state_factory() - workload_id = state.create_workload() - task_ins = create_task_ins( - consumer_node_id=0, anonymous=True, workload_id=workload_id - ) + run_id = state.create_run() + task_ins = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id) # Execute task_ins_uuid = state.store_task_ins(task_ins) @@ -199,10 +197,8 @@ def test_task_ins_store_anonymous_and_fail_retrieving_identitiy(self) -> None: """Store anonymous TaskIns and fail to retrieve it.""" # Prepare state: State = self.state_factory() - workload_id = state.create_workload() - task_ins = create_task_ins( - consumer_node_id=0, anonymous=True, workload_id=workload_id - ) + run_id = state.create_run() + task_ins = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id) # Execute _ = state.store_task_ins(task_ins) @@ -215,10 +211,8 @@ def test_task_ins_store_identity_and_fail_retrieving_anonymous(self) -> None: """Store identity TaskIns and fail retrieving it as anonymous.""" # Prepare state: State = self.state_factory() - workload_id = state.create_workload() - task_ins = create_task_ins( - consumer_node_id=1, anonymous=False, workload_id=workload_id - ) + run_id = state.create_run() + task_ins = create_task_ins(consumer_node_id=1, anonymous=False, run_id=run_id) # Execute _ = state.store_task_ins(task_ins) @@ -231,10 +225,8 @@ def test_task_ins_store_identity_and_retrieve_identity(self) -> None: """Store identity TaskIns and retrieve it.""" # Prepare state: State = self.state_factory() - workload_id = state.create_workload() - task_ins = create_task_ins( - consumer_node_id=1, anonymous=False, workload_id=workload_id - ) + run_id = state.create_run() + task_ins = create_task_ins(consumer_node_id=1, anonymous=False, run_id=run_id) # Execute task_ins_uuid = state.store_task_ins(task_ins) @@ -250,10 +242,8 @@ def test_task_ins_store_delivered_and_fail_retrieving(self) -> None: """Fail retrieving delivered task.""" # Prepare state: State = self.state_factory() - workload_id = state.create_workload() - task_ins = create_task_ins( - consumer_node_id=1, anonymous=False, workload_id=workload_id - ) + run_id = state.create_run() + task_ins = create_task_ins(consumer_node_id=1, anonymous=False, run_id=run_id) # Execute _ = state.store_task_ins(task_ins) @@ -278,13 +268,11 @@ def test_get_task_ins_limit_throws_for_limit_zero(self) -> None: with self.assertRaises(AssertionError): state.get_task_ins(node_id=1, limit=0) - def test_task_ins_store_invalid_workload_id_and_fail(self) -> None: - """Store TaskIns with invalid workload_id and fail.""" + def test_task_ins_store_invalid_run_id_and_fail(self) -> None: + """Store TaskIns with invalid run_id and fail.""" # Prepare state: State = self.state_factory() - task_ins = create_task_ins( - consumer_node_id=0, anonymous=True, workload_id=61016 - ) + task_ins = create_task_ins(consumer_node_id=0, anonymous=True, run_id=61016) # Execute task_id = state.store_task_ins(task_ins) @@ -297,13 +285,13 @@ def test_task_res_store_and_retrieve_by_task_ins_id(self) -> None: """Store TaskRes retrieve it by task_ins_id.""" # Prepare state: State = self.state_factory() - workload_id = state.create_workload() + run_id = state.create_run() task_ins_id = uuid4() task_res = create_task_res( producer_node_id=0, anonymous=True, ancestry=[str(task_ins_id)], - workload_id=workload_id, + run_id=run_id, ) # Execute @@ -318,10 +306,10 @@ def test_node_ids_initial_state(self) -> None: """Test retrieving all node_ids and empty initial state.""" # Prepare state: State = self.state_factory() - workload_id = state.create_workload() + run_id = state.create_run() # Execute - retrieved_node_ids = state.get_nodes(workload_id) + retrieved_node_ids = state.get_nodes(run_id) # Assert assert len(retrieved_node_ids) == 0 @@ -330,13 +318,13 @@ def test_create_node_and_get_nodes(self) -> None: """Test creating a client node.""" # Prepare state: State = self.state_factory() - workload_id = state.create_workload() + run_id = state.create_run() node_ids = [] # Execute for _ in range(10): node_ids.append(state.create_node()) - retrieved_node_ids = state.get_nodes(workload_id) + retrieved_node_ids = state.get_nodes(run_id) # Assert for i in retrieved_node_ids: @@ -346,26 +334,26 @@ def test_delete_node(self) -> None: """Test deleting a client node.""" # Prepare state: State = self.state_factory() - workload_id = state.create_workload() + run_id = state.create_run() node_id = state.create_node() # Execute state.delete_node(node_id) - retrieved_node_ids = state.get_nodes(workload_id) + retrieved_node_ids = state.get_nodes(run_id) # Assert assert len(retrieved_node_ids) == 0 - def test_get_nodes_invalid_workload_id(self) -> None: - """Test retrieving all node_ids with invalid workload_id.""" + def test_get_nodes_invalid_run_id(self) -> None: + """Test retrieving all node_ids with invalid run_id.""" # Prepare state: State = self.state_factory() - state.create_workload() - invalid_workload_id = 61016 + state.create_run() + invalid_run_id = 61016 state.create_node() # Execute - retrieved_node_ids = state.get_nodes(invalid_workload_id) + retrieved_node_ids = state.get_nodes(invalid_run_id) # Assert assert len(retrieved_node_ids) == 0 @@ -374,13 +362,9 @@ def test_num_task_ins(self) -> None: """Test if num_tasks returns correct number of not delivered task_ins.""" # Prepare state: State = self.state_factory() - workload_id = state.create_workload() - task_0 = create_task_ins( - consumer_node_id=0, anonymous=True, workload_id=workload_id - ) - task_1 = create_task_ins( - consumer_node_id=0, anonymous=True, workload_id=workload_id - ) + run_id = state.create_run() + task_0 = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id) + task_1 = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id) # Store two tasks state.store_task_ins(task_0) @@ -396,12 +380,12 @@ def test_num_task_res(self) -> None: """Test if num_tasks returns correct number of not delivered task_res.""" # Prepare state: State = self.state_factory() - workload_id = state.create_workload() + run_id = state.create_run() task_0 = create_task_res( - producer_node_id=0, anonymous=True, ancestry=["1"], workload_id=workload_id + producer_node_id=0, anonymous=True, ancestry=["1"], run_id=run_id ) task_1 = create_task_res( - producer_node_id=0, anonymous=True, ancestry=["1"], workload_id=workload_id + producer_node_id=0, anonymous=True, ancestry=["1"], run_id=run_id ) # Store two tasks @@ -418,7 +402,7 @@ def test_num_task_res(self) -> None: def create_task_ins( consumer_node_id: int, anonymous: bool, - workload_id: int, + run_id: int, delivered_at: str = "", ) -> TaskIns: """Create a TaskIns for testing.""" @@ -429,7 +413,7 @@ def create_task_ins( task = TaskIns( task_id="", group_id="", - workload_id=workload_id, + run_id=run_id, task=Task( delivered_at=delivered_at, producer=Node(node_id=0, anonymous=True), @@ -446,13 +430,13 @@ def create_task_res( producer_node_id: int, anonymous: bool, ancestry: List[str], - workload_id: int, + run_id: int, ) -> TaskRes: """Create a TaskRes for testing.""" task_res = TaskRes( task_id="", group_id="", - workload_id=workload_id, + run_id=run_id, task=Task( producer=Node(node_id=producer_node_id, anonymous=anonymous), consumer=Node(node_id=0, anonymous=True), diff --git a/src/py/flwr/server/strategy/aggregate.py b/src/py/flwr/server/strategy/aggregate.py index 63926f2eaa51..c668b55eebe6 100644 --- a/src/py/flwr/server/strategy/aggregate.py +++ b/src/py/flwr/server/strategy/aggregate.py @@ -20,13 +20,14 @@ import numpy as np -from flwr.common import NDArray, NDArrays +from flwr.common import FitRes, NDArray, NDArrays, parameters_to_ndarrays +from flwr.server.client_proxy import ClientProxy def aggregate(results: List[Tuple[NDArrays, int]]) -> NDArrays: """Compute weighted average.""" # Calculate the total number of examples used during training - num_examples_total = sum([num_examples for _, num_examples in results]) + num_examples_total = sum(num_examples for (_, num_examples) in results) # Create a list of weights, each multiplied by the related number of examples weighted_weights = [ @@ -41,6 +42,31 @@ def aggregate(results: List[Tuple[NDArrays, int]]) -> NDArrays: return weights_prime +def aggregate_inplace(results: List[Tuple[ClientProxy, FitRes]]) -> NDArrays: + """Compute in-place weighted average.""" + # Count total examples + num_examples_total = sum(fit_res.num_examples for (_, fit_res) in results) + + # Compute scaling factors for each result + scaling_factors = [ + fit_res.num_examples / num_examples_total for _, fit_res in results + ] + + # Let's do in-place aggregation + # Get first result, then add up each other + params = [ + scaling_factors[0] * x for x in parameters_to_ndarrays(results[0][1].parameters) + ] + for i, (_, fit_res) in enumerate(results[1:]): + res = ( + scaling_factors[i + 1] * x + for x in parameters_to_ndarrays(fit_res.parameters) + ) + params = [reduce(np.add, layer_updates) for layer_updates in zip(params, res)] + + return params + + def aggregate_median(results: List[Tuple[NDArrays, int]]) -> NDArrays: """Compute median.""" # Create a list of weights and ignore the number of examples @@ -69,9 +95,9 @@ def aggregate_krum( # For each client, take the n-f-2 closest parameters vectors num_closest = max(1, len(weights) - num_malicious - 2) closest_indices = [] - for i, _ in enumerate(distance_matrix): + for distance in distance_matrix: closest_indices.append( - np.argsort(distance_matrix[i])[1 : num_closest + 1].tolist() # noqa: E203 + np.argsort(distance)[1 : num_closest + 1].tolist() # noqa: E203 ) # Compute the score for each client, that is the sum of the distances @@ -176,7 +202,7 @@ def aggregate_bulyan( def weighted_loss_avg(results: List[Tuple[int, float]]) -> float: """Aggregate evaluation results obtained from multiple clients.""" - num_total_evaluation_examples = sum([num_examples for num_examples, _ in results]) + num_total_evaluation_examples = sum(num_examples for (num_examples, _) in results) weighted_losses = [num_examples * loss for num_examples, loss in results] return sum(weighted_losses) / num_total_evaluation_examples @@ -207,9 +233,9 @@ def _compute_distances(weights: List[NDArrays]) -> NDArray: """ flat_w = np.array([np.concatenate(p, axis=None).ravel() for p in weights]) distance_matrix = np.zeros((len(weights), len(weights))) - for i, _ in enumerate(flat_w): - for j, _ in enumerate(flat_w): - delta = flat_w[i] - flat_w[j] + for i, flat_w_i in enumerate(flat_w): + for j, flat_w_j in enumerate(flat_w): + delta = flat_w_i - flat_w_j norm = np.linalg.norm(delta) distance_matrix[i, j] = norm**2 return distance_matrix diff --git a/src/py/flwr/server/strategy/dpfedavg_adaptive.py b/src/py/flwr/server/strategy/dpfedavg_adaptive.py index 3269735e9d73..8b3278cc9ba0 100644 --- a/src/py/flwr/server/strategy/dpfedavg_adaptive.py +++ b/src/py/flwr/server/strategy/dpfedavg_adaptive.py @@ -91,7 +91,7 @@ def _update_clip_norm(self, results: List[Tuple[ClientProxy, FitRes]]) -> None: norm_bit_set_count = 0 for client_proxy, fit_res in results: if "dpfedavg_norm_bit" not in fit_res.metrics: - raise Exception( + raise KeyError( f"Indicator bit not returned by client with id {client_proxy.cid}." ) if fit_res.metrics["dpfedavg_norm_bit"]: diff --git a/src/py/flwr/server/strategy/dpfedavg_fixed.py b/src/py/flwr/server/strategy/dpfedavg_fixed.py index 0154cfd79fc5..f2f1c206f3de 100644 --- a/src/py/flwr/server/strategy/dpfedavg_fixed.py +++ b/src/py/flwr/server/strategy/dpfedavg_fixed.py @@ -46,11 +46,11 @@ def __init__( self.num_sampled_clients = num_sampled_clients if clip_norm <= 0: - raise Exception("The clipping threshold should be a positive value.") + raise ValueError("The clipping threshold should be a positive value.") self.clip_norm = clip_norm if noise_multiplier < 0: - raise Exception("The noise multiplier should be a non-negative value.") + raise ValueError("The noise multiplier should be a non-negative value.") self.noise_multiplier = noise_multiplier self.server_side_noising = server_side_noising diff --git a/src/py/flwr/server/strategy/fedavg.py b/src/py/flwr/server/strategy/fedavg.py index c93c8cb8b83e..e4b126823fb6 100644 --- a/src/py/flwr/server/strategy/fedavg.py +++ b/src/py/flwr/server/strategy/fedavg.py @@ -37,7 +37,7 @@ from flwr.server.client_manager import ClientManager from flwr.server.client_proxy import ClientProxy -from .aggregate import aggregate, weighted_loss_avg +from .aggregate import aggregate, aggregate_inplace, weighted_loss_avg from .strategy import Strategy WARNING_MIN_AVAILABLE_CLIENTS_TOO_LOW = """ @@ -107,6 +107,7 @@ def __init__( initial_parameters: Optional[Parameters] = None, fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, + inplace: bool = True, ) -> None: super().__init__() @@ -128,6 +129,7 @@ def __init__( self.initial_parameters = initial_parameters self.fit_metrics_aggregation_fn = fit_metrics_aggregation_fn self.evaluate_metrics_aggregation_fn = evaluate_metrics_aggregation_fn + self.inplace = inplace def __repr__(self) -> str: """Compute a string representation of the strategy.""" @@ -226,12 +228,18 @@ def aggregate_fit( if not self.accept_failures and failures: return None, {} - # Convert results - weights_results = [ - (parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples) - for _, fit_res in results - ] - parameters_aggregated = ndarrays_to_parameters(aggregate(weights_results)) + if self.inplace: + # Does in-place weighted average of results + aggregated_ndarrays = aggregate_inplace(results) + else: + # Convert results + weights_results = [ + (parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples) + for _, fit_res in results + ] + aggregated_ndarrays = aggregate(weights_results) + + parameters_aggregated = ndarrays_to_parameters(aggregated_ndarrays) # Aggregate custom metrics if aggregation fn was provided metrics_aggregated = {} diff --git a/src/py/flwr/server/strategy/fedavg_android.py b/src/py/flwr/server/strategy/fedavg_android.py index e890f7216020..6678b7ced114 100644 --- a/src/py/flwr/server/strategy/fedavg_android.py +++ b/src/py/flwr/server/strategy/fedavg_android.py @@ -234,12 +234,10 @@ def parameters_to_ndarrays(self, parameters: Parameters) -> NDArrays: """Convert parameters object to NumPy weights.""" return [self.bytes_to_ndarray(tensor) for tensor in parameters.tensors] - # pylint: disable=R0201 def ndarray_to_bytes(self, ndarray: NDArray) -> bytes: """Serialize NumPy array to bytes.""" return ndarray.tobytes() - # pylint: disable=R0201 def bytes_to_ndarray(self, tensor: bytes) -> NDArray: """Deserialize NumPy array from bytes.""" ndarray_deserialized = np.frombuffer(tensor, dtype=np.float32) diff --git a/src/py/flwr/server/strategy/fedavg_test.py b/src/py/flwr/server/strategy/fedavg_test.py index 947736f4a571..e62eaa5c5832 100644 --- a/src/py/flwr/server/strategy/fedavg_test.py +++ b/src/py/flwr/server/strategy/fedavg_test.py @@ -15,6 +15,16 @@ """FedAvg tests.""" +from typing import List, Tuple, Union +from unittest.mock import MagicMock + +import numpy as np +from numpy.testing import assert_allclose + +from flwr.common import Code, FitRes, Status, parameters_to_ndarrays +from flwr.common.parameter import ndarrays_to_parameters +from flwr.server.client_proxy import ClientProxy + from .fedavg import FedAvg @@ -120,3 +130,51 @@ def test_fedavg_num_evaluation_clients_minimum() -> None: # Assert assert expected == actual + + +def test_inplace_aggregate_fit_equivalence() -> None: + """Test aggregate_fit equivalence between FedAvg and its inplace version.""" + # Prepare + weights0_0 = np.random.randn(100, 64) + weights0_1 = np.random.randn(314, 628, 3) + weights1_0 = np.random.randn(100, 64) + weights1_1 = np.random.randn(314, 628, 3) + + results: List[Tuple[ClientProxy, FitRes]] = [ + ( + MagicMock(), + FitRes( + status=Status(code=Code.OK, message="Success"), + parameters=ndarrays_to_parameters([weights0_0, weights0_1]), + num_examples=1, + metrics={}, + ), + ), + ( + MagicMock(), + FitRes( + status=Status(code=Code.OK, message="Success"), + parameters=ndarrays_to_parameters([weights1_0, weights1_1]), + num_examples=5, + metrics={}, + ), + ), + ] + failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]] = [] + + fedavg_reference = FedAvg(inplace=False) + fedavg_inplace = FedAvg() + + # Execute + reference, _ = fedavg_reference.aggregate_fit(1, results, failures) + assert reference + inplace, _ = fedavg_inplace.aggregate_fit(1, results, failures) + assert inplace + + # Convert to NumPy to check similarity + reference_np = parameters_to_ndarrays(reference) + inplace_np = parameters_to_ndarrays(inplace) + + # Assert + for ref, inp in zip(reference_np, inplace_np): + assert_allclose(ref, inp) diff --git a/src/py/flwr/server/strategy/fedmedian.py b/src/py/flwr/server/strategy/fedmedian.py index 7a5bf1425b44..17e979d92beb 100644 --- a/src/py/flwr/server/strategy/fedmedian.py +++ b/src/py/flwr/server/strategy/fedmedian.py @@ -36,7 +36,7 @@ class FedMedian(FedAvg): - """Configurable FedAvg with Momentum strategy implementation.""" + """Configurable FedMedian strategy implementation.""" def __repr__(self) -> str: """Compute a string representation of the strategy.""" diff --git a/src/py/flwr/server/strategy/qfedavg.py b/src/py/flwr/server/strategy/qfedavg.py index 94a67fbcbfae..758e8e608e9f 100644 --- a/src/py/flwr/server/strategy/qfedavg.py +++ b/src/py/flwr/server/strategy/qfedavg.py @@ -185,7 +185,7 @@ def norm_grad(grad_list: NDArrays) -> float: hs_ffl = [] if self.pre_weights is None: - raise Exception("QffedAvg pre_weights are None in aggregate_fit") + raise AttributeError("QffedAvg pre_weights are None in aggregate_fit") weights_before = self.pre_weights eval_result = self.evaluate( diff --git a/src/py/flwr/server/utils/validator_test.py b/src/py/flwr/server/utils/validator_test.py index cab51fbf46de..6627cc9a7887 100644 --- a/src/py/flwr/server/utils/validator_test.py +++ b/src/py/flwr/server/utils/validator_test.py @@ -135,7 +135,7 @@ def create_task_ins( task = TaskIns( task_id="", group_id="", - workload_id=0, + run_id=0, task=Task( delivered_at=delivered_at, producer=Node(node_id=0, anonymous=True), @@ -162,7 +162,7 @@ def create_task_res( task_res = TaskRes( task_id="", group_id="", - workload_id=0, + run_id=0, task=Task( producer=Node(node_id=producer_node_id, anonymous=anonymous), consumer=Node(node_id=0, anonymous=True), diff --git a/src/py/flwr/simulation/app.py b/src/py/flwr/simulation/app.py index c519f5a551f0..6a18a258ac60 100644 --- a/src/py/flwr/simulation/app.py +++ b/src/py/flwr/simulation/app.py @@ -314,18 +314,30 @@ def update_resources(f_stop: threading.Event) -> None: log(ERROR, traceback.format_exc()) log( ERROR, - "Your simulation crashed :(. This could be because of several reasons." + "Your simulation crashed :(. This could be because of several reasons. " "The most common are: " + "\n\t > Sometimes, issues in the simulation code itself can cause crashes. " + "It's always a good idea to double-check your code for any potential bugs " + "or inconsistencies that might be contributing to the problem. " + "For example: " + "\n\t\t - You might be using a class attribute in your clients that " + "hasn't been defined." + "\n\t\t - There could be an incorrect method call to a 3rd party library " + "(e.g., PyTorch)." + "\n\t\t - The return types of methods in your clients/strategies might be " + "incorrect." "\n\t > Your system couldn't fit a single VirtualClient: try lowering " "`client_resources`." "\n\t > All the actors in your pool crashed. This could be because: " "\n\t\t - You clients hit an out-of-memory (OOM) error and actors couldn't " "recover from it. Try launching your simulation with more generous " "`client_resources` setting (i.e. it seems %s is " - "not enough for your workload). Use fewer concurrent actors. " + "not enough for your run). Use fewer concurrent actors. " "\n\t\t - You were running a multi-node simulation and all worker nodes " "disconnected. The head node might still be alive but cannot accommodate " - "any actor with resources: %s.", + "any actor with resources: %s." + "\nTake a look at the Flower simulation examples for guidance " + ".", client_resources, client_resources, ) diff --git a/src/py/flwr/simulation/ray_transport/ray_actor.py b/src/py/flwr/simulation/ray_transport/ray_actor.py index 640817910396..38af3f08daa2 100644 --- a/src/py/flwr/simulation/ray_transport/ray_actor.py +++ b/src/py/flwr/simulation/ray_transport/ray_actor.py @@ -27,7 +27,7 @@ from flwr import common from flwr.client import Client, ClientFn -from flwr.client.workload_state import WorkloadState +from flwr.client.run_state import RunState from flwr.common.logger import log from flwr.simulation.ray_transport.utils import check_clientfn_returns_client @@ -61,9 +61,9 @@ def run( client_fn: ClientFn, job_fn: JobFn, cid: str, - state: WorkloadState, - ) -> Tuple[str, ClientRes, WorkloadState]: - """Run a client workload.""" + state: RunState, + ) -> Tuple[str, ClientRes, RunState]: + """Run a client run.""" # Execute tasks and return result # return also cid which is needed to ensure results # from the pool are correctly assigned to each ClientProxy @@ -79,12 +79,12 @@ def run( except Exception as ex: client_trace = traceback.format_exc() message = ( - "\n\tSomething went wrong when running your client workload." + "\n\tSomething went wrong when running your client run." "\n\tClient " + cid + " crashed when the " + self.__class__.__name__ - + " was running its workload." + + " was running its run." "\n\tException triggered on the client side: " + client_trace, ) raise ClientException(str(message)) from ex @@ -94,7 +94,7 @@ def run( @ray.remote class DefaultActor(VirtualClientEngineActor): - """A Ray Actor class that runs client workloads. + """A Ray Actor class that runs client runs. Parameters ---------- @@ -237,10 +237,8 @@ def add_actors_to_pool(self, num_actors: int) -> None: self._idle_actors.extend(new_actors) self.num_actors += num_actors - def submit( - self, fn: Any, value: Tuple[ClientFn, JobFn, str, WorkloadState] - ) -> None: - """Take idle actor and assign it a client workload. + def submit(self, fn: Any, value: Tuple[ClientFn, JobFn, str, RunState]) -> None: + """Take idle actor and assign it a client run. Submit a job to an actor by first removing it from the list of idle actors, then check if this actor was flagged to be removed from the pool @@ -257,7 +255,7 @@ def submit( self._cid_to_future[cid]["future"] = future_key def submit_client_job( - self, actor_fn: Any, job: Tuple[ClientFn, JobFn, str, WorkloadState] + self, actor_fn: Any, job: Tuple[ClientFn, JobFn, str, RunState] ) -> None: """Submit a job while tracking client ids.""" _, _, cid, _ = job @@ -297,7 +295,7 @@ def _is_future_ready(self, cid: str) -> bool: return self._cid_to_future[cid]["ready"] # type: ignore - def _fetch_future_result(self, cid: str) -> Tuple[ClientRes, WorkloadState]: + def _fetch_future_result(self, cid: str) -> Tuple[ClientRes, RunState]: """Fetch result and updated state for a VirtualClient from Object Store. The job submitted by the ClientProxy interfacing with client with cid=cid is @@ -307,7 +305,7 @@ def _fetch_future_result(self, cid: str) -> Tuple[ClientRes, WorkloadState]: future: ObjectRef[Any] = self._cid_to_future[cid]["future"] # type: ignore res_cid, res, updated_state = ray.get( future - ) # type: (str, ClientRes, WorkloadState) + ) # type: (str, ClientRes, RunState) except ray.exceptions.RayActorError as ex: log(ERROR, ex) if hasattr(ex, "actor_id"): @@ -411,7 +409,7 @@ def process_unordered_future(self, timeout: Optional[float] = None) -> None: def get_client_result( self, cid: str, timeout: Optional[float] - ) -> Tuple[ClientRes, WorkloadState]: + ) -> Tuple[ClientRes, RunState]: """Get result from VirtualClient with specific cid.""" # Loop until all jobs submitted to the pool are completed. Break early # if the result for the ClientProxy calling this method is ready @@ -423,5 +421,5 @@ def get_client_result( break # Fetch result belonging to the VirtualClient calling this method - # Return both result from tasks and (potentially) updated workload state + # Return both result from tasks and (potentially) updated run state return self._fetch_future_result(cid) diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py index c6a63298dae6..5c05850dfd2f 100644 --- a/src/py/flwr/simulation/ray_transport/ray_client_proxy.py +++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy.py @@ -132,16 +132,16 @@ def __init__( self.proxy_state = NodeState() def _submit_job(self, job_fn: JobFn, timeout: Optional[float]) -> ClientRes: - # The VCE is not exposed to TaskIns, it won't handle multilple workloads - # For the time being, fixing workload_id is a small compromise + # The VCE is not exposed to TaskIns, it won't handle multilple runs + # For the time being, fixing run_id is a small compromise # This will be one of the first points to address integrating VCE + DriverAPI - workload_id = 0 + run_id = 0 # Register state - self.proxy_state.register_workloadstate(workload_id=workload_id) + self.proxy_state.register_runstate(run_id=run_id) # Retrieve state - state = self.proxy_state.retrieve_workloadstate(workload_id=workload_id) + state = self.proxy_state.retrieve_runstate(run_id=run_id) try: self.actor_pool.submit_client_job( @@ -151,14 +151,12 @@ def _submit_job(self, job_fn: JobFn, timeout: Optional[float]) -> ClientRes: res, updated_state = self.actor_pool.get_client_result(self.cid, timeout) # Update state - self.proxy_state.update_workloadstate( - workload_id=workload_id, workload_state=updated_state - ) + self.proxy_state.update_runstate(run_id=run_id, run_state=updated_state) except Exception as ex: if self.actor_pool.num_actors == 0: # At this point we want to stop the simulation. - # since no more client workloads will be executed + # since no more client runs will be executed log(ERROR, "ActorPool is empty!!!") log(ERROR, traceback.format_exc()) log(ERROR, ex) diff --git a/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py b/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py index b87418b671d3..9df71635b949 100644 --- a/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py +++ b/src/py/flwr/simulation/ray_transport/ray_client_proxy_test.py @@ -22,7 +22,7 @@ import ray from flwr.client import Client, NumPyClient -from flwr.client.workload_state import WorkloadState +from flwr.client.run_state import RunState from flwr.common import Code, GetPropertiesRes, Status from flwr.simulation.ray_transport.ray_actor import ( ClientRes, @@ -46,7 +46,7 @@ def get_dummy_client(cid: str) -> Client: return DummyClient(cid).to_client() -# A dummy workload +# A dummy run def job_fn(cid: str) -> JobFn: # pragma: no cover """Construct a simple job with cid dependency.""" @@ -112,22 +112,22 @@ def test_cid_consistency_one_at_a_time() -> None: ray.shutdown() -def test_cid_consistency_all_submit_first_workload_consistency() -> None: +def test_cid_consistency_all_submit_first_run_consistency() -> None: """Test that ClientProxies get the result of client job they submit. All jobs are submitted at the same time. Then fetched one at a time. This also tests - NodeState (at each Proxy) and WorkloadState basic functionality. + NodeState (at each Proxy) and RunState basic functionality. """ proxies, _ = prep() - workload_id = 0 + run_id = 0 # submit all jobs (collect later) shuffle(proxies) for prox in proxies: # Register state - prox.proxy_state.register_workloadstate(workload_id=workload_id) + prox.proxy_state.register_runstate(run_id=run_id) # Retrieve state - state = prox.proxy_state.retrieve_workloadstate(workload_id=workload_id) + state = prox.proxy_state.retrieve_runstate(run_id=run_id) job = job_fn(prox.cid) prox.actor_pool.submit_client_job( @@ -139,12 +139,12 @@ def test_cid_consistency_all_submit_first_workload_consistency() -> None: shuffle(proxies) for prox in proxies: res, updated_state = prox.actor_pool.get_client_result(prox.cid, timeout=None) - prox.proxy_state.update_workloadstate(workload_id, workload_state=updated_state) + prox.proxy_state.update_runstate(run_id, run_state=updated_state) res = cast(GetPropertiesRes, res) assert int(prox.cid) * pi == res.properties["result"] assert ( str(int(prox.cid) * pi) - == prox.proxy_state.retrieve_workloadstate(workload_id).state["result"] + == prox.proxy_state.retrieve_runstate(run_id).state["result"] ) ray.shutdown() @@ -162,7 +162,7 @@ def test_cid_consistency_without_proxies() -> None: job = job_fn(cid) pool.submit_client_job( lambda a, c_fn, j_fn, cid_, state: a.run.remote(c_fn, j_fn, cid_, state), - (get_dummy_client, job, cid, WorkloadState(state={})), + (get_dummy_client, job, cid, RunState(state={})), ) # fetch results one at a time diff --git a/src/py/flwr/simulation/ray_transport/utils.py b/src/py/flwr/simulation/ray_transport/utils.py index c8e6aa6cbe21..41aa8049eaf0 100644 --- a/src/py/flwr/simulation/ray_transport/utils.py +++ b/src/py/flwr/simulation/ray_transport/utils.py @@ -37,7 +37,7 @@ def enable_tf_gpu_growth() -> None: # the same GPU. # Luckily we can disable this behavior by enabling memory growth # on the GPU. In this way, VRAM allocated to the processes grows based - # on the needs for the workload. (this is for instance the default + # on the needs for the run. (this is for instance the default # behavior in PyTorch) # While this behavior is critical for Actors, you'll likely need it # as well in your main process (where the server runs and might evaluate diff --git a/src/py/flwr_experimental/ops/__init__.py b/src/py/flwr_experimental/ops/__init__.py index b56c757e0207..bad31028e68c 100644 --- a/src/py/flwr_experimental/ops/__init__.py +++ b/src/py/flwr_experimental/ops/__init__.py @@ -13,4 +13,4 @@ # limitations under the License. # ============================================================================== """Flower ops provides an opinionated way to provision necessary compute -infrastructure for running Flower workloads.""" +infrastructure for running Flower runs."""